aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Sampling.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Sampling.hs')
-rw-r--r--src/PPL/Sampling.hs52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
new file mode 100644
index 0000000..a3f38db
--- /dev/null
+++ b/src/PPL/Sampling.hs
@@ -0,0 +1,52 @@
+{-# LANGUAGE ViewPatterns #-}
+
+module PPL.Sampling where
+
+import Control.Monad.IO.Class
+import Control.Monad.Trans.State
+import Data.Bifunctor
+import Data.Monoid
+import Numeric.Log
+import PPL.Distr
+import PPL.Internal hiding (split)
+import System.Random (getStdGen, newStdGen, random, randoms, split)
+
+importance :: MonadIO m => Int -> Meas a -> m [a]
+importance n m = do
+ newStdGen
+ g <- getStdGen
+ let ys = take n $ accumulate xs
+ max = snd $ last ys
+ xs = samples m $ randomTree g1
+ (g1, g2) = split g
+ let rs = randoms g2
+ pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max
+ where
+ cumsum = tail . scanl (+) 0
+ accumulate = uncurry zip . second cumsum . unzip
+
+mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)]
+mh p m = do
+ newStdGen
+ g <- getStdGen
+ let (g1, g2) = split g
+ (x, w) = head $ samples m t
+ t = randomTree g1
+ pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2
+ where
+ step (t, x, w) = do
+ g <- get
+ let (g1, g2) = split g
+ t' = mutateTree p g1 t
+ (x', w') = head $ samples m t'
+ ratio = w' / w
+ (Exp . log -> r, g3) = random g2
+ put g3
+ pure $
+ if r < ratio
+ then (t', x', w')
+ else (t, x, w)
+
+ iterateM f x = do
+ y <- f x
+ (y :) <$> iterateM f y