Skip to content

Neural Network Emission Model Example

This is a pedagogical example that serves as a proof of concept. We will build a simple General Relativistic Neural Radiance Field (NeRF) model to be ray traced and optimize with the ADAM algorithm.

##Setup We will first import the necessary packages and set the random seed. Our emission model will be a neural network built with Lux.jl

julia
using Lux
using Krang
using Random
Random.seed!(123)
rng = Random.GLOBAL_RNG
Random._GLOBAL_RNG()

Our model will take in spacetime coordinates and return observed intensity value for a given pixel: We will use 0.99 spin Kerr metric with an observer sitting at 20 degrees inclination with respect to the spin axis in this example. These parameters are fixed for this example, but could be made to vary in the optimization process.

Lets define an ImageModel which will be comprised of an emission layer that we will raytrace. We will do this by first creating a struct to represent our image model that will store our emission model as a layer.

julia
struct ImageModel{T <: Chain}
    emission_layer::T
end

The models in Lux are functors that take in features, parameters and model state, and return the output and model state. Lets define the function associated with our ImageModel type. We will assume that the emission is coming from emission that originates in the equatorial plane.

julia
function (m::ImageModel)(x, ps, st)
    metric = Krang.Kerr(0.99e0)
    θo = Float64(20/180*π)
    pixels = Krang.IntensityPixel.(Ref(metric), x[1,:], x[2,:], θo)

    sze = unsafe_trunc(Int, sqrt(size(x)[2]))
    coords = zeros(Float64, 2,sze*sze)
    emission_vals = zeros(Float64, 1, sze*sze)
    for n in 0:1
        for i in 1:sze
            for j in 1:sze
                pix = pixels[i+(j-1)*sze]
                α, β = Krang.screen_coordinate(pix)
                T = typeof(α)
                rs, ϕs = Krang.emission_coordinates_fast_light(pix, Float64(π/2), β > 0, n)[1:3]
                xs = rs * cos(ϕs)
                ys = rs * sin(ϕs)
                if hypot(xs, ys)  Krang.horizon(metric)
                    coords[1,i+(j-1)*sze] = zero(T)
                    coords[2,i+(j-1)*sze] = zero(T)
                else
                    coords[1,i+(j-1)*sze] = xs
                    coords[2,i+(j-1)*sze] = ys
                end

            end
        end
        emission_vals .+= m.emission_layer(coords, ps, st)[1]
    end
    emission_vals,st
end

Lets define an emisison layer for our model as a simple fully connected neural network with 2 hidden layers. The emission layer will take in 2D coordinates on an equatorial disk in the bulk spacetime and return a scalar intensity value.

julia
emission_model = Chain(
    Dense(2 => 20, Lux.sigmoid),
    Dense(20 => 20, Lux.sigmoid),
    Dense(20 => 1, Lux.sigmoid)
    )

ps, st = Lux.setup(rng, emission_model); # Get the emission model parameters and state

We can now create an image model with our emission layer.

julia
image_model = ImageModel(emission_model)

# Plotting the model
Main.var"Main".ImageModel{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 20, σ), layer_2 = Dense(20 => 20, σ), layer_3 = Dense(20 => 1, σ)), nothing))

Lets create an 20x20 pixel image of the image_model with a field of view of 10MG/c2.

julia
sze = 20
ρmax = 10e0
pixels = zeros(Float64, 2, sze*sze)
for (iiter, i) in enumerate(range(-ρmax, ρmax, sze))
    for (jiter, j) in enumerate(range(-ρmax, ρmax, sze))
        pixels[1,iiter+(jiter-1)*sze] = Float64(i)
        pixels[2,iiter+(jiter-1)*sze] = Float64(j)
    end
end

We can see the effects of raytracing on emission in the bulk spacetime by plotting an image of the emission model and the image model.

julia
using CairoMakie
curr_theme = Theme(
    Axis = (xticksvisible=false, xticklabelsvisible=false, yticksvisible=false, yticklabelsvisible=false,),
    Heatmap =(colormap=:afmhot, ),
)
set_theme!(merge(curr_theme, theme_latexfonts()))

emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)

fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity)
heatmap!(Axis(fig[1,2], aspect=1, title="Image Model (Lensed Emission Model)"), received_intensity)
save("emission_model_and_target_model.png", fig)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148

Fitting the NeRF model

This will be a toy example showing the mechanics of fitting our ImageModel to a target image using the normalized cross correlation as a kernel for our loss function. This will be the image we will try to fit our model to.

julia
target_img = reshape(received_intensity, 1, sze*sze);

Lets fit our model using the normalized cross correlation as a kernel for our loss function.

julia
using Enzyme
using Optimization
using OptimizationOptimisers
using StatsBase
using ComponentArrays
Enzyme.Compiler.RunAttributor[] = false

function mse(img1::Matrix{T}, img2::Matrix{T}) where T
    mean(((img1 ./ sum(img1))  .- (img2 ./ sum(img2))) .^ 2)
end

function loss_function(pixels, y, ps, st)
    y_pred, st = image_model(pixels, ps, st)
    mse(y, y_pred), st
end

mse(target_img, target_img)

ps, st = Lux.setup(rng, emission_model);
image_model = ImageModel(emission_model);

emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)
loss_function(pixels, target_img, ps, st)

fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity, colormap=:afmhot)
heatmap!(Axis(fig[1,2], aspect=1, title="Imgage Model (Lensed Emission Model)"), received_intensity, colormap=:afmhot)
save("emission_model_and_image_model.png", fig)
WARNING: using ComponentArrays.Axis in module Main conflicts with an existing identifier.
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148

Lets define callback function to print the loss as we optimize our model.

julia
mutable struct Callback
    counter::Int
    stride::Int
    const f::Function
end
Callback(stride, f) = Callback(0, stride, f)
function (c::Callback)(state, loss, others...)
    c.counter += 1
    if c.counter % c.stride == 0
        @info "On step $(c.counter) loss = $(loss)"
        return false
    else
        return false
    end
end

We can now optimize our model using the ADAM optimizer.

julia
ps_trained, st_trained = let st=Ref(st), x=pixels, y=reshape(target_img, 1, sze*sze)

    optprob = Optimization.OptimizationProblem(
        Optimization.OptimizationFunction(
            function(ps, constants)
                loss, st[] = loss_function(x, y, ps, st[])
                loss
            end,
            Optimization.AutoEnzyme()
        ),
        ComponentArrays.ComponentVector{Float64}(ps)
    )

    solution = Optimization.solve(
        optprob,
        OptimizationOptimisers.Adam(),
        maxiters = 5_000,
    callback=Callback(100,()->nothing)
    )

    solution.u, st[]
end
((layer_1 = (weight = [1.0005848719794814 0.24795275886295454; 0.3743418635604734 0.2904557103156609; … ; -1.239666423342946 0.6853127591959539; 0.8887709342775812 -0.7608151956871815], bias = [0.5045865254772324, -0.41422566055439286, 0.4599924775806768, -0.5175657433013219, 0.09321064237695045, -0.5046559260647945, -0.13554087799856815, 0.3790151073162474, 0.28204722252240766, -0.3850923054766238, -0.10175961662526113, 0.24310116143418517, -0.4138331144629324, 0.25377584118529023, -0.6651197151721837, 0.38722442494704457, -0.03828902591278784, -0.4235416559543119, 0.5824904293735732, -0.5186838612216287]), layer_2 = (weight = [-0.07442846887568741 -0.3793986540222628 … -0.17375751049444568 0.1732900635089906; -0.10285897127409208 -0.19669260324805302 … -0.028178090811663867 -0.21714785795178576; … ; 0.12859645407684425 0.0006867310348069802 … 0.15612818400754813 0.10454533749185173; 0.15948077981932185 -0.07097260593217176 … -0.3384276017164609 -0.0405052850116075], bias = [-0.08775392977358966, 0.016515908057121522, 0.09717522048284148, -0.20844551538553316, 0.07205200864743207, -0.06721725335636437, -0.034048890381950345, -0.08643170960714602, -0.022337932215292483, 0.16583326801759485, -0.01906124288747484, -0.029640108878031965, -0.06953676792403249, 0.04227810117783419, 0.1654419664702043, -0.16935169434854547, 0.08984767872962829, -0.16257803355964448, 0.0673608085471147, 0.20286990715614184]), layer_3 = (weight = [-0.07933056659453981 -0.2460041795241698 … 0.10197652864271659 0.22510844709338151], bias = [-0.08899706616195209])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Let's plot the results of our optimization. and compare it to the target image.

julia
received_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps_trained, st_trained))
acc_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps, st))
loss_function(pixels, target_img, ps, st)
loss_function(pixels, target_img, ps_trained, st_trained)
using Printf
begin
    fig = Figure(size=(700, 300));
    heatmap!(Axis(fig[1,1], aspect=1, title="Target Image"), reshape(target_img, sze, sze))
    heatmap!(Axis(fig[1,2], aspect=1, title="Starting State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps, st)[1])))"), acc_intensity)
    heatmap!(Axis(fig[1,3], aspect=1, title="Fitted State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps_trained, st_trained)[1])))"), received_intensity)
    save("neural_net_results.png", fig)
end
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148


This page was generated using Literate.jl.