aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r--src/PPL/Internal.hs29
1 files changed, 11 insertions, 18 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index b7d49ee..0ba6ae5 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -13,6 +13,8 @@ module PPL.Internal
randomTree,
samples,
mutateTree,
+ splitTrees,
+ draw,
)
where
@@ -29,7 +31,10 @@ import qualified System.Random as R
-- Reimplementation of the LazyPPL monads to avoid some dependencies
-data Tree = Tree !Double [Tree]
+data Tree = Tree
+ { draw :: !Double,
+ splitTrees :: [Tree]
+ }
split :: Tree -> (Tree, Tree)
split (Tree r (t : ts)) = (t, Tree r ts)
@@ -41,23 +46,11 @@ randomTree g = let (a, g') = random g in Tree a (randomTrees g')
randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
{-# INLINE mutateTree #-}
-mutateTree :: RandomGen g => Double -> Double -> g -> Tree -> Tree
-mutateTree p q g (Tree a ts) =
- let (r, g1) = random g
- (b, g2) = random g1
- in Tree
- ( if r >= p
- then a
- else
- if r < p * q
- then 1 - a
- else b
- )
- (mutateTrees p q g2 ts)
- where
- mutateTrees p q g (t : ts) =
- let (g1, g2) = R.split g
- in mutateTree p q g1 t : mutateTrees p q g2 ts
+mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
+mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
+ if r < p
+ then b
+ else Tree a $ zipWith3 (mutateTree p) rs bs ts
newtype Prob a = Prob {runProb :: Tree -> a}