diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL/Internal.hs | 19 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 56 | 
2 files changed, 60 insertions, 15 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index b4283af..49737d0 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -13,6 +13,7 @@ module PPL.Internal      randomTree,      samples,      mutateTree, +    newTree,      Tree(..),    )  where @@ -27,6 +28,9 @@ import qualified Language.Haskell.TH.Syntax as TH  import Numeric.Log  import System.Random hiding (split, uniform)  import qualified System.Random as R +import qualified Data.Map.Strict as M +import Data.IORef +import System.IO.Unsafe  -- Reimplementation of the LazyPPL monads to avoid some dependencies @@ -38,13 +42,26 @@ data Tree = Tree  split :: Tree -> (Tree, Tree)  split (Tree r (t : ts)) = (t, Tree r ts) +{-# INLINE newTree #-} +newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree +newTree s = go [] +  where +    go id = Tree (unsafePerformIO $ do +      (m, g) <- readIORef s +      case M.lookup id m of +        Nothing -> do +          let (x, g') = R.random g +          writeIORef s (M.insert id x m, g') +          pure x +        Just x -> pure x) [go (i:id) | i <- [0..]] +  {-# INLINE randomTree #-}  randomTree :: RandomGen g => g -> Tree  randomTree g = let (a, g') = random g in Tree a (randomTrees g')    where      randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 -{-# INLINE mutateTree #-} +{-# NOINLINE mutateTree #-}  mutateTree :: Double -> Tree -> Tree -> Tree -> Tree  mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =    if r < p diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 8e0ab8f..6807439 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -2,6 +2,7 @@  {-# LANGUAGE BlockArguments #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TupleSections #-}  module PPL.Sampling    ( mh, @@ -24,25 +25,52 @@ import qualified Streaming as S  import Streaming.Prelude (Of, Stream, yield)  import System.IO.Unsafe  import System.Random (StdGen, random, randoms) - -mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () -mh g p m = step t0 t x w +import qualified System.Random as R +import Data.IORef +import Control.Monad +import Debug.Trace + +mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p m = do +    let (g0, g1) = R.split g +    omega <- liftIO $ newIORef (mempty, g0) +    let (x, w) = head $ samples m $ newTree omega +    step g1 omega x w    where -    (t0, t) = split $ randomTree g -    (x, w) = head $ samples m t -    step !t0 !t !x !w = do -      let (t1 : t2 : t3 : t4 : _) = splitTrees t0 -          t' = mutateTree p t1 t2 t -          (x', w') = head $ samples m t' +    step !g0 !omega !x !w = do  +      let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 +      omega' <- mutate g1 omega +      let (!x', !w') = head $ samples m $ newTree omega'            ratio = w' / w -          (Exp . log -> r) = draw t3 -          (t'', x'', w'') = +          (omega'', x'', w'') =              if r < ratio -              then (t', x', w') -              else (t, x, w) +              then (omega', x', w') +              else (omega, x, w)        yield (x'', w'') -      step t4 t'' x'' w'' +      step g2 omega'' x'' w'' +     +    mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) ->  m (IORef (M.Map [Int] Double, StdGen)) +    mutate g omega = do +      (m, g0) <- liftIO $ readIORef omega +      let (r:q:_) = R.randoms g +          ks = M.keys m +          k = ks !! floor (r * join traceShow (fromIntegral (length ks))) +          m' = M.insert k q m +      liftIO $ newIORef $ (m',g0) +       +     where +      go x = do +        g <- get +        let (r, g1) = R.random g +            (y, g2) = R.random g1 +        if r < p  +          then do +            put g2 +            pure y +          else do +            put g1 +            pure x  -- Single site MH | 
