aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Sampling.hs
blob: 1d208388e8c0c8533d5de448125d4de0d8ab88df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
{-# LANGUAGE ViewPatterns #-}

module PPL.Sampling where

import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Monoid
import Numeric.Log
import PPL.Distr
import PPL.Internal hiding (split)
import System.Random (getStdGen, newStdGen, random, randoms, split)

importance :: MonadIO m => Int -> Meas a -> m [a]
importance n m = do
  newStdGen
  g <- getStdGen
  let ys = take n $ accumulate xs
      max = snd $ last ys
      xs = samples m $ randomTree g1
      (g1, g2) = split g
  let rs = randoms g2
  pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max
  where
    cumsum = tail . scanl (+) 0
    accumulate = uncurry zip . second cumsum . unzip

mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)]
mh p m = do
  newStdGen
  g <- getStdGen
  let (g1, g2) = split g
      (x, w) = head $ samples m t
      t = randomTree g1
  pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2
  where
    step (t, x, w) = do
      g <- get
      let (g1, g2) = split g
          t' = mutateTree p g1 t
          (x', w') = head $ samples m t'
          ratio = w' / w
          (Exp . log -> r, g3) = random g2
      put g3
      pure $!
        if r < ratio
          then (t', x', w')
          else (t, x, w)

    iterateM f x = do
      y <- f x
      (y :) <$> iterateM f y