Low-rank matrix completion: ADMM
This example illustrates low-rank matrix completion using the Julia language.
History:
- 2017-11-07 Greg Ongie, University of Michigan, original version
- 2017-11-12 Jeff Fessler, minor modifications
- 2021-08-23 Julia 1.6.2
- 2023-04-11 Literate version for Julia 1.8
This page comes from a single Julia file: lrmc3.jl
.
You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: lrmc3.ipynb
, or open it in binder here: lrmc3.ipynb
.
Setup
Add the Julia packages used in this demo. Change false
to true
in the following code block if you are using any of the following packages for the first time.
if false
import Pkg
Pkg.add([
"DelimitedFiles"
"Downloads"
"InteractiveUtils"
"LaTeXStrings"
"LinearAlgebra"
"MIRT"
"MIRTjim"
"Plots"
])
end
Tell Julia to use the following packages. Run Pkg.add()
in the preceding code block first, if needed.
using DelimitedFiles: readdlm
using Downloads: download
using InteractiveUtils: versioninfo
using LaTeXStrings
using LinearAlgebra: svd, svdvals, Diagonal, norm
using MIRT: pogm_restart
using MIRTjim: jim, prompt
using Plots: plot, scatter, scatter!, savefig, default
default(); default(markerstrokecolor=:auto, label = "")
The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.
isinteractive() && prompt(:prompt);
TOP SECRET
On 2017-11-06 10:34:12 GMT, Agent 556 obtained a photo of the illegal arms dealer, code name Professor X-Ray. However, Agent 556 spilled soup on their lapel-camera, shorting out several CCD sensors. The image matrix has several missing entries; we were only able to recover 25% of data!
Agent 551: Your mission, should you choose to accept it, is to recover the missing entries and uncover the true identity of Professor X-Ray.
Read in data with missing pixels set to zero
if !@isdefined(Y)
tmp = homedir() * "/web/course/551/julia/demo/professorxray.txt" # jf
if !isfile(tmp)
url = "https://web.eecs.umich.edu/~fessler/course/551/julia/demo/professorxray.txt"
tmp = download(url)
end
Y = collect(readdlm(tmp)')
py = jim(Y, "Y: Corrupted image matrix of Professor X-Ray\n (missing pixels set to 0)")
end
Create binary mask $Ω$ (true=observed, false=unobserved)
Ω = Y .!= 0
percent_nonzero = sum(Ω) / length(Ω) # count proportion of missing entries
0.25
Show mask
pm = jim(Ω, "Ω: Locations of observed entries")
Low-rank approximation
A simple low-rank approximation works poorly for this much missing data
r = 20
U,s,V = svd(Y)
Xr = U[:,1:r] * Diagonal(s[1:r]) * V[:,1:r]'
pr = jim(Xr, "Low-rank approximation for r=$r")
Low-rank matrix completion
Instead, we will try to uncover the identity of Professor X-Ray using low-rank matrix completion.
The optimization problem we will solve is:
\[\min_{\mathbf X} \frac{1}{2} ‖ P_Ω(\mathbf X) - P_Ω(\mathbf Y) ‖_2^2 + β ‖ \mathbf X ‖_* \quad\quad\text{(NN-min)}\]
where $\mathbf Y$ is the zero-filled input data matrix, and $P_Ω$ is the operator that extracts a vector of entries belonging to the index set $Ω$.
Define cost function for optimization problem:
nucnorm = (X) -> sum(svdvals(X))
costfun = (X, beta) -> 0.5 * norm(X[Ω] - Y[Ω])^2 + beta * nucnorm(X);
Define singular value soft thresholding (SVST) function
function SVST(X, beta)
U,s,V = svd(X)
sthresh = @. max(s - beta, 0)
return U * Diagonal(sthresh) * V'
end;
Iterative Soft-Thresholding Algorithm (ISTA)
ISTA is an extension of gradient descent to convex cost functions that look like $\min_x f(x) + g(x)$ where $f(x)$ is smooth and $g(x)$ is non-smooth. Also known as a proximal gradient method.
ISTA algorithm for solving (NN-min):
initialize $\mathbf X_0 = \mathbf Y$ (zero-fill missing entries)
for
$k=0,1,2,…$$[\hat{\mathbf X}_k]_{i,j} = \begin{cases}[\mathbf X_k]_{i,j} & (i,j) ∉ Ω \\ [\mathbf Y]_{i,j} & (i,j) ∈ Ω \end{cases}$ (Put back in known entries)
$\mathbf X_{k+1} = \text{SVST}(\hat{\mathbf X}_k,β)$ (Singular value soft-thresholding)
end
Apply ISTA:
niter = 400
beta = 0.01 # chosen by trial-and-error here
function lrmc_ista(Y)
X = copy(Y)
Xold = copy(X)
cost_ista = zeros(niter+1)
cost_ista[1] = costfun(X,beta)
for k in 1:niter
X[Ω] = Y[Ω]
X = SVST(X,beta)
cost_ista[k+1] = costfun(X,beta)
end
return X, cost_ista
end;
if !@isdefined(Xista)
Xista, cost_ista = lrmc_ista(Y)
pj_ista = jim(Xista, "ISTA result at $niter iterations")
end
What went wrong? Let's investigate. First, let's see if the above solution is actually low-rank.
s_ista = svdvals(Xista)
s0 = svdvals(Y)
plot(title = "singular values",
xtick = [1, sum(s .> 20*eps()), minimum(size(Y))])
scatter!(s0, color=:black, label="Y (initialization)")
scatter!(s_ista, color=:red, label="X (ISTA)")
prompt()
Now let's check the cost function descent:
scatter(cost_ista, color=:red,
title = "cost vs. iteration",
xlabel = "iteration",
ylabel = "cost function value",
label = "ISTA",
)
prompt()
Fast Iterative Soft-Thresholding Algorithm (FISTA)
Modification of ISTA that includes Nesterov acceleration for faster convergence.
Reference:
- Beck, A. and Teboulle, M., 2009. A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM J. on Imaging Sciences, 2(1), pp.183-202.
FISTA algorithm for solving (NN-min)
initialize matrices $\mathbf Z_0 = \mathbf X_0 = \mathbf Y$
for
$k=0,1,2,…$$[\hat{\mathbf Z}_k]_{i,j} = \begin{cases}[\mathbf Z_k]_{i,j} & (i,j) ∉ Ω \\ [\mathbf Y]_{i,j} & (i,j) ∈ Ω \end{cases}$ (Put back in known entries)
$\mathbf X_{k+1} = \text{SVST}(\hat{\mathbf Z}_k,\beta)$
$t_{k+1} = \frac{1 + \sqrt{1+4t_k^2}}{2}$ (Nesterov step-size)
$\mathbf Z_{k+1} = \mathbf X_{k+1} + \frac{t_k-1}{t_{k+1}}(\mathbf X_{k+1}-\mathbf X_{k})$ (Momentum update)
end
Run FISTA:
niter = 200
function lrmc_fista(Y)
X = copy(Y)
Z = copy(X)
Xold = copy(X)
told = 1
cost_fista = zeros(niter+1)
cost_fista[1] = costfun(X,beta)
for k in 1:niter
Z[Ω] = Y[Ω]
X = SVST(Z,beta)
t = (1 + sqrt(1+4*told^2))/2
Z = X + ((told-1)/t)*(X-Xold)
Xold = X
told = t
cost_fista[k+1] = costfun(X,beta) # comment out to speed-up
end
return X, cost_fista
end;
if !@isdefined(Xfista)
Xfista, cost_fista = lrmc_fista(Y)
pj_fista = jim(Xfista, "FISTA result at $niter iterations")
end
plot(title = "cost vs. iteration",
xlabel="iteration", ylabel = "cost function value")
scatter!(cost_ista, label="ISTA", color=:red)
scatter!(cost_fista, label="FISTA", color=:blue)
prompt()
See if the FISTA result is "low rank"
s_fista = svdvals(Xfista)
effective_rank = count(>(0.01*s_fista[1]), s_fista)
19
ps = plot(title="singular values",
xtick = [1, effective_rank, count(>(20*eps()), s_fista), minimum(size(Y))])
scatter!(s0, label="Y (initial)", color=:black)
scatter!(s_fista, label="X (FISTA)", color=:blue)
prompt()
Exercise: think about why $σ_1(X) > σ_1(Y)$ !
Alternating directions method of multipliers (ADMM)
ADMM is another approach that uses SVST as a sub-routine, closely related to proximal gradient descent.
It is faster than FISTA, but the algorithm requires a tuning parameter $μ$. (Here we use $μ = β$).
References:
- Cai, J.F., Candès, E.J. and Shen, Z., 2010. A singular value thresholding algorithm for matrix completion. SIAM J. Optimization, 20(4), pp. 1956-1982.
- Boyd, S., Parikh, N., Chu, E., Peleato, B. and Eckstein, J., 2011. Distributed optimization and statistical learning via the alternating direction method of multipliers. Foundations and Trends in Machine Learning, 3(1), pp. 1-122.
Run alternating directions method of multipliers (ADMM) algorithm:
niter = 50
50
Choice of parameter $μ$ can greatly affect convergence rate
function lrmc_admm(Y; mu::Real = beta)
X = copy(Y)
Z = zeros(size(X))
L = zeros(size(X))
cost_admm = zeros(niter+1)
cost_admm[1] = costfun(X,beta)
for k in 1:niter
Z = SVST(X + L, beta / mu)
X = (Y + mu * (Z - L)) ./ (mu .+ Ω)
L = L + X - Z
cost_admm[k+1] = costfun(X,beta) # comment out to speed-up
end
return X, cost_admm
end;
if !@isdefined(Xadmm)
Xadmm, cost_admm = lrmc_admm(Y)
pj_admm = jim(Xadmm, "ADMM result at $niter iterations")
end
pc = plot(title = "cost vs. iteration",
xtick = [0, 50, 200, 400],
xlabel = "iteration", ylabel = "cost function value")
scatter!(0:400, cost_ista, label="ISTA", color=:red)
scatter!(0:200, cost_fista, label="FISTA", color=:blue)
scatter!(0:niter, cost_admm, label="ADMM", color=:magenta)
prompt()
All singular values
s_admm = svdvals(Xadmm)
scatter!(ps, s_admm, label="X (ADMM)", color=:magenta, marker=:square)
prompt()
For a suitable choice of $μ$, ADMM converges faster than FISTA.
Proximal optimized gradient method (POGM)
The proximal optimized gradient method (POGM) with adaptive restart is faster than FISTA with very similar computation per iteration. Unlike ADMM, POGM does not require any algorithm tuning parameter $μ$, making it easier to use in many practical composite optimization problems.
if !@isdefined(Xpogm)
Fcost = X -> costfun(X, beta)
f_grad = X -> Ω .* (X - Y) # gradient of smooth term
f_L = 1 # Lipschitz constant of f_grad
g_prox = (X, c) -> SVST(X, c * beta)
fun = (iter, xk, yk, is_restart) -> (xk, Fcost(xk), is_restart)
niter = 150
Xpogm, out = pogm_restart(Y, Fcost, f_grad, f_L; g_prox, fun, niter)
cost_pogm = [o[2] for o in out]
pj_pogm = jim(Xpogm, "POGM result at $niter iterations")
end
scatter!(pc, 0:niter, cost_pogm, label="POGM", color=:green)
prompt()
Reproducibility
This page was generated with the following version of Julia:
using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
11-element Vector{SubString{String}}:
"Julia Version 1.11.1"
"Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)"
"Build Info:"
" Official https://julialang.org/ release"
"Platform Info:"
" OS: Linux (x86_64-linux-gnu)"
" CPU: 4 × AMD EPYC 7763 64-Core Processor"
" WORD_SIZE: 64"
" LLVM: libLLVM-16.0.6 (ORCJIT, znver3)"
"Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
""
And with the following package versions
import Pkg; Pkg.status()
Status `~/work/book-la-demo/book-la-demo/docs/Project.toml`
[6e4b80f9] BenchmarkTools v1.5.0
[aaaa29a8] Clustering v0.15.7
[35d6a980] ColorSchemes v3.27.1
⌅ [3da002f7] ColorTypes v0.11.5
⌃ [c3611d14] ColorVectorSpace v0.10.0
[717857b8] DSP v0.7.10
[72c85766] Demos v0.1.0 `~/work/book-la-demo/book-la-demo`
[e30172f5] Documenter v1.7.0
[4f61f5a4] FFTViews v0.3.2
[7a1cc6ca] FFTW v1.8.0
[587475ba] Flux v0.14.25
[a09fc81d] ImageCore v0.10.4
[71a99df6] ImagePhantoms v0.8.1
[b964fa9f] LaTeXStrings v1.4.0
[7031d0ef] LazyGrids v1.0.0
[599c1a8e] LinearMapsAA v0.12.0
[98b081ad] Literate v2.20.1
[7035ae7a] MIRT v0.18.2
[170b2178] MIRTjim v0.25.0
[eb30cadb] MLDatasets v0.7.18
[efe261a4] NFFT v0.13.5
[6ef6ca0d] NMF v1.0.3
[15e1cf62] NPZ v0.4.3
[0b1bfda6] OneHotArrays v0.2.5
[429524aa] Optim v1.10.0
[91a5bcdd] Plots v1.40.9
[f27b6e38] Polynomials v4.0.11
[2913bbd2] StatsBase v0.34.3
[d6d074c3] VideoIO v1.1.0
[b77e0a4c] InteractiveUtils v1.11.0
[37e2e46d] LinearAlgebra v1.11.0
[44cfe95a] Pkg v1.11.0
[9a3f8284] Random v1.11.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`
This page was generated using Literate.jl.