Neural Network Emission Model Example
This is a pedagogical example that serves as a proof of concept. We will build a simple General Relativistic Neural Radiance Field (NeRF) model to be ray traced and optimize with the ADAM algorithm.
##Setup We will first import the necessary packages and set the random seed. Our emission model will be a neural network built with Lux.jl
using Lux
using Krang
using Random
Random.seed!(123)
rng = Random.GLOBAL_RNG
Random._GLOBAL_RNG()
Our model will take in spacetime coordinates and return observed intensity value for a given pixel: We will use 0.99 spin Kerr metric with an observer sitting at 20 degrees inclination with respect to the spin axis in this example. These parameters are fixed for this example, but could be made to vary in the optimization process.
Lets define an ImageModel
which will be comprised of an emission layer that we will raytrace. We will do this by first creating a struct to represent our image model that will store our emission model as a layer.
struct ImageModel{T <: Chain}
emission_layer::T
end
The models in Lux are functors that take in features, parameters and model state, and return the output and model state. Lets define the function associated with our ImageModel
type. We will assume that the emission is coming from emission that originates in the equatorial plane.
function (m::ImageModel)(x, ps, st)
metric = Krang.Kerr(0.99e0)
θo = Float64(20/180*π)
pixels = Krang.IntensityPixel.(Ref(metric), x[1,:], x[2,:], θo)
sze = unsafe_trunc(Int, sqrt(size(x)[2]))
coords = zeros(Float64, 2,sze*sze)
emission_vals = zeros(Float64, 1, sze*sze)
for n in 0:1
for i in 1:sze
for j in 1:sze
pix = pixels[i+(j-1)*sze]
α, β = Krang.screen_coordinate(pix)
T = typeof(α)
rs, ϕs = Krang.emission_coordinates_fast_light(pix, Float64(π/2), β > 0, n)[1:3]
xs = rs * cos(ϕs)
ys = rs * sin(ϕs)
if hypot(xs, ys) ≤ Krang.horizon(metric)
coords[1,i+(j-1)*sze] = zero(T)
coords[2,i+(j-1)*sze] = zero(T)
else
coords[1,i+(j-1)*sze] = xs
coords[2,i+(j-1)*sze] = ys
end
end
end
emission_vals .+= m.emission_layer(coords, ps, st)[1]
end
emission_vals,st
end
Lets define an emisison layer for our model as a simple fully connected neural network with 2 hidden layers. The emission layer will take in 2D coordinates on an equatorial disk in the bulk spacetime and return a scalar intensity value.
emission_model = Chain(
Dense(2 => 20, Lux.sigmoid),
Dense(20 => 20, Lux.sigmoid),
Dense(20 => 1, Lux.sigmoid)
)
ps, st = Lux.setup(rng, emission_model); # Get the emission model parameters and state
We can now create an image model with our emission layer.
image_model = ImageModel(emission_model)
# Plotting the model
Main.var"Main".ImageModel{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 20, σ), layer_2 = Dense(20 => 20, σ), layer_3 = Dense(20 => 1, σ)), nothing))
Lets create an 20x20 pixel image of the image_model
with a field of view of
sze = 20
ρmax = 10e0
pixels = zeros(Float64, 2, sze*sze)
for (iiter, i) in enumerate(range(-ρmax, ρmax, sze))
for (jiter, j) in enumerate(range(-ρmax, ρmax, sze))
pixels[1,iiter+(jiter-1)*sze] = Float64(i)
pixels[2,iiter+(jiter-1)*sze] = Float64(j)
end
end
We can see the effects of raytracing on emission in the bulk spacetime by plotting an image of the emission model and the image model.
using CairoMakie
curr_theme = Theme(
Axis = (xticksvisible=false, xticklabelsvisible=false, yticksvisible=false, yticklabelsvisible=false,),
Heatmap =(colormap=:afmhot, ),
)
set_theme!(merge(curr_theme, theme_latexfonts()))
emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)
fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity)
heatmap!(Axis(fig[1,2], aspect=1, title="Image Model (Lensed Emission Model)"), received_intensity)
save("emission_model_and_target_model.png", fig)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
Fitting the NeRF model
This will be a toy example showing the mechanics of fitting our ImageModel to a target image using the normalized cross correlation as a kernel for our loss function. This will be the image we will try to fit our model to.
target_img = reshape(received_intensity, 1, sze*sze);
Lets fit our model using the normalized cross correlation as a kernel for our loss function.
using Enzyme
using Optimization
using OptimizationOptimisers
using StatsBase
using ComponentArrays
Enzyme.Compiler.RunAttributor[] = false
function mse(img1::Matrix{T}, img2::Matrix{T}) where T
mean(((img1 ./ sum(img1)) .- (img2 ./ sum(img2))) .^ 2)
end
function loss_function(pixels, y, ps, st)
y_pred, st = image_model(pixels, ps, st)
mse(y, y_pred), st
end
mse(target_img, target_img)
ps, st = Lux.setup(rng, emission_model);
image_model = ImageModel(emission_model);
emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)
loss_function(pixels, target_img, ps, st)
fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity, colormap=:afmhot)
heatmap!(Axis(fig[1,2], aspect=1, title="Imgage Model (Lensed Emission Model)"), received_intensity, colormap=:afmhot)
save("emission_model_and_image_model.png", fig)
WARNING: using ComponentArrays.Axis in module Main conflicts with an existing identifier.
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
Lets define callback function to print the loss as we optimize our model.
mutable struct Callback
counter::Int
stride::Int
const f::Function
end
Callback(stride, f) = Callback(0, stride, f)
function (c::Callback)(state, loss, others...)
c.counter += 1
if c.counter % c.stride == 0
@info "On step $(c.counter) loss = $(loss)"
return false
else
return false
end
end
We can now optimize our model using the ADAM optimizer.
ps_trained, st_trained = let st=Ref(st), x=pixels, y=reshape(target_img, 1, sze*sze)
optprob = Optimization.OptimizationProblem(
Optimization.OptimizationFunction(
function(ps, constants)
loss, st[] = loss_function(x, y, ps, st[])
loss
end,
Optimization.AutoEnzyme()
),
ComponentArrays.ComponentVector{Float64}(ps)
)
solution = Optimization.solve(
optprob,
OptimizationOptimisers.Adam(),
maxiters = 5_000,
callback=Callback(100,()->nothing)
)
solution.u, st[]
end
((layer_1 = (weight = [1.0005848719794814 0.24795275886295454; 0.3743418635604734 0.2904557103156609; … ; -1.239666423342946 0.6853127591959539; 0.8887709342775812 -0.7608151956871815], bias = [0.5045865254772324, -0.41422566055439286, 0.4599924775806768, -0.5175657433013219, 0.09321064237695045, -0.5046559260647945, -0.13554087799856815, 0.3790151073162474, 0.28204722252240766, -0.3850923054766238, -0.10175961662526113, 0.24310116143418517, -0.4138331144629324, 0.25377584118529023, -0.6651197151721837, 0.38722442494704457, -0.03828902591278784, -0.4235416559543119, 0.5824904293735732, -0.5186838612216287]), layer_2 = (weight = [-0.07442846887568741 -0.3793986540222628 … -0.17375751049444568 0.1732900635089906; -0.10285897127409208 -0.19669260324805302 … -0.028178090811663867 -0.21714785795178576; … ; 0.12859645407684425 0.0006867310348069802 … 0.15612818400754813 0.10454533749185173; 0.15948077981932185 -0.07097260593217176 … -0.3384276017164609 -0.0405052850116075], bias = [-0.08775392977358966, 0.016515908057121522, 0.09717522048284148, -0.20844551538553316, 0.07205200864743207, -0.06721725335636437, -0.034048890381950345, -0.08643170960714602, -0.022337932215292483, 0.16583326801759485, -0.01906124288747484, -0.029640108878031965, -0.06953676792403249, 0.04227810117783419, 0.1654419664702043, -0.16935169434854547, 0.08984767872962829, -0.16257803355964448, 0.0673608085471147, 0.20286990715614184]), layer_3 = (weight = [-0.07933056659453981 -0.2460041795241698 … 0.10197652864271659 0.22510844709338151], bias = [-0.08899706616195209])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Let's plot the results of our optimization. and compare it to the target image.
received_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps_trained, st_trained))
acc_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps, st))
loss_function(pixels, target_img, ps, st)
loss_function(pixels, target_img, ps_trained, st_trained)
using Printf
begin
fig = Figure(size=(700, 300));
heatmap!(Axis(fig[1,1], aspect=1, title="Target Image"), reshape(target_img, sze, sze))
heatmap!(Axis(fig[1,2], aspect=1, title="Starting State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps, st)[1])))"), acc_intensity)
heatmap!(Axis(fig[1,3], aspect=1, title="Fitted State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps_trained, st_trained)[1])))"), received_intensity)
save("neural_net_results.png", fig)
end
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:148
This page was generated using Literate.jl.