From 6d3684d1b32087b385bebce9ea0fa22cb522ab21 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Thu, 20 Jun 2024 17:08:56 +1000 Subject: update futhark --- pca.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'pca.fut') diff --git a/pca.fut b/pca.fut index 29df69c..9d1fd22 100644 --- a/pca.fut +++ b/pca.fut @@ -59,9 +59,9 @@ def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64 let U = tabulate_2d k d (\i j -> 1e-3 * (fp (0.5 + 1.7548 * f64.from_fraction (i * k + j) 1) - 0.5)) let pack [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : []f64 = flatten B ++ flatten U let unpack xs = - let (a, b) = split (n*k) xs - in (unflatten n k a, unflatten k d b) - let df [m] (par: [m]f64) : [m]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par :> [m]f64 + let (a, b) = split xs + in (unflatten a, unflatten b) + let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par let init = flatten B ++ flatten U let par = rmsprop iters 0.999 init df let (Bf, Uf ) = unpack par -- cgit v1.2.3