diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL/Internal.hs | 4 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 79 | 
2 files changed, 6 insertions, 77 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 49737d0..7273d23 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -28,7 +28,7 @@ import qualified Language.Haskell.TH.Syntax as TH  import Numeric.Log  import System.Random hiding (split, uniform)  import qualified System.Random as R -import qualified Data.Map.Strict as M +import qualified Data.HashMap.Strict as M  import Data.IORef  import System.IO.Unsafe @@ -43,7 +43,7 @@ split :: Tree -> (Tree, Tree)  split (Tree r (t : ts)) = (t, Tree r ts)  {-# INLINE newTree #-} -newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree +newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree  newTree s = go []    where      go id = Tree (unsafePerformIO $ do diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 6807439..9a749e9 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -6,7 +6,6 @@  module PPL.Sampling    ( mh, -    ssmh,    )  where @@ -15,7 +14,7 @@ import Control.Exception (evaluate)  import Control.Monad.IO.Class  import Control.Monad.Trans.State  import Data.Bifunctor -import qualified Data.Map.Strict as M +import qualified Data.HashMap.Strict as M  import Data.Monoid  import GHC.Exts.Heap  import Numeric.Log @@ -50,7 +49,7 @@ mh g p m = do        yield (x'', w'')        step g2 omega'' x'' w'' -    mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) ->  m (IORef (M.Map [Int] Double, StdGen)) +    mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) ->  m (IORef (M.HashMap [Int] Double, StdGen))      mutate g omega = do        (m, g0) <- liftIO $ readIORef omega        let (r:q:_) = R.randoms g @@ -70,76 +69,6 @@ mh g p m = do              pure y            else do              put g1 -            pure x +            pure x   +             --- Single site MH - --- Truncated trees -data TTree = TTree Bool [Maybe TTree] deriving (Show) - -type Site = [Int] - -trunc :: Tree -> IO TTree -trunc = truncTree . asBox -  where -    truncTree t = -      getBoxedClosureData' t >>= \case -        ConstrClosure _ [l, r] [] _ _ "Tree" -> -          getBoxedClosureData' l >>= \case -            ConstrClosure {dataArgs = [d], name = "D#"} -> -              TTree True <$> truncTrees r -            x -> error $ "truncTree:ConstrClosure:" ++ show x -        ConstrClosure _ [r] [d] _ _ "Tree" ->  -          TTree False <$> truncTrees r -        x -> error $ "truncTree:" ++ show x - -    getBoxedClosureData' x = -      getBoxedClosureData x >>= \c -> case c of -        BlackholeClosure _ t -> getBoxedClosureData' t -        _ -> pure c - -    truncTrees b = -      getBoxedClosureData' b >>= \case -        ConstrClosure _ [l, r] [] _ _ ":" -> -          getBoxedClosureData' l >>= \case -            ConstrClosure {name = "Tree"} -> -              ((:) . Just) <$> truncTree l <*>  truncTrees r -            _ -> (Nothing :) <$> truncTrees r -        _ -> pure [] - -trunc' t x w = unsafePerformIO $ do -  evaluate (rnf x) -  evaluate (rnf w) -  trunc t - -sites :: Site -> TTree -> [Site] -sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] - -mutate = M.foldrWithKey go -  where -    go [] d (Tree _ ts) = Tree d ts -    go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts - -ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () -ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w -  where -    (t0, t) = split $ randomTree g -    (x, w) = head $ samples m t0 - -    step !t !sub !tt !x !w = do -      let ss = sites [] tt -          (t1 : t2 : t3 : t4 : _) = splitTrees t -          i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate -          sub' = M.insert (reverse $ ss !! i) (draw t3) sub -          t' = mutate t0 sub' -          (x', w') = head $ samples m t' -          tt' = trunc' t' x' w' -          ratio = w' / w -          (Exp . log -> r) = draw t4 -          (sub'', tt'', x'', w'') = -            if r < ratio -              then (sub', tt', x', w') -              else (sub, tt, x, w) - -      yield (x'', w'') -      step t1 sub'' tt'' x'' w'' | 
