Low-Rank SURE

This example illustrates Stein's unbiased risk estimation (SURE) for parameter selection in low-rank matrix approximation, using the Julia language.

This page comes from a single Julia file: lr-sure.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: lr-sure.ipynb, or open it in binder here: lr-sure.ipynb.

Setup

Add the Julia packages used in this demo. Change false to true in the following code block if you are using any of the following packages for the first time.

if false
    import Pkg
    Pkg.add([
        "InteractiveUtils"
        "LaTeXStrings"
        "LinearAlgebra"
        "MIRTjim"
        "Plots"
        "Random"
    ])
end

Tell Julia to use the following packages. Run Pkg.add() in the preceding code block first, if needed.

using InteractiveUtils: versioninfo
using LaTeXStrings
using LinearAlgebra: svd, svdvals, Diagonal, norm
using MIRTjim: prompt
using Plots: default, gui, plot, plot!, scatter!, savefig
using Random: seed!
default(); default(label="", markerstrokecolor=:auto, markersize=7,
    labelfontsize=20, tickfontsize=16, legendfontsize=17, widen=true)

The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.

isinteractive() && prompt(:prompt);

Generate data

Noiseless low-rank matrix and noisy data matrix

M, N = 100, 50 # problem size
seed!(0)
Ktrue = 5 # true rank (planted model)
X = svd(randn(M,Ktrue)).U * Diagonal(1:Ktrue) * svd(randn(Ktrue,N)).Vt
sig0 = 0.03 # noise standard deviation
Y = X + sig0 * randn(size(X)) # noisy
sy = svdvals(Y)
sx = svdvals(X)
sx[1:Ktrue]
5-element Vector{Float64}:
 5.0
 4.000000000000003
 3.0
 2.0
 0.9999999999999997
sy[1:Ktrue]
5-element Vector{Float64}:
 5.023759688280403
 4.051459702189543
 3.0061483656272916
 1.9979253120833067
 1.1124738202874183

Plot singular values

ps = plot(xaxis = (L"k", (1,N), [1, Ktrue, N]), yaxis = (L"σ", (0,5.5), 0:5))
scatter!(1:N, sy, color=:red, marker=:hexagon,
 label=L"\sigma_k(Y) \ \mathrm{noisy}")
scatter!(1:N, sx, color=:blue, label=L"\sigma_k(X) \ \mathrm{noiseless}")
Example block output
prompt()

# savefig(ps, "lr_sure1s.pdf")

Low-rank approximation with various ranks

(U, sy, V) = svd(Y)
nrmse_K = zeros(N)
nrmsd_K = zeros(N)
nrmsd = (D) -> norm(D) / norm(Y) * 100
nrmse = (D) -> norm(D) / norm(X) * 100
for K in 1:N
    Xh = U[:,1:K] * Diagonal(sy[1:K]) * V[:,1:K]'
    nrmsd_K[K] = nrmsd(Xh - Y)
    nrmse_K[K] = nrmse(Xh - X)
end
nrmsd_K = [nrmsd(0 .- Y); nrmsd_K]
nrmse_K = [nrmse(0 .- X); nrmse_K]
klist = 0:N;

Plot normalized root mean-squared error/difference versus rank K

pk = plot( # legend=:outertop,
    xaxis = (L"K", (1,N), [0, 2, Ktrue, N]),
    yaxis = ("'Error' [%]", (0, 100), 0:20:100),
)
scatter!(klist, nrmse_K, color=:blue,
    label=L"\mathrm{NRMSE\ } ‖ \! \hat{X}_K - X \ ‖_{\mathrm{F}} / ‖X \ ‖_{\mathrm{F}} \cdot 100\%",
)
scatter!(klist, nrmsd_K, color=:red, marker=:diamond,
    label=L"\mathrm{NRMSD\ } ‖ \! \hat{X}_K - Y \ ‖_{\mathrm{F}} / ‖Y \ ‖_{\mathrm{F}} \cdot 100\%",
)
Example block output
prompt()

# savefig(pk, "lr_sure1a.pdf")

Explore (nuclear norm) regularized version

soft = (s,β) -> max.(s-β,0) # soft threshold function
dsoft = (s,β) -> Float32.(s .> β) # "derivative" thereof
reglist = [range(0, 0.5, 11); 0.75:0.25:6]
Nr = length(reglist)
nrmse_reg = zeros(Nr)
nrmsd_reg = zeros(Nr)
for ir in 1:Nr
    reg = reglist[ir]
    Xh = U * Diagonal(soft.(sy,reg)) * V'
    nrmsd_reg[ir] = nrmsd(Xh - Y)
    nrmse_reg[ir] = nrmse(Xh - X)
end;

Plot NRMSE and NRMSD versus regularization parameter

pb = plot(legend=:topleft, xaxis = (L"β", (0,6), 0:6),
    yaxis = ("'Error' [%]", (0, 100), 0:20:100))
scatter!(reglist, nrmse_reg, color=:blue,
    label=L"\mathrm{NRMSE\ } ‖ \! \hat{X}_{\beta} - X \ ‖_{\mathrm{F}} / ‖X \ ‖_{\mathrm{F}} \cdot 100\%",
#  label=L"\mathrm{NRMSE}", # book
)
scatter!(reglist, nrmsd_reg, color=:red, marker=:diamond,
    label=L"\mathrm{NRMSD\ } ‖ \! \hat{X}_{\beta} - Y \ ‖_{\mathrm{F}} / ‖Y \ ‖_{\mathrm{F}} \cdot 100\%",
#  label=L"\mathrm{NRMSD}", # book
)
Example block output
prompt()

# savefig(pb, "lr_sure1b.pdf")

Explore SURE for selecting $β$

\[\mathrm{SURE}(β) = ‖ \hat{X} - Y ‖^2 - MN \sigma_0^2 + 2 σ_0^2 \left( |M - N| \sum_{i=1}^{\min(M,N)} \frac{h(σ_iσ)}{σ_i} + \sum_{i=1}^{\min(M,N)} \dot{h}_i(σ_i;β) + 2 \sum_{i \neq j}^{\min(M,N)} \frac{σ_i h_i(σ_i;β)}{σ_i^2 - σ_j^2} \right)\]

  • sy singular values of Y
  • reg regularization parameter
  • v0 = sigma_0^2 noise variance
function sure(sy, reg, v0, M, N)
    sh = soft.(sy, reg) # estimated singular values
    big = sy.^2 .- (sy.^2)'
    big[big .== 0] .= Inf # trick to avoid divide by 0
    big = (sy .* sh) ./ big # [sy[i] * sh[i] / big[i,j] for i in 1:N, j in 1:N]
    big = sum(big)
    norm(sh - sy)^2 - M*N*v0 + 2*v0*(abs(M-N)*sum(sh ./ sy) + sum(dsoft.(sy,reg)) + 2*big)
end
sure (generic function with 1 method)

Evaluate SURE for each candidate regularization parameter

sure_reg = [sure(sy, reglist[ir], sig0^2, M, N) for ir in 1:Nr]
reg_best = reglist[argmin(sure_reg)] # SURE pick for β
0.3

Plot NRMSE and NRMSD versus regularization parameter

psb = plot(legend=:bottomright, widen=true,
    xaxis = (L"β", (0,6), [reg_best, 5, 6]),
    yaxis = ("'Error' [%]", (0,100), 0:20:100),
)
scatter!(reglist, nrmse_reg, color=:blue,
    label=L"\mathrm{NRMSE\ } ‖ \! \hat{X}_\beta - X \ ‖_{\mathrm{F}} / ‖X \ ‖_{\mathrm{F}} \cdot 100\%",
#  label=L"\mathrm{NRMSE}", # book
)
scatter!(reglist, nrmsd_reg, color=:red, marker=:diamond,
    label=L"\mathrm{NRMSD\ } ‖ \! \hat{X}_\beta - Y \ ‖_{\mathrm{F}} / ‖Y \ ‖_{\mathrm{F}} \cdot 100\%",
#  label=L"\mathrm{NRMSD}", # book
)
scatter!(reglist, sqrt.(sure_reg)/norm(Y)*100, color=:green, marker=:star,
    label=L"(\mathrm{SURE}(\beta))^{1/2} / ‖Y \ ‖_{\mathrm{F}} \cdot 100\%")
Example block output
prompt()

# savefig(psb, "lr_sure1c.pdf")

Examine shrunk singular values for best regularization parameter

sh = soft.(sy,reg_best)
psk = plot(
    xaxis = (L"k", (1, N), [1, Ktrue, sum(sh .!= 0), N]),
    yaxis = (L"σ", (0, 5.5), 0:6),
    legendfontsize = 20,
)
scatter!(1:N, sy, color=:red, marker=:hexagon, label=L"\sigma(Y) \ \mathrm{noisy}")
scatter!(1:N, sx, color=:blue, label=L"\sigma(X) \ \mathrm{noiseless}")
scatter!(1:N, sh, color=:green, marker=:star, label=L"\hat{\sigma} \ \ \mathrm{SURE} \ \hat{\beta}")
Example block output
prompt()

# savefig(psk, "lr_sure1t.pdf")

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
 "Julia Version 1.10.1"
 "Commit 7790d6f0641 (2024-02-13 20:41 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LIBM: libopenlibm"
 "  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/book-la-demo/book-la-demo/docs/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [aaaa29a8] Clustering v0.15.7
  [35d6a980] ColorSchemes v3.24.0
  [3da002f7] ColorTypes v0.11.4
⌅ [c3611d14] ColorVectorSpace v0.9.10
  [717857b8] DSP v0.7.9
  [72c85766] Demos v0.1.0 `~/work/book-la-demo/book-la-demo`
  [e30172f5] Documenter v1.2.1
  [4f61f5a4] FFTViews v0.3.2
  [7a1cc6ca] FFTW v1.8.0
  [587475ba] Flux v0.14.12
⌅ [a09fc81d] ImageCore v0.9.4
  [71a99df6] ImagePhantoms v0.7.2
  [b964fa9f] LaTeXStrings v1.3.1
  [7031d0ef] LazyGrids v0.5.0
  [599c1a8e] LinearMapsAA v0.11.0
  [98b081ad] Literate v2.16.1
  [7035ae7a] MIRT v0.17.0
  [170b2178] MIRTjim v0.23.0
  [eb30cadb] MLDatasets v0.7.14
  [efe261a4] NFFT v0.13.3
  [6ef6ca0d] NMF v1.0.2
  [15e1cf62] NPZ v0.4.3
  [0b1bfda6] OneHotArrays v0.2.5
  [429524aa] Optim v1.9.2
  [91a5bcdd] Plots v1.40.1
  [f27b6e38] Polynomials v4.0.6
  [2913bbd2] StatsBase v0.34.2
  [d6d074c3] VideoIO v1.0.9
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

This page was generated using Literate.jl.