From 42edc4f3b5dcd460aef75af536ee541c975b0e8b Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Thu, 9 Feb 2023 09:36:19 +1100 Subject: allow setting number of iterations for optimiser --- pca.fut | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'pca.fut') diff --git a/pca.fut b/pca.fut index a52df80..ed4d155 100644 --- a/pca.fut +++ b/pca.fut @@ -24,11 +24,11 @@ def embed [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : [n][d]f64 = let C = tabulate_2d k k (\i j -> (if i == j then 1f64 else 0f64) - f64.from_fraction 1 k) in (map (map f64.exp) (matmul_f64 (matmul_f64 (unitfstT (copy B)) C) U)) -entry loss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : f64 = +def loss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : f64 = let Y = embed B U in sum (map2 (\x y -> - psi y - dot (map2 (-) x y) (vjp psi y 1)) X Y) -entry dloss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : ([n][k]f64, [k][d]f64) = vjp (uncurry (loss X)) (B, U) 1 +def dloss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : ([n][k]f64, [k][d]f64) = vjp (uncurry (loss X)) (B, U) 1 -- Approximate quantiles def quantile [n] (q: f64) (xs: [n]f64) : f64 = @@ -51,7 +51,7 @@ def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) : let (_, par, _) = iterate_until (\(i,_,_) -> i <= 0) step (iters, par, replicate n 0) in par -def pca [n][d] (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) = +def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) = let mX = sum (flatten X) let X' = map (map (/ mX)) X let fp x = x - f64.floor x @@ -63,13 +63,13 @@ def pca [n][d] (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) = in (unflatten n k a, unflatten k d b) let df [m] (par: [m]f64) : [m]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par :> [m]f64 let init = flatten B ++ flatten U - let par = rmsprop 10000 0.999 init df + let par = rmsprop iters 0.999 init df let (Bf, Uf ) = unpack par in (Bf, Uf, loss X' Bf Uf) -entry pcaWithQuantile [n][d] (q: f64) (k: i64) (X: [n][d]f64) : ([n][k]f64, [d][k]f64, [n][d]f64, f64) = +entry pcaWithQuantile [n][d] (iters: i32) (q: f64) (k: i64) (X: [n][d]f64) : ([n][k]f64, [d][k]f64, [n][d]f64, f64) = let qs = map (quantile q) X let Y = tabulate_2d n (d+1) (\i j -> if j == d then qs[i] else X[i][j]) - let (B, U, l) = pca (1+k) Y + let (B, U, l) = pca iters (1+k) Y let Y' = embed B U in (map (\x -> x[0:k]) B, (transpose U[0:k])[0:d], map (\x -> x[0:d]) Y', l) -- cgit v1.2.3