aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
blob: d9275080740e16e96ab97329d3792225d368ba74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample,
randomTree, samples, mutateTree) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Monoid
import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
import System.Random hiding (uniform, split)
import qualified System.Random as R
import Data.Bifunctor
import Control.Monad.IO.Class

-- Reimplementation of the LazyPPL monads to avoid some dependencies

data Tree = Tree !Double [Tree]

split :: Tree -> (Tree, Tree)
split (Tree r (t : ts)) = (t, Tree r ts)

{-# INLINE randomTree #-}
randomTree :: RandomGen g => g -> Tree
randomTree g = let (a, g') = random g in Tree a (randomTrees g')
  where
    randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2

{-# INLINE mutateTree #-}
mutateTree :: RandomGen g => Double -> g -> Tree -> Tree
mutateTree p g (Tree a ts) =
  let (r, g1) = random g
      (b, g2) = random g1
  in Tree (if r < p then b else a) (mutateTrees p g2 ts)
  where
    mutateTrees p g (t:ts) =
      let (g1, g2) = R.split g
      in mutateTree p g1 t : mutateTrees p g2 ts

newtype Prob a = Prob {runProb :: Tree -> a}

instance Monad Prob where
  Prob f >>= g = Prob $ \t ->
    let (t1, t2) = split t
        (Prob g') = g (f t1)
     in g' t2

instance Functor Prob where fmap = liftM

instance Applicative Prob where pure = Prob . const; (<*>) = ap

uniform = Prob $ \(Tree r _) -> r

newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
  deriving (Functor, Applicative, Monad)

{-# INLINE score #-}
score :: Double -> Meas ()
score = scoreLog . Exp . log . max eps
  where
    eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval

{-# INLINE scoreLog #-}
scoreLog :: Log Double -> Meas ()
scoreLog = Meas . tell . Product

{-# INLINE sample #-}
sample :: Prob a -> Meas a
sample = Meas . lift

{-# INLINE samples #-}
samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
samples (Meas m) = map (second getProduct) . runProb f
  where
    f = runWriterT m >>= \x -> (x:) <$> f