diff options
author | Justin Bedo <cu@cua0.org> | 2022-12-16 10:57:55 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2022-12-31 17:28:29 +1100 |
commit | e318b40cfc1a3079375a015a651b1b44f6279ad0 (patch) | |
tree | 40a72f1b807fd77a1dd53ebf3f0cc284e65f30ad /src/PPL/Internal.hs |
init
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r-- | src/PPL/Internal.hs | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs new file mode 100644 index 0000000..a729d3d --- /dev/null +++ b/src/PPL/Internal.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample, +randomTree, samples, mutateTree) where + +import Control.Monad +import Control.Monad.Trans.Class +import Control.Monad.Trans.Writer +import Data.Monoid +import qualified Language.Haskell.TH.Syntax as TH +import Numeric.Log +import System.Random hiding (uniform, split) +import qualified System.Random as R +import Data.Bifunctor +import Control.Monad.IO.Class + +-- Reimplementation of the LazyPPL monads to avoid some dependencies + +data Tree = Tree Double [Tree] + +split :: Tree -> (Tree, Tree) +split (Tree r (t : ts)) = (t, Tree r ts) + +{-# INLINE randomTree #-} +randomTree :: RandomGen g => g -> Tree +randomTree g = let (a, g') = random g in Tree a (randomTrees g') + +{-# INLINE randomTrees #-} +randomTrees :: RandomGen g => g -> [Tree] +randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 + +{-# INLINE mutateTree #-} +mutateTree :: RandomGen g => Double -> g -> Tree -> Tree +mutateTree p g (Tree a ts) = + let (r, g1) = random g + (b, g2) = random g1 + in Tree (if r < p then b else a) (mutateTrees p g2 ts) + +{-# INLINE mutateTrees #-} +mutateTrees :: RandomGen g => Double -> g -> [Tree] -> [Tree] +mutateTrees p g (t:ts) = + let (g1, g2) = R.split g + in mutateTree p g1 t : mutateTrees p g2 ts + +newtype Prob a = Prob {runProb :: Tree -> a} + +instance Monad Prob where + Prob f >>= g = Prob $ \t -> + let (t1, t2) = split t + (Prob g') = g (f t1) + in g' t2 + +instance Functor Prob where fmap = liftM + +instance Applicative Prob where pure = Prob . const; (<*>) = ap + +uniform = Prob $ \(Tree r _) -> r + +newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) + deriving (Functor, Applicative, Monad) + +{-# INLINE score #-} +score :: Double -> Meas () +score = scoreLog . Exp . log . max eps + where + eps = $(TH.lift (until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval + +{-# INLINE scoreLog #-} +scoreLog :: Log Double -> Meas () +scoreLog = Meas . tell . Product + +sample :: Prob a -> Meas a +sample = Meas . lift + +{-# INLINE samples #-} +samples :: forall a. Meas a -> Tree -> [(a, Log Double)] +samples (Meas m) t = map (second getProduct) $ runProb f t + where + f = runWriterT m >>= \x -> (x:) <$> f |