# Implementing a diffusion model in Julia/Lux.jl

As a hobby recently, I was implementing a diffusion model using Lux.jl, a neural net framework made by Julia.

For training data, I used a flower dataset called Oxford 102 flowers dataset, which, when trained well, can generate images like the following

Due to the limited computing resources available, it is not possible to create a very high resolution image, but it is still possible to generate an image that looks like it.

I post them on twitter, (at @MathSorcerer’s suggestion) throw them to Julia discourse and finally committed it to Lux.jl’s official example. It was merged successfully the other day.

I have always thought Julia is a good language, but in the machine learning field, it has been somewhat invisible in front of the overwhelming amount of stuff in the python ecosystem. Nevertheless, recently, there seems to be a variety of languages available that can even be used for deep learning with GPUs.

In this article, I would like to introduce how to implement a deep learning model using Lux.jl and its surrounding ecosystem on the subject of diffusion models. It seems there are few articles of constructing a reasonably complex network that requires customizing the layers by yourself. I hope this article will be helpful as an example of how to implement a practical deep learning model.

Disclaimer: My experience with Julia is limited to using it in my doctoral research and occasionally using it as a hobby, so please comment if there are any mistakes in the content. Also, please forgive the fact that I write “diffusion model” in the title, but there is almost no discussion about diffusion models in the following.

# Diffusion Models

Diffusion model is a deep generative model that is currently very popular and can obtain high quality images by gradually removing noise from random input. You can find many articles on the Web if you Google it, so I will not explain it in detail here. There is also an article I wrote in Japanese.

This time, I reimplemented the following Keras example in Julia.

The model used here is Denoising Diffusion Implicit Models (DDIM), which is an improved version of the earlier DDPM and other diffusion models to make the generation faster.

There is also a Jax implementation (daigo0927/jax-ddim) that I refer to here.

# Deep Learning packages in Julia

## Flux and Lux

Julia has several packages for neural networks, the most famous of which is Flux.jl. This is similar to TensorFlow and PyTorch in terms of position. Flux defines the NN layer as a structure with trainable parameters, similar to the way they define models as classes. It also includes a data loader and activation function (actually re-exporting NNlib.jl and MLUtils.jl).

On the other hand, Lux.jl is similar to python frameworks such as Jax/Haiku. Lux provides only layer management functions, and other functions need to be integrated with external packages. Other functions need to be integrated with external packages. Also, the layer structure in Lux only holds information to determine the structure of the model (e.g., number of input and output dimensions) and has no trainable parameters inside. The Lux documentation states

• Functional Design - Pure Functions and Deterministic Function Calls.
• No more implicit parameterization.
• Compiler and AD-friendly Neural Networks.

The design of the framework is oriented toward pure functionality, similar to the frameworks around Jax.

The reason for adopting such a design is well described in the document Why use Lux over Flux?. To summarize, it seems that explicit parameter handling is more compatible with Neural ODEs (which are popular in Julia’s machine learning community).

## Deep Learning Ecosystem

Julia’s ecosystem has a design philosophy called composability 1, which encourages users to develop small independent packages that can be freely combined. As mentioned above, Lux.jl by itself cannot perform deep learning, and it is necessary to implement a data loader, optimizer, or other functions for image or language processing using other packages. This is in contrast to TensorFlow and PyTorch.

So, here is a list of packages that are often used when implementing neural nets in Julia

• Zygote.jl: the de facto standard for automatic backward differentiation. While TensorFlow and PyTorch define their own tensor types and perform differentiation on them, Zygote can differentiate using Julia’s standard arrays, etc.
• NNlib.jl: The activation function and other functions are implemented. As mentioned above, Zygote allows you to perform differentiation using functions that you define as you like. However, since the numerical stability of the same function may differ depending on the implementation method, you should use the functions that exist in NNlib.
• Optimisers.jl: As the name suggests, an optimizer is implemented.

# Implementing a Neural Network with Lux.jl

Let’s take a look at how to use Lux through the implementation of DDIM. Since the contents of DDIM and UNet, the backbone of DDIM, cannot be fully explained, please refer to the Keras example and repository implementations pasted above. Here, we will focus on the key points of using Lux.

## Define the model

### Using preimplemented layers

Let’s look at the definition of residual_block, one of the components of the UNet architecture. This is a function that returns a layer based on the number of channels in the input and output images.

function residual_block(in_channels::Int, out_channels::Int)
if in_channels == out_channels
first_layer = NoOpLayer()
else
end

return Chain(first_layer,
SkipConnection(Chain(BatchNorm(out_channels; affine=false, momentum=0.99),
Conv((3, 3), out_channels => out_channels; stride=1,
Conv((3, 3), out_channels => out_channels; stride=1,
end


Frequently used layers such as convolution and batch normalization are already implemented in Lux as Conv and BatchNorm. See Documentation for the currently implemented layers.

Lux uses Chain to connect layers in series, like Sequential in Keras. You can also define a skip connection of type $\boldsymbol x + f(\boldsymbol x)$ with SkipConnection(some_layer, +).

In this case, we add a convolutional layer to align the number of channels only when the number of input channels in_channel and the number of output channels out_channel are different, but such a branch can be realized with if and NoOpLayer in combination with Chain.

To define a functional API or PyTorch-like model for Keras, you need to implement a custom layer as described below.

### Implementing a custom layer

While a model with a simple structure can be implemented as above, the UNet used in this case generates skip connections between down-sampling blocks and up-sampling blocks, making it difficult to implement simply by combining Chain and SkipConnection. Also, DDIM requires processing such as adding noise to the input image and calling UNet. In such cases, it is better to define your own layers and implement flexible processing. The DenoisingDiffusionImplicitModel implemented this time also uses a custom layer.

struct DenoisingDiffusionImplicitModel{T <: AbstractFloat} <:
Lux.AbstractExplicitContainerLayer{(:unet, :batchnorm)}
unet::UNet
batchnorm::BatchNorm
min_signal_rate::T
max_signal_rate::T
end

# constructor
function DenoisingDiffusionImplicitModel(image_size::Tuple{Int, Int};
channels=[32, 64, 96, 128], block_depth=2,
min_freq=1.0f0, max_freq=1000.0f0,
embedding_dims=32, min_signal_rate=0.02f0,
max_signal_rate=0.95f0)
unet = UNet(...)
batchnorm = BatchNorm(...)

return DenoisingDiffusionImplicitModel(unet, batchnorm, min_signal_rate,
max_signal_rate)
end

# calling the model
function (ddim::DenoisingDiffusionImplicitModel{T})(x::Tuple{AbstractArray{T, 4},
AbstractRNG}, ps,
st::NamedTuple) where {
T <:
AbstractFloat}
images, rng = x
# generates noises and get pred_noises and pred_image from images + noises
...
return (noises, images, pred_noises, pred_images), st
end



All layers in Lux will be subtypes of Lux.AbstractExplicitLayer. In particular, if a layer contains another layer inside, it can be made a subtype of Lux.AbstractExplicitContainerLayer and given parameters like (:unet, :batchnorm) to facilitate tracking model parameters. This may be a little cumbersome to write. See documentation for more information on custom layers.

Lux’s layers takes as an input a triplet (data, model parameter, model state), and returns (result, updated state). Initialization of ps and st is explained later.

## Loss functions

As mentioned above, Zygote will differentiate any Julia function you like, so you can use the output of the model to define the loss as you like. Here we defined the following function that explicitly takes the parameters ps and state st along with the model ddim.

function compute_loss(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
rng::AbstractRNG, ps, st::NamedTuple) where {T <: AbstractFloat}
(noises, images, pred_noises, pred_images), st = ddim((images, rng), ps, st)
noise_loss = mean(abs.(pred_noises - noises))
image_loss = mean(abs.(pred_images - images))
loss = noise_loss + image_loss
return loss, st
end


## Training

It is convenient to use DataLoader in MLUtils.jl for data loader. Pass a dataset and batching options as follows.

ds = ImageDataset(dataset_dir, x -> preprocess_image(x, image_size), true)
parallel=true, rng=rng, shuffle=true)


The dataset passed to DataLoader can be any Julia structure with the following two methods

• length(ds::ImageDataset): size of the dataset.
• getindex(ds::ImageDataset, i::Int): get the i-th data.

This allows for situations where data that does not fit in RAM can be read from disk. I think it was designed with PyTorch’s data loader in mind.

### Initialization

Initialize the model with Lux.setup and the optimizers with Optimisers.setup.

rng = Random.MersenneTwister()

ddim = DenoisingDiffusionImplicitModel((image_size, image_size); channels=channels,
block_depth=block_depth, min_freq=min_freq,
max_freq=max_freq, embedding_dims=embedding_dims,
min_signal_rate=min_signal_rate,
max_signal_rate=max_signal_rate)
ps, st = Lux.setup(rng, ddim) .|> gpu

opt = AdamW(learning_rate, (9.0f-1, 9.99f-1), weight_decay)
opt_st = Optimisers.setup(opt, ps) |> gpu


|> gpu transfers the parameters to the GPU. If the GPU is not available, gpu does nothing and returns the input as is, so there is no need to rewrite gpu and cpu for the device.

### Training Loop

The Training loop is outlined below.

function train_step(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
rng::AbstractRNG, ps, st::NamedTuple,
opt_st::NamedTuple) where {T <: AbstractFloat}
(loss, st), back = Zygote.pullback(p -> compute_loss(ddim, images, rng, p, st), ps)
gs = back((one(loss), nothing))[1]
opt_st, ps = Optimisers.update(opt_st, ps, gs)
return loss, ps, st, opt_st
end
...
rng_gen = Random.MersenneTwister()
Random.seed!(rng_gen, 0)
...
for epoch in 1:epochs
losses = []

st = Lux.trainmode(st)
for images in iter
images = images |> gpu
loss, ps, st, opt_st = train_step(ddim, images, rng, ps, st, opt_st)
push!(losses, loss)
set_description(iter, "Epoch: $(epoch) Loss:$(mean(losses))")
end

# generate and save inference at each epoch
st = Lux.testmode(st)
generated_images, _ = generate(ddim, Lux.replicate(rng_gen), # to get inference on the same noises
(image_size, image_size, 3, 10), val_diffusion_steps,
ps, st)
...
end


The Zygote.pullback is the automatic differentiation part, and the key is to use a lambda function so that only the derivative of the model parameter ps can be computed.

This time the model uses batch normalization, and this layer works differently for training and inference. This switching is done in Lux.trainmode/Lux.testmode.

A note on random numbers. This model makes heavy use of random number generation due to its nature. For example, in both training and inference, you can use randn(rng, ...) , where randn mutates the random seed rng. Especially, we want to evaluate the same random seed every time of inference, so here we pass a copy of the original random seed by Lux.replicate.

In Lux, snapshots are determined by giving ps, st, and opt_st. Since these are ordinary Julia objects, they can be serialized and saved with an appropriate library. For example, using BSON.jl

function save_checkpoint(ps, st, opt_st, output_dir, epoch)
path = joinpath(output_dir, @sprintf("checkpoint_%.4d.bson", epoch))
return bson(path, Dict(:ps => cpu(ps), :st => cpu(st), :opt_st => cpu(opt_st)))
end


function load_checkpoint(path)
return d[:ps], d[:st], d[:opt_st]
end


I feel that this ease of use is an advantage of explicitly handling parameters and state.

# Summary

Through this implementation, I found that even a reasonably complex neural network model like diffusion model can be implemented in Julia without much difficulty. Since Lux handles parameters explicitly, it is easy to know what you are doing and has good visibility. This is similar to Jax.

Many people are wondering if Lux has an advantage over python. Personally, I think the advantages of using Julia for deep learning are

• JIT and the ability to write for loops easily, which makes custom-written loss and data preprocessing faster. In particular, processes that are difficult to vectorize with numpy should be faster with Julia.
• Since it can be typed, type parsing can reduce bugs.

I feel that Julia is a good choice.

On the other hand, TensorFlow and PyTorch are sufficient for models that can be easily written using APIs provided by frameworks, and python is by far the strongest in terms of production support and integration with cloud services. 2 In terms of speed, I don’t think that Julia is superior unless CPU processing is the bottleneck, as I don’t think the speed on the GPU is much different between the two. In addition, I think python is a better experience when it comes to the process of writing code, running scripts, and debugging, because Julia has what is commonly called TTFX JIT overhead.

Personally, I am interested in combining neural networks and differential equations, such as those being worked on in SciML, and since Julia is said to be superior in that field, I would like to dig a little deeper.

1

On the other hand, combining various packages can lead to bugs that the respective package developers did not anticipate, sometimes causing critical threads and become a hot topic. For example, this.

2

I think there are times, especially in Industry, when you have large scale distributed learning, but I haven’t looked into whether Julia can handle that much.