From 3080e05e64292a934f9ce32e908c1c8a88dc40ec Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Thu, 20 Jun 2024 17:09:59 +1000 Subject: add stopping criteria --- pca.fut | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'pca.fut') diff --git a/pca.fut b/pca.fut index 9d1fd22..8ffca59 100644 --- a/pca.fut +++ b/pca.fut @@ -42,13 +42,13 @@ def quantile [n] (q: f64) (xs: [n]f64) : f64 = in p -- Optimiser -def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) : [n]f64 = - let step (i, par, s) = +def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) (f: [n]f64 -> f64): [n]f64 = + let step (i, l, _, par, s) = let d = df par let s' = map2 (\s x -> beta * s + (1-beta)*x**2) s d let par' = map3 (\x s d -> x - 1e-3 * d / (f64.epsilon + f64.sqrt s)) par s' d - in (i-1, par', s') - let (_, par, _) = iterate_until (\(i,_,_) -> i <= 0) step (iters, par, replicate n 0) + in (i - 1, f par', l, par', s') + let (_, _, _, par, _) = iterate_until (\(i, l',l,_, _) -> l-l' <= f64.epsilon || i <= 0) step (iters, f (replicate n 0), f64.highest, par, replicate n 0) in par def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) = @@ -62,8 +62,9 @@ def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64 let (a, b) = split xs in (unflatten a, unflatten b) let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par + let f (par: [n*k+k*d]f64) : f64 = (uncurry (loss X') <-< unpack) par let init = flatten B ++ flatten U - let par = rmsprop iters 0.999 init df + let par = rmsprop iters 0.999 init df f let (Bf, Uf ) = unpack par in (Bf, Uf, loss X' Bf Uf) -- cgit v1.2.3