diff options
Diffstat (limited to 'bin/cluster.hs')
-rw-r--r-- | bin/cluster.hs | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs index 3c9dd6c..6bcac29 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -6,7 +6,6 @@ import Data.Fixed (mod') import Data.Foldable (toList) import Data.List import qualified Data.Map as M -import qualified Data.Vector.Unboxed as V import Numeric.Log hiding (sum) import Options.Applicative import PPL @@ -53,27 +52,31 @@ normTree (Tree x l r) = go $ Tree x l r drawTreeProbs = toList . normTree . treeFromList <$> iid uniform +model :: [[Int]] -> Meas ([[Double]], [Int]) model xs = do - (clmuAcinar, clmuDuctal, params, clusters) <- sample $ do + (ps, params, clusters) <- sample $ do -- Sample hyperparameters a <- bounded 1 10 -- CRP style dir <- cumsum <$> dirichletProcess a rs <- iid uniform - clmuAcinar <- drawTreeProbs - clmuDuctal <- drawTreeProbs + ps <- iid drawTreeProbs let clusters = map (\r -> first (>= r) dir) rs - params = map (\i -> zip clmuAcinar clmuDuctal !! i) clusters - pure (clmuAcinar, clmuDuctal, params, clusters) - mapM_ scoreLog $ zipWith likelihood params (V.toList xs) - let cls = take (V.length xs) clusters + params = map (transpose ps !!) clusters + pure (ps, params, clusters) + mapM_ scoreLog $ zipWith likelihood params xs + let cls = take (length xs) clusters k = maximum cls + 1 - pure (take k clmuAcinar, take k clmuDuctal, cls) + n = length $ head xs + pure (map (take k) $ take n ps, cls) where - likelihood (ap, dp) (ac, ad, dc, dd) = - max (pois (fromIntegral ad * ap) ac) (pois (fromIntegral ad * ap / 2) ac) - * max (pois (fromIntegral dd * dp) dc) (pois (fromIntegral dd * dp / 2) dc) + likelihood ps cnts = product $ zipWith go ps (pairs cnts) + where + go p (c, d) = max (pois (fromIntegral d * p) c) (pois (fromIntegral d * p / 2) c) + pairs (a : b : rs) = (a, b) : pairs rs + pairs [] = [] + pairs _ = error "unexpected number of columns, expecting count/depth pairs" -- Command line args data Options = Options @@ -106,12 +109,13 @@ main = run =<< execParser opts run opts = do setStdGen . mkStdGen $ seed opts (hdr : lines) <- lines <$> readFile (input opts) - let parsed = V.fromList $ map ((\[_, ac, ad, dc, dd] -> (dbl ac, dbl ad, dbl dc, dbl dd)) . words) lines + let parsed = map (map dbl . tail . words) lines dbl = round . read :: String -> Int - ((a, d, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) (model parsed) - writeFile (propsPath opts) . unlines $ zipWith (\a b -> show a <> "," <> show b) a d + ((ps, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) (model parsed) + writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps writeFile (clusterPath opts) . unlines $ map show cl where - mml ((a, d, cl), lik) = sum (map (log . (+ 1)) a) + sum (map (log . (+ 1)) d) + sum (map (log . (+ 1)) tab) - sum (map stirling tab) - ln lik + mml ((ps, cl), lik) = sum' (sum' (log . (+ 1))) ps + sum' (log . (+ 1)) tab - sum' stirling tab - ln lik where tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl] + sum' f = sum . map f |