diff options
Diffstat (limited to 'pca.fut')
-rw-r--r-- | pca.fut | 11 |
1 files changed, 6 insertions, 5 deletions
@@ -42,13 +42,13 @@ def quantile [n] (q: f64) (xs: [n]f64) : f64 = in p -- Optimiser -def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) : [n]f64 = - let step (i, par, s) = +def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) (f: [n]f64 -> f64): [n]f64 = + let step (i, l, _, par, s) = let d = df par let s' = map2 (\s x -> beta * s + (1-beta)*x**2) s d let par' = map3 (\x s d -> x - 1e-3 * d / (f64.epsilon + f64.sqrt s)) par s' d - in (i-1, par', s') - let (_, par, _) = iterate_until (\(i,_,_) -> i <= 0) step (iters, par, replicate n 0) + in (i - 1, f par', l, par', s') + let (_, _, _, par, _) = iterate_until (\(i, l',l,_, _) -> l-l' <= f64.epsilon || i <= 0) step (iters, f (replicate n 0), f64.highest, par, replicate n 0) in par def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) = @@ -62,8 +62,9 @@ def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64 let (a, b) = split xs in (unflatten a, unflatten b) let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par + let f (par: [n*k+k*d]f64) : f64 = (uncurry (loss X') <-< unpack) par let init = flatten B ++ flatten U - let par = rmsprop iters 0.999 init df + let par = rmsprop iters 0.999 init df f let (Bf, Uf ) = unpack par in (Bf, Uf, loss X' Bf Uf) |