aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2022-12-16 10:57:55 +1100
committerJustin Bedo <cu@cua0.org>2022-12-31 17:28:29 +1100
commite318b40cfc1a3079375a015a651b1b44f6279ad0 (patch)
tree40a72f1b807fd77a1dd53ebf3f0cc284e65f30ad /src/PPL/Internal.hs
init
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r--src/PPL/Internal.hs81
1 files changed, 81 insertions, 0 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
new file mode 100644
index 0000000..a729d3d
--- /dev/null
+++ b/src/PPL/Internal.hs
@@ -0,0 +1,81 @@
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+
+module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample,
+randomTree, samples, mutateTree) where
+
+import Control.Monad
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Writer
+import Data.Monoid
+import qualified Language.Haskell.TH.Syntax as TH
+import Numeric.Log
+import System.Random hiding (uniform, split)
+import qualified System.Random as R
+import Data.Bifunctor
+import Control.Monad.IO.Class
+
+-- Reimplementation of the LazyPPL monads to avoid some dependencies
+
+data Tree = Tree Double [Tree]
+
+split :: Tree -> (Tree, Tree)
+split (Tree r (t : ts)) = (t, Tree r ts)
+
+{-# INLINE randomTree #-}
+randomTree :: RandomGen g => g -> Tree
+randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+
+{-# INLINE randomTrees #-}
+randomTrees :: RandomGen g => g -> [Tree]
+randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
+
+{-# INLINE mutateTree #-}
+mutateTree :: RandomGen g => Double -> g -> Tree -> Tree
+mutateTree p g (Tree a ts) =
+ let (r, g1) = random g
+ (b, g2) = random g1
+ in Tree (if r < p then b else a) (mutateTrees p g2 ts)
+
+{-# INLINE mutateTrees #-}
+mutateTrees :: RandomGen g => Double -> g -> [Tree] -> [Tree]
+mutateTrees p g (t:ts) =
+ let (g1, g2) = R.split g
+ in mutateTree p g1 t : mutateTrees p g2 ts
+
+newtype Prob a = Prob {runProb :: Tree -> a}
+
+instance Monad Prob where
+ Prob f >>= g = Prob $ \t ->
+ let (t1, t2) = split t
+ (Prob g') = g (f t1)
+ in g' t2
+
+instance Functor Prob where fmap = liftM
+
+instance Applicative Prob where pure = Prob . const; (<*>) = ap
+
+uniform = Prob $ \(Tree r _) -> r
+
+newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
+ deriving (Functor, Applicative, Monad)
+
+{-# INLINE score #-}
+score :: Double -> Meas ()
+score = scoreLog . Exp . log . max eps
+ where
+ eps = $(TH.lift (until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval
+
+{-# INLINE scoreLog #-}
+scoreLog :: Log Double -> Meas ()
+scoreLog = Meas . tell . Product
+
+sample :: Prob a -> Meas a
+sample = Meas . lift
+
+{-# INLINE samples #-}
+samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
+samples (Meas m) t = map (second getProduct) $ runProb f t
+ where
+ f = runWriterT m >>= \x -> (x:) <$> f