Sampling

WIP on sampling using a score model.

This page comes from a single Julia file: 02-sample.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 02-sample.ipynb, or open it in binder here: 02-sample.ipynb.

Setup

Packages needed here.

using MIRTjim: jim, prompt
using Distributions: Distribution, Normal, MixtureModel, logpdf, pdf
import Distributions # var, mean
import Distributions: logpdf, pdf
import ForwardDiff # derivative, gradient
using LaTeXStrings
using Printf: @sprintf
using Random: shuffle, seed!; seed!(0)
using StatsBase: mean, std
using Plots: Plot, plot, plot!, scatter, scatter!, histogram, quiver!
using Plots: @animate, gif
import Plots # Animation
using Plots: default, gui, savefig
using Plots.PlotMeasures: px
default(label="", markerstrokecolor=:auto, linewidth=2,
 labelfontsize = 14, tickfontsize = 12, legendfontsize = 14,
)

The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.

isinteractive() ? jim(:prompt, true) : prompt(:prompt);

Overview

Given a score function $\bm{s}(\bm{x}; \bm{θ}) = \nabla_{\bm{x}} \log p(\bm{x}; \bm{θ}),$ one can use Langevin dynamics to draw samples from $p(\bm{x}; \bm{θ}).$

Illustration

Sampling from a gaussian mixture distribution.

Some convenience methods

logpdf(d::Distribution) = x -> logpdf(d, x)
pdf(d::Distribution) = x -> pdf(d, x)
derivative(f::Function) = x -> ForwardDiff.derivative(f, x)
gradient(f::Function) = x -> ForwardDiff.gradient(f, x)
# hessian(f::Function) = x -> ForwardDiff.hessian(f, x)
score(d::Distribution) = derivative(logpdf(d))
score_deriv(d::Distribution) = derivative(score(d)); # scalar x only


mix = MixtureModel(Normal, [(3,1), (13,3)], [0.4, 0.6])

left_margin = 20px; bottom_margin = 10px
xaxis = (L"x", (-4,24), [0, 3, 13, 20])
pmp = plot(pdf(mix); label="Gaussian mixture pdf", color = :blue,
 left_margin, bottom_margin, xaxis, size=(600,300),
 yaxis = (L"p(x)", (0, 0.17), (0:3)*0.05),
)

score1 = score(mix)

ylabel_score1 = L"s(x) = \frac{\mathrm{d}}{\mathrm{d}x} \, \log \ p(x)"
ps1 = plot(score1; xaxis, color=:magenta,
 size=(600,300), label = "GMM score function",
 yaxis = (ylabel_score1, (-5,5), -4:2:4), left_margin, bottom_margin,
)

pps = plot(pmp, ps1, layout=(2,1))
Example block output
prompt()


function sampler( ;
    score::Function = score1,
    T::Int = 600,
    α0::Real = 1,
    p0::Real = 0.99,
    alpha::AbstractVector = (@. α0 * (p0 ^ (1:T))^2),
    ntrial::Int = 1000, # for ph
    beta::Real = 1,
    seed::Int = 0,
    init_mean::Real = Distributions.mean(mix), # todo: cheating?
    init_std::Real = sqrt(Distributions.var(mix)),
)

    seed!(seed)
    xrun = Matrix{Float32}(undef, ntrial, T+1)
    xrun[:,1] = init_mean .+ init_std * randn(ntrial)

    for it in 1:T
        old = xrun[:,it]
        αt = alpha[it]
        xrun[:,it+1] = old + αt * score.(old) + sqrt(2*beta*αt) * randn(ntrial)
    end
    return xrun
end;


if !@isdefined(xrun) || true
    T = 600
    ntrial = 5000
    xrun = sampler(; T, ntrial)
end;

ntrace = 50
psl = plot(xrun[1:ntrace,:]', xlabel="Iteration (t)",
 xticks = 0:100:T,
 yaxis = (L"x_t", (-4,26), [0, 3, 13, 20]),
 annotate = (T, -2, "$ntrace generated samples", :right),
)
Example block output
prompt()
# savefig(psl, "gmm-prior-trace-$ntrace.pdf")


function gmm_hist(it::Int)
    ph = histogram(xrun[:,it];
        bins = -12:0.5:36, xaxis,
        label = "$ntrial generated samples", normalize = true,
        yaxis = (L"p(x)", (0, 0.17), 0:0.1:0.2),
        annotate = (-3, 0.14, "t = $it", :left),
    )
    plot!(ph, x -> pdf(mix)(x);
         linewidth=3, color=:black, label="GMM Distribution",
    )
    return ph
end
gmm_hist (generic function with 1 method)

Animate sampling process over time

if isinteractive()
    ph = gmm_hist(T)
else
    anim = @animate for it in [1:10; 20:10:100; 200:100:T]
        ph = gmm_hist(it)
    # tmp = @sprintf("%03d", it)
    # savefig(ph, "gmm-prior-sample-$ntrial,$tmp.pdf")
    end
    gif(anim, "gmm-hist.gif", fps = 6)
end
Example block output
prompt()
# savefig(ph, "gmm-prior-sample-$ntrial.pdf")

Kernel density estimate and its score function

ntrain = 200
train_data = rand(mix, ntrain)
gsig = 0.9

kde = MixtureModel(Normal, [(x, gsig) for x in train_data])
pkd = deepcopy(pmp)
plot!(pkd, pdf(kde); xaxis, label="KDE, σ=$gsig", widen=true, color=:green)
scatter!(pkd, train_data, zeros(ntrain), label="data, N=$ntrain", color=:black)
Example block output
prompt()
# savefig(pkd, "gmm-kde-pdf-$ntrain.pdf")

pks = deepcopy(ps1)
plot!(pks, score(kde), label="KDE score, σ=$gsig", color=:green)
Example block output
prompt()
# savefig(pks, "gmm-kde-score-$ntrain.pdf")

# plot(pkd, pks; layout=(2,1))

This page was generated using Literate.jl.