aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2023-01-24 16:41:53 +1100
committerJustin Bedo <cu@cua0.org>2023-01-26 14:14:11 +1100
commit7653e357f04aa39c1e96037bf1ea2e4338f8ae76 (patch)
tree8d2a21425d93a29984aff31f122420de6f7e99a6 /src
parent275ed22d5050488e6d40bb5800f1ade9c30d8a76 (diff)
reimplement sampler in streaming framework
Diffstat (limited to 'src')
-rw-r--r--src/PPL/Sampling.hs48
1 files changed, 15 insertions, 33 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index b3b8a90..ad11837 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE BangPatterns #-}
module PPL.Sampling where
@@ -9,44 +10,25 @@ import Data.Monoid
import Numeric.Log
import PPL.Distr
import PPL.Internal hiding (split)
-import System.Random (getStdGen, newStdGen, random, randoms, split)
+import System.Random (StdGen, random, randoms, split)
+import qualified Streaming as S
+import Streaming.Prelude (Stream, yield, Of)
-importance :: MonadIO m => Int -> Meas a -> m [a]
-importance n m = do
- newStdGen
- g <- getStdGen
- let ys = take n $ accumulate xs
- max = snd $ last ys
- xs = samples m $ randomTree g1
- (g1, g2) = split g
- let rs = randoms g2
- pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max
+mh :: Monad m => StdGen -> Double -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p q m = step g2 t x w
where
- cumsum = tail . scanl (+) 0
- accumulate = uncurry zip . second cumsum . unzip
+ t = randomTree g1
+ (g1,g2) = split g
+ (x, w) = head $ samples m t
-mh :: MonadIO m => Double -> Double -> Meas a -> m [(a, Log Double)]
-mh p q m = do
- newStdGen
- g <- getStdGen
- let (g1, g2) = split g
- (x, w) = head $ samples m t
- t = randomTree g1
- pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2
- where
- step (t, x, w) = do
- g <- get
+ step !g !t !x !w = do
let (g1, g2) = split g
t' = mutateTree p q g1 t
(x', w') = head $ samples m t'
ratio = w' / w
(Exp . log -> r, g3) = random g2
- put g3
- pure $!
- if r < ratio
- then (t', x', w')
- else (t, x, w)
-
- iterateM f x = do
- y <- f x
- (y :) <$> iterateM f y
+ (t'', x'', w'') = if r < ratio
+ then (t', x', w')
+ else (t, x, w)
+ yield (x'', w'')
+ step g3 t'' x'' w''