aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Sampling.hs
blob: ad118374b01f4fa38d621820aa3b8db9209db9b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}

module PPL.Sampling where

import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Monoid
import Numeric.Log
import PPL.Distr
import PPL.Internal hiding (split)
import System.Random (StdGen, random, randoms, split)
import qualified Streaming as S
import Streaming.Prelude (Stream, yield, Of)

mh :: Monad m => StdGen -> Double -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p q m = step g2 t x w
  where
    t = randomTree g1
    (g1,g2) = split g
    (x, w) = head $ samples m t

    step !g !t !x !w = do
      let (g1, g2) = split g
          t' = mutateTree p q g1 t
          (x', w') = head $ samples m t'
          ratio = w' / w
          (Exp . log -> r, g3) = random g2
          (t'', x'', w'') = if r < ratio
            then (t', x', w')
            else (t, x, w)
      yield (x'', w'')
      step g3 t'' x'' w''