diff options
author | Justin Bedo <cu@cua0.org> | 2022-12-16 10:57:55 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2022-12-31 17:28:29 +1100 |
commit | e318b40cfc1a3079375a015a651b1b44f6279ad0 (patch) | |
tree | 40a72f1b807fd77a1dd53ebf3f0cc284e65f30ad /src/PPL/Sampling.hs |
init
Diffstat (limited to 'src/PPL/Sampling.hs')
-rw-r--r-- | src/PPL/Sampling.hs | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs new file mode 100644 index 0000000..a3f38db --- /dev/null +++ b/src/PPL/Sampling.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE ViewPatterns #-} + +module PPL.Sampling where + +import Control.Monad.IO.Class +import Control.Monad.Trans.State +import Data.Bifunctor +import Data.Monoid +import Numeric.Log +import PPL.Distr +import PPL.Internal hiding (split) +import System.Random (getStdGen, newStdGen, random, randoms, split) + +importance :: MonadIO m => Int -> Meas a -> m [a] +importance n m = do + newStdGen + g <- getStdGen + let ys = take n $ accumulate xs + max = snd $ last ys + xs = samples m $ randomTree g1 + (g1, g2) = split g + let rs = randoms g2 + pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max + where + cumsum = tail . scanl (+) 0 + accumulate = uncurry zip . second cumsum . unzip + +mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)] +mh p m = do + newStdGen + g <- getStdGen + let (g1, g2) = split g + (x, w) = head $ samples m t + t = randomTree g1 + pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2 + where + step (t, x, w) = do + g <- get + let (g1, g2) = split g + t' = mutateTree p g1 t + (x', w') = head $ samples m t' + ratio = w' / w + (Exp . log -> r, g3) = random g2 + put g3 + pure $ + if r < ratio + then (t', x', w') + else (t, x, w) + + iterateM f x = do + y <- f x + (y :) <$> iterateM f y |