From 7653e357f04aa39c1e96037bf1ea2e4338f8ae76 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Tue, 24 Jan 2023 16:41:53 +1100 Subject: reimplement sampler in streaming framework --- src/PPL/Sampling.hs | 48 +++++++++++++++--------------------------------- 1 file changed, 15 insertions(+), 33 deletions(-) (limited to 'src/PPL') diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index b3b8a90..ad11837 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,4 +1,5 @@ {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE BangPatterns #-} module PPL.Sampling where @@ -9,44 +10,25 @@ import Data.Monoid import Numeric.Log import PPL.Distr import PPL.Internal hiding (split) -import System.Random (getStdGen, newStdGen, random, randoms, split) +import System.Random (StdGen, random, randoms, split) +import qualified Streaming as S +import Streaming.Prelude (Stream, yield, Of) -importance :: MonadIO m => Int -> Meas a -> m [a] -importance n m = do - newStdGen - g <- getStdGen - let ys = take n $ accumulate xs - max = snd $ last ys - xs = samples m $ randomTree g1 - (g1, g2) = split g - let rs = randoms g2 - pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max +mh :: Monad m => StdGen -> Double -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p q m = step g2 t x w where - cumsum = tail . scanl (+) 0 - accumulate = uncurry zip . second cumsum . unzip + t = randomTree g1 + (g1,g2) = split g + (x, w) = head $ samples m t -mh :: MonadIO m => Double -> Double -> Meas a -> m [(a, Log Double)] -mh p q m = do - newStdGen - g <- getStdGen - let (g1, g2) = split g - (x, w) = head $ samples m t - t = randomTree g1 - pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2 - where - step (t, x, w) = do - g <- get + step !g !t !x !w = do let (g1, g2) = split g t' = mutateTree p q g1 t (x', w') = head $ samples m t' ratio = w' / w (Exp . log -> r, g3) = random g2 - put g3 - pure $! - if r < ratio - then (t', x', w') - else (t, x, w) - - iterateM f x = do - y <- f x - (y :) <$> iterateM f y + (t'', x'', w'') = if r < ratio + then (t', x', w') + else (t, x, w) + yield (x'', w'') + step g3 t'' x'' w'' -- cgit v1.2.3