Spectral clustering

This example illustrates spectral clustering via normalized graph Laplacian applied to hand-written digits.

This page comes from a single Julia file: spectral-cluster.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: spectral-cluster.ipynb, or open it in binder here: spectral-cluster.ipynb.

Setup

Add the Julia packages used in this demo. Change false to true in the following code block if you are using any of the following packages for the first time.

if false
    import Pkg
    Pkg.add([
        "Clustering"
        "InteractiveUtils"
        "LaTeXStrings"
        "LinearAlgebra"
        "MIRTjim"
        "MLDatasets"
        "Plots"
        "Random"
        "StatsBase"
    ])
end

Tell Julia to use the following packages. Run Pkg.add() in the preceding code block first, if needed.

using Clustering: kmeans
using InteractiveUtils: versioninfo
using LaTeXStrings # pretty plot labels
using LinearAlgebra: I, norm, Diagonal, eigen
using MIRTjim: jim, prompt
using MLDatasets: MNIST
using Plots: default, gui, plot, scatter, plot!, scatter!
using Random: seed!, randperm
using StatsBase: mean
default(); default(markersize=5, markerstrokecolor=:auto, label="")

The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.

isinteractive() ? jim(:prompt, true) : prompt(:draw);

Load data

Read the MNIST data for some handwritten digits. This code will automatically download the data from web if needed and put it in a folder like: ~/.julia/datadeps/MNIST/.

if !@isdefined(data)
    digitn = (0, 1, 3) # which digits to use
    isinteractive() || (ENV["DATADEPS_ALWAYS_ACCEPT"] = true) # avoid prompt
    dataset = MNIST(Float32, :train)
    nrep = 30
    # function to extract the 1st 1000 examples of digit n:
    data = n -> dataset.features[:,:,findall(==(n), dataset.targets)[1:nrep]]
    data = cat(dims=4, data.(digitn)...)
    labels = vcat([fill(d, nrep) for d in digitn]...) # to check later
    nx, ny, nrep, ndigit = size(data)
    data = data[:,2:ny,:,:] # make images non-square to force debug
    ny = size(data,2)
    data = reshape(data, nx, ny, :)
    seed!(0)
    tmp = randperm(nrep * ndigit)
    data = data[:,:,tmp]
    labels = labels[tmp]
    size(data) # (nx, ny, nrep*ndigit)
end
(28, 27, 90)

Look at "unlabeled" image data for unsupervised clustering

jim(data)
# savefig("spectral-cluster-data.pdf")
Example block output

Choose similarity function

σ = 2^-2 # tuning parameter
sfun(x,z) = exp(-norm(x-z)^2/nx/ny/σ^2)
sfun (generic function with 1 method)

Weight matrix

slices = eachslice(data, dims=3)
W = [sfun(x,z) for x in slices, z in slices]
pw = jim(W, "weight matrix W")
Example block output

Degree matrix

D = Diagonal(vec(sum(W; dims=2)))
90×90 LinearAlgebra.Diagonal{Float64, Vector{Float64}}:
 17.9375   ⋅         ⋅        ⋅      …    ⋅        ⋅        ⋅        ⋅ 
   ⋅      7.44538    ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅       11.2192    ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅      13.4635       ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅      …    ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
  ⋮                                  ⋱                             
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅      …    ⋅        ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅         17.5636    ⋅        ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅      11.1902    ⋅        ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅      20.0574    ⋅ 
   ⋅       ⋅         ⋅        ⋅           ⋅        ⋅        ⋅      13.0212

Normalized graph Laplacian

L = I - inv(D) * W
jim(L, "Normalized graph Laplacian L")
Example block output

Eigendecomposition and eigenvalues

eig = eigen(L)
pe = scatter(eig.values, xlabel = L"k", ylabel="Eigenvalues")
Example block output
prompt()

Apply k-means++ to eigenvectors

K = length(digitn) # try using the known number of digits
Y = eig.vectors[:,1:K]'
r3 = kmeans(Y, K)
Clustering.KmeansResult{Matrix{Float64}, Float64, Int64}([-0.10540925533894602 -0.1054092553389459 -0.10540925533894589; -0.05965510469889494 0.1729140876850089 0.09040966863417839; -0.017795041027864646 -0.20774078611379543 0.08713850947870803], [1, 2, 1, 3, 1, 2, 3, 3, 1, 3  …  3, 3, 3, 1, 1, 1, 1, 3, 1, 3], [0.0007758797387535099, 0.003023539930957192, 0.0010693793212867252, 0.007498893983837671, 0.005581144560563932, 0.014482220355212155, 0.002521383348555749, 0.01888474732946329, 0.00130221978281745, 0.009481468843814317  …  0.00324644904122659, 0.004346361864787832, 0.009033330271250035, 0.0012124926826523816, 0.0009452103438290921, 0.019635139375460846, 0.000686945615129525, 0.002536207561285468, 0.0013602675923303012, 0.016262944799841257], [45, 12, 33], [45, 12, 33], 0.42862910914519925, 5, true)

Confusion matrix using class assignments from kmeans++

label_list = unique(labels)

result = zeros(Int, K, length(label_list))
for k in 1:K # each cluster
    rck = r3.assignments .== k
    for (j,l) in enumerate(label_list)
        result[k,j] = count(rck .& (l .== labels))
    end
end
result
3×3 Matrix{Int64}:
 30   3  12
  0  12   0
  0  15  18

Visualize the clustered digits

p3 = jim(
 [jim(data[:,:,r3.assignments .== k], "Class $k"; prompt=false) for k in 1:K]...
)
Example block output

The clustering here seems only so-so, at least from the digit classification point of view. Each of these digits lives reasonably close to a manifold, and apparently the simply Gaussian similarity function used here does not adequately capture within-manifold similarities.

However, there is no reason to think that it is optimal to use the same number of classes as digits. Let's try again using more classes (larger $K$).

K = 9
Y = eig.vectors[:,1:K]'
r9 = kmeans(Y, K)
p9 = jim(
 [jim(data[:,:,r9.assignments .== k], "Class $k"; prompt=false) for k in 1:K]...
)
Example block output

Now there is somewhat more consistency between images in the same class,

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
 "Julia Version 1.10.1"
 "Commit 7790d6f0641 (2024-02-13 20:41 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LIBM: libopenlibm"
 "  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/book-la-demo/book-la-demo/docs/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [aaaa29a8] Clustering v0.15.7
  [35d6a980] ColorSchemes v3.24.0
  [3da002f7] ColorTypes v0.11.4
⌅ [c3611d14] ColorVectorSpace v0.9.10
  [717857b8] DSP v0.7.9
  [72c85766] Demos v0.1.0 `~/work/book-la-demo/book-la-demo`
  [e30172f5] Documenter v1.2.1
  [4f61f5a4] FFTViews v0.3.2
  [7a1cc6ca] FFTW v1.8.0
  [587475ba] Flux v0.14.12
⌅ [a09fc81d] ImageCore v0.9.4
  [71a99df6] ImagePhantoms v0.7.2
  [b964fa9f] LaTeXStrings v1.3.1
  [7031d0ef] LazyGrids v0.5.0
  [599c1a8e] LinearMapsAA v0.11.0
  [98b081ad] Literate v2.16.1
  [7035ae7a] MIRT v0.17.0
  [170b2178] MIRTjim v0.23.0
  [eb30cadb] MLDatasets v0.7.14
  [efe261a4] NFFT v0.13.3
  [6ef6ca0d] NMF v1.0.2
  [15e1cf62] NPZ v0.4.3
  [0b1bfda6] OneHotArrays v0.2.5
  [429524aa] Optim v1.9.2
  [91a5bcdd] Plots v1.40.1
  [f27b6e38] Polynomials v4.0.6
  [2913bbd2] StatsBase v0.34.2
  [d6d074c3] VideoIO v1.0.9
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

This page was generated using Literate.jl.