This article is from JuliaLang Advent Calendar 2022 (Japanese) Day 14.
As a hobby recently, I was implementing a diffusion model using Lux.jl, a neural net framework made by Julia.
For training data, I used a flower dataset called Oxford 102 flowers dataset, which, when trained well, can generate images like the following
Due to the limited computing resources available, it is not possible to create a very high resolution image, but it is still possible to generate an image that looks like it.
I post them on twitter, (at @MathSorcerer’s suggestion) throw them to Julia discourse and finally committed it to Lux.jl’s official example. It was merged successfully the other day.
I have always thought Julia is a good language, but in the machine learning field, it has been somewhat invisible in front of the overwhelming amount of stuff in the python ecosystem. Nevertheless, recently, there seems to be a variety of languages available that can even be used for deep learning with GPUs.
In this article, I would like to introduce how to implement a deep learning model using Lux.jl and its surrounding ecosystem on the subject of diffusion models. It seems there are few articles of constructing a reasonably complex network that requires customizing the layers by yourself. I hope this article will be helpful as an example of how to implement a practical deep learning model.
Disclaimer: My experience with Julia is limited to using it in my doctoral research and occasionally using it as a hobby, so please comment if there are any mistakes in the content. Also, please forgive the fact that I write “diffusion model” in the title, but there is almost no discussion about diffusion models in the following.
Diffusion model is a deep generative model that is currently very popular and can obtain high quality images by gradually removing noise from random input. You can find many articles on the Web if you Google it, so I will not explain it in detail here. There is also an article I wrote in Japanese.
This time, I reimplemented the Keras example in Julia.
The model used here is Denoising Diffusion Implicit Models (DDIM), which is an improved version of the earlier DDPM and other diffusion models to make the generation faster.
There is also a Jax implementation (daigo0927/jax-ddim) that I refer to here.
Julia has several packages for neural networks, the most famous of which is Flux.jl. This is similar to TensorFlow and PyTorch in terms of position. Flux defines the NN layer as a structure with trainable parameters, similar to the way they define models as classes. It also includes a data loader and activation function (actually re-exporting NNlib.jl and MLUtils.jl).
On the other hand, Lux.jl is similar to python frameworks such as Jax/Haiku. Lux provides only layer management functions, and other functions need to be integrated with external packages. Other functions need to be integrated with external packages. Also, the layer structure in Lux only holds information to determine the structure of the model (e.g., number of input and output dimensions) and has no trainable parameters inside. The Lux documentation states
- Functional Design - Pure Functions and Deterministic Function Calls.
- No more implicit parameterization.
- Compiler and AD-friendly Neural Networks.
The design of the framework is oriented toward pure functionality, similar to the frameworks around Jax.
The reason for adopting such a design is well described in the document Why use Lux over Flux?. To summarize, it seems that explicit parameter handling is more compatible with Neural ODEs (which are popular in Julia’s machine learning community).
Julia’s ecosystem has a design philosophy called composability 1, which encourages users to develop small independent packages that can be freely combined. As mentioned above, Lux.jl by itself cannot perform deep learning, and it is necessary to implement a data loader, optimizer, or other functions for image or language processing using other packages. This is in contrast to TensorFlow and PyTorch.
So, here is a list of packages that are often used when implementing neural nets in Julia
Let’s take a look at how to use Lux through the implementation of DDIM. Since the contents of DDIM and UNet, the backbone of DDIM, cannot be fully explained, please refer to the Keras example and repository implementations pasted above. Here, we will focus on the key points of using Lux.
Let’s look at the definition of residual_block
, one of the components of the UNet architecture. This is a function that returns a layer based on the number of channels in the input and output images.
function residual_block(in_channels::Int, out_channels::Int)
if in_channels == out_channels
first_layer = NoOpLayer()
else
first_layer = Conv((3, 3), in_channels => out_channels; pad=SamePad())
end
return Chain(first_layer,
SkipConnection(Chain(BatchNorm(out_channels; affine=false, momentum=0.99),
Conv((3, 3), out_channels => out_channels; stride=1,
pad=(1, 1)), swish,
Conv((3, 3), out_channels => out_channels; stride=1,
pad=(1, 1))), +))
end
Frequently used layers such as convolution and batch normalization are already implemented in Lux as Conv
and BatchNorm
. See Documentation for the currently implemented layers.
Lux uses Chain
to connect layers in series, like Sequential
in Keras. You can also define a skip connection of type $\boldsymbol x + f(\boldsymbol x)$ with SkipConnection(some_layer, +)
.
In this case, we add a convolutional layer to align the number of channels only when the number of input channels in_channel
and the number of output channels out_channel
are different, but such a branch can be realized with if
and NoOpLayer
in combination with Chain
.
To define a functional API or PyTorch-like model for Keras, you need to implement a custom layer as described below.
While a model with a simple structure can be implemented as above, the UNet used in this case generates skip connections between down-sampling blocks and up-sampling blocks, making it difficult to implement simply by combining Chain
and SkipConnection
. Also, DDIM requires processing such as adding noise to the input image and calling UNet. In such cases, it is better to define your own layers and implement flexible processing. The DenoisingDiffusionImplicitModel
implemented this time also uses a custom layer.
struct DenoisingDiffusionImplicitModel{T <: AbstractFloat} <:
Lux.AbstractExplicitContainerLayer{(:unet, :batchnorm)}
unet::UNet
batchnorm::BatchNorm
min_signal_rate::T
max_signal_rate::T
end
# constructor
function DenoisingDiffusionImplicitModel(image_size::Tuple{Int, Int};
channels=[32, 64, 96, 128], block_depth=2,
min_freq=1.0f0, max_freq=1000.0f0,
embedding_dims=32, min_signal_rate=0.02f0,
max_signal_rate=0.95f0)
unet = UNet(...)
batchnorm = BatchNorm(...)
return DenoisingDiffusionImplicitModel(unet, batchnorm, min_signal_rate,
max_signal_rate)
end
# calling the model
function (ddim::DenoisingDiffusionImplicitModel{T})(x::Tuple{AbstractArray{T, 4},
AbstractRNG}, ps,
st::NamedTuple) where {
T <:
AbstractFloat}
images, rng = x
# generates noises and get pred_noises and pred_image from images + noises
...
return (noises, images, pred_noises, pred_images), st
end
All layers in Lux will be subtypes of Lux.AbstractExplicitLayer
. In particular, if a layer contains another layer inside, it can be made a subtype of Lux.AbstractExplicitContainerLayer
and given parameters like (:unet, :batchnorm)
to facilitate tracking model parameters. This may be a little cumbersome to write. See documentation for more information on custom layers.
Lux’s layers takes as an input a triplet (data, model parameter, model state)
, and returns (result, updated state)
. Initialization of ps
and st
is explained later.
As mentioned above, Zygote will differentiate any Julia function you like, so you can use the output of the model to define the loss as you like. Here we defined the following function that explicitly takes the parameters ps
and state st
along with the model ddim
.
function compute_loss(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
rng::AbstractRNG, ps, st::NamedTuple) where {T <: AbstractFloat}
(noises, images, pred_noises, pred_images), st = ddim((images, rng), ps, st)
noise_loss = mean(abs.(pred_noises - noises))
image_loss = mean(abs.(pred_images - images))
loss = noise_loss + image_loss
return loss, st
end
It is convenient to use DataLoader
in MLUtils.jl for data loader. Pass a dataset and batching options as follows.
ds = ImageDataset(dataset_dir, x -> preprocess_image(x, image_size), true)
data_loader = DataLoader(ds; batchsize=batchsize, partial=false, collate=true,
parallel=true, rng=rng, shuffle=true)
The dataset passed to DataLoader
can be any Julia structure with the following two methods
length(ds::ImageDataset)
: size of the dataset.getindex(ds::ImageDataset, i::Int)
: get the i
-th data.This allows for situations where data that does not fit in RAM can be read from disk. I think it was designed with PyTorch’s data loader in mind.
Initialize the model with Lux.setup
and the optimizers with Optimisers.setup
.
rng = Random.MersenneTwister()
ddim = DenoisingDiffusionImplicitModel((image_size, image_size); channels=channels,
block_depth=block_depth, min_freq=min_freq,
max_freq=max_freq, embedding_dims=embedding_dims,
min_signal_rate=min_signal_rate,
max_signal_rate=max_signal_rate)
ps, st = Lux.setup(rng, ddim) .|> gpu
opt = AdamW(learning_rate, (9.0f-1, 9.99f-1), weight_decay)
opt_st = Optimisers.setup(opt, ps) |> gpu
|> gpu
transfers the parameters to the GPU. If the GPU is not available, gpu
does nothing and returns the input as is, so there is no need to rewrite gpu
and cpu
for the device.
The Training loop is outlined below.
function train_step(ddim::DenoisingDiffusionImplicitModel{T}, images::AbstractArray{T, 4},
rng::AbstractRNG, ps, st::NamedTuple,
opt_st::NamedTuple) where {T <: AbstractFloat}
(loss, st), back = Zygote.pullback(p -> compute_loss(ddim, images, rng, p, st), ps)
gs = back((one(loss), nothing))[1]
opt_st, ps = Optimisers.update(opt_st, ps, gs)
return loss, ps, st, opt_st
end
...
rng_gen = Random.MersenneTwister()
Random.seed!(rng_gen, 0)
...
for epoch in 1:epochs
losses = []
iter = ProgressBar(data_loader)
st = Lux.trainmode(st)
for images in iter
images = images |> gpu
loss, ps, st, opt_st = train_step(ddim, images, rng, ps, st, opt_st)
push!(losses, loss)
set_description(iter, "Epoch: $(epoch) Loss: $(mean(losses))")
end
# generate and save inference at each epoch
st = Lux.testmode(st)
generated_images, _ = generate(ddim, Lux.replicate(rng_gen), # to get inference on the same noises
(image_size, image_size, 3, 10), val_diffusion_steps,
ps, st)
...
end
The Zygote.pullback
is the automatic differentiation part, and the key is to use a lambda function so that only the derivative of the model parameter ps
can be computed.
This time the model uses batch normalization, and this layer works differently for training and inference. This switching is done in Lux.trainmode/Lux.testmode
.
A note on random numbers. This model makes heavy use of random number generation due to its nature. For example, in both training and inference, you can use randn(rng, ...)
, where randn
mutates the random seed rng
. Especially, we want to evaluate the same random seed every time of inference, so here we pass a copy of the original random seed by Lux.replicate
.
In Lux, snapshots are determined by giving ps
, st
, and opt_st
. Since these are ordinary Julia objects, they can be serialized and saved with an appropriate library. For example, using BSON.jl
function save_checkpoint(ps, st, opt_st, output_dir, epoch)
path = joinpath(output_dir, @sprintf("checkpoint_%.4d.bson", epoch))
return bson(path, Dict(:ps => cpu(ps), :st => cpu(st), :opt_st => cpu(opt_st)))
end
Loading is done with
function load_checkpoint(path)
d = BSON.load(path)
return d[:ps], d[:st], d[:opt_st]
end
I feel that this ease of use is an advantage of explicitly handling parameters and state.
For the full implementation, please see the repository linked at the beginning of this article.
Through this implementation, I found that even a reasonably complex neural network model like diffusion model can be implemented in Julia without much difficulty. Since Lux handles parameters explicitly, it is easy to know what you are doing and has good visibility. This is similar to Jax.
Many people are wondering if Lux has an advantage over python. Personally, I think the advantages of using Julia for deep learning are
for
loops easily, which makes custom-written loss and data preprocessing faster. In particular, processes that are difficult to vectorize with numpy should be faster with Julia.I feel that Julia is a good choice.
On the other hand, TensorFlow and PyTorch are sufficient for models that can be easily written using APIs provided by frameworks, and python is by far the strongest in terms of production support and integration with cloud services. 2 In terms of speed, I don’t think that Julia is superior unless CPU processing is the bottleneck, as I don’t think the speed on the GPU is much different between the two. In addition, I think python is a better experience when it comes to the process of writing code, running scripts, and debugging, because Julia has what is commonly called TTFX JIT overhead.
Personally, I am interested in combining neural networks and differential equations, such as those being worked on in SciML, and since Julia is said to be superior in that field, I would like to dig a little deeper.
On the other hand, combining various packages can lead to bugs that the respective package developers did not anticipate, sometimes causing critical threads and become a hot topic. For example, this.
I think there are times, especially in Industry, when you have large scale distributed learning, but I haven’t looked into whether Julia can handle that much.