aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r--src/PPL/Internal.hs19
1 files changed, 18 insertions, 1 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index b4283af..49737d0 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -13,6 +13,7 @@ module PPL.Internal
randomTree,
samples,
mutateTree,
+ newTree,
Tree(..),
)
where
@@ -27,6 +28,9 @@ import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
import System.Random hiding (split, uniform)
import qualified System.Random as R
+import qualified Data.Map.Strict as M
+import Data.IORef
+import System.IO.Unsafe
-- Reimplementation of the LazyPPL monads to avoid some dependencies
@@ -38,13 +42,26 @@ data Tree = Tree
split :: Tree -> (Tree, Tree)
split (Tree r (t : ts)) = (t, Tree r ts)
+{-# INLINE newTree #-}
+newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree
+newTree s = go []
+ where
+ go id = Tree (unsafePerformIO $ do
+ (m, g) <- readIORef s
+ case M.lookup id m of
+ Nothing -> do
+ let (x, g') = R.random g
+ writeIORef s (M.insert id x m, g')
+ pure x
+ Just x -> pure x) [go (i:id) | i <- [0..]]
+
{-# INLINE randomTree #-}
randomTree :: RandomGen g => g -> Tree
randomTree g = let (a, g') = random g in Tree a (randomTrees g')
where
randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-{-# INLINE mutateTree #-}
+{-# NOINLINE mutateTree #-}
mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
if r < p