Double Descent in LS

This example illustrates the phenomenon of double descent in least squares (LS) polynomial fitting using the Julia language.

Inspired by the article "Characterizations of Double Descent" by Manuchehr Aminian in SIAM News 58(10) Dec. 2025.

This page comes from a single Julia file: double-descent.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: double-descent.ipynb, or open it in binder here: double-descent.ipynb.

Setup

Add the Julia packages used in this demo. Change false to true in the following code block if you are using any of the following packages for the first time.

if false
    import Pkg
    Pkg.add([
        "InteractiveUtils"
        "LaTeXStrings"
        "LinearAlgebra"
        "MIRTjim"
        "Plots"
    ])
end

Tell Julia to use the following packages. Run Pkg.add() in the preceding code block first, if needed.

using InteractiveUtils: versioninfo
using LaTeXStrings
using LinearAlgebra: diag, norm, I, svdvals
using MIRTjim: prompt, jim
using Plots: default, gui, plot, plot!, scatter, scatter!, savefig
default(); default(label="", markerstrokecolor=:auto, widen=true, linewidth=2,
    markersize = 6, tickfontsize=12, labelfontsize = 16, legendfontsize=14)

The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.

isinteractive() && prompt(:prompt);

# Simulate data
M = 100 # number of data points
P = 99 # highest polynomial degree
t = range(-1, 1, M)
fun(t) = atan(2*t) # nonlinear function
y = fun.(t)
train = 1:(M÷2)
test = (M÷2+1):M;

Legendre polynomial basis

Build Legendre polynomial basis using Bonnet's recursion formula:

\[(n+1) P_{n+1}(x) = (2n+1) x P_n(x) - n P_{n-1}(x)\]

L = ones(M, P)
L[:,2] .= t
for k in 3:P
    n = k - 2 # caution: SIAM article had an error here
    L[:, k] = ((2n+1) * t .* L[:, k-1] - n * L[:, k-2]) / (n + 1)
end
pl = plot(t, L[:,1:5], title="First 5 Legendre polynomials", marker=:dot)
Example block output

Check recursion for k=3, corresponding to n=1 in Bonnet's recursion

p2(x) = (1/2) * (3x^2 - 1)
@assert p2.(t) ≈ L[:,3]

Check basis function normalization (continuous vs discrete)

p = 0:(P-1)
normp = @. sqrt(2 / (2p+1)) # theoretical L₂[-1,1] norm
norme = norm.(eachcol(L)) / sqrt(M/2) # empirical norm, account for dx
plot(xlabel="degree", ylabel="norm")
plot!(p, norme, marker=:dot, color=:red, label="empirical")
plot!(p, normp, marker=:dot, color=:blue, label="analytical")
Example block output

Normalize basis functions using empirical norms

L = L ./ norme' / sqrt(M/2);

Scree plot

Examine the singular values of Legendre basis L. Clearly L is not semiunitary, and the last ~15 values are very small. So fitting with more than ~80 components will be very unstable, even if all M samples were available.

scatter(svdvals(L), xaxis = ("k", (0,100), 0:10:100), ylabel = L"σ_k")
Example block output

Examine orthogonality of the basis functions

The Legendre polynomials are orthogonal in $L₂[-1,1]$, but the following correlation figure shows that they are not orthogonal when sampled.

pc = jim(p, p, L'L, "correlation")
Example block output

Evaluate OLS solutions for increasing k

The training error decreases monotonically with polynomial degree

errors = zeros(P,3)
for k in 1:P
    A = L[:, 1:k]
    xhat = A[train,:] \ y[train]
    residual = A*xhat - y
    errors[k,1] = norm(residual[train])
    errors[k,2:3] .= norm.((residual[train], residual[test]), Inf)
end
ptrain = scatter(p, 100*errors[:,1]/norm(y[train]), title="NRMSE training",
 xlabel="Polynomial degree")
Example block output

The test error exhibits double descent

ptest = plot(p, 100*errors[:,3]/norm(y[train]), title="NRMSE test",
 marker=:dot,
 xlabel = "Polynomial degree",
 yaxis = ("NRMSE (%)", (0, 100), ),
)
Example block output

Show fits for small, medium and large polynomial degree

pfit = plot(
 xaxis = ("x", (-1,1), -1:1),
 yaxis = ("y", (-1,1) .* 1.5, -1:1),
)
scatter!(t, y)
for k in (2, 10, 99)
   A = L[:, 1:k]
   xhat = A[train,:] \ y[train]
   plot!(t, A * xhat, label="k=$k")
end
pfit
Example block output

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
 "Julia Version 1.12.4"
 "Commit 01a2eadb047 (2026-01-06 16:56 UTC)"
 "Build Info:"
 "  Official https://julialang.org release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)"
 "  GC: Built with stock GC"
 "Threads: 1 default, 1 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/book-la-demo/book-la-demo/docs/Project.toml`
  [6e4b80f9] BenchmarkTools v1.6.3
  [aaaa29a8] Clustering v0.15.8
  [35d6a980] ColorSchemes v3.31.0
  [3da002f7] ColorTypes v0.12.1
  [c3611d14] ColorVectorSpace v0.11.0
  [717857b8] DSP v0.8.4
  [72c85766] Demos v0.1.0 `~/work/book-la-demo/book-la-demo`
  [e30172f5] Documenter v1.16.1
  [4f61f5a4] FFTViews v0.3.2
  [7a1cc6ca] FFTW v1.10.0
  [587475ba] Flux v0.16.8
  [a09fc81d] ImageCore v0.10.5
  [71a99df6] ImagePhantoms v0.8.1
  [b964fa9f] LaTeXStrings v1.4.0
  [7031d0ef] LazyGrids v1.1.0
  [599c1a8e] LinearMapsAA v0.12.0
  [98b081ad] Literate v2.21.0
  [7035ae7a] MIRT v0.18.3
  [170b2178] MIRTjim v0.26.0
  [eb30cadb] MLDatasets v0.7.20
  [efe261a4] NFFT v0.14.3
  [6ef6ca0d] NMF v1.0.3
  [15e1cf62] NPZ v0.4.3
  [0b1bfda6] OneHotArrays v0.2.10
  [429524aa] Optim v2.0.0
  [91a5bcdd] Plots v1.41.4
  [f27b6e38] Polynomials v4.1.0
  [2913bbd2] StatsBase v0.34.10
  [d6d074c3] VideoIO v1.4.0
  [b77e0a4c] InteractiveUtils v1.11.0
  [37e2e46d] LinearAlgebra v1.12.0
  [44cfe95a] Pkg v1.12.1
  [9a3f8284] Random v1.11.0

This page was generated using Literate.jl.