aboutsummaryrefslogtreecommitdiff
path: root/bin/cluster.hs
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2022-12-16 10:57:55 +1100
committerJustin Bedo <cu@cua0.org>2023-01-16 09:02:49 +1100
commita09d5c9d9f7097dd4baf4e9611e488ee1d12ca2f (patch)
tree5fbabdaa8fbc3929dddacf2ddc19eded9a0c692d /bin/cluster.hs
init
Diffstat (limited to 'bin/cluster.hs')
-rw-r--r--bin/cluster.hs88
1 files changed, 88 insertions, 0 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs
new file mode 100644
index 0000000..70e22c7
--- /dev/null
+++ b/bin/cluster.hs
@@ -0,0 +1,88 @@
+{-# LANGUAGE ViewPatterns #-}
+
+module Main where
+
+import Data.Fixed (mod')
+import Data.Foldable (toList)
+import Data.List
+import qualified Data.Map as M
+import qualified Data.Vector.Unboxed as V
+import Numeric.Log hiding (sum)
+import PPL
+import System.Random (mkStdGen, setStdGen)
+
+cumsum = scanl1 (+)
+
+first f xs = snd . head . filter (f . fst) $ zip xs [0 ..]
+
+stirling n = n * log n - n
+
+pois lambda (fromIntegral -> k) = lambda' ** k' * Exp (negate lambda) / Exp (stirling k)
+ where
+ lambda' = Exp $ log lambda
+ k' = Exp $ log k
+
+-- (infinite) binary trees
+data Tree a = Empty | Tree a (Tree a) (Tree a)
+ deriving (Show)
+
+instance Foldable Tree where
+ foldMap f t = bftrav [t]
+ where
+ bftrav [] = mempty
+ bftrav (Empty : ts) = bftrav ts
+ bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r])
+
+-- Infinite trees from infinite lists
+-- NB: it's harder to partition a list so that it folds back to
+-- equivalence. It doesn't really matter here since we're only
+-- unfolding random uniforms anyway.
+treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart)
+ where
+ (lpart, rpart) = unzip $ partition xs
+ partition (a : b : xs) = (a, b) : partition xs
+
+-- Constrain trees so leaves sum to node value
+normTree :: Tree Double -> Tree Double
+normTree (Tree x l r) = go $ Tree x l r
+ where
+ go (Tree x (Tree u l r) (Tree v l' r')) =
+ let s = x / (u + v)
+ in Tree x (go $ Tree (s * u) l r) (go $ Tree (s * v) l' r')
+
+drawTreeProbs = toList . normTree . treeFromList <$> iid uniform
+
+model xs = do
+ (clmuAcinar, clmuDuctal, params, clusters) <- sample $ do
+ -- Sample hyperparameters
+ a <- bounded 1 10
+
+ -- CRP style
+ dir <- cumsum <$> dirichletProcess a
+ rs <- iid uniform
+ clmuAcinar <- drawTreeProbs
+ clmuDuctal <- drawTreeProbs
+ let clusters = map (\r -> first (>= r) dir) rs
+ params = map (\i -> zip clmuAcinar clmuDuctal !! i) clusters
+ pure (clmuAcinar, clmuDuctal, params, clusters)
+ mapM_ scoreLog $ zipWith likelihood params (V.toList xs)
+ let cls = take (V.length xs) clusters
+ k = maximum cls + 1
+ pure (take k clmuAcinar, take k clmuDuctal, cls)
+ where
+ likelihood (ap, dp) (ac, ad, dc, dd) =
+ max (pois (fromIntegral ad * ap) ac) (pois (fromIntegral ad * ap / 2) ac)
+ * max (pois (fromIntegral dd * dp) dc) (pois (fromIntegral dd * dp / 2) dc)
+
+main = do
+ setStdGen $ mkStdGen 42
+ (hdr : lines) <- lines <$> readFile "/nix/store/9xkb2apajv9sy37akz24x3jj6kw5hn7h-bionix-dump-counts"
+ let parsed = V.fromList $ map ((\[_, ac, ad, dc, dd] -> (dbl ac, dbl ad, dbl dc, dbl dd)) . words) lines
+ dbl = round . read :: String -> Int
+ ((a, d, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take 100000 <$> mh 0.3 (model parsed)
+ writeFile "/data/props" . unlines $ zipWith (\a b -> show a <> "," <> show b) a d
+ writeFile "/data/clusters" . unlines $ map show cl
+ where
+ mml ((a, d, cl), lik) = sum (map (log . (+ 1)) a) + sum (map (log . (+ 1)) d) + sum (map (log . (+ 1)) tab) - sum (map stirling tab) - ln lik
+ where
+ tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl]