aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Sampling.hs
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2023-01-27 17:00:48 +1100
committerJustin Bedo <cu@cua0.org>2023-01-27 18:08:42 +1100
commit368ebf9f96e4fded6b69ca609f9db53eb19b4589 (patch)
treeb853e7effd96e70e433e340ba39a848f955e1107 /src/PPL/Sampling.hs
parent7653e357f04aa39c1e96037bf1ea2e4338f8ae76 (diff)
rewrite tree mutation function to avoid leaking memory
Diffstat (limited to 'src/PPL/Sampling.hs')
-rw-r--r--src/PPL/Sampling.hs21
1 files changed, 10 insertions, 11 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index ad11837..84e5ada 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -9,26 +9,25 @@ import Data.Bifunctor
import Data.Monoid
import Numeric.Log
import PPL.Distr
-import PPL.Internal hiding (split)
-import System.Random (StdGen, random, randoms, split)
+import PPL.Internal
+import System.Random (StdGen, random, randoms)
import qualified Streaming as S
import Streaming.Prelude (Stream, yield, Of)
-mh :: Monad m => StdGen -> Double -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p q m = step g2 t x w
+mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = step t0 t x w
where
- t = randomTree g1
- (g1,g2) = split g
+ (t0,t) = split $ randomTree g
(x, w) = head $ samples m t
- step !g !t !x !w = do
- let (g1, g2) = split g
- t' = mutateTree p q g1 t
+ step !t0 !t !x !w = do
+ let (t1:t2:t3:t4:_) = splitTrees t0
+ t' = mutateTree p t1 t2 t
(x', w') = head $ samples m t'
ratio = w' / w
- (Exp . log -> r, g3) = random g2
+ (Exp . log -> r) = draw t3
(t'', x'', w'') = if r < ratio
then (t', x', w')
else (t, x, w)
yield (x'', w'')
- step g3 t'' x'' w''
+ step t4 t'' x'' w''