aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2023-01-20 09:17:05 +1100
committerJustin Bedo <cu@cua0.org>2023-01-23 12:44:44 +1100
commit190a4dc5f5398e6646823aa41637f54cc8cb54aa (patch)
tree361e9503d5c0b5c8d84269821d1d1f0a6cc11913
parent0718de4acc18df39152fc55d6bd279af56d7e2af (diff)
add symmetry
-rw-r--r--src/PPL/Internal.hs42
-rw-r--r--src/PPL/Sampling.hs6
2 files changed, 32 insertions, 16 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index d927508..de48a3f 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,20 +1,31 @@
-{-# LANGUAGE TemplateHaskell #-}
-{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TemplateHaskell #-}
-module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample,
-randomTree, samples, mutateTree) where
+module PPL.Internal
+ ( uniform,
+ split,
+ Prob (..),
+ Meas,
+ score,
+ scoreLog,
+ sample,
+ randomTree,
+ samples,
+ mutateTree,
+ )
+where
import Control.Monad
+import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
+import Data.Bifunctor
import Data.Monoid
import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
-import System.Random hiding (uniform, split)
+import System.Random hiding (split, uniform)
import qualified System.Random as R
-import Data.Bifunctor
-import Control.Monad.IO.Class
-- Reimplementation of the LazyPPL monads to avoid some dependencies
@@ -30,15 +41,20 @@ randomTree g = let (a, g') = random g in Tree a (randomTrees g')
randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
{-# INLINE mutateTree #-}
-mutateTree :: RandomGen g => Double -> g -> Tree -> Tree
-mutateTree p g (Tree a ts) =
+mutateTree :: RandomGen g => Double -> Double -> g -> Tree -> Tree
+mutateTree p q g (Tree a ts) =
let (r, g1) = random g
(b, g2) = random g1
- in Tree (if r < p then b else a) (mutateTrees p g2 ts)
+ in if r >= p
+ then Tree a (mutateTrees p q g1 ts)
+ else
+ if r < p * q
+ then Tree (1 - a) (mutateTrees p q g1 ts)
+ else Tree b (mutateTrees p q g2 ts)
where
- mutateTrees p g (t:ts) =
+ mutateTrees p q g (t : ts) =
let (g1, g2) = R.split g
- in mutateTree p g1 t : mutateTrees p g2 ts
+ in mutateTree p q g1 t : mutateTrees p q g2 ts
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -75,4 +91,4 @@ sample = Meas . lift
samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
samples (Meas m) = map (second getProduct) . runProb f
where
- f = runWriterT m >>= \x -> (x:) <$> f
+ f = runWriterT m >>= \x -> (x :) <$> f
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 1d20838..b3b8a90 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -25,8 +25,8 @@ importance n m = do
cumsum = tail . scanl (+) 0
accumulate = uncurry zip . second cumsum . unzip
-mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)]
-mh p m = do
+mh :: MonadIO m => Double -> Double -> Meas a -> m [(a, Log Double)]
+mh p q m = do
newStdGen
g <- getStdGen
let (g1, g2) = split g
@@ -37,7 +37,7 @@ mh p m = do
step (t, x, w) = do
g <- get
let (g1, g2) = split g
- t' = mutateTree p g1 t
+ t' = mutateTree p q g1 t
(x', w') = head $ samples m t'
ratio = w' / w
(Exp . log -> r, g3) = random g2