From 190a4dc5f5398e6646823aa41637f54cc8cb54aa Mon Sep 17 00:00:00 2001
From: Justin Bedo <cu@cua0.org>
Date: Fri, 20 Jan 2023 09:17:05 +1100
Subject: add symmetry

---
 src/PPL/Internal.hs | 42 +++++++++++++++++++++++++++++-------------
 src/PPL/Sampling.hs |  6 +++---
 2 files changed, 32 insertions(+), 16 deletions(-)

diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index d927508..de48a3f 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,20 +1,31 @@
-{-# LANGUAGE TemplateHaskell #-}
-{-# LANGUAGE RankNTypes #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TemplateHaskell #-}
 
-module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample,
-randomTree, samples, mutateTree) where
+module PPL.Internal
+  ( uniform,
+    split,
+    Prob (..),
+    Meas,
+    score,
+    scoreLog,
+    sample,
+    randomTree,
+    samples,
+    mutateTree,
+  )
+where
 
 import Control.Monad
+import Control.Monad.IO.Class
 import Control.Monad.Trans.Class
 import Control.Monad.Trans.Writer
+import Data.Bifunctor
 import Data.Monoid
 import qualified Language.Haskell.TH.Syntax as TH
 import Numeric.Log
-import System.Random hiding (uniform, split)
+import System.Random hiding (split, uniform)
 import qualified System.Random as R
-import Data.Bifunctor
-import Control.Monad.IO.Class
 
 -- Reimplementation of the LazyPPL monads to avoid some dependencies
 
@@ -30,15 +41,20 @@ randomTree g = let (a, g') = random g in Tree a (randomTrees g')
     randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
 
 {-# INLINE mutateTree #-}
-mutateTree :: RandomGen g => Double -> g -> Tree -> Tree
-mutateTree p g (Tree a ts) =
+mutateTree :: RandomGen g => Double -> Double -> g -> Tree -> Tree
+mutateTree p q g (Tree a ts) =
   let (r, g1) = random g
       (b, g2) = random g1
-  in Tree (if r < p then b else a) (mutateTrees p g2 ts)
+   in if r >= p
+        then Tree a (mutateTrees p q g1 ts)
+        else
+          if r < p * q
+            then Tree (1 - a) (mutateTrees p q g1 ts)
+            else Tree b (mutateTrees p q g2 ts)
   where
-    mutateTrees p g (t:ts) =
+    mutateTrees p q g (t : ts) =
       let (g1, g2) = R.split g
-      in mutateTree p g1 t : mutateTrees p g2 ts
+       in mutateTree p q g1 t : mutateTrees p q g2 ts
 
 newtype Prob a = Prob {runProb :: Tree -> a}
 
@@ -75,4 +91,4 @@ sample = Meas . lift
 samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
 samples (Meas m) = map (second getProduct) . runProb f
   where
-    f = runWriterT m >>= \x -> (x:) <$> f
+    f = runWriterT m >>= \x -> (x :) <$> f
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 1d20838..b3b8a90 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -25,8 +25,8 @@ importance n m = do
     cumsum = tail . scanl (+) 0
     accumulate = uncurry zip . second cumsum . unzip
 
-mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)]
-mh p m = do
+mh :: MonadIO m => Double -> Double -> Meas a -> m [(a, Log Double)]
+mh p q m = do
   newStdGen
   g <- getStdGen
   let (g1, g2) = split g
@@ -37,7 +37,7 @@ mh p m = do
     step (t, x, w) = do
       g <- get
       let (g1, g2) = split g
-          t' = mutateTree p g1 t
+          t' = mutateTree p q g1 t
           (x', w') = head $ samples m t'
           ratio = w' / w
           (Exp . log -> r, g3) = random g2
-- 
cgit v1.2.3