diff options
| -rw-r--r-- | bin/cluster.hs | 15 | 
1 files changed, 8 insertions, 7 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs index ba9e9f1..e78e261 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -6,7 +6,7 @@ module Main where  import Control.Monad  import Data.Fixed (mod')  import Data.Foldable (toList) -import Data.List +import Data.List hiding (group)  import qualified Data.Map as M  import GHC.Compact  import Numeric.Log hiding (sum) @@ -47,10 +47,10 @@ instance Foldable Tree where        bftrav [] = mempty        bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r]) -{-# INLINE pairs #-} -pairs (a : b : rs) = (a, b) : pairs rs -pairs [] = [] -pairs _ = error "unexpected number of columns, expecting pairs" +{-# INLINE group #-} +group (a : b : c: d: rs) = (a, b, c, d) : group rs +group [] = [] +group s = error $ "unexpected number of columns, expecting 4:" <> show s  -- Infinite trees from infinite lists  -- NB: it's harder to partition a list so that it folds back to @@ -59,6 +59,7 @@ pairs _ = error "unexpected number of columns, expecting pairs"  treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart)    where      (lpart, rpart) = unzip $ pairs xs +    pairs (a:b:rs) = (a, b) : pairs rs  -- Constrain trees so leaves sum to node value  normTree :: Tree Double -> Tree Double @@ -87,9 +88,9 @@ model xs = do        n = length (head xs) `div` 2    pure (map (take k) $ take n ps, cls)    where -    likelihood ps cnts = product $ zipWith go ps (pairs cnts) +    likelihood ps cnts = product $ zipWith go ps (group cnts)        where -        go p (c, d) = binom d c p +        go p (c, d, e, f) = binom d c (min 1 (p * fromIntegral (1+e-f) / fromIntegral (1+f)))  -- Tabulate list  tabulate xs = M.elems $ M.fromListWith (+) [(c, 1) | c <- xs]  | 
