{-# LANGUAGE ViewPatterns #-} module PPL.Sampling where import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor import Data.Monoid import Numeric.Log import PPL.Distr import PPL.Internal hiding (split) import System.Random (getStdGen, newStdGen, random, randoms, split) importance :: MonadIO m => Int -> Meas a -> m [a] importance n m = do newStdGen g <- getStdGen let ys = take n $ accumulate xs max = snd $ last ys xs = samples m $ randomTree g1 (g1, g2) = split g let rs = randoms g2 pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max where cumsum = tail . scanl (+) 0 accumulate = uncurry zip . second cumsum . unzip mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)] mh p m = do newStdGen g <- getStdGen let (g1, g2) = split g (x, w) = head $ samples m t t = randomTree g1 pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2 where step (t, x, w) = do g <- get let (g1, g2) = split g t' = mutateTree p g1 t (x', w') = head $ samples m t' ratio = w' / w (Exp . log -> r, g3) = random g2 put g3 pure $ if r < ratio then (t', x', w') else (t, x, w) iterateM f x = do y <- f x (y :) <$> iterateM f y