diff options
Diffstat (limited to 'src/PPL')
| -rw-r--r-- | src/PPL/Sampling.hs | 20 | 
1 files changed, 8 insertions, 12 deletions
| diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 520f182..8e0ab8f 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -24,7 +24,6 @@ import qualified Streaming as S  import Streaming.Prelude (Of, Stream, yield)  import System.IO.Unsafe  import System.Random (StdGen, random, randoms) -import Unsafe.Coerce  mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()  mh g p m = step t0 t x w @@ -48,7 +47,7 @@ mh g p m = step t0 t x w  -- Single site MH  -- Truncated trees -data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show) +data TTree = TTree Bool [Maybe TTree] deriving (Show)  type Site = [Int] @@ -59,11 +58,11 @@ trunc = truncTree . asBox        getBoxedClosureData' t >>= \case          ConstrClosure _ [l, r] [] _ _ "Tree" ->            getBoxedClosureData' l >>= \case -            ConstrClosure {dataArgs = [d], name = "D#"} -> do -              TTree (Just $ unsafeCoerce d) <$> truncTrees r +            ConstrClosure {dataArgs = [d], name = "D#"} -> +              TTree True <$> truncTrees r              x -> error $ "truncTree:ConstrClosure:" ++ show x          ConstrClosure _ [r] [d] _ _ "Tree" ->  -          TTree (Just $ unsafeCoerce d) <$> truncTrees r +          TTree False <$> truncTrees r          x -> error $ "truncTree:" ++ show x      getBoxedClosureData' x = @@ -75,11 +74,9 @@ trunc = truncTree . asBox        getBoxedClosureData' b >>= \case          ConstrClosure _ [l, r] [] _ _ ":" ->            getBoxedClosureData' l >>= \case -            ConstrClosure {name = "Tree"} -> do -              l' <- truncTree l -              r' <- truncTrees r -              pure $ Just l' : r' -            _ -> pure [] +            ConstrClosure {name = "Tree"} -> +              ((:) . Just) <$> truncTree l <*>  truncTrees r +            _ -> (Nothing :) <$> truncTrees r          _ -> pure []  trunc' t x w = unsafePerformIO $ do @@ -88,8 +85,7 @@ trunc' t x w = unsafePerformIO $ do    trunc t  sites :: Site -> TTree -> [Site] -sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] -sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] +sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]  mutate = M.foldrWithKey go    where | 
