Neural Network Emission Model Example
This is a pedagogical example that serves as a proof of concept. We will build a simple Neural Network emission model to be ray traced and optimize with ADAM algorithm.
using Lux
using Krang
using Random
Random.seed!(123)
rng = Random.GLOBAL_RNG
Random._GLOBAL_RNG()
Our model will take in spacetime coordinates and return observed intensity value for a given pixel:
We will first define a simple emisison model that will be raytraced. The emission model will ge taken to be a simple fully connected Neural Network with 2 hidden layers
emission_model = Chain(
Dense(2 => 20, Lux.sigmoid),
Dense(20 => 20, Lux.sigmoid),
Dense(20 => 1, Lux.sigmoid)
)
ps, st = Lux.setup(rng, emission_model);
We will use 0.99 spin Kerr metric with an observer sitting at 20 degrees inclination with respect to the spin axis in this example. These parameters could be made to float in the optimization process. We will do this by defining an ImageModel
comprised of an emission layer and a raytracing layer. We will create a struct for our image model and store our emission model as a layer to be raytraced.
struct ImageModel{T <: Chain}
emission_layer::T
end
The models in Lux are functors that take in features, parameters and model state, and return the output and model state.
function (m::ImageModel)(x, ps, st)
metric = Krang.Kerr(0.99e0)
θo = Float64(20/180*π)
pixels = Krang.IntensityPixel.(Ref(metric), x[1,:], x[2,:], θo)
sze = unsafe_trunc(Int, sqrt(size(x)[2]))
coords = zeros(Float64, 2,sze*sze)
emission_vals = zeros(Float64, 1, sze*sze)
for n in 0:1
for i in 1:sze
for j in 1:sze
pix = pixels[i+(j-1)*sze]
α, β = Krang.screen_coordinate(pix)
T = typeof(α)
rs, _, ϕs = Krang.emission_coordinates_fast_light(pix, Float64(π/2), β > 0, n)[1:3]
xs = rs * cos(ϕs)
ys = rs * sin(ϕs)
if hypot(xs, ys) ≤ Krang.horizon(metric)
coords[1,i+(j-1)*sze] = zero(T)
coords[2,i+(j-1)*sze] = zero(T)
else
coords[1,i+(j-1)*sze] = xs
coords[2,i+(j-1)*sze] = ys
end
end
end
emission_vals .+= m.emission_layer(coords, ps, st)[1]
end
emission_vals,st
end
Lets create an 20x20 pixel image with a field of view of
image_model = ImageModel(emission_model)
sze = 20
ρmax = 10e0
pixels = zeros(Float64, 2, sze*sze)
for (iiter, i) in enumerate(range(-ρmax, ρmax, sze))
for (jiter, j) in enumerate(range(-ρmax, ρmax, sze))
pixels[1,iiter+(jiter-1)*sze] = Float64(i)
pixels[2,iiter+(jiter-1)*sze] = Float64(j)
end
end
Lets see what our emission model looks like before and after raytracing.
using CairoMakie
curr_theme = Theme(
Axis = (xticksvisible=false, xticklabelsvisible=false, yticksvisible=false, yticklabelsvisible=false,),
Heatmap =(colormap=:afmhot, ),
)
set_theme!(merge(curr_theme, theme_latexfonts()))
emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)
fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity)
heatmap!(Axis(fig[1,2], aspect=1, title="Image Model (Lensed Emission Model)"), received_intensity)
save("emission_model_and_target_model.png", fig)
This will be the image we will try to fit our model to.
target_img = reshape(received_intensity, 1, sze*sze);
Fitting the model
Lets fit our model using the normalized cross correlation as a kernel for our loss function.
using Enzyme
using Optimization
using OptimizationOptimisers
using StatsBase
using ComponentArrays
Enzyme.Compiler.RunAttributor[] = false
function mse(img1::Matrix{T}, img2::Matrix{T}) where T
mean(((img1 ./ sum(img1)) .- (img2 ./ sum(img2))) .^ 2)
end
function loss_function(pixels, y, ps, st)
y_pred, st = image_model(pixels, ps, st)
mse(y, y_pred), st
end
mse(target_img, target_img)
ps, st = Lux.setup(rng, emission_model);
image_model = ImageModel(emission_model);
emitted_intensity = reshape(emission_model(pixels, ps, st)[1], sze, sze)
received_intensity = reshape(image_model(pixels, ps, st)[1], sze, sze)
loss_function(pixels, target_img, ps, st)
fig = Figure();
heatmap!(Axis(fig[1,1], aspect=1, title="Emission Model"), emitted_intensity, colormap=:afmhot)
heatmap!(Axis(fig[1,2], aspect=1, title="Imgage Model (Lensed Emission Model)"), received_intensity, colormap=:afmhot)
save("emission_model_and_image_model.png", fig)
WARNING: using ComponentArrays.Axis in module Main conflicts with an existing identifier.
mutable struct Callback
counter::Int
stride::Int
const f::Function
end
Callback(stride, f) = Callback(0, stride, f)
function (c::Callback)(state, loss, others...)
c.counter += 1
if c.counter % c.stride == 0
@info "On step $(c.counter) loss = $(loss)"
return false
else
return false
end
end
ps_trained, st_trained = let st=Ref(st), x=pixels, y=reshape(target_img, 1, sze*sze)
optprob = Optimization.OptimizationProblem(
Optimization.OptimizationFunction(
function(ps, constants)
loss, st[] = loss_function(x, y, ps, st[])
loss
end,
Optimization.AutoEnzyme()
),
ComponentArrays.ComponentVector{Float64}(ps)
)
solution = Optimization.solve(
optprob,
OptimizationOptimisers.Adam(),
maxiters = 5_000,
callback=Callback(100,()->nothing)
)
solution.u, st[]
end
received_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps_trained, st_trained))
acc_intensity, st = ((x) -> (reshape(x[1], sze, sze), x[2]))(image_model(pixels, ps, st))
loss_function(pixels, target_img, ps, st)
loss_function(pixels, target_img, ps_trained, st_trained)
using Printf
begin
fig = Figure(size=(700, 300));
heatmap!(Axis(fig[1,1], aspect=1, title="Target Image"), reshape(target_img, sze, sze))
heatmap!(Axis(fig[1,2], aspect=1, title="Starting State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps, st)[1])))"), acc_intensity)
heatmap!(Axis(fig[1,3], aspect=1, title="Fitted State (loss=$(@sprintf("%0.2e", loss_function(pixels, target_img, ps_trained, st_trained)[1])))"), received_intensity)
save("neural_net_results.png", fig)
end
[ Info: On step 100 loss = 1.5882207599055176e-9
[ Info: On step 200 loss = 1.2093943598385204e-9
[ Info: On step 300 loss = 9.611184402425459e-10
[ Info: On step 400 loss = 7.738415581577792e-10
[ Info: On step 500 loss = 6.292009063377417e-10
[ Info: On step 600 loss = 5.157069395915786e-10
[ Info: On step 700 loss = 4.24593991261194e-10
[ Info: On step 800 loss = 3.5151101349676817e-10
[ Info: On step 900 loss = 2.9473785989042027e-10
[ Info: On step 1000 loss = 2.4997508765914876e-10
[ Info: On step 1100 loss = 2.1456974654309221e-10
[ Info: On step 1200 loss = 1.8654323205926308e-10
[ Info: On step 1300 loss = 1.6433911146982609e-10
[ Info: On step 1400 loss = 1.4671894640662876e-10
[ Info: On step 1500 loss = 1.3269596654754365e-10
[ Info: On step 1600 loss = 1.2148607326280858e-10
[ Info: On step 1700 loss = 1.1246918495087748e-10
[ Info: On step 1800 loss = 1.0515774180917044e-10
[ Info: On step 1900 loss = 9.917068885764458e-11
[ Info: On step 2000 loss = 9.42119382737621e-11
[ Info: On step 2100 loss = 9.005262965300225e-11
[ Info: On step 2200 loss = 8.651665065264868e-11
[ Info: On step 2300 loss = 8.346894736449101e-11
[ Info: On step 2400 loss = 8.080619372777296e-11
[ Info: On step 2500 loss = 7.844942530883769e-11
[ Info: On step 2600 loss = 7.633828258090623e-11
[ Info: On step 2700 loss = 7.442655271042437e-11
[ Info: On step 2800 loss = 7.267874427834422e-11
[ Info: On step 2900 loss = 7.10674734642396e-11
[ Info: On step 3000 loss = 6.957148079302454e-11
[ Info: On step 3100 loss = 6.817413329418442e-11
[ Info: On step 3200 loss = 6.686229736010723e-11
[ Info: On step 3300 loss = 6.562549279887417e-11
[ Info: On step 3400 loss = 6.445525899621512e-11
[ Info: On step 3500 loss = 6.334468034443793e-11
[ Info: On step 3600 loss = 6.22880308266488e-11
[ Info: On step 3700 loss = 6.128050750098734e-11
[ Info: On step 3800 loss = 6.031803018382387e-11
[ Info: On step 3900 loss = 5.939709037232203e-11
[ Info: On step 4000 loss = 5.851463678011365e-11
[ Info: On step 4100 loss = 5.766798811150064e-11
[ Info: On step 4200 loss = 5.685476612798187e-11
[ Info: On step 4300 loss = 5.607284386718261e-11
[ Info: On step 4400 loss = 5.532030521357729e-11
[ Info: On step 4500 loss = 5.4595413010803e-11
[ Info: On step 4600 loss = 5.389658363629323e-11
[ Info: On step 4700 loss = 5.322236649758969e-11
[ Info: On step 4800 loss = 5.2571427306180395e-11
[ Info: On step 4900 loss = 5.194253427644049e-11
[ Info: On step 5000 loss = 5.133454661188927e-11
This page was generated using Literate.jl.