03-flux-ring2

Basic Introduction to Machine Learning: 03-flux-ring2

Illustrate basic artificial NN training using Julia's Flux library

  • Jeff Fessler, University of Michigan
  • 2018-10-18 Julia 1.0.1 original
  • 2023-01-29 Julia 1.8.5 update

This page was generated from a single Julia file: 03-flux-ring2.jl.

In any such Julia documentation, you can access the source code using the "Edit on GitHub" link in the top right.

The corresponding notebook can be viewed in nbviewer here: 03-flux-ring2.ipynb, and opened in binder here: 03-flux-ring2.ipynb.

Setup

Packages needed here.

using LinearAlgebra: norm
using Random: seed!
using LaTeXStrings # pretty plot labels
import Flux # Julia package for deep learning
using Flux: Dense, Chain, relu, params, Adam, throttle, mse
using Plots: Plot, plot, plot!, scatter!, default, gui
using MIRTjim: jim, prompt
using InteractiveUtils: versioninfo

default(markersize=5, markerstrokecolor=:auto, label="")
default(legendfontsize=10, labelfontsize=12, tickfontsize=10)

The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.

isinteractive() ? jim(:prompt, true) : prompt(:draw);

Generate (synthetic) data

Function to simulate data that cannot be linearly separated

function simdata(; n1 = 40, n2 = 120, σ1 = 0.8, σ2 = 2, r2 = 3)
    data1 = σ1 * randn(2,n1)
    rad2 = r2 .+ σ2*rand(1,n2)
    ang2 = rand(1,n2) * 2π
    data2 = [rad2 .* cos.(ang2); rad2 .* sin.(ang2)]
    X = [data1 data2] # 2 × N = n1+n2
    Y = [-ones(1,n1) ones(1,n2)] # 1 × N
    @assert size(X,2) == size(Y,2)
    return (X,Y)
end;

Scatter plot routine

function datasplit(X,Y)
    data1 = X[:,findall(==(-1), vec(Y))]
    data2 = X[:,findall(==(1), vec(Y))]
    return (data1, data2)
end;

function plot_data(X,Y; kwargs...)
    data1, data2 = datasplit(X,Y)
    plot(xlabel=L"x_1", ylabel=L"x_2"; kwargs...)
    scatter!(data1[1,:], data1[2,:], color=:blue, label="class1")
    scatter!(data2[1,:], data2[2,:], color=:red, label="class2")
    plot!(xlim=[-1,1]*6, ylim=[-1,1]*6)
    plot!(aspect_ratio=1, xtick=-6:6:6, ytick=-6:6:6)
end;

Training data

seed!(0)
(Xtrain, Ytrain) = simdata()
plot_data(Xtrain,Ytrain)
Example block output
prompt()

Validation and testing data

(Xvalid, Yvalid) = simdata()
(Xtest, Ytest) = simdata()

p1 = plot_data(Xvalid, Yvalid; title="Validation")
p2 = plot_data(Xtest, Ytest; title="Test")
plot(p1, p2)
Example block output
prompt()

Train simple MLP model

A multilayer perceptron model (MLP) consists of multiple fully connected layers.

Train a basic NN model with 1 hidden layer

if !@isdefined(state)
    nhidden = 10 # neurons in hidden layer
    model = Chain(Dense(2,nhidden,relu), Dense(nhidden,1))
    loss3(model, x, y) = mse(model(x), y) # admittedly silly choice
    iters = 10000
    dataset = Base.Iterators.repeated((Xtrain, Ytrain), iters)
    state = Flux.setup(Adam(), model)
    Flux.train!(loss3, model, dataset, state)
end;
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(2 => 10, relu)  # 30 parameters
│   summary(x) = "2×160 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/uRn8o/src/layers/stateless.jl:60

Plot results after training

function display_decision_boundaries(
    X, Y, model;
    x1range = range(-1,1,101)*6, x2range = x1range, τ = 0.0,
    kwargs...,
)
    data1,data2 = datasplit(X,Y)
    D = [model([x1;x2])[1] for x1 in x1range, x2 in x2range]
    jim(x1range, x2range, sign.(D.-τ); color=:grays, kwargs...)
    scatter!(data1[1,:], data1[2,:], color = :blue, label = "Class 1")
    scatter!(data2[1,:], data2[2,:], color = :red, label = "Class 2")
    plot!(xlabel=L"x_1", ylabel=L"x_2")
    plot!(xlim=[-1,1]*6, ylim=[-1,1]*6)
    plot!(aspect_ratio=1, xtick=-6:6:6, ytick=-6:6:6)
end;

Examine classification accuracy

classacc(model, x, y::Number) = sign(model(x)[1]) == y
classacc(model, x, y::AbstractArray) = classacc(model, x, y[1])
function classacc(X, Y)
    tmp = zip(eachcol(X), eachcol(Y))
    tmp = count(xy -> classacc(model, xy...), tmp)
    tmp = tmp / size(Y,2) * 100
    return round(tmp, digits=3)
end

lossXY = loss3(model, Xtrain, Ytrain)
display_decision_boundaries(Xtrain, Ytrain, model)
plot!(title = "Train: MSE Loss = $(round(lossXY,digits=4)), " *
    "Class=$(classacc(Xtrain, Ytrain)) %")
Example block output
prompt()

Train while validating

Create a basic NN model with 1 hidden layer. This version evaluates performance every epoch for both the training data and validation data.

nhidden = 10 # neurons in hidden layer
layer2 = Dense(2, nhidden, relu)
layer3 = Dense(nhidden, 1)
model = Chain(layer2, layer3)
loss3(model, x, y) = mse(model(x), y)

nouter = 80 # of outer iterations, for showing loss
losstrain = zeros(nouter+1)
lossvalid = zeros(nouter+1)

iters = 100
losstrain[1] = loss3(model, Xtrain, Ytrain)
lossvalid[1] = loss3(model, Xvalid, Yvalid)

for io in 1:nouter
    # @show io
    idataset = Base.Iterators.repeated((Xtrain, Ytrain), iters)
    istate = Flux.setup(Adam(), model)
    Flux.train!(loss3, model, idataset, istate)
    losstrain[io+1] = loss3(model, Xtrain, Ytrain)
    lossvalid[io+1] = loss3(model, Xvalid, Yvalid)
    if (io ≤ 6) && false # set to true to make images
        display_decision_boundaries(Xtrain, Ytrain, model)
        plot!(title="$(io*iters) epochs")
        # savefig("ml-flux-$(io*iters).pdf")
    end
end

loss_train = loss3(model, Xtrain, Ytrain)
loss_valid = loss3(model, Xvalid, Yvalid)
p1 = display_decision_boundaries(Xtrain, Ytrain, model;
 title="Train:\nMSE Loss = $(round(loss_train,digits=4))\n" *
    "Class=$(classacc(Xtrain, Ytrain)) %",
)
p2 = display_decision_boundaries(Xvalid, Yvalid, model;
 title="Valid:\nMSE Loss = $(round(loss_valid,digits=4))\n" *
    "Class=$(classacc(Xvalid, Yvalid)) %",
)
p12 = plot(p1, p2)
Example block output
prompt()

Show MSE loss vs epoch

ivalid = findfirst(>(0), diff(lossvalid))
plot(xlabel="epoch/$(iters)", ylabel="RMSE loss", ylim=[0,1.05*maximum(losstrain)])
plot!(0:nouter, sqrt.(losstrain), label="training loss", marker=:o, color=:green)
plot!(0:nouter, sqrt.(lossvalid), label="validation loss", marker=:+, color=:violet)
plot!(xticks = [0, ivalid, nouter])
Example block output
prompt()

Show response of (trained) first hidden layer

x1range = range(-1,1,31) * 6
x2range = range(-1,1,33) * 6
layer2data = [layer2([x1;x2])[n] for x1 = x1range, x2 = x2range, n in 1:nhidden]

pl = Array{Plot}(undef, nhidden)
for n in 1:nhidden
    ptmp = jim(x1range, x2range, layer2data[:,:,n], color=:cividis,
        xtick=-6:6:6, ytick=-6:6:6,
    )
    if n == 7
        plot!(ptmp, xlabel=L"x_1", ylabel=L"x_2")
    end
    pl[n] = ptmp
end
plot(pl[1:9]...)
Example block output
prompt()

Reproducibility

This page was generated with the following version of Julia:

io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
11-element Vector{SubString{String}}:
 "Julia Version 1.11.6"
 "Commit 9615af0f269 (2025-07-09 12:58 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/ismrm_ml2/ismrm_ml2/docs/Project.toml`
  [31c24e10] Distributions v0.25.120
  [e30172f5] Documenter v1.14.1
  [587475ba] Flux v0.16.5
  [b964fa9f] LaTeXStrings v1.4.0
  [98b081ad] Literate v2.20.1
  [170b2178] MIRTjim v0.25.0
  [eb30cadb] MLDatasets v0.7.18
  [91a5bcdd] Plots v1.40.18
  [2913bbd2] StatsBase v0.34.6
  [1986cc42] Unitful v1.24.0
  [ef84fa70] ismrm_ml2 v0.0.1 `~/work/ismrm_ml2/ismrm_ml2`
  [b77e0a4c] InteractiveUtils v1.11.0
  [37e2e46d] LinearAlgebra v1.11.0
  [9a3f8284] Random v1.11.0

This page was generated using Literate.jl.