03-flux-ring2
Basic Introduction to Machine Learning: 03-flux-ring2
Illustrate basic artificial NN training using Julia's Flux library
- Jeff Fessler, University of Michigan
- 2018-10-18 Julia 1.0.1 original
- 2023-01-29 Julia 1.8.5 update
This page was generated from a single Julia file: 03-flux-ring2.jl.
In any such Julia documentation, you can access the source code using the "Edit on GitHub" link in the top right.
The corresponding notebook can be viewed in nbviewer here: 03-flux-ring2.ipynb
, and opened in binder here: 03-flux-ring2.ipynb
.
Setup
Packages needed here.
using LinearAlgebra: norm
using Random: seed!
using LaTeXStrings # pretty plot labels
import Flux # Julia package for deep learning
using Flux: Dense, Chain, relu, params, Adam, throttle, mse
using Plots: Plot, plot, plot!, scatter!, default, gui
using MIRTjim: jim, prompt
using InteractiveUtils: versioninfo
default(markersize=5, markerstrokecolor=:auto, label="")
default(legendfontsize=10, labelfontsize=12, tickfontsize=10)
The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.
isinteractive() ? jim(:prompt, true) : prompt(:draw);
Generate (synthetic) data
Function to simulate data that cannot be linearly separated
function simdata(; n1 = 40, n2 = 120, σ1 = 0.8, σ2 = 2, r2 = 3)
data1 = σ1 * randn(2,n1)
rad2 = r2 .+ σ2*rand(1,n2)
ang2 = rand(1,n2) * 2π
data2 = [rad2 .* cos.(ang2); rad2 .* sin.(ang2)]
X = [data1 data2] # 2 × N = n1+n2
Y = [-ones(1,n1) ones(1,n2)] # 1 × N
@assert size(X,2) == size(Y,2)
return (X,Y)
end;
Scatter plot routine
function datasplit(X,Y)
data1 = X[:,findall(==(-1), vec(Y))]
data2 = X[:,findall(==(1), vec(Y))]
return (data1, data2)
end;
function plot_data(X,Y; kwargs...)
data1, data2 = datasplit(X,Y)
plot(xlabel=L"x_1", ylabel=L"x_2"; kwargs...)
scatter!(data1[1,:], data1[2,:], color=:blue, label="class1")
scatter!(data2[1,:], data2[2,:], color=:red, label="class2")
plot!(xlim=[-1,1]*6, ylim=[-1,1]*6)
plot!(aspect_ratio=1, xtick=-6:6:6, ytick=-6:6:6)
end;
Training data
seed!(0)
(Xtrain, Ytrain) = simdata()
plot_data(Xtrain,Ytrain)
prompt()
Validation and testing data
(Xvalid, Yvalid) = simdata()
(Xtest, Ytest) = simdata()
p1 = plot_data(Xvalid, Yvalid; title="Validation")
p2 = plot_data(Xtest, Ytest; title="Test")
plot(p1, p2)
prompt()
Train simple MLP model
A multilayer perceptron model (MLP) consists of multiple fully connected layers.
Train a basic NN model with 1 hidden layer
if !@isdefined(state)
nhidden = 10 # neurons in hidden layer
model = Chain(Dense(2,nhidden,relu), Dense(nhidden,1))
loss3(model, x, y) = mse(model(x), y) # admittedly silly choice
iters = 10000
dataset = Base.Iterators.repeated((Xtrain, Ytrain), iters)
state = Flux.setup(Adam(), model)
Flux.train!(loss3, model, dataset, state)
end;
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(2 => 10, relu) # 30 parameters
│ summary(x) = "2×160 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/uRn8o/src/layers/stateless.jl:60
Plot results after training
function display_decision_boundaries(
X, Y, model;
x1range = range(-1,1,101)*6, x2range = x1range, τ = 0.0,
kwargs...,
)
data1,data2 = datasplit(X,Y)
D = [model([x1;x2])[1] for x1 in x1range, x2 in x2range]
jim(x1range, x2range, sign.(D.-τ); color=:grays, kwargs...)
scatter!(data1[1,:], data1[2,:], color = :blue, label = "Class 1")
scatter!(data2[1,:], data2[2,:], color = :red, label = "Class 2")
plot!(xlabel=L"x_1", ylabel=L"x_2")
plot!(xlim=[-1,1]*6, ylim=[-1,1]*6)
plot!(aspect_ratio=1, xtick=-6:6:6, ytick=-6:6:6)
end;
Examine classification accuracy
classacc(model, x, y::Number) = sign(model(x)[1]) == y
classacc(model, x, y::AbstractArray) = classacc(model, x, y[1])
function classacc(X, Y)
tmp = zip(eachcol(X), eachcol(Y))
tmp = count(xy -> classacc(model, xy...), tmp)
tmp = tmp / size(Y,2) * 100
return round(tmp, digits=3)
end
lossXY = loss3(model, Xtrain, Ytrain)
display_decision_boundaries(Xtrain, Ytrain, model)
plot!(title = "Train: MSE Loss = $(round(lossXY,digits=4)), " *
"Class=$(classacc(Xtrain, Ytrain)) %")
prompt()
Train while validating
Create a basic NN model with 1 hidden layer. This version evaluates performance every epoch for both the training data and validation data.
nhidden = 10 # neurons in hidden layer
layer2 = Dense(2, nhidden, relu)
layer3 = Dense(nhidden, 1)
model = Chain(layer2, layer3)
loss3(model, x, y) = mse(model(x), y)
nouter = 80 # of outer iterations, for showing loss
losstrain = zeros(nouter+1)
lossvalid = zeros(nouter+1)
iters = 100
losstrain[1] = loss3(model, Xtrain, Ytrain)
lossvalid[1] = loss3(model, Xvalid, Yvalid)
for io in 1:nouter
# @show io
idataset = Base.Iterators.repeated((Xtrain, Ytrain), iters)
istate = Flux.setup(Adam(), model)
Flux.train!(loss3, model, idataset, istate)
losstrain[io+1] = loss3(model, Xtrain, Ytrain)
lossvalid[io+1] = loss3(model, Xvalid, Yvalid)
if (io ≤ 6) && false # set to true to make images
display_decision_boundaries(Xtrain, Ytrain, model)
plot!(title="$(io*iters) epochs")
# savefig("ml-flux-$(io*iters).pdf")
end
end
loss_train = loss3(model, Xtrain, Ytrain)
loss_valid = loss3(model, Xvalid, Yvalid)
p1 = display_decision_boundaries(Xtrain, Ytrain, model;
title="Train:\nMSE Loss = $(round(loss_train,digits=4))\n" *
"Class=$(classacc(Xtrain, Ytrain)) %",
)
p2 = display_decision_boundaries(Xvalid, Yvalid, model;
title="Valid:\nMSE Loss = $(round(loss_valid,digits=4))\n" *
"Class=$(classacc(Xvalid, Yvalid)) %",
)
p12 = plot(p1, p2)
prompt()
Show MSE loss vs epoch
ivalid = findfirst(>(0), diff(lossvalid))
plot(xlabel="epoch/$(iters)", ylabel="RMSE loss", ylim=[0,1.05*maximum(losstrain)])
plot!(0:nouter, sqrt.(losstrain), label="training loss", marker=:o, color=:green)
plot!(0:nouter, sqrt.(lossvalid), label="validation loss", marker=:+, color=:violet)
plot!(xticks = [0, ivalid, nouter])
prompt()
Show response of (trained) first hidden layer
x1range = range(-1,1,31) * 6
x2range = range(-1,1,33) * 6
layer2data = [layer2([x1;x2])[n] for x1 = x1range, x2 = x2range, n in 1:nhidden]
pl = Array{Plot}(undef, nhidden)
for n in 1:nhidden
ptmp = jim(x1range, x2range, layer2data[:,:,n], color=:cividis,
xtick=-6:6:6, ytick=-6:6:6,
)
if n == 7
plot!(ptmp, xlabel=L"x_1", ylabel=L"x_2")
end
pl[n] = ptmp
end
plot(pl[1:9]...)
prompt()
Reproducibility
This page was generated with the following version of Julia:
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
11-element Vector{SubString{String}}:
"Julia Version 1.11.6"
"Commit 9615af0f269 (2025-07-09 12:58 UTC)"
"Build Info:"
" Official https://julialang.org/ release"
"Platform Info:"
" OS: Linux (x86_64-linux-gnu)"
" CPU: 4 × AMD EPYC 7763 64-Core Processor"
" WORD_SIZE: 64"
" LLVM: libLLVM-16.0.6 (ORCJIT, znver3)"
"Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
""
And with the following package versions
import Pkg; Pkg.status()
Status `~/work/ismrm_ml2/ismrm_ml2/docs/Project.toml`
[31c24e10] Distributions v0.25.120
[e30172f5] Documenter v1.14.1
[587475ba] Flux v0.16.5
[b964fa9f] LaTeXStrings v1.4.0
[98b081ad] Literate v2.20.1
[170b2178] MIRTjim v0.25.0
[eb30cadb] MLDatasets v0.7.18
[91a5bcdd] Plots v1.40.18
[2913bbd2] StatsBase v0.34.6
[1986cc42] Unitful v1.24.0
[ef84fa70] ismrm_ml2 v0.0.1 `~/work/ismrm_ml2/ismrm_ml2`
[b77e0a4c] InteractiveUtils v1.11.0
[37e2e46d] LinearAlgebra v1.11.0
[9a3f8284] Random v1.11.0
This page was generated using Literate.jl.