diff options
author | Justin Bedo <cu@cua0.org> | 2023-02-09 10:09:47 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2023-02-15 15:13:06 +1100 |
commit | 4ecc8d1b42cb2c4992fe227719ed4662ca1c8168 (patch) | |
tree | 32b91431a4a09793d757a079a3c097c5fdc721fb /pca.fut | |
parent | 42edc4f3b5dcd460aef75af536ee541c975b0e8b (diff) |
return log-space embedding rather than count space
Diffstat (limited to 'pca.fut')
-rw-r--r-- | pca.fut | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -22,10 +22,10 @@ def psi [d] (x: [d]f64) : f64 = sum (map ((*x[d-1]) <-< phi <-< (/(f64.epsilon+x def embed [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : [n][d]f64 = let C = tabulate_2d k k (\i j -> (if i == j then 1f64 else 0f64) - f64.from_fraction 1 k) - in (map (map f64.exp) (matmul_f64 (matmul_f64 (unitfstT (copy B)) C) U)) + in (matmul_f64 (matmul_f64 (unitfstT (copy B)) C) U) def loss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : f64 = - let Y = embed B U + let Y = map (map f64.exp) (embed B U) in sum (map2 (\x y -> - psi y - dot (map2 (-) x y) (vjp psi y 1)) X Y) def dloss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : ([n][k]f64, [k][d]f64) = vjp (uncurry (loss X)) (B, U) 1 |