summaryrefslogtreecommitdiff
path: root/pca.fut
blob: 8ffca599db29dab2c775e5e873f06c3eff9fa085 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
-- basics
def matmul [n][m][p] 'a
           (add: a -> a -> a) (mul: a -> a -> a) (zero: a)
           (A: [n][m]a) (B: [m][p]a) : [n][p]a =
  map (\A_row ->
         map (\B_col ->
                reduce add zero (map2 mul A_row B_col))
             (transpose B))
      A
def matmul_f64 = matmul (+) (*) 0f64
def sum = reduce (+) 0f64
def dot u v = sum (map2 (*) u v)

-- sets the last row to 1
def unitlst [n][m] (X: *[n][m]f64): *[n][m]f64 =
  X with [n-1] = tabulate m (const 1)
def unitlstT [n][m] (X: *[n][m]f64): *[n][m]f64 = transpose (unitlst (transpose X))

-- Bregman generators
def phi (z: f64) : f64 = z * f64.log z - z
def psi [d] (x: [d]f64) : f64 = sum (map ((*x[d-1]) <-< phi <-< (/(f64.epsilon+x[d-1]))) x)

def embed [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : [n][d]f64 =
  let C = tabulate_2d k k (\i j -> (if i == j then 1f64 else 0f64) - f64.from_fraction 1 k)
  in (matmul_f64 (matmul_f64 (unitlstT (copy B)) C) U)

def loss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : f64 =
  let Y = map (map f64.exp) (embed B U)
  in sum (map2 (\x y -> - psi y - dot (map2 (-) x y) (vjp psi y 1)) X Y)

def dloss [n][d][k] (X: [n][d]f64) (B: [n][k]f64) (U: [k][d]f64) : ([n][k]f64, [k][d]f64) = vjp (uncurry (loss X)) (B, U) 1

-- Approximate quantiles
def quantile [n] (q: f64) (xs: [n]f64) : f64 =
  let (p, _, _) = loop (_, xs, idx) = (0/0, xs, f64.to_i64 (q*f64.from_fraction n 1)) while length xs > 1 do
    let pivot = head xs
    let (left, right) = partition (<=pivot) (tail xs)
    let m = 1+length left
    in if m > idx
      then (pivot, left, idx)
      else (pivot, right, idx - m)
  in p

-- Optimiser
def rmsprop [n] (iters: i32) (beta: f64) (par: [n]f64) (df: [n]f64 -> [n]f64) (f: [n]f64 -> f64): [n]f64 =
  let step (i, l, _, par, s) =
    let d = df par
    let s' = map2 (\s x -> beta * s + (1-beta)*x**2) s d
    let par' = map3 (\x s d -> x - 1e-3 * d / (f64.epsilon + f64.sqrt s)) par s' d
    in (i - 1, f par', l, par', s')
  let (_, _, _, par, _) = iterate_until (\(i, l',l,_, _) -> l-l' <= f64.epsilon || i <= 0) step (iters, f (replicate n 0), f64.highest, par, replicate n 0)
  in par

def pca [n][d] (iters: i32) (k: i64) (X: [n][d]f64) : ([n][k]f64, [k][d]f64, f64) =
  let mX = sum (flatten X)
  let X' = map (map (/ mX)) X
  let fp x = x - f64.floor x
  let B = tabulate_2d n k (\i j -> if j == k-1 then 1 else 1e-3 * (fp (0.5 + 1.3247 * f64.from_fraction (i * k + j) 1) - 0.5))
  let U = tabulate_2d k d (\i j -> 1e-3 * (fp (0.5 + 1.7548 * f64.from_fraction (i * k + j) 1) - 0.5))
  let pack [n][d][k] (B: [n][k]f64) (U: [k][d]f64) : []f64 = flatten B ++ flatten U
  let unpack xs =
    let (a, b) = split xs
    in (unflatten a, unflatten b)
  let df (par: [n*k+k*d]f64) : [n*k+k*d]f64 = (uncurry pack <-< uncurry (dloss X') <-< unpack) par
  let f (par: [n*k+k*d]f64) : f64 = (uncurry (loss X') <-< unpack) par
  let init = flatten B ++ flatten U
  let par = rmsprop iters 0.999 init df f
  let (Bf, Uf ) = unpack par
  in (Bf, Uf, loss X' Bf Uf)

entry pcaWithQuantile [n][d] (iters: i32) (q: f64) (k: i64) (X: [n][d]f64) : ([n][k]f64, [d][k]f64, [n][d]f64, f64) =
  let qs = map (quantile q) X
  let Y = tabulate_2d n (d+1) (\i j -> if j == d then qs[i] else X[i][j])
  let (B, U, l) = pca iters (1+k) Y
  let Y' = embed B U
  in (map (\x -> x[0:k]) B, (transpose U[0:k])[0:d], map (\x -> map (f64.- x[d]) x[0:d]) Y', l)