aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2023-01-17 09:33:44 +1100
committerJustin Bedo <cu@cua0.org>2023-01-17 09:37:05 +1100
commit4f82de5bbcd6dc17256d73551bb68b4ee8de454e (patch)
tree354dd6a359b49307f1767f4e525f5d4db2bf47cf
parentddf49e4cce3978368cc08e015e5d1492a5b7c93f (diff)
simplify code
-rw-r--r--bin/cluster.hs22
1 files changed, 10 insertions, 12 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs
index 553f6e0..8d1745d 100644
--- a/bin/cluster.hs
+++ b/bin/cluster.hs
@@ -24,32 +24,33 @@ pois lambda (fromIntegral -> k) = lambda' ** k' * Exp (negate lambda) / Exp (sti
k' = Exp $ log k
-- (infinite) binary trees
-data Tree a = Empty | Tree a (Tree a) (Tree a)
+data Tree a = Tree a (Tree a) (Tree a)
deriving (Show)
instance Foldable Tree where
foldMap f t = bftrav [t]
where
bftrav [] = mempty
- bftrav (Empty : ts) = bftrav ts
bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r])
+{-# INLINE pairs #-}
+pairs (a : b : rs) = (a, b) : pairs rs
+pairs [] = []
+pairs _ = error "unexpected number of columns, expecting pairs"
+
-- Infinite trees from infinite lists
-- NB: it's harder to partition a list so that it folds back to
-- equivalence. It doesn't really matter here since we're only
-- unfolding random uniforms anyway.
treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart)
where
- (lpart, rpart) = unzip $ partition xs
- partition (a : b : xs) = (a, b) : partition xs
+ (lpart, rpart) = unzip $ pairs xs
-- Constrain trees so leaves sum to node value
normTree :: Tree Double -> Tree Double
-normTree (Tree x l r) = go $ Tree x l r
- where
- go (Tree x (Tree u l r) (Tree v l' r')) =
- let s = x / (u + v)
- in Tree x (go $ Tree (s * u) l r) (go $ Tree (s * v) l' r')
+normTree (Tree x (Tree u l r) (Tree v l' r')) =
+ let s = x / (u + v)
+ in Tree x (normTree $ Tree (s * u) l r) (normTree $ Tree (s * v) l' r')
drawTreeProbs = toList . normTree . treeFromList <$> iid uniform
@@ -75,9 +76,6 @@ model xs = do
likelihood ps cnts = product $ zipWith go ps (pairs cnts)
where
go p (c, d) = max (pois (fromIntegral d * p) c) (pois (fromIntegral d * p / 2) c)
- pairs (a : b : rs) = (a, b) : pairs rs
- pairs [] = []
- pairs _ = error "unexpected number of columns, expecting count/depth pairs"
-- Tabulate list
tabulate xs = M.elems $ M.fromListWith (+) [(c, 1) | c <- xs]