aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-02-27 09:58:21 +1100
committerJustin Bedo <cu@cua0.org>2025-02-27 13:18:14 +1100
commitd621243bffa14214c722a9be181f4a98d41380bb (patch)
tree5e290c8caf3a0bb26d5a7ddaa2ec734f4a525d73
parentd1a1ffa811b199f5d7d02d14cf76631a0ecd2699 (diff)
ssmh: minor optimisations
-rw-r--r--src/PPL/Sampling.hs20
1 files changed, 8 insertions, 12 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 520f182..8e0ab8f 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -24,7 +24,6 @@ import qualified Streaming as S
import Streaming.Prelude (Of, Stream, yield)
import System.IO.Unsafe
import System.Random (StdGen, random, randoms)
-import Unsafe.Coerce
mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p m = step t0 t x w
@@ -48,7 +47,7 @@ mh g p m = step t0 t x w
-- Single site MH
-- Truncated trees
-data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show)
+data TTree = TTree Bool [Maybe TTree] deriving (Show)
type Site = [Int]
@@ -59,11 +58,11 @@ trunc = truncTree . asBox
getBoxedClosureData' t >>= \case
ConstrClosure _ [l, r] [] _ _ "Tree" ->
getBoxedClosureData' l >>= \case
- ConstrClosure {dataArgs = [d], name = "D#"} -> do
- TTree (Just $ unsafeCoerce d) <$> truncTrees r
+ ConstrClosure {dataArgs = [d], name = "D#"} ->
+ TTree True <$> truncTrees r
x -> error $ "truncTree:ConstrClosure:" ++ show x
ConstrClosure _ [r] [d] _ _ "Tree" ->
- TTree (Just $ unsafeCoerce d) <$> truncTrees r
+ TTree False <$> truncTrees r
x -> error $ "truncTree:" ++ show x
getBoxedClosureData' x =
@@ -75,11 +74,9 @@ trunc = truncTree . asBox
getBoxedClosureData' b >>= \case
ConstrClosure _ [l, r] [] _ _ ":" ->
getBoxedClosureData' l >>= \case
- ConstrClosure {name = "Tree"} -> do
- l' <- truncTree l
- r' <- truncTrees r
- pure $ Just l' : r'
- _ -> pure []
+ ConstrClosure {name = "Tree"} ->
+ ((:) . Just) <$> truncTree l <*> truncTrees r
+ _ -> (Nothing :) <$> truncTrees r
_ -> pure []
trunc' t x w = unsafePerformIO $ do
@@ -88,8 +85,7 @@ trunc' t x w = unsafePerformIO $ do
trunc t
sites :: Site -> TTree -> [Site]
-sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
-sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
+sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
mutate = M.foldrWithKey go
where