Low-rank matrix completion: ADMM

This example illustrates low-rank matrix completion using the Julia language.

History:

  • 2017-11-07 Greg Ongie, University of Michigan, original version
  • 2017-11-12 Jeff Fessler, minor modifications
  • 2021-08-23 Julia 1.6.2
  • 2023-04-11 Literate version for Julia 1.8

This page comes from a single Julia file: lrmc3.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: lrmc3.ipynb, or open it in binder here: lrmc3.ipynb.

Setup

Add the Julia packages used in this demo. Change false to true in the following code block if you are using any of the following packages for the first time.

if false
    import Pkg
    Pkg.add([
        "DelimitedFiles"
        "Downloads"
        "InteractiveUtils"
        "LaTeXStrings"
        "LinearAlgebra"
        "MIRT"
        "MIRTjim"
        "Plots"
    ])
end

Tell Julia to use the following packages. Run Pkg.add() in the preceding code block first, if needed.

using DelimitedFiles: readdlm
using Downloads: download
using InteractiveUtils: versioninfo
using LaTeXStrings
using LinearAlgebra: svd, svdvals, Diagonal, norm
using MIRT: pogm_restart
using MIRTjim: jim, prompt
using Plots: plot, scatter, scatter!, savefig, default
default(); default(markerstrokecolor=:auto, label = "")

The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.

isinteractive() && prompt(:prompt);

TOP SECRET

On 2017-11-06 10:34:12 GMT, Agent 556 obtained a photo of the illegal arms dealer, code name Professor X-Ray. However, Agent 556 spilled soup on their lapel-camera, shorting out several CCD sensors. The image matrix has several missing entries; we were only able to recover 25% of data!

Agent 551: Your mission, should you choose to accept it, is to recover the missing entries and uncover the true identity of Professor X-Ray.

Read in data with missing pixels set to zero

if !@isdefined(Y)
    tmp = homedir() * "/web/course/551/julia/demo/professorxray.txt" # jf
    if !isfile(tmp)
        url = "https://web.eecs.umich.edu/~fessler/course/551/julia/demo/professorxray.txt"
        tmp = download(url)
    end
    Y = collect(readdlm(tmp)')
    py = jim(Y, "Y: Corrupted image matrix of Professor X-Ray\n (missing pixels set to 0)")
end
Example block output

Create binary mask $Ω$ (true=observed, false=unobserved)

Ω = Y .!= 0
percent_nonzero = sum(Ω) / length(Ω) # count proportion of missing entries
0.25

Show mask

pm = jim(Ω, "Ω: Locations of observed entries")
Example block output

Low-rank approximation

A simple low-rank approximation works poorly for this much missing data

r = 20
U,s,V = svd(Y)
Xr = U[:,1:r] * Diagonal(s[1:r]) * V[:,1:r]'
pr = jim(Xr, "Low-rank approximation for r=$r")
Example block output

Low-rank matrix completion

Instead, we will try to uncover the identity of Professor X-Ray using low-rank matrix completion.

The optimization problem we will solve is:

\[\min_{\mathbf X} \frac{1}{2} ‖ P_Ω(\mathbf X) - P_Ω(\mathbf Y) ‖_2^2 + β ‖ \mathbf X ‖_* \quad\quad\text{(NN-min)}\]

where $\mathbf Y$ is the zero-filled input data matrix, and $P_Ω$ is the operator that extracts a vector of entries belonging to the index set $Ω$.

Define cost function for optimization problem:

nucnorm = (X) -> sum(svdvals(X))
costfun = (X, beta) -> 0.5 * norm(X[Ω] - Y[Ω])^2 + beta * nucnorm(X);

Define singular value soft thresholding (SVST) function

function SVST(X, beta)
    U,s,V = svd(X)
    sthresh = @. max(s - beta, 0)
    return U * Diagonal(sthresh) * V'
end;

Iterative Soft-Thresholding Algorithm (ISTA)

ISTA is an extension of gradient descent to convex cost functions that look like $\min_x f(x) + g(x)$ where $f(x)$ is smooth and $g(x)$ is non-smooth. Also known as a proximal gradient method.

ISTA algorithm for solving (NN-min):

  • initialize $\mathbf X_0 = \mathbf Y$ (zero-fill missing entries)

  • for $k=0,1,2,…$

    • $[\hat{\mathbf X}_k]_{i,j} = \begin{cases}[\mathbf X_k]_{i,j} & (i,j) ∉ Ω \\ [\mathbf Y]_{i,j} & (i,j) ∈ Ω \end{cases}$ (Put back in known entries)

    • $\mathbf X_{k+1} = \text{SVST}(\hat{\mathbf X}_k,β)$ (Singular value soft-thresholding)

  • end

Apply ISTA:

niter = 400
beta = 0.01 # chosen by trial-and-error here
function lrmc_ista(Y)
    X = copy(Y)
    Xold = copy(X)
    cost_ista = zeros(niter+1)
    cost_ista[1] = costfun(X,beta)
    for k in 1:niter
        X[Ω] = Y[Ω]
        X = SVST(X,beta)
        cost_ista[k+1] = costfun(X,beta)
    end
    return X, cost_ista
end;

if !@isdefined(Xista)
    Xista, cost_ista = lrmc_ista(Y)
    pj_ista = jim(Xista, "ISTA result at $niter iterations")
end
Example block output

What went wrong? Let's investigate. First, let's see if the above solution is actually low-rank.

s_ista = svdvals(Xista)
s0 = svdvals(Y)
plot(title = "singular values",
    xtick = [1, sum(s .> 20*eps()), minimum(size(Y))])
scatter!(s0, color=:black, label="Y (initialization)")
scatter!(s_ista, color=:red, label="X (ISTA)")
Example block output
prompt()

Now let's check the cost function descent:

scatter(cost_ista, color=:red,
    title = "cost vs. iteration",
    xlabel = "iteration",
    ylabel = "cost function value",
    label = "ISTA",
)
Example block output
prompt()

Fast Iterative Soft-Thresholding Algorithm (FISTA)

Modification of ISTA that includes Nesterov acceleration for faster convergence.

Reference:

FISTA algorithm for solving (NN-min)

  • initialize matrices $\mathbf Z_0 = \mathbf X_0 = \mathbf Y$

  • for $k=0,1,2,…$

    • $[\hat{\mathbf Z}_k]_{i,j} = \begin{cases}[\mathbf Z_k]_{i,j} & (i,j) ∉ Ω \\ [\mathbf Y]_{i,j} & (i,j) ∈ Ω \end{cases}$ (Put back in known entries)

    • $\mathbf X_{k+1} = \text{SVST}(\hat{\mathbf Z}_k,\beta)$

    • $t_{k+1} = \frac{1 + \sqrt{1+4t_k^2}}{2}$ (Nesterov step-size)

    • $\mathbf Z_{k+1} = \mathbf X_{k+1} + \frac{t_k-1}{t_{k+1}}(\mathbf X_{k+1}-\mathbf X_{k})$ (Momentum update)

  • end

Run FISTA:

niter = 200
function lrmc_fista(Y)
    X = copy(Y)
    Z = copy(X)
    Xold = copy(X)
    told = 1
    cost_fista = zeros(niter+1)
    cost_fista[1] = costfun(X,beta)
    for k in 1:niter
        Z[Ω] = Y[Ω]
        X = SVST(Z,beta)
        t = (1 + sqrt(1+4*told^2))/2
        Z = X + ((told-1)/t)*(X-Xold)
        Xold = X
        told = t
        cost_fista[k+1] = costfun(X,beta) # comment out to speed-up
    end
    return X, cost_fista
end;

if !@isdefined(Xfista)
    Xfista, cost_fista = lrmc_fista(Y)
    pj_fista = jim(Xfista, "FISTA result at $niter iterations")
end
Example block output
plot(title = "cost vs. iteration",
    xlabel="iteration", ylabel = "cost function value")
scatter!(cost_ista, label="ISTA", color=:red)
scatter!(cost_fista, label="FISTA", color=:blue)
Example block output
prompt()

See if the FISTA result is "low rank"

s_fista = svdvals(Xfista)
effective_rank = count(>(0.01*s_fista[1]), s_fista)
19
ps = plot(title="singular values",
    xtick = [1, effective_rank, count(>(20*eps()), s_fista), minimum(size(Y))])
scatter!(s0, label="Y (initial)", color=:black)
scatter!(s_fista, label="X (FISTA)", color=:blue)
Example block output
prompt()

Exercise: think about why $σ_1(X) > σ_1(Y)$ !

Alternating directions method of multipliers (ADMM)

ADMM is another approach that uses SVST as a sub-routine, closely related to proximal gradient descent.

It is faster than FISTA, but the algorithm requires a tuning parameter $μ$. (Here we use $μ = β$).

References:

Run alternating directions method of multipliers (ADMM) algorithm:

niter = 50
50

Choice of parameter $μ$ can greatly affect convergence rate

function lrmc_admm(Y; mu::Real = beta)
    X = copy(Y)
    Z = zeros(size(X))
    L = zeros(size(X))
    cost_admm = zeros(niter+1)
    cost_admm[1] = costfun(X,beta)
    for k in 1:niter
        Z = SVST(X + L, beta / mu)
        X = (Y + mu * (Z - L)) ./ (mu .+ Ω)
        L = L + X - Z
        cost_admm[k+1] = costfun(X,beta) # comment out to speed-up
    end
    return X, cost_admm
end;

if !@isdefined(Xadmm)
    Xadmm, cost_admm = lrmc_admm(Y)
    pj_admm = jim(Xadmm, "ADMM result at $niter iterations")
end
Example block output
pc = plot(title = "cost vs. iteration",
    xtick = [0, 50, 200, 400],
    xlabel = "iteration", ylabel = "cost function value")
scatter!(0:400, cost_ista, label="ISTA", color=:red)
scatter!(0:200, cost_fista, label="FISTA", color=:blue)
scatter!(0:niter, cost_admm, label="ADMM", color=:magenta)
Example block output
prompt()

All singular values

s_admm = svdvals(Xadmm)
scatter!(ps, s_admm, label="X (ADMM)", color=:magenta, marker=:square)
Example block output
prompt()

For a suitable choice of $μ$, ADMM converges faster than FISTA.

Proximal optimized gradient method (POGM)

The proximal optimized gradient method (POGM) with adaptive restart is faster than FISTA with very similar computation per iteration. Unlike ADMM, POGM does not require any algorithm tuning parameter $μ$, making it easier to use in many practical composite optimization problems.

if !@isdefined(Xpogm)
    Fcost = X -> costfun(X, beta)
    f_grad = X -> Ω .* (X - Y) # gradient of smooth term
    f_L = 1 # Lipschitz constant of f_grad
    g_prox = (X, c) -> SVST(X, c * beta)
    fun = (iter, xk, yk, is_restart) -> (xk, Fcost(xk), is_restart)
    niter = 150
    Xpogm, out = pogm_restart(Y, Fcost, f_grad, f_L; g_prox, fun, niter)
    cost_pogm = [o[2] for o in out]
    pj_pogm = jim(Xpogm, "POGM result at $niter iterations")
end
Example block output
scatter!(pc, 0:niter, cost_pogm, label="POGM", color=:green)
Example block output
prompt()

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
 "Julia Version 1.10.1"
 "Commit 7790d6f0641 (2024-02-13 20:41 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LIBM: libopenlibm"
 "  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/book-la-demo/book-la-demo/docs/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [aaaa29a8] Clustering v0.15.7
  [35d6a980] ColorSchemes v3.24.0
  [3da002f7] ColorTypes v0.11.4
⌅ [c3611d14] ColorVectorSpace v0.9.10
  [717857b8] DSP v0.7.9
  [72c85766] Demos v0.1.0 `~/work/book-la-demo/book-la-demo`
  [e30172f5] Documenter v1.2.1
  [4f61f5a4] FFTViews v0.3.2
  [7a1cc6ca] FFTW v1.8.0
  [587475ba] Flux v0.14.12
⌅ [a09fc81d] ImageCore v0.9.4
  [71a99df6] ImagePhantoms v0.7.2
  [b964fa9f] LaTeXStrings v1.3.1
  [7031d0ef] LazyGrids v0.5.0
  [599c1a8e] LinearMapsAA v0.11.0
  [98b081ad] Literate v2.16.1
  [7035ae7a] MIRT v0.17.0
  [170b2178] MIRTjim v0.23.0
  [eb30cadb] MLDatasets v0.7.14
  [efe261a4] NFFT v0.13.3
  [6ef6ca0d] NMF v1.0.2
  [15e1cf62] NPZ v0.4.3
  [0b1bfda6] OneHotArrays v0.2.5
  [429524aa] Optim v1.9.2
  [91a5bcdd] Plots v1.40.1
  [f27b6e38] Polynomials v4.0.6
  [2913bbd2] StatsBase v0.34.2
  [d6d074c3] VideoIO v1.0.9
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

This page was generated using Literate.jl.