1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
{-# LANGUAGE ViewPatterns #-}
module Main where
import Data.Fixed (mod')
import Data.Foldable (toList)
import Data.List
import qualified Data.Map as M
import qualified Data.Vector.Unboxed as V
import Numeric.Log hiding (sum)
import PPL
import System.Random (mkStdGen, setStdGen)
cumsum = scanl1 (+)
first f xs = snd . head . filter (f . fst) $ zip xs [0 ..]
stirling n = n * log n - n
pois lambda (fromIntegral -> k) = lambda' ** k' * Exp (negate lambda) / Exp (stirling k)
where
lambda' = Exp $ log lambda
k' = Exp $ log k
-- (infinite) binary trees
data Tree a = Empty | Tree a (Tree a) (Tree a)
deriving (Show)
instance Foldable Tree where
foldMap f t = bftrav [t]
where
bftrav [] = mempty
bftrav (Empty : ts) = bftrav ts
bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r])
-- Infinite trees from infinite lists
-- NB: it's harder to partition a list so that it folds back to
-- equivalence. It doesn't really matter here since we're only
-- unfolding random uniforms anyway.
treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart)
where
(lpart, rpart) = unzip $ partition xs
partition (a : b : xs) = (a, b) : partition xs
-- Constrain trees so leaves sum to node value
normTree :: Tree Double -> Tree Double
normTree (Tree x l r) = go $ Tree x l r
where
go (Tree x (Tree u l r) (Tree v l' r')) =
let s = x / (u + v)
in Tree x (go $ Tree (s * u) l r) (go $ Tree (s * v) l' r')
drawTreeProbs = toList . normTree . treeFromList <$> iid uniform
model xs = do
(clmuAcinar, clmuDuctal, params, clusters) <- sample $ do
-- Sample hyperparameters
a <- bounded 1 10
-- CRP style
dir <- cumsum <$> dirichletProcess a
rs <- iid uniform
clmuAcinar <- drawTreeProbs
clmuDuctal <- drawTreeProbs
let clusters = map (\r -> first (>= r) dir) rs
params = map (\i -> zip clmuAcinar clmuDuctal !! i) clusters
pure (clmuAcinar, clmuDuctal, params, clusters)
mapM_ scoreLog $ zipWith likelihood params (V.toList xs)
let cls = take (V.length xs) clusters
k = maximum cls + 1
pure (take k clmuAcinar, take k clmuDuctal, cls)
where
likelihood (ap, dp) (ac, ad, dc, dd) =
max (pois (fromIntegral ad * ap) ac) (pois (fromIntegral ad * ap / 2) ac)
* max (pois (fromIntegral dd * dp) dc) (pois (fromIntegral dd * dp / 2) dc)
main = do
setStdGen $ mkStdGen 42
(hdr : lines) <- lines <$> readFile "/nix/store/9xkb2apajv9sy37akz24x3jj6kw5hn7h-bionix-dump-counts"
let parsed = V.fromList $ map ((\[_, ac, ad, dc, dd] -> (dbl ac, dbl ad, dbl dc, dbl dd)) . words) lines
dbl = round . read :: String -> Int
((a, d, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take 100000 <$> mh 0.3 (model parsed)
writeFile "/data/props" . unlines $ zipWith (\a b -> show a <> "," <> show b) a d
writeFile "/data/clusters" . unlines $ map show cl
where
mml ((a, d, cl), lik) = sum (map (log . (+ 1)) a) + sum (map (log . (+ 1)) d) + sum (map (log . (+ 1)) tab) - sum (map stirling tab) - ln lik
where
tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl]
|