diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL/Internal.hs | 3 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 108 | 
2 files changed, 99 insertions, 12 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 0ba6ae5..b4283af 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -13,8 +13,7 @@ module PPL.Internal      randomTree,      samples,      mutateTree, -    splitTrees, -    draw, +    Tree(..),    )  where diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 84e5ada..520f182 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,33 +1,121 @@ -{-# LANGUAGE ViewPatterns #-}  {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ViewPatterns #-} -module PPL.Sampling where +module PPL.Sampling +  ( mh, +    ssmh, +  ) +where +import Control.DeepSeq +import Control.Exception (evaluate)  import Control.Monad.IO.Class  import Control.Monad.Trans.State  import Data.Bifunctor +import qualified Data.Map.Strict as M  import Data.Monoid +import GHC.Exts.Heap  import Numeric.Log  import PPL.Distr  import PPL.Internal -import System.Random (StdGen, random, randoms)  import qualified Streaming as S -import Streaming.Prelude (Stream, yield, Of) +import Streaming.Prelude (Of, Stream, yield) +import System.IO.Unsafe +import System.Random (StdGen, random, randoms) +import Unsafe.Coerce -mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()  mh g p m = step t0 t x w    where -    (t0,t) = split $ randomTree g +    (t0, t) = split $ randomTree g      (x, w) = head $ samples m t      step !t0 !t !x !w = do -      let (t1:t2:t3:t4:_) = splitTrees t0 +      let (t1 : t2 : t3 : t4 : _) = splitTrees t0            t' = mutateTree p t1 t2 t            (x', w') = head $ samples m t'            ratio = w' / w            (Exp . log -> r) = draw t3 -          (t'', x'', w'') = if r < ratio -            then (t', x', w') -            else (t, x, w) +          (t'', x'', w'') = +            if r < ratio +              then (t', x', w') +              else (t, x, w)        yield (x'', w'')        step t4 t'' x'' w'' + +-- Single site MH + +-- Truncated trees +data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show) + +type Site = [Int] + +trunc :: Tree -> IO TTree +trunc = truncTree . asBox +  where +    truncTree t = +      getBoxedClosureData' t >>= \case +        ConstrClosure _ [l, r] [] _ _ "Tree" -> +          getBoxedClosureData' l >>= \case +            ConstrClosure {dataArgs = [d], name = "D#"} -> do +              TTree (Just $ unsafeCoerce d) <$> truncTrees r +            x -> error $ "truncTree:ConstrClosure:" ++ show x +        ConstrClosure _ [r] [d] _ _ "Tree" ->  +          TTree (Just $ unsafeCoerce d) <$> truncTrees r +        x -> error $ "truncTree:" ++ show x + +    getBoxedClosureData' x = +      getBoxedClosureData x >>= \c -> case c of +        BlackholeClosure _ t -> getBoxedClosureData' t +        _ -> pure c + +    truncTrees b = +      getBoxedClosureData' b >>= \case +        ConstrClosure _ [l, r] [] _ _ ":" -> +          getBoxedClosureData' l >>= \case +            ConstrClosure {name = "Tree"} -> do +              l' <- truncTree l +              r' <- truncTrees r +              pure $ Just l' : r' +            _ -> pure [] +        _ -> pure [] + +trunc' t x w = unsafePerformIO $ do +  evaluate (rnf x) +  evaluate (rnf w) +  trunc t + +sites :: Site -> TTree -> [Site] +sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] +sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] + +mutate = M.foldrWithKey go +  where +    go [] d (Tree _ ts) = Tree d ts +    go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts + +ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () +ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w +  where +    (t0, t) = split $ randomTree g +    (x, w) = head $ samples m t0 + +    step !t !sub !tt !x !w = do +      let ss = sites [] tt +          (t1 : t2 : t3 : t4 : _) = splitTrees t +          i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate +          sub' = M.insert (reverse $ ss !! i) (draw t3) sub +          t' = mutate t0 sub' +          (x', w') = head $ samples m t' +          tt' = trunc' t' x' w' +          ratio = w' / w +          (Exp . log -> r) = draw t4 +          (sub'', tt'', x'', w'') = +            if r < ratio +              then (sub', tt', x', w') +              else (sub, tt, x, w) + +      yield (x'', w'') +      step t1 sub'' tt'' x'' w'' | 
