diff options
-rw-r--r-- | bin/cluster.hs | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs index ba9e9f1..e78e261 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -6,7 +6,7 @@ module Main where import Control.Monad import Data.Fixed (mod') import Data.Foldable (toList) -import Data.List +import Data.List hiding (group) import qualified Data.Map as M import GHC.Compact import Numeric.Log hiding (sum) @@ -47,10 +47,10 @@ instance Foldable Tree where bftrav [] = mempty bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r]) -{-# INLINE pairs #-} -pairs (a : b : rs) = (a, b) : pairs rs -pairs [] = [] -pairs _ = error "unexpected number of columns, expecting pairs" +{-# INLINE group #-} +group (a : b : c: d: rs) = (a, b, c, d) : group rs +group [] = [] +group s = error $ "unexpected number of columns, expecting 4:" <> show s -- Infinite trees from infinite lists -- NB: it's harder to partition a list so that it folds back to @@ -59,6 +59,7 @@ pairs _ = error "unexpected number of columns, expecting pairs" treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart) where (lpart, rpart) = unzip $ pairs xs + pairs (a:b:rs) = (a, b) : pairs rs -- Constrain trees so leaves sum to node value normTree :: Tree Double -> Tree Double @@ -87,9 +88,9 @@ model xs = do n = length (head xs) `div` 2 pure (map (take k) $ take n ps, cls) where - likelihood ps cnts = product $ zipWith go ps (pairs cnts) + likelihood ps cnts = product $ zipWith go ps (group cnts) where - go p (c, d) = binom d c p + go p (c, d, e, f) = binom d c (min 1 (p * fromIntegral (1+e-f) / fromIntegral (1+f))) -- Tabulate list tabulate xs = M.elems $ M.fromListWith (+) [(c, 1) | c <- xs] |