summaryrefslogtreecommitdiff
path: root/pca.fut
diff options
context:
space:
mode:
Diffstat (limited to 'pca.fut')
-rw-r--r--pca.fut11
1 files changed, 6 insertions, 5 deletions
diff --git a/pca.fut b/pca.fut
index 9d1fd22..8ffca59 100644
--- a/pca.fut
+++ b/pca.fut
@@ -42,13 +42,13 @@ def quantile [n] (q: f64) (xs: [n]f64) : f64 =
in p
-- Optimiser
-def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) : [n]f64 =
- let step (i, par, s) =
+def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) (f: [n]f64 -> f64): [n]f64 =
+ let step (i, l, _, par, s) =
let d = df par
let s' = map2 (\s x -> beta * s + (1-beta)*x**2) s d
let par' = map3 (\x s d -> x - 1e-3 * d / (f64.epsilon + f64.sqrt s)) par s' d
- in (i-1, par', s')
- let (_, par, _) = iterate_until (\(i,_,_) -> i <= 0) step (iters, par, replicate n 0)
+ in (i - 1, f par', l, par', s')
+ let (_, _, _, par, _) = iterate_until (\(i, l',l,_, _) -> l-l' <= f64.epsilon || i <= 0) step (iters, f (replicate n 0), f64.highest, par, replicate n 0)
in par
def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) =
@@ -62,8 +62,9 @@ def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64
let (a, b) = split xs
in (unflatten a, unflatten b)
let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par
+ let f (par: [n*k+k*d]f64) : f64 = (uncurry (loss X') <-< unpack) par
let init = flatten B ++ flatten U
- let par = rmsprop iters 0.999 init df
+ let par = rmsprop iters 0.999 init df f
let (Bf, Uf ) = unpack par
in (Bf, Uf, loss X' Bf Uf)