diff options
author | Justin Bedo <cu@cua0.org> | 2025-02-27 09:58:21 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-02-27 13:18:14 +1100 |
commit | d621243bffa14214c722a9be181f4a98d41380bb (patch) | |
tree | 5e290c8caf3a0bb26d5a7ddaa2ec734f4a525d73 | |
parent | d1a1ffa811b199f5d7d02d14cf76631a0ecd2699 (diff) |
ssmh: minor optimisations
-rw-r--r-- | src/PPL/Sampling.hs | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 520f182..8e0ab8f 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -24,7 +24,6 @@ import qualified Streaming as S import Streaming.Prelude (Of, Stream, yield) import System.IO.Unsafe import System.Random (StdGen, random, randoms) -import Unsafe.Coerce mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = step t0 t x w @@ -48,7 +47,7 @@ mh g p m = step t0 t x w -- Single site MH -- Truncated trees -data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show) +data TTree = TTree Bool [Maybe TTree] deriving (Show) type Site = [Int] @@ -59,11 +58,11 @@ trunc = truncTree . asBox getBoxedClosureData' t >>= \case ConstrClosure _ [l, r] [] _ _ "Tree" -> getBoxedClosureData' l >>= \case - ConstrClosure {dataArgs = [d], name = "D#"} -> do - TTree (Just $ unsafeCoerce d) <$> truncTrees r + ConstrClosure {dataArgs = [d], name = "D#"} -> + TTree True <$> truncTrees r x -> error $ "truncTree:ConstrClosure:" ++ show x ConstrClosure _ [r] [d] _ _ "Tree" -> - TTree (Just $ unsafeCoerce d) <$> truncTrees r + TTree False <$> truncTrees r x -> error $ "truncTree:" ++ show x getBoxedClosureData' x = @@ -75,11 +74,9 @@ trunc = truncTree . asBox getBoxedClosureData' b >>= \case ConstrClosure _ [l, r] [] _ _ ":" -> getBoxedClosureData' l >>= \case - ConstrClosure {name = "Tree"} -> do - l' <- truncTree l - r' <- truncTrees r - pure $ Just l' : r' - _ -> pure [] + ConstrClosure {name = "Tree"} -> + ((:) . Just) <$> truncTree l <*> truncTrees r + _ -> (Nothing :) <$> truncTrees r _ -> pure [] trunc' t x w = unsafePerformIO $ do @@ -88,8 +85,7 @@ trunc' t x w = unsafePerformIO $ do trunc t sites :: Site -> TTree -> [Site] -sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] -sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] +sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] mutate = M.foldrWithKey go where |