{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample, randomTree, samples, mutateTree) where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Monoid import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (uniform, split) import qualified System.Random as R import Data.Bifunctor import Control.Monad.IO.Class -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree !Double [Tree] split :: Tree -> (Tree, Tree) split (Tree r (t : ts)) = (t, Tree r ts) {-# INLINE randomTree #-} randomTree :: RandomGen g => g -> Tree randomTree g = let (a, g') = random g in Tree a (randomTrees g') where randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 {-# INLINE mutateTree #-} mutateTree :: RandomGen g => Double -> g -> Tree -> Tree mutateTree p g (Tree a ts) = let (r, g1) = random g (b, g2) = random g1 in Tree (if r < p then b else a) (mutateTrees p g2 ts) where mutateTrees p g (t:ts) = let (g1, g2) = R.split g in mutateTree p g1 t : mutateTrees p g2 ts newtype Prob a = Prob {runProb :: Tree -> a} instance Monad Prob where Prob f >>= g = Prob $ \t -> let (t1, t2) = split t (Prob g') = g (f t1) in g' t2 instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) {-# INLINE score #-} score :: Double -> Meas () score = scoreLog . Exp . log . max eps where eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval {-# INLINE scoreLog #-} scoreLog :: Log Double -> Meas () scoreLog = Meas . tell . Product {-# INLINE sample #-} sample :: Prob a -> Meas a sample = Meas . lift {-# INLINE samples #-} samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x:) <$> f