{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} module PPL.Internal ( uniform, split, Prob (..), Meas, score, scoreLog, sample, randomTree, samples, mutateTree, splitTrees, draw, ) where import Control.Monad import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor import Data.Monoid import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (split, uniform) import qualified System.Random as R -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, splitTrees :: [Tree] } split :: Tree -> (Tree, Tree) split (Tree r (t : ts)) = (t, Tree r ts) {-# INLINE randomTree #-} randomTree :: RandomGen g => g -> Tree randomTree g = let (a, g') = random g in Tree a (randomTrees g') where randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 {-# INLINE mutateTree #-} mutateTree :: Double -> Tree -> Tree -> Tree -> Tree mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = if r < p then b else Tree a $ zipWith3 (mutateTree p) rs bs ts newtype Prob a = Prob {runProb :: Tree -> a} instance Monad Prob where Prob f >>= g = Prob $ \t -> let (t1, t2) = split t (Prob g') = g (f t1) in g' t2 instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) {-# INLINE score #-} score :: Double -> Meas () score = scoreLog . Exp . log . max eps where eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval {-# INLINE scoreLog #-} scoreLog :: Log Double -> Meas () scoreLog = Meas . tell . Product {-# INLINE sample #-} sample :: Prob a -> Meas a sample = Meas . lift {-# INLINE samples #-} samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f