loglog()Zola2023-06-30T00:00:00+00:00https://yng87.page/en/atom.xmlFind the minimum number of attempts to select a usable battery2023-06-30T00:00:00+00:002023-06-30T00:00:00+00:00Unknownhttps://yng87.page/en/blog/2023/graphspuzzle/<p>Recently I came across an interesting puzzle at <a href="https://bkamins.github.io/julialang/2023/06/23/graphspuzzle.html">The return of the graphs (and an interesting puzzle)</a></p>
<blockquote>
<p>You are given eight batteries, out of which four are good and four are depleted, but visually they do not differ. You have a flashlight that needs two good batteries to work. You want your flashlight to work. What is the minimum number of times you need to put two batteries into your flashlight to be sure it works in the worst case.</p>
</blockquote>
<p>There are 8 batteries in total, and there are 28 ways to choose 2 of them, and 6 ways to choose 4 of the good ones, so even if you put them in the flashlight without thinking, you can be sure that it will work on the 23rd try.</p>
<p>But of course it is possible to get it to work in fewer attempts, and the correct answer is 7.</p>
<p>It is easy to see that 7 times is enough: if we divide the batteries numbered from 1 to 8 into three groups <code>{1, 2, 3}</code>, <code>{4, 5, 6}</code>, <code>{7, 8}</code>, somewhere in the group will be at least two good ones. So, we should try all pairs within each group, and the number of times is $3+3+1=7$.</p>
<p>Proving that 6 times is not enough is a bit complicated, so please see the blog at the beginning of this article.</p>
<p>Interestingly, this problem can also be approached as a graph problem. Prepare a graph with 8 nodes, and consider that when there is an edge between nodes $i$ and $j$, the corresponding battery is put into the flashlight.</p>
<p>For example, a pattern that works well with 7 comparisons can be represented as the following graph.</p>
<pre style="background-color:#282c34;color:#dcdfe4;"><code><span> 1 4 7
</span><span> / \ / \ |
</span><span>2---3 5---6 8
</span></code></pre>
<p>Then the question to consider is “In an undirected graph with 8 nodes and $E$ edges, can the maximum size of a subset of nodes with no edges between them (called an independent set) be less than 3?” If there is an independent set of size 4 or larger, then no good pair can be found if all the good nodes are in it. On the other hand, any independent set can satisfy the condition if its size is less than 3.</p>
<p>For example, even if there are 7 edges, if we create a graph like</p>
<pre style="background-color:#282c34;color:#dcdfe4;"><code><span> 1 4 7
</span><span> / \ / \
</span><span>2---3--5---6 8
</span></code></pre>
<p>then <code>{1, 4, 7, 8}</code> will be an independent set of size 4, and these will be undiscoverable in the case of a good.</p>
<p>If we can arrive at such a problem, the rest can be solved by brute force with the help of a computer. We can enumerate all possible graphs for each number of edges and check whether the condition is satisfied. Specifically, starting from the graph with zero edges, repeat the following steps: </p>
<ol>
<li>Enumerate all graphs with $E+1$ edges by adding one edge in any pattern to the graph with $E$ edges. Combine isomorphic graphs into one. </li>
<li>Find the maximum size of the independent set of nodes for each graph. </li>
<li>Find the minimum value of this maximum size for the entire graph of $E+1$ edges. </li>
<li>If this minimum is less than or equal to 3, you are done.</li>
</ol>
<p>The Julia code in the blog did not work in my environment, so I have rewritten it in Rust:</p>
<figure class="blogcard">
<a href="https://gist.github.com/yng87/8fe7bff1725a28ffbc332f3632185e61" target="_blank" rel="noopener noreferrer" aria-label="記事詳細へ(別窓で開く)">
<div class="blogcard-content">
<div class="blogcard-image">
<div class="blogcard-image-wrapper">
<img src="/gist_logo.png" alt=Graphs puzzle loading="lazy">
</div>
</div>
<div class="blogcard-text">
<div class="blogcard-title">Graphs puzzle</div>
<div class="blogcard-footer">Gist</div>
</div>
</div>
</a>
</figure>
What happens when AI-generated data is used repeatedly for training2023-06-28T00:00:00+00:002023-06-28T00:00:00+00:00Unknownhttps://yng87.page/en/blog/2023/gaussian-mixture-collapse/<p>I read an interesting paper, <a href="https://arxiv.org/abs/2305.17493v2">The Curse of Recursion: Training on Generated Data Makes Models Forget</a>, introduced in <a href="https://importai.substack.com/p/import-ai-333-synthetic-data-makes">Import AI</a>.</p>
<p>The content of the article pointed out that if data generated by a model is repeatedly used to train another model, information on the tail of the distribution is gradually lost, and the range of output becomes narrower. This should be especially important now that the results of generative AI are beginning to flood the Internet.</p>
<p>In the paper, they used theoretical and numerical analyses of Gaussian mixture as examples, and called this phenomenon <em>model collapse</em>. The numerical analysis seemed to be easy to do, so I checked it myself.</p>
<h1 id="python-implementation">Python implementation</h1>
<p>What I want to do is to repeat the process of parameter fitting of one-dimensional Gaussian mixture distribution, generating new data from the fitted distribution, and fitting the distribution again based on the new data.</p>
<p>Specifically, given the training data for the $i$-th step, mix the following three to form the training data for the $i+1$-th step</p>
<ul>
<li>The training data for the $i$-th step</li>
<li>A sample from the Gaussian mixture distribution fitted with the training data of the $i$-th</li>
<li>Sample from the true distribution</li>
</ul>
<p>For $i=0$, the samples are from the true Gaussian mixture distribution, and the ratio of the three data can be changed as desired.</p>
<pre data-lang="python" style="background-color:#282c34;color:#dcdfe4;" class="language-python "><code class="language-python" data-lang="python"><span style="color:#c678dd;">from </span><span>collections </span><span style="color:#c678dd;">import </span><span>namedtuple
</span><span style="color:#c678dd;">import </span><span>numpy </span><span style="color:#c678dd;">as </span><span>np
</span><span style="color:#c678dd;">import </span><span>matplotlib.pyplot </span><span style="color:#c678dd;">as </span><span>plt
</span><span style="color:#c678dd;">from </span><span>sklearn.mixture </span><span style="color:#c678dd;">import </span><span>GaussianMixture
</span><span style="color:#c678dd;">from </span><span>tqdm </span><span style="color:#c678dd;">import </span><span>tqdm
</span><span>
</span><span>GaussianMixtureParams </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">namedtuple</span><span>(</span><span style="color:#98c379;">'GaussianMixtureParams'</span><span>, (</span><span style="color:#98c379;">'mus'</span><span>, </span><span style="color:#98c379;">'sigmas'</span><span>, </span><span style="color:#98c379;">'mixture_ratio'</span><span>))
</span><span>
</span><span style="color:#c678dd;">def </span><span style="color:#61afef;">sample_from_gaussian_mixture</span><span>(</span><span style="color:#e06c75;">gm_params</span><span>, </span><span style="color:#e06c75;">num_samples</span><span>):
</span><span> mus </span><span style="color:#c678dd;">= </span><span>np.</span><span style="color:#e06c75;">array</span><span>(gm_params.mus)
</span><span> sigmas </span><span style="color:#c678dd;">= </span><span>np.</span><span style="color:#e06c75;">array</span><span>(gm_params.sigmas)
</span><span>
</span><span> num_mixture </span><span style="color:#c678dd;">= </span><span style="color:#61afef;">len</span><span>(gm_params.mixture_ratio)
</span><span> mixture_idx </span><span style="color:#c678dd;">= </span><span>np.random.</span><span style="color:#e06c75;">choice</span><span>(num_mixture, </span><span style="color:#e06c75;">size</span><span style="color:#c678dd;">=</span><span>num_samples, </span><span style="color:#e06c75;">p</span><span style="color:#c678dd;">=</span><span>gm_params.mixture_ratio)
</span><span> standard_normal_noises </span><span style="color:#c678dd;">= </span><span>np.random.</span><span style="color:#e06c75;">normal</span><span>(</span><span style="color:#e06c75;">size</span><span style="color:#c678dd;">=</span><span>num_samples)
</span><span> samples </span><span style="color:#c678dd;">= </span><span>mus[mixture_idx] </span><span style="color:#c678dd;">+ </span><span>standard_normal_noises </span><span style="color:#c678dd;">* </span><span>sigmas[mixture_idx]
</span><span>
</span><span> </span><span style="color:#c678dd;">return </span><span>samples
</span><span>
</span><span style="color:#c678dd;">def </span><span style="color:#61afef;">fit_gaussian_mixture</span><span>(</span><span style="color:#e06c75;">samples</span><span>, </span><span style="color:#e06c75;">n_components</span><span>, </span><span style="color:#e06c75;">random_state</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">0</span><span>):
</span><span> gm </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">GaussianMixture</span><span>(</span><span style="color:#e06c75;">n_components</span><span style="color:#c678dd;">=</span><span>n_components, </span><span style="color:#e06c75;">random_state</span><span style="color:#c678dd;">=</span><span>random_state).</span><span style="color:#e06c75;">fit</span><span>(samples.</span><span style="color:#e06c75;">reshape</span><span>(</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>,</span><span style="color:#e5c07b;">1</span><span>))
</span><span>
</span><span> mus </span><span style="color:#c678dd;">= </span><span>gm.means_.</span><span style="color:#e06c75;">reshape</span><span>(</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>)
</span><span> sorted_idx </span><span style="color:#c678dd;">= </span><span>np.</span><span style="color:#e06c75;">argsort</span><span>(mus)
</span><span>
</span><span> gm_params </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">GaussianMixtureParams</span><span>(
</span><span> gm.means_.</span><span style="color:#e06c75;">reshape</span><span>(</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>)[sorted_idx],
</span><span> np.</span><span style="color:#e06c75;">sqrt</span><span>(gm.covariances_).</span><span style="color:#e06c75;">reshape</span><span>(</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>)[sorted_idx],
</span><span> gm.weights_[sorted_idx]
</span><span> )
</span><span>
</span><span> </span><span style="color:#c678dd;">return </span><span>gm_params
</span><span>
</span><span>
</span><span style="color:#c678dd;">def </span><span style="color:#61afef;">sample_next_gen</span><span>(</span><span style="color:#e06c75;">gm_params_org</span><span>, </span><span style="color:#e06c75;">gm_params_fitted</span><span>, </span><span style="color:#e06c75;">samples_current_gen</span><span>, </span><span style="color:#e06c75;">ratio_org</span><span>, </span><span style="color:#e06c75;">ratio_fitted</span><span>, </span><span style="color:#e06c75;">num_samples</span><span>):
</span><span> num_samples_org </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">int</span><span>(num_samples </span><span style="color:#c678dd;">* </span><span>ratio_org)
</span><span> num_samples_fitted </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">int</span><span>(num_samples </span><span style="color:#c678dd;">* </span><span>ratio_fitted)
</span><span> num_samples_current </span><span style="color:#c678dd;">= </span><span>num_samples </span><span style="color:#c678dd;">- </span><span>num_samples_org </span><span style="color:#c678dd;">- </span><span>num_samples_fitted
</span><span> </span><span style="color:#c678dd;">if </span><span>num_samples_current </span><span style="color:#c678dd;">< </span><span style="color:#e5c07b;">0</span><span>:
</span><span> </span><span style="color:#c678dd;">raise </span><span style="color:#e06c75;">ValueError</span><span>(</span><span style="color:#98c379;">"Invalid ratio"</span><span>)
</span><span>
</span><span> samples_org </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">sample_from_gaussian_mixture</span><span>(gm_params_org, num_samples_org)
</span><span> samples_fitted </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">sample_from_gaussian_mixture</span><span>(gm_params_fitted, num_samples_fitted)
</span><span> samples_current </span><span style="color:#c678dd;">= </span><span>np.random.</span><span style="color:#e06c75;">choice</span><span>(samples_current_gen, </span><span style="color:#e06c75;">size</span><span style="color:#c678dd;">=</span><span>num_samples_current, </span><span style="color:#e06c75;">replace</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">False</span><span>)
</span><span>
</span><span> </span><span style="color:#c678dd;">return </span><span>np.</span><span style="color:#e06c75;">concatenate</span><span>([samples_org, samples_fitted, samples_current])
</span></code></pre>
<h1 id="results">Results</h1>
<p>I ran a total of 2000 steps, sampling $N=10000$ each time. The true Gaussian mixture has two components and the parameters are</p>
<ul>
<li>$\boldsymbol \mu = (-2, 2)$</li>
<li>$\boldsymbol \sigma = (1, 1)$</li>
<li>mixture ratio = 1:1</li>
</ul>
<p>For simplicity, only the data generated from the distribution fitted in the previous step is used for training</p>
<pre data-lang="python" style="background-color:#282c34;color:#dcdfe4;" class="language-python "><code class="language-python" data-lang="python"><span>num_samples </span><span style="color:#c678dd;">= </span><span style="color:#e5c07b;">10000
</span><span>num_gen </span><span style="color:#c678dd;">= </span><span style="color:#e5c07b;">2000
</span><span>
</span><span>gm_params_org </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">GaussianMixtureParams</span><span>([</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">2</span><span>, </span><span style="color:#e5c07b;">2</span><span>], [</span><span style="color:#e5c07b;">1</span><span>, </span><span style="color:#e5c07b;">1</span><span>], [</span><span style="color:#e5c07b;">0.5</span><span>, </span><span style="color:#e5c07b;">0.5</span><span>])
</span><span>samples </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">sample_from_gaussian_mixture</span><span>(gm_params_org, num_samples)
</span><span>
</span><span style="color:#c678dd;">for </span><span style="color:#e06c75;">_ </span><span style="color:#c678dd;">in </span><span style="color:#e06c75;">tqdm</span><span>(</span><span style="color:#61afef;">range</span><span>(num_gen)):
</span><span> gm_params_fitted </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">fit_gaussian_mixture</span><span>(samples, </span><span style="color:#e5c07b;">2</span><span>)
</span><span> samples </span><span style="color:#c678dd;">= </span><span style="color:#e06c75;">sample_next_gen</span><span>(gm_params_org, gm_params_fitted, samples, </span><span style="color:#e5c07b;">0.0</span><span>, </span><span style="color:#e5c07b;">1.0</span><span>, num_samples)
</span></code></pre>
<p>Here is the plot of transition of the fitted variance
<img src="/blog/2023/gm_sigmas.png" alt="" />
The distribtuion gradually becomes narrower.</p>
<p>Comparing the initial distribution to that at the 2000th step, the mixing ratio also seems to be biased toward one of the mountains.
<img src="/blog/2023/gm_hist.png" alt="" /></p>
<p>However, when I repeated the process several times with different random number seeds, the fitted distribution sometimes went in a wider direction, and the speed of model collapse also changed depending on the mixing ratio of the three types of data. I am not sure how much of the results apply to real LLMs, etc.</p>
<p>The link to the notebook is <a href="https://gist.github.com/yng87/b69d8e93f73c42c84d84d1b7840326d7">here</a>. Please let me know if there are any mistakes.</p>
Bayes factor for hypothesis testing2023-05-03T00:00:00+00:002023-05-03T00:00:00+00:00Unknownhttps://yng87.page/en/blog/2023/hypothesis-testing-freq-and-bayes/<p>In hypothesis testing based on Bayesian statistics, there is a method that uses a quantity called the Bayes Factor. I believe that when Bayesian methods are mentioned as a way to overcome the problems of frequentist testing, it often refers to techniques like the Bayes Factor. In my work, I have mostly encountered situations where frequentist testing has been sufficient, so I have not dealt with this much. However, I recently studied it and would like to provide a brief summary.</p>
<!-- toc -->
<h1 id="hypothesis-testing-methods">Hypothesis Testing Methods</h1>
<p>To understand hypothesis testing using Bayes Factors, let’s compare it with classical frequentist methods and those using Bayesian credible intervals. I think it is easier to understand by bringing up an analytically calculable example, so let’s consider a simple coin toss experiment. Let’s say we toss a coin $N$ times and obtain heads $m$ times. We want to determine whether this coin is unbiased or not. Let the probability of getting heads be $\mu$. We are then faced with the problem of choosing between:</p>
<ul>
<li>Null hypothesis ($H_0$): $\mu = 0.5$</li>
<li>Alternative hypothesis ($H_1$): $\mu \neq 0.5$</li>
</ul>
<h2 id="frequentist-hypothesis-testing">Frequentist Hypothesis Testing</h2>
<p>First, let’s briefly summarize hypothesis testing using frequentist methods. In frequentist statistics, we find the probability distribution of the test statistic under $H_0$ and calculate the probability of obtaining the observed value (<em>p-value</em>). For the coin toss experiment, we can perform a <a href="https://en.wikipedia.org/wiki/Binomial_test">binomial test</a> or a <a href="https://en.wikipedia.org/wiki/Chi-squared_test">chi-squared test</a>. We calculate the p-value and reject $H_0$ by contradiction if it is smaller than a predetermined threshold value.</p>
<p>Some characteristics of frequentist hypothesis testing are:</p>
<ul>
<li>Typically, it deals with cases where the distribution of the statistic can be analytically calculated. Even when this is not the case, if there are a large number of samples, it can sometimes be applied through the central limit theorem. As a result, analyses can often be performed with light numerical calculations.</li>
<li>While it is possible to reject the null hypothesis, it is not possible to accept it. Frequentist testing has a structure similar to proof by contradiction, so we can conclude that “since there is a contradiction, the assumption is false,” but we cannot actively claim that the null hypothesis is true.</li>
<li>(We won’t go into detail here, but) it violates <a href="https://en.wikipedia.org/wiki/Likelihood_principle">the likelihood principle</a>.</li>
</ul>
<p>Next, let’s look at hypothesis testing using Bayesian credible intervals, which is one of the methods using Bayesian statistics.</p>
<h2 id="bayesian-credible-intervals">Bayesian Credible Intervals</h2>
<p>One of the hypothesis testing methods using Bayesian statistics is the use of credible intervals. This method applies the results of parameter estimation in a statistical model to hypothesis testing.</p>
<p>Here, let’s model the coin toss experiment using a binomial distribution. We set the prior of the probability of getting heads, $\mu$, to a beta distribution:
$$
\begin{aligned}
m &\sim \mathrm{B}(N, \mu), \\
\mu &\sim \mathrm{Beta}(a, b).
\end{aligned}
$$</p>
<p>Here, $a$ and $b$ are hyperparameters. As is well-known, the beta distribution is a conjugate prior for the binomial distribution, and the posterior can be analytically calculated:
$$
\mu \sim \mathrm{Beta}(m+a, N-m+b).
$$</p>
<p>When using this for hypothesis testing, we calculate the interval with a high probability of containing $\mu$ in the posterior distribution (a <a href="https://en.wikipedia.org/wiki/Credible_interval">credible interval</a>), and check whether the null hypothesis is included in that interval. In this example, the null hypothesis is $\mu=0.5$, and if this value is outside the credible interval, we can reject the null hypothesis. Conversely, if it is inside the interval, we cannot say that the data actively supports the alternative hypothesis.</p>
<p>The method using credible intervals has the following characteristics:</p>
<ul>
<li>We can calculate the probability distribution of $\mu$. Thus, the obtained credible interval can be straightforwardly interpreted as the “interval containing the parameter with a certain probability”. In frequentist statistics, $\mu$ was not a random variable, so its distribution could not be calculated. Although there is a similar concept of confidence intervals in frequentist statistics, caution is required in interpreting them since $\mu$ is not a random variable.</li>
<li>Generally, as the sample size increases, the influence of the prior distribution decreases. In the example above, if $m, N-m \gg a, b$, the result becomes insensitive to the choice of $a$ and $b$.</li>
<li>To obtain the posterior distribution, heavy numerical computations such as MCMC are generally required.</li>
<li>It does not treat the null hypothesis and alternative hypothesis equally. In the example above, the null hypothesis is a single point, $\mu=0.5$, but the probability of $\mu$ taking this value is always zero under the posterior distribution $\mathrm{Beta}(m+a, N-m+b)$. Therefore, it seems unsuitable for cases where the focus is on whether the null hypothesis can be accepted or not. This point is discussed in detail in <a href="https://link.springer.com/article/10.3758/s13423-017-1420-7">Rouder, Haaf, and Vandekerckhove (2018)</a>.</li>
</ul>
<h2 id="bayes-factors">Bayes Factors</h2>
<h3 id="overview">Overview</h3>
<p>Although the method using credible intervals involves Bayesian statistics, it essentially involves parameter estimation of probability distributions. While there are some advantages, such as the interval estimates obtained being easier to interpret compared to frequentist approaches, the method still isn’t suitable for actively adopting null hypotheses.</p>
<p>On the other hand, the Bayes factor is a method that utilizes the model comparison concept of Bayesian statistics and is capable of overcoming this difficulty.</p>
<p>What we truly want to evaluate is the posterior probability of hypothesis $H$ given data $\mathcal{D}$, represented as $p(H|\mathcal{D})$. In Bayesian statistics, hypotheses must be represented as statistical models. Let’s denote this as $\mathcal{M}$. Using Bayes’ theorem, the posterior probability of model $\mathcal{M}$ can be written as:
$$
p(\mathcal{M}|\mathcal{D}) = \frac{p(\mathcal{D}|\mathcal{M})p(\mathcal{M})}{p(\mathcal{D})}
$$
Taking the ratio of this quantity between the null hypothesis and the alternative hypothesis gives us:</p>
<p>$$
\frac{p(\mathcal{M}_1|\mathcal{D})}{p(\mathcal{M}_0|\mathcal{D})}
= \frac{p(\mathcal{D}|\mathcal{M}_1)}{p(\mathcal{D}|\mathcal{M}_0)}
\times \frac{p(\mathcal{M}_1)}{p(\mathcal{M}_0)}
$$
The first term on the right-hand side:
$$
BF_{10} = \frac{p(\mathcal{D}|\mathcal{M}_1)}{p(\mathcal{D}|\mathcal{M}_0)}
$$
is called the Bayes factor and is the ratio of marginal likelihoods $p(\mathcal{D}|\mathcal{M})$. $p(\mathcal{M}_1)/p(\mathcal{M}_0)$ is the ratio of prior beliefs for each model. If the Bayes factor can be calculated, this means that we can update these odds based on the data.</p>
<p>In particular, setting the prior odds to 1 allows us to test hypotheses based on how far the Bayes factor deviates from 1. If the Bayes factor is significantly larger than 1, the alternative hypothesis is accepted, and if it is significantly smaller than 1, the null hypothesis is accepted. While previous methods have been somewhat heuristic, the Bayes factor naturally evaluates the posterior probability of a model in accordance with the laws of probability, making it theoretically straightforward.</p>
<p>As the Bayes factor is a test, it is necessary to determine a threshold for accepting or rejecting hypotheses based on its value depending on each industry or situation. For example, some commonly used criteria are summarized in <a href="https://www.jstor.org/stable/2291091">Kass and Raftery (1995)</a>.</p>
<h3 id="marginal-likelihood">Marginal Likelihood</h3>
<p>The crux of the Bayes factor lies in the marginal likelihood. Let’s use the coin toss example to calculate it. Just like in the case of credible intervals, if we model the coin toss with a binomial distribution, under $\mathcal{M}_0$ with $\mu=1/2$, we get:
$$
p(\mathcal{D}|\mathcal{M}_0) = \binom{N}{m}2^{-N}
$$</p>
<p>On the other hand, the case of $\mathcal{M}_1$ gets a bit more complicated. Since $\mu \neq 1/2$ is a continuous quantity, we must perform the following integration:
$$
p(\mathcal{D}|\mathcal{M}_1) = \int_0^1d\mu\ p(\mathcal{D}|\mathcal{M}_1, \mu)p(\mu|\mathcal{M}_1).
$$</p>
<p>Let’s adopt the same binomial distribution as before for the likelihood $p(\mathcal{D}|\mathcal{M}_1, \mu)$:
$$
p(\mathcal{D}|\mathcal{M}_1, \mu) = \binom{N}{m}\mu^m(1-\mu)^{N-m},
$$
and assume a Beta distribution $\mathrm{Beta}(a, b)$ for the prior distribution of $\mu$, $p(\mu|\mathcal{M}_1)$. The density function of the Beta distribution is:
$$
p(\mu|\mathcal{M}_1) = \frac{\Gamma(a+b)}{\Gamma(a)\Gamma(b)}\mu^{a-1}(1-\mu)^{b-1}
$$
With this, the marginal likelihood can be analytically calculated as:
$$
p(\mathcal{D}|\mathcal{M}_1) = \frac{N!}{m!(N-m)!}\cdot\frac{\Gamma(a+b)}{\Gamma(a)\Gamma(b)}\cdot\frac{\Gamma(m+a)\Gamma(N-m+b)}{\Gamma(a+b+N)}
$$
In particular, when $m, N-m \gg a, b$, we can obtain:
$$
\ln p(\mathcal{D}|\mathcal{M}_1) \simeq \ln\left(\frac{\Gamma(a+b)}{\Gamma(a)\Gamma(b)}\right) +(a-1)\ln m+(b-1)\ln(N-m)-(a+b-1)\ln N.
$$</p>
<h3 id="dependence-on-prior-distribution">Dependence on Prior Distribution</h3>
<p>As can be seen from this expression, the marginal likelihood depends on the prior distribution of $\mu$, $p(\mu|\mathcal{M}_1)$. In particular, <a href="https://www.jstor.org/stable/2291091">the influence of the prior distribution persists even as the sample size increases</a>. Of course, we are free to choose a prior distribution other than the Beta distribution, which would provide even more degrees of freedom than what this expression represents.</p>
<p>In the case of Bayesian credible intervals, the influence of the prior distribution tended to diminish as the sample size increased. On the other hand, the marginal likelihood is the result of multiplying the likelihood by the prior distribution and integrating it, so the value will be small if the prior distribution is either too narrow or too wide. <sup class="footnote-reference"><a href="#prml">1</a></sup> Therefore, for example, using an overly uninformative distribution can result in a smaller marginal likelihood for the alternative hypothesis and greater support for the null hypothesis.</p>
<p>Because of this property, when using Bayes factors for hypothesis testing, it is important to utilize domain knowledge in the design of prior distributions or to use distributions that have been agreed upon within the community. Additionally, sensitivity analysis regarding the choice of prior distributions is necessary. While there are default priors designed to alleviate this burden to some extent, there is <a href="https://vasishth.github.io/bayescogsci/book/ch-bf.html">criticism against blindly using them</a>. There are various discussions about prior distributions on <a href="https://www.reddit.com/r/statistics/comments/10qpo9v/q_how_can_bayes_factor_be_compatible_with_the/">Reddit</a>, and in a more practical setting, the importance of constructing prior distributions using past experimental results has been advocated in <a href="https://dl.acm.org/doi/10.1145/2740908.2742563">Bing’s A/B testing case study</a>.</p>
<h3 id="other-features">Other Features</h3>
<p>Apart from sensitivity to prior distributions, let’s summarize some features of Bayes factors:</p>
<ul>
<li>With Bayes factors, it is possible to accept the null hypothesis. This is because both the null hypothesis and the alternative hypothesis are treated equally.</li>
<li>Similar to the case with credible intervals, in general, it is not possible to perform the integrations analytically, so some numerical calculations are required. In R, there is a package called <a href="https://cran.r-project.org/web/packages/BayesFactor/vignettes/manual.html">BayesFactor</a> for this purpose.</li>
<li>It allows for the comparison of any two hypotheses. In the case of frequentist methods or Bayesian credible intervals, the null hypothesis and the alternative hypothesis had the same binomial distribution for coin tosses, and the value of the parameter $\mu$ was examined. With Bayes factors, it is possible to compare models with completely different distributions.</li>
</ul>
<h1 id="summary">Summary</h1>
<p>The above content is summarized in the table below:</p>
<table><thead><tr><th></th><th>Accepting Null Hypothesis</th><th>Comparison of Any Hypotheses</th><th>Dependence on Prior</th><th>Computational Cost</th></tr></thead><tbody>
<tr><td>Frequentist</td><td>Not possible</td><td>Not possible</td><td>-</td><td>Low</td></tr>
<tr><td>Bayesian Credible Intervals</td><td>Not possible</td><td>Not possible</td><td>Low</td><td>Generally high</td></tr>
<tr><td>Bayes Factors</td><td>Possible</td><td>Possible</td><td>High</td><td>Generally high</td></tr>
</tbody></table>
<p>Personally, considering the difficulty of calculations and sensitivity analysis, my impression is that the use of Bayes factors in practice seems to be limited to cases where analysis is not possible with other methods, such as wanting to accept the null hypothesis.</p>
<p>In the case of common RCT effectiveness evaluations, I think it is easier to use frequentist hypothesis testing after properly setting the effect size and sample size design.</p>
<h1 id="references">References</h1>
<ul>
<li><a href="https://doi.org/10.3758/s13423-017-1420-7">Bayesian inference for psychology, part IV parameter estimation and Bayes factors</a></li>
<li><a href="https://vasishth.github.io/bayescogsci/book/ch-bf.html">Chapter 15 Bayes factors | An Introduction to Bayesian Data Analysis for Cognitive Science</a></li>
<li><a href="https://www.jstor.org/stable/2291091">Bayes factor</a></li>
<li><a href="https://blog.logicoffee.tech/posts/math/different-types-of-hypothesis-testing.html">仮説検定の手法ごとに結論がどれほど変わるか比べてみた | Blogicoffee</a></li>
<li><a href="https://qiita.com/hokudai_meiyo/items/038b55e0380c3f653640">ベイズ統計の仮説検定 - Qiita</a></li>
<li><a href="https://www.jstage.jst.go.jp/article/sjpr/61/1/61_101/_pdf">ベイズファクターによる心理学的仮説・モデルの評価</a></li>
</ul>
<hr />
<div class="footnote-definition" id="prml"><sup class="footnote-definition-label">1</sup>
<p>For example, there is an intuitive explanation in Chapter 3 of PRML.</p>
</div>
Implementing a diffusion model in Julia/Lux.jl2022-12-14T00:00:00+00:002022-12-14T00:00:00+00:00Unknownhttps://yng87.page/en/blog/2022/lux-ddim/<p>This article is from <a href="https://qiita.com/advent-calendar/2022/julia">JuliaLang Advent Calendar 2022 (Japanese)</a> Day 14.</p>
<p>As a hobby recently, I was implementing a diffusion model using <a href="https://github.com/avik-pal/Lux.jl">Lux.jl</a>, a neural net framework made by Julia.</p>
<figure class="blogcard">
<a href="https://github.com/yng87/DDIM.jl" target="_blank" rel="noopener noreferrer" aria-label="記事詳細へ(別窓で開く)">
<div class="blogcard-content">
<div class="blogcard-image">
<div class="blogcard-image-wrapper">
<img src="https://opengraph.githubassets.com/0/yng87/DDIM.jl" alt=yng87/DDIM.jl loading="lazy">
</div>
</div>
<div class="blogcard-text">
<div class="blogcard-title">GitHub - yng87/DDIM.jl</div>
<div class="blogcard-footer">GitHub</div>
</div>
</div>
</a>
</figure>
<p>For training data, I used a flower dataset called <a href="https://www.robots.ox.ac.uk/~vgg/data/flowers/102/">Oxford 102 flowers dataset</a>, which, when trained well, can generate images like the following</p>
<p><img src="/blog/2022/ddim_img_005.gif" alt="" /></p>
<p>Due to the limited computing resources available, it is not possible to create a very high resolution image, but it is still possible to generate an image that looks like it.</p>
<p>I post them on twitter, (at @MathSorcerer’s suggestion) <a href="https://discourse.julialang.org/t/lux-jl-implementation-of-denoising-diffusion-model/90275">throw them to Julia discourse</a> and finally committed it to <a href="https://github.com/avik-pal/Lux.jl/tree/main/examples/DDIM">Lux.jl’s official example</a>. It was merged successfully the other day.</p>
<p>I have always thought Julia is a good language, but in the machine learning field, it has been somewhat invisible in front of the overwhelming amount of stuff in the python ecosystem. Nevertheless, recently, there seems to be a variety of languages available that can even be used for deep learning with GPUs.</p>
<p>In this article, I would like to introduce how to implement a deep learning model using Lux.jl and its surrounding ecosystem on the subject of diffusion models. It seems there are few articles of constructing a reasonably complex network that requires customizing the layers by yourself. I hope this article will be helpful as an example of how to implement a practical deep learning model.</p>
<p>Disclaimer: My experience with Julia is limited to using it in my doctoral research and occasionally using it as a hobby, so please comment if there are any mistakes in the content. Also, please forgive the fact that I write “diffusion model” in the title, but there is almost no discussion about diffusion models in the following.</p>
<!-- toc -->
<h1 id="diffusion-models">Diffusion Models</h1>
<p>Diffusion model is a deep generative model that is currently very popular and can obtain high quality images by gradually removing noise from random input. You can find many articles on the Web if you Google it, so I will not explain it in detail here. There is also <a href="/blog/2022/06/diffusion_model_derivation/">an article I wrote in Japanese</a>.</p>
<p>This time, I reimplemented <a href="https://keras.io/examples/generative/ddim/">the Keras example</a> in Julia.</p>
<p>The model used here is <a href="https://arxiv.org/abs/2010.02502">Denoising Diffusion Implicit Models (DDIM)</a>, which is an improved version of the earlier DDPM and other diffusion models to make the generation faster.</p>
<p>There is also a Jax implementation (<a href="https://github.com/daigo0927/jax-ddim">daigo0927/jax-ddim</a>) that I refer to here.</p>
<h1 id="deep-learning-packages-in-julia">Deep Learning packages in Julia</h1>
<h2 id="flux-and-lux">Flux and Lux</h2>
<p>Julia has several packages for neural networks, the most famous of which is <a href="https://github.com/FluxML/Flux.jl">Flux.jl</a>. This is similar to TensorFlow and PyTorch in terms of position. Flux defines the NN layer as a structure with trainable parameters, similar to the way they define models as classes. It also includes a data loader and activation function (actually re-exporting NNlib.jl and MLUtils.jl).</p>
<p>On the other hand, Lux.jl is similar to python frameworks such as <a href="https://github.com/deepmind/dm-haiku">Jax/Haiku</a>. Lux provides only layer management functions, and other functions need to be integrated with external packages. Other functions need to be integrated with external packages. Also, the layer structure in Lux only holds information to determine the structure of the model (e.g., number of input and output dimensions) and has no trainable parameters inside. The <a href="http://lux.csail.mit.edu/dev/">Lux documentation</a> states</p>
<blockquote>
<ul>
<li>Functional Design - Pure Functions and Deterministic Function Calls.</li>
<li>No more implicit parameterization.</li>
<li>Compiler and AD-friendly Neural Networks.</li>
</ul>
</blockquote>
<p>The design of the framework is oriented toward pure functionality, similar to the frameworks around Jax.</p>
<p>The reason for adopting such a design is well described in the document <a href="http://lux.csail.mit.edu/dev/introduction/overview/#why-use-lux-over-flux">Why use Lux over Flux?</a>. To summarize, it seems that explicit parameter handling is more compatible with Neural ODEs (which are popular in Julia’s machine learning community).</p>
<h2 id="deep-learning-ecosystem">Deep Learning Ecosystem</h2>
<p>Julia’s ecosystem has a design philosophy called composability <sup class="footnote-reference"><a href="#compbug">1</a></sup>, which encourages users to develop small independent packages that can be freely combined. As mentioned above, Lux.jl by itself cannot perform deep learning, and it is necessary to implement a data loader, optimizer, or other functions for image or language processing using other packages. This is in contrast to TensorFlow and PyTorch.</p>
<p>So, here is a list of packages that are often used when implementing neural nets in Julia</p>
<ul>
<li><a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a>: the de facto standard for automatic backward differentiation. While TensorFlow and PyTorch define their own tensor types and perform differentiation on them, Zygote can differentiate using Julia’s standard arrays, etc.</li>
<li><a href="https://github.com/FluxML/NNlib.jl">NNlib.jl</a>: The activation function and other functions are implemented. As mentioned above, Zygote allows you to perform differentiation using functions that you define as you like. However, since the numerical stability of the same function may differ depending on the implementation method, you should use the functions that exist in NNlib.</li>
<li><a href="https://github.com/JuliaML/MLUtils.jl">MLUtils.jl</a>: This module implements utilities such as a data loader, which reads data in thread parallel like PyTorch’s data loader.</li>
<li><a href="https://github.com/FluxML/Optimisers.jl">Optimisers.jl</a>: As the name suggests, an optimizer is implemented.</li>
</ul>
<h1 id="implementing-a-neural-network-with-lux-jl">Implementing a Neural Network with Lux.jl</h1>
<p>Let’s take a look at how to use Lux through the implementation of DDIM. Since the contents of DDIM and UNet, the backbone of DDIM, cannot be fully explained, please refer to the Keras example and repository implementations pasted above. Here, we will focus on the key points of using Lux.</p>
<h2 id="define-the-model">Define the model</h2>
<h3 id="using-preimplemented-layers">Using preimplemented layers</h3>
<p>Let’s look at the definition of <code>residual_block</code>, one of the components of the UNet architecture. This is a function that returns a layer based on the number of channels in the input and output images.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">function </span><span style="color:#61afef;">residual_block</span><span>(in_channels::Int, out_channels::Int)
</span><span> </span><span style="color:#c678dd;">if</span><span> in_channels </span><span style="color:#c678dd;">==</span><span> out_channels
</span><span> first_layer </span><span style="color:#c678dd;">=</span><span> NoOpLayer()
</span><span> </span><span style="color:#c678dd;">else
</span><span> first_layer </span><span style="color:#c678dd;">=</span><span> Conv((</span><span style="color:#e5c07b;">3</span><span>, </span><span style="color:#e5c07b;">3</span><span>), in_channels </span><span style="color:#c678dd;">=></span><span> out_channels; pad</span><span style="color:#c678dd;">=</span><span>SamePad())
</span><span> </span><span style="color:#c678dd;">end
</span><span>
</span><span> </span><span style="color:#c678dd;">return</span><span> Chain(first_layer,
</span><span> SkipConnection(Chain(BatchNorm(out_channels; affine</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">false</span><span>, momentum</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">0.99</span><span>),
</span><span> Conv((</span><span style="color:#e5c07b;">3</span><span>, </span><span style="color:#e5c07b;">3</span><span>), out_channels </span><span style="color:#c678dd;">=></span><span> out_channels; stride</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">1</span><span>,
</span><span> pad</span><span style="color:#c678dd;">=</span><span>(</span><span style="color:#e5c07b;">1</span><span>, </span><span style="color:#e5c07b;">1</span><span>)), swish,
</span><span> Conv((</span><span style="color:#e5c07b;">3</span><span>, </span><span style="color:#e5c07b;">3</span><span>), out_channels </span><span style="color:#c678dd;">=></span><span> out_channels; stride</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">1</span><span>,
</span><span> pad</span><span style="color:#c678dd;">=</span><span>(</span><span style="color:#e5c07b;">1</span><span>, </span><span style="color:#e5c07b;">1</span><span>))), </span><span style="color:#c678dd;">+</span><span>))
</span><span style="color:#c678dd;">end
</span></code></pre>
<p>Frequently used layers such as convolution and batch normalization are already implemented in Lux as <code>Conv</code> and <code>BatchNorm</code>. See <a href="http://lux.csail.mit.edu/stable/api/layers/">Documentation</a> for the currently implemented layers.</p>
<p>Lux uses <code>Chain</code> to connect layers in series, like <code>Sequential</code> in Keras. You can also define a skip connection of type $\boldsymbol x + f(\boldsymbol x)$ with <code>SkipConnection(some_layer, +)</code>.</p>
<p>In this case, we add a convolutional layer to align the number of channels only when the number of input channels <code>in_channel</code> and the number of output channels <code>out_channel</code> are different, but such a branch can be realized with <code>if</code> and <code>NoOpLayer</code> in combination with <code>Chain</code>.</p>
<p>To define a functional API or PyTorch-like model for Keras, you need to implement a custom layer as described below.</p>
<h3 id="implementing-a-custom-layer">Implementing a custom layer</h3>
<p>While a model with a simple structure can be implemented as above, the UNet used in this case generates skip connections between down-sampling blocks and up-sampling blocks, making it difficult to implement simply by combining <code>Chain</code> and <code>SkipConnection</code>. Also, DDIM requires processing such as adding noise to the input image and calling UNet. In such cases, it is better to define your own layers and implement flexible processing. The <code>DenoisingDiffusionImplicitModel</code> implemented this time also uses a custom layer.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">struct </span><span>DenoisingDiffusionImplicitModel{T </span><span style="color:#c678dd;"><:</span><span> AbstractFloat} </span><span style="color:#c678dd;"><:
</span><span> Lux.AbstractExplicitContainerLayer{(</span><span style="color:#e5c07b;">:unet</span><span>, </span><span style="color:#e5c07b;">:batchnorm</span><span>)}
</span><span> unet</span><span style="color:#c678dd;">::</span><span>UNet
</span><span> batchnorm</span><span style="color:#c678dd;">::</span><span>BatchNorm
</span><span> min_signal_rate</span><span style="color:#c678dd;">::</span><span>T
</span><span> max_signal_rate</span><span style="color:#c678dd;">::</span><span>T
</span><span style="color:#c678dd;">end
</span><span>
</span><span style="color:#5c6370;"># constructor
</span><span style="color:#c678dd;">function </span><span style="color:#61afef;">DenoisingDiffusionImplicitModel</span><span>(image_size::Tuple{Int, Int};
</span><span> channels=[32, 64, 96, 128], block_depth=2,
</span><span> min_freq=1.0f0, max_freq=1000.0f0,
</span><span> embedding_dims=32, min_signal_rate=0.02f0,
</span><span> max_signal_rate=0.95f0)
</span><span> unet </span><span style="color:#c678dd;">=</span><span> UNet(...)
</span><span> batchnorm </span><span style="color:#c678dd;">=</span><span> BatchNorm(...)
</span><span>
</span><span> </span><span style="color:#c678dd;">return</span><span> DenoisingDiffusionImplicitModel(unet, batchnorm, min_signal_rate,
</span><span> max_signal_rate)
</span><span style="color:#c678dd;">end
</span><span>
</span><span style="color:#5c6370;"># calling the model
</span><span style="color:#c678dd;">function </span><span>(ddim</span><span style="color:#c678dd;">::</span><span>DenoisingDiffusionImplicitModel{T})(x</span><span style="color:#c678dd;">::</span><span>Tuple{AbstractArray{T, </span><span style="color:#e5c07b;">4</span><span>},
</span><span> AbstractRNG}, ps,
</span><span> st</span><span style="color:#c678dd;">::</span><span>NamedTuple) </span><span style="color:#c678dd;">where </span><span>{
</span><span> T </span><span style="color:#c678dd;"><:
</span><span> AbstractFloat}
</span><span> images, rng </span><span style="color:#c678dd;">=</span><span> x
</span><span> </span><span style="color:#5c6370;"># generates noises and get pred_noises and pred_image from images + noises
</span><span> ...
</span><span> </span><span style="color:#c678dd;">return </span><span>(noises, images, pred_noises, pred_images), st
</span><span style="color:#c678dd;">end
</span><span>
</span></code></pre>
<p>All layers in Lux will be subtypes of <code>Lux.AbstractExplicitLayer</code>. In particular, if a layer contains another layer inside, it can be made a subtype of <code>Lux.AbstractExplicitContainerLayer</code> and given parameters like <code>(:unet, :batchnorm)</code> to facilitate tracking model parameters. This may be a little cumbersome to write. See <a href="http://lux.csail.mit.edu/stable/manual/interface/#lux-interface">documentation</a> for more information on custom layers.</p>
<p>Lux’s layers takes as an input a triplet <code>(data, model parameter, model state)</code>, and returns <code>(result, updated state)</code>. Initialization of <code>ps</code> and <code>st</code> is explained later.</p>
<h2 id="loss-functions">Loss functions</h2>
<p>As mentioned above, Zygote will differentiate any Julia function you like, so you can use the output of the model to define the loss as you like. Here we defined the following function that explicitly takes the parameters <code>ps</code> and state <code>st</code> along with the model <code>ddim</code>.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">function </span><span style="color:#61afef;">compute_loss</span><span>(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
</span><span> rng::AbstractRNG, ps, st::NamedTuple) </span><span style="color:#c678dd;">where </span><span>{T </span><span style="color:#c678dd;"><:</span><span> AbstractFloat}
</span><span> (noises, images, pred_noises, pred_images), st </span><span style="color:#c678dd;">=</span><span> ddim((images, rng), ps, st)
</span><span> noise_loss </span><span style="color:#c678dd;">=</span><span> mean(abs.(pred_noises </span><span style="color:#c678dd;">-</span><span> noises))
</span><span> image_loss </span><span style="color:#c678dd;">=</span><span> mean(abs.(pred_images </span><span style="color:#c678dd;">-</span><span> images))
</span><span> loss </span><span style="color:#c678dd;">=</span><span> noise_loss </span><span style="color:#c678dd;">+</span><span> image_loss
</span><span> </span><span style="color:#c678dd;">return</span><span> loss, st
</span><span style="color:#c678dd;">end
</span></code></pre>
<h2 id="training">Training</h2>
<h3 id="dataloader">DataLoader</h3>
<p>It is convenient to use <code>DataLoader</code> in MLUtils.jl for data loader. Pass a dataset and batching options as follows.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span>ds </span><span style="color:#c678dd;">=</span><span> ImageDataset(dataset_dir, x </span><span style="color:#c678dd;">-></span><span> preprocess_image(x, image_size), </span><span style="color:#e5c07b;">true</span><span>)
</span><span>data_loader </span><span style="color:#c678dd;">=</span><span> DataLoader(ds; batchsize</span><span style="color:#c678dd;">=</span><span>batchsize, partial</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">false</span><span>, collate</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">true</span><span>,
</span><span> parallel</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">true</span><span>, rng</span><span style="color:#c678dd;">=</span><span>rng, shuffle</span><span style="color:#c678dd;">=</span><span style="color:#e5c07b;">true</span><span>)
</span></code></pre>
<p>The dataset passed to <code>DataLoader</code> can be any Julia structure with the following two methods</p>
<ul>
<li><code>length(ds::ImageDataset)</code>: size of the dataset.</li>
<li><code>getindex(ds::ImageDataset, i::Int)</code>: get the <code>i</code>-th data.</li>
</ul>
<p>This allows for situations where data that does not fit in RAM can be read from disk. I think it was designed with PyTorch’s data loader in mind.</p>
<h3 id="initialization">Initialization</h3>
<p>Initialize the model with <code>Lux.setup</code> and the optimizers with <code>Optimisers.setup</code>.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span>rng </span><span style="color:#c678dd;">=</span><span> Random.MersenneTwister()
</span><span>
</span><span>ddim </span><span style="color:#c678dd;">=</span><span> DenoisingDiffusionImplicitModel((image_size, image_size); channels</span><span style="color:#c678dd;">=</span><span>channels,
</span><span> block_depth</span><span style="color:#c678dd;">=</span><span>block_depth, min_freq</span><span style="color:#c678dd;">=</span><span>min_freq,
</span><span> max_freq</span><span style="color:#c678dd;">=</span><span>max_freq, embedding_dims</span><span style="color:#c678dd;">=</span><span>embedding_dims,
</span><span> min_signal_rate</span><span style="color:#c678dd;">=</span><span>min_signal_rate,
</span><span> max_signal_rate</span><span style="color:#c678dd;">=</span><span>max_signal_rate)
</span><span>ps, st </span><span style="color:#c678dd;">=</span><span> Lux.setup(rng, ddim) </span><span style="color:#c678dd;">.|></span><span> gpu
</span><span>
</span><span>opt </span><span style="color:#c678dd;">=</span><span> AdamW(learning_rate, (</span><span style="color:#e5c07b;">9.0</span><span>f</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>, </span><span style="color:#e5c07b;">9.99</span><span>f</span><span style="color:#c678dd;">-</span><span style="color:#e5c07b;">1</span><span>), weight_decay)
</span><span>opt_st </span><span style="color:#c678dd;">=</span><span> Optimisers.setup(opt, ps) </span><span style="color:#c678dd;">|></span><span> gpu
</span></code></pre>
<p><code>|> gpu</code> transfers the parameters to the GPU. If the GPU is not available, <code>gpu</code> does nothing and returns the input as is, so there is no need to rewrite <code>gpu</code> and <code>cpu</code> for the device.</p>
<h3 id="training-loop">Training Loop</h3>
<p>The Training loop is outlined below.</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">function </span><span style="color:#61afef;">train_step</span><span>(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
</span><span> rng::AbstractRNG, ps, st::NamedTuple,
</span><span> opt_st::NamedTuple) </span><span style="color:#c678dd;">where </span><span>{T </span><span style="color:#c678dd;"><:</span><span> AbstractFloat}
</span><span> (loss, st), back </span><span style="color:#c678dd;">=</span><span> Zygote.pullback(p </span><span style="color:#c678dd;">-></span><span> compute_loss(ddim, images, rng, p, st), ps)
</span><span> gs </span><span style="color:#c678dd;">=</span><span> back((one(loss), </span><span style="color:#e5c07b;">nothing</span><span>))[</span><span style="color:#e5c07b;">1</span><span>]
</span><span> opt_st, ps </span><span style="color:#c678dd;">=</span><span> Optimisers.update(opt_st, ps, gs)
</span><span> </span><span style="color:#c678dd;">return</span><span> loss, ps, st, opt_st
</span><span style="color:#c678dd;">end
</span><span>...
</span><span>rng_gen </span><span style="color:#c678dd;">=</span><span> Random.MersenneTwister()
</span><span>Random.seed</span><span style="color:#c678dd;">!</span><span>(rng_gen, </span><span style="color:#e5c07b;">0</span><span>)
</span><span>...
</span><span style="color:#c678dd;">for</span><span> epoch </span><span style="color:#c678dd;">in </span><span style="color:#e5c07b;">1</span><span style="color:#c678dd;">:</span><span>epochs
</span><span> losses </span><span style="color:#c678dd;">= </span><span>[]
</span><span> iter </span><span style="color:#c678dd;">=</span><span> ProgressBar(data_loader)
</span><span>
</span><span> st </span><span style="color:#c678dd;">=</span><span> Lux.trainmode(st)
</span><span> </span><span style="color:#c678dd;">for</span><span> images </span><span style="color:#c678dd;">in</span><span> iter
</span><span> images </span><span style="color:#c678dd;">=</span><span> images </span><span style="color:#c678dd;">|></span><span> gpu
</span><span> loss, ps, st, opt_st </span><span style="color:#c678dd;">=</span><span> train_step(ddim, images, rng, ps, st, opt_st)
</span><span> push</span><span style="color:#c678dd;">!</span><span>(losses, loss)
</span><span> set_description(iter, </span><span style="color:#98c379;">"Epoch: </span><span style="color:#c678dd;">$</span><span>(epoch)</span><span style="color:#98c379;"> Loss: </span><span style="color:#c678dd;">$</span><span>(mean(losses))</span><span style="color:#98c379;">"</span><span>)
</span><span> </span><span style="color:#c678dd;">end
</span><span>
</span><span> </span><span style="color:#5c6370;"># generate and save inference at each epoch
</span><span> st </span><span style="color:#c678dd;">=</span><span> Lux.testmode(st)
</span><span> generated_images, _ </span><span style="color:#c678dd;">=</span><span> generate(ddim, Lux.replicate(rng_gen), </span><span style="color:#5c6370;"># to get inference on the same noises
</span><span> (image_size, image_size, </span><span style="color:#e5c07b;">3</span><span>, </span><span style="color:#e5c07b;">10</span><span>), val_diffusion_steps,
</span><span> ps, st)
</span><span> ...
</span><span style="color:#c678dd;">end
</span></code></pre>
<p>The <code>Zygote.pullback</code> is the automatic differentiation part, and the key is to use a lambda function so that only the derivative of the model parameter <code>ps</code> can be computed.</p>
<p>This time the model uses batch normalization, and this layer works differently for training and inference. This switching is done in <code>Lux.trainmode/Lux.testmode</code>.</p>
<p>A note on random numbers. This model makes heavy use of random number generation due to its nature. For example, in both training and inference, you can use <code>randn(rng, ...) </code>, where <code>randn</code> mutates the random seed <code>rng</code>. Especially, we want to evaluate the same random seed every time of inference, so here we pass a copy of the original random seed by <code>Lux.replicate</code>.</p>
<h3 id="saving-and-loading-models">Saving and Loading Models</h3>
<p>In Lux, snapshots are determined by giving <code>ps</code>, <code>st</code>, and <code>opt_st</code>. Since these are ordinary Julia objects, they can be serialized and saved with an appropriate library. For example, using <a href="https://github.com/JuliaIO/BSON.jl">BSON.jl</a></p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">function </span><span style="color:#61afef;">save_checkpoint</span><span>(ps, st, opt_st, output_dir, epoch)
</span><span> path </span><span style="color:#c678dd;">=</span><span> joinpath(output_dir, @</span><span style="color:#e06c75;">sprintf</span><span>(</span><span style="color:#98c379;">"checkpoint_%.4d.bson"</span><span>, epoch))
</span><span> </span><span style="color:#c678dd;">return</span><span> bson(path, Dict(</span><span style="color:#e5c07b;">:ps </span><span style="color:#c678dd;">=></span><span> cpu(ps), </span><span style="color:#e5c07b;">:st </span><span style="color:#c678dd;">=></span><span> cpu(st), </span><span style="color:#e5c07b;">:opt_st </span><span style="color:#c678dd;">=></span><span> cpu(opt_st)))
</span><span style="color:#c678dd;">end
</span></code></pre>
<p>Loading is done with</p>
<pre data-lang="julia" style="background-color:#282c34;color:#dcdfe4;" class="language-julia "><code class="language-julia" data-lang="julia"><span style="color:#c678dd;">function </span><span style="color:#61afef;">load_checkpoint</span><span>(path)
</span><span> d </span><span style="color:#c678dd;">=</span><span> BSON.load(path)
</span><span> </span><span style="color:#c678dd;">return</span><span> d[</span><span style="color:#e5c07b;">:ps</span><span>], d[</span><span style="color:#e5c07b;">:st</span><span>], d[</span><span style="color:#e5c07b;">:opt_st</span><span>]
</span><span style="color:#c678dd;">end
</span></code></pre>
<p>I feel that this ease of use is an advantage of explicitly handling parameters and state.</p>
<p>For the full implementation, please see the repository linked at the beginning of this article.</p>
<h1 id="summary">Summary</h1>
<p>Through this implementation, I found that even a reasonably complex neural network model like diffusion model can be implemented in Julia without much difficulty. Since Lux handles parameters explicitly, it is easy to know what you are doing and has good visibility. This is similar to Jax.</p>
<p>Many people are wondering if Lux has an advantage over python. Personally, I think the advantages of using Julia for deep learning are</p>
<ul>
<li>JIT and the ability to write <code>for</code> loops easily, which makes custom-written loss and data preprocessing faster. In particular, processes that are difficult to vectorize with numpy should be faster with Julia.</li>
<li>Since it can be typed, type parsing can reduce bugs.</li>
</ul>
<p>I feel that Julia is a good choice.</p>
<p>On the other hand, TensorFlow and PyTorch are sufficient for models that can be easily written using APIs provided by frameworks, and python is by far the strongest in terms of production support and integration with cloud services. <sup class="footnote-reference"><a href="#large">2</a></sup> In terms of speed, I don’t think that Julia is superior unless CPU processing is the bottleneck, as I don’t think the speed on the GPU is much different between the two. In addition, I think python is a better experience when it comes to the process of writing code, running scripts, and debugging, because Julia has what is commonly called <a href="https://discourse.julialang.org/t/taking-ttfx-seriously-can-we-make-common-packages-faster-to-load-and-use/74949">TTFX</a> JIT overhead.</p>
<p>Personally, I am interested in combining neural networks and differential equations, such as those being worked on in <a href="https://sciml.ai/">SciML</a>, and since Julia is said to be superior in that field, I would like to dig a little deeper.</p>
<hr />
<div class="footnote-definition" id="compbug"><sup class="footnote-definition-label">1</sup>
<p>On the other hand, combining various packages can lead to bugs that the respective package developers did not anticipate, sometimes causing critical threads and become a hot topic. For example, <a href="https://news.ycombinator.com/item?id=31396861">this</a>.</p>
</div>
<div class="footnote-definition" id="large"><sup class="footnote-definition-label">2</sup>
<p>I think there are times, especially in Industry, when you have large scale distributed learning, but I haven’t looked into whether Julia can handle that much.</p>
</div>