Principal component analysis (PCA) illustration

This example illustrates PCA of hand-written digit data.

Add the Julia packages used in this demo. Change false to true in the following code block if you are using any of the following packages for the first time.

if false
    import Pkg

Tell Julia to use the following packages. Run Pkg.add() in the preceding code block first, if needed.

using InteractiveUtils: versioninfo
using LaTeXStrings # nice plot labels
using LinearAlgebra: svd
using MIRTjim: jim, prompt
using MLDatasets: MNIST
using Plots: default, gui, plot, savefig, scatter, scatter!
using Plots.PlotMeasures: px
using Random: seed!, randperm
using StatsBase: mean
default(); default(markersize=5, markerstrokecolor=:auto, label="",
 tickfontsize=14, labelfontsize=18, legendfontsize=18, titlefontsize=18)

The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.

isinteractive() ? jim(:prompt, true) : prompt(:draw);

Load data

Read the MNIST data for some handwritten digits. This code will automatically download the data from web if needed and put it in a folder like: ~/.julia/datadeps/MNIST/.

if !@isdefined(data)
    digitn = (0, 1, 4) # which digits to use
    isinteractive() || (ENV["DATADEPS_ALWAYS_ACCEPT"] = true) # avoid prompt
    dataset = MNIST(Float32, :train)
    nrep = 60 # how many of each digit
    # function to extract the 1st `nrep` examples of digit n:
    data = n -> dataset.features[:,:,findall(==(n), dataset.targets)[1:nrep]]
    data = cat(dims=4, data.(digitn)...)
    labels = vcat([fill(d, nrep) for d in digitn]...) # to check later
    nx, ny, nrep, ndigit = size(data)
    data = data[:,2:ny,:,:] # make images non-square to force debug
    ny = size(data,2)
    data = reshape(data, nx, ny, :)
    tmp = randperm(nrep * ndigit)
    data = data[:,:,tmp]
    labels = labels[tmp]
    size(data) # (nx, ny, nrep*ndigit)
(28, 27, 180)

Look at "unlabeled" image data prior to unsupervised dimensionality reduction

pd = jim(data, "Data"; size=(600,300), cticks=0:1,
# xticks = false, yticks = false, tickfontsize=12, right_margin=-5px, # book
# savefig(pd, "pca-data.pdf")
Example block output

Compute sample average of data

μ = mean(data, dims=3)
pm = jim(μ, "Mean")
# savefig(pm, "pca-mean.pdf")
Example block output

Scree plot

Show singular values.

data2 = reshape(data .- μ, :, nrep*ndigit) # (nx*ny, nrep*ndigit)
f = svd(data2)
ps = scatter(f.S; title="Scree plot", widen=true,
 xaxis = (L"k", (1,ndigit*nrep), [1, 6, ndigit*nrep]),
 yaxis = (L"σ_k", (0,48), [0, 0, 47]),
# savefig(ps, "pca-scree.pdf")
Example block output

Principal components

The first 6 or so singular values are notably larger than the rest, but for simplicity of visualization here we just use the first two components.

K = 2
Q = f.U[:,1:K]
pq = jim(reshape(Q, nx,ny,:), "First $K singular components"; size=(600,300))
# savefig(pq, "pca-q.pdf")
Example block output

Now use the learned subspace basis Q to perform dimensionality reduction. The resulting coefficients are called "factors" in factor analysis and "scores" in PCA.

z = Q' * data2 # (K, nrep*ndigit)
2×180 Matrix{Float32}:
 -3.68769  6.25571  -3.73453  -1.93295  …  -3.69552  6.81066   0.995399
  3.45898  1.48057   2.90521  -3.54616      2.52211  2.24464  -3.96443

PCA scores

The three digits are remarkably well separated even in just two dimensions.

pz = plot(title = "Score plot for $ndigit digits",
 xaxis=("Score 1", (-5,8), -3:3:6),
 yaxis=("Score 2", (-6,4), -4:4:4),
markers = (:circle, :diamond, :square)
for (i,d) in enumerate(digitn)
    scatter!(z[1,labels .== d], z[2,labels .== d], label="Digit $d", marker=markers[i])
# savefig(pz, "pca-score.pdf")
Example block output


This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
11-element Vector{SubString{String}}:
 "Julia Version 1.11.2"
 "Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"

And with the following package versions

This page was generated using Literate.jl.