{-# LANGUAGE ViewPatterns #-} module Main where import Data.Fixed (mod') import Data.Foldable (toList) import Data.List import qualified Data.Map as M import qualified Data.Vector.Unboxed as V import Numeric.Log hiding (sum) import PPL import System.Random (mkStdGen, setStdGen) cumsum = scanl1 (+) first f xs = snd . head . filter (f . fst) $ zip xs [0 ..] stirling n = n * log n - n pois lambda (fromIntegral -> k) = lambda' ** k' * Exp (negate lambda) / Exp (stirling k) where lambda' = Exp $ log lambda k' = Exp $ log k -- (infinite) binary trees data Tree a = Empty | Tree a (Tree a) (Tree a) deriving (Show) instance Foldable Tree where foldMap f t = bftrav [t] where bftrav [] = mempty bftrav (Empty : ts) = bftrav ts bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r]) -- Infinite trees from infinite lists -- NB: it's harder to partition a list so that it folds back to -- equivalence. It doesn't really matter here since we're only -- unfolding random uniforms anyway. treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart) where (lpart, rpart) = unzip $ partition xs partition (a : b : xs) = (a, b) : partition xs -- Constrain trees so leaves sum to node value normTree :: Tree Double -> Tree Double normTree (Tree x l r) = go $ Tree x l r where go (Tree x (Tree u l r) (Tree v l' r')) = let s = x / (u + v) in Tree x (go $ Tree (s * u) l r) (go $ Tree (s * v) l' r') drawTreeProbs = toList . normTree . treeFromList <$> iid uniform model xs = do (clmuAcinar, clmuDuctal, params, clusters) <- sample $ do -- Sample hyperparameters a <- bounded 1 10 -- CRP style dir <- cumsum <$> dirichletProcess a rs <- iid uniform clmuAcinar <- drawTreeProbs clmuDuctal <- drawTreeProbs let clusters = map (\r -> first (>= r) dir) rs params = map (\i -> zip clmuAcinar clmuDuctal !! i) clusters pure (clmuAcinar, clmuDuctal, params, clusters) mapM_ scoreLog $ zipWith likelihood params (V.toList xs) let cls = take (V.length xs) clusters k = maximum cls + 1 pure (take k clmuAcinar, take k clmuDuctal, cls) where likelihood (ap, dp) (ac, ad, dc, dd) = max (pois (fromIntegral ad * ap) ac) (pois (fromIntegral ad * ap / 2) ac) * max (pois (fromIntegral dd * dp) dc) (pois (fromIntegral dd * dp / 2) dc) main = do setStdGen $ mkStdGen 42 (hdr : lines) <- lines <$> readFile "/nix/store/9xkb2apajv9sy37akz24x3jj6kw5hn7h-bionix-dump-counts" let parsed = V.fromList $ map ((\[_, ac, ad, dc, dd] -> (dbl ac, dbl ad, dbl dc, dbl dd)) . words) lines dbl = round . read :: String -> Int ((a, d, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take 100000 <$> mh 0.3 (model parsed) writeFile "/data/props" . unlines $ zipWith (\a b -> show a <> "," <> show b) a d writeFile "/data/clusters" . unlines $ map show cl where mml ((a, d, cl), lik) = sum (map (log . (+ 1)) a) + sum (map (log . (+ 1)) d) + sum (map (log . (+ 1)) tab) - sum (map stirling tab) - ln lik where tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl]