diff options
-rw-r--r-- | src/PPL/Internal.hs | 42 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 6 |
2 files changed, 32 insertions, 16 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index d927508..de48a3f 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,20 +1,31 @@ -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TemplateHaskell #-} -module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample, -randomTree, samples, mutateTree) where +module PPL.Internal + ( uniform, + split, + Prob (..), + Meas, + score, + scoreLog, + sample, + randomTree, + samples, + mutateTree, + ) +where import Control.Monad +import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer +import Data.Bifunctor import Data.Monoid import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log -import System.Random hiding (uniform, split) +import System.Random hiding (split, uniform) import qualified System.Random as R -import Data.Bifunctor -import Control.Monad.IO.Class -- Reimplementation of the LazyPPL monads to avoid some dependencies @@ -30,15 +41,20 @@ randomTree g = let (a, g') = random g in Tree a (randomTrees g') randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 {-# INLINE mutateTree #-} -mutateTree :: RandomGen g => Double -> g -> Tree -> Tree -mutateTree p g (Tree a ts) = +mutateTree :: RandomGen g => Double -> Double -> g -> Tree -> Tree +mutateTree p q g (Tree a ts) = let (r, g1) = random g (b, g2) = random g1 - in Tree (if r < p then b else a) (mutateTrees p g2 ts) + in if r >= p + then Tree a (mutateTrees p q g1 ts) + else + if r < p * q + then Tree (1 - a) (mutateTrees p q g1 ts) + else Tree b (mutateTrees p q g2 ts) where - mutateTrees p g (t:ts) = + mutateTrees p q g (t : ts) = let (g1, g2) = R.split g - in mutateTree p g1 t : mutateTrees p g2 ts + in mutateTree p q g1 t : mutateTrees p q g2 ts newtype Prob a = Prob {runProb :: Tree -> a} @@ -75,4 +91,4 @@ sample = Meas . lift samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where - f = runWriterT m >>= \x -> (x:) <$> f + f = runWriterT m >>= \x -> (x :) <$> f diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 1d20838..b3b8a90 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -25,8 +25,8 @@ importance n m = do cumsum = tail . scanl (+) 0 accumulate = uncurry zip . second cumsum . unzip -mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)] -mh p m = do +mh :: MonadIO m => Double -> Double -> Meas a -> m [(a, Log Double)] +mh p q m = do newStdGen g <- getStdGen let (g1, g2) = split g @@ -37,7 +37,7 @@ mh p m = do step (t, x, w) = do g <- get let (g1, g2) = split g - t' = mutateTree p g1 t + t' = mutateTree p q g1 t (x', w') = head $ samples m t' ratio = w' / w (Exp . log -> r, g3) = random g2 |