From e318b40cfc1a3079375a015a651b1b44f6279ad0 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Fri, 16 Dec 2022 10:57:55 +1100 Subject: init --- src/PPL/Sampling.hs | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/PPL/Sampling.hs (limited to 'src/PPL/Sampling.hs') diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs new file mode 100644 index 0000000..a3f38db --- /dev/null +++ b/src/PPL/Sampling.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE ViewPatterns #-} + +module PPL.Sampling where + +import Control.Monad.IO.Class +import Control.Monad.Trans.State +import Data.Bifunctor +import Data.Monoid +import Numeric.Log +import PPL.Distr +import PPL.Internal hiding (split) +import System.Random (getStdGen, newStdGen, random, randoms, split) + +importance :: MonadIO m => Int -> Meas a -> m [a] +importance n m = do + newStdGen + g <- getStdGen + let ys = take n $ accumulate xs + max = snd $ last ys + xs = samples m $ randomTree g1 + (g1, g2) = split g + let rs = randoms g2 + pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max + where + cumsum = tail . scanl (+) 0 + accumulate = uncurry zip . second cumsum . unzip + +mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)] +mh p m = do + newStdGen + g <- getStdGen + let (g1, g2) = split g + (x, w) = head $ samples m t + t = randomTree g1 + pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2 + where + step (t, x, w) = do + g <- get + let (g1, g2) = split g + t' = mutateTree p g1 t + (x', w') = head $ samples m t' + ratio = w' / w + (Exp . log -> r, g3) = random g2 + put g3 + pure $ + if r < ratio + then (t', x', w') + else (t, x, w) + + iterateM f x = do + y <- f x + (y :) <$> iterateM f y -- cgit v1.2.3