summaryrefslogtreecommitdiff
path: root/pca.fut
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2024-06-20 17:08:56 +1000
committerJustin Bedo <cu@cua0.org>2024-06-20 17:09:58 +1000
commit6d3684d1b32087b385bebce9ea0fa22cb522ab21 (patch)
tree58a5a803524108fa74a9d97483b763ff09fea6bc /pca.fut
parent56d3507d774357da1281fa3b0c49ed4e0466800d (diff)
update futhark
Diffstat (limited to 'pca.fut')
-rw-r--r--pca.fut6
1 files changed, 3 insertions, 3 deletions
diff --git a/pca.fut b/pca.fut
index 29df69c..9d1fd22 100644
--- a/pca.fut
+++ b/pca.fut
@@ -59,9 +59,9 @@ def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64
let U = tabulate_2d k d (\i j -> 1e-3 * (fp (0.5 + 1.7548 * f64.from_fraction (i * k + j) 1) - 0.5))
let pack [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : []f64 = flatten B ++ flatten U
let unpack xs =
- let (a, b) = split (n*k) xs
- in (unflatten n k a, unflatten k d b)
- let df [m] (par: [m]f64) : [m]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par :> [m]f64
+ let (a, b) = split xs
+ in (unflatten a, unflatten b)
+ let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par
let init = flatten B ++ flatten U
let par = rmsprop iters 0.999 init df
let (Bf, Uf ) = unpack par