aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2023-01-16 09:25:40 +1100
committerJustin Bedo <cu@cua0.org>2023-01-16 10:05:54 +1100
commit3668d6b45478ef9e8a48fc469f931b52f6dad9c6 (patch)
tree4bc654316c5dbdd122a1c1f8854287da2e7e91a8
parent7f4bcef26d58934045ba48e1d53e491a319184d0 (diff)
generalise to any number of samples
-rw-r--r--bin/cluster.hs36
1 files changed, 20 insertions, 16 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs
index 3c9dd6c..6bcac29 100644
--- a/bin/cluster.hs
+++ b/bin/cluster.hs
@@ -6,7 +6,6 @@ import Data.Fixed (mod')
import Data.Foldable (toList)
import Data.List
import qualified Data.Map as M
-import qualified Data.Vector.Unboxed as V
import Numeric.Log hiding (sum)
import Options.Applicative
import PPL
@@ -53,27 +52,31 @@ normTree (Tree x l r) = go $ Tree x l r
drawTreeProbs = toList . normTree . treeFromList <$> iid uniform
+model :: [[Int]] -> Meas ([[Double]], [Int])
model xs = do
- (clmuAcinar, clmuDuctal, params, clusters) <- sample $ do
+ (ps, params, clusters) <- sample $ do
-- Sample hyperparameters
a <- bounded 1 10
-- CRP style
dir <- cumsum <$> dirichletProcess a
rs <- iid uniform
- clmuAcinar <- drawTreeProbs
- clmuDuctal <- drawTreeProbs
+ ps <- iid drawTreeProbs
let clusters = map (\r -> first (>= r) dir) rs
- params = map (\i -> zip clmuAcinar clmuDuctal !! i) clusters
- pure (clmuAcinar, clmuDuctal, params, clusters)
- mapM_ scoreLog $ zipWith likelihood params (V.toList xs)
- let cls = take (V.length xs) clusters
+ params = map (transpose ps !!) clusters
+ pure (ps, params, clusters)
+ mapM_ scoreLog $ zipWith likelihood params xs
+ let cls = take (length xs) clusters
k = maximum cls + 1
- pure (take k clmuAcinar, take k clmuDuctal, cls)
+ n = length $ head xs
+ pure (map (take k) $ take n ps, cls)
where
- likelihood (ap, dp) (ac, ad, dc, dd) =
- max (pois (fromIntegral ad * ap) ac) (pois (fromIntegral ad * ap / 2) ac)
- * max (pois (fromIntegral dd * dp) dc) (pois (fromIntegral dd * dp / 2) dc)
+ likelihood ps cnts = product $ zipWith go ps (pairs cnts)
+ where
+ go p (c, d) = max (pois (fromIntegral d * p) c) (pois (fromIntegral d * p / 2) c)
+ pairs (a : b : rs) = (a, b) : pairs rs
+ pairs [] = []
+ pairs _ = error "unexpected number of columns, expecting count/depth pairs"
-- Command line args
data Options = Options
@@ -106,12 +109,13 @@ main = run =<< execParser opts
run opts = do
setStdGen . mkStdGen $ seed opts
(hdr : lines) <- lines <$> readFile (input opts)
- let parsed = V.fromList $ map ((\[_, ac, ad, dc, dd] -> (dbl ac, dbl ad, dbl dc, dbl dd)) . words) lines
+ let parsed = map (map dbl . tail . words) lines
dbl = round . read :: String -> Int
- ((a, d, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) (model parsed)
- writeFile (propsPath opts) . unlines $ zipWith (\a b -> show a <> "," <> show b) a d
+ ((ps, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) (model parsed)
+ writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps
writeFile (clusterPath opts) . unlines $ map show cl
where
- mml ((a, d, cl), lik) = sum (map (log . (+ 1)) a) + sum (map (log . (+ 1)) d) + sum (map (log . (+ 1)) tab) - sum (map stirling tab) - ln lik
+ mml ((ps, cl), lik) = sum' (sum' (log . (+ 1))) ps + sum' (log . (+ 1)) tab - sum' stirling tab - ln lik
where
tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl]
+ sum' f = sum . map f