diff options
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r-- | src/PPL/Internal.hs | 38 |
1 files changed, 13 insertions, 25 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 7273d23..960b19a 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,19 +1,18 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE LambdaCase #-} module PPL.Internal ( uniform, - split, Prob (..), Meas, score, scoreLog, sample, - randomTree, samples, - mutateTree, newTree, + HashMap, Tree(..), ) where @@ -28,45 +27,34 @@ import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (split, uniform) import qualified System.Random as R -import qualified Data.HashMap.Strict as M import Data.IORef import System.IO.Unsafe +import qualified Data.Vector.Mutable as VM +import qualified Data.Vector.Unboxed.Mutable as UM +import qualified Data.Vector.Hashtables as H + +type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, - splitTrees :: [Tree] + split :: (Tree, Tree) } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - {-# INLINE newTree #-} -newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree +newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree newTree s = go [] where go id = Tree (unsafePerformIO $ do (m, g) <- readIORef s - case M.lookup id m of + H.lookup m id >>= \case Nothing -> do let (x, g') = R.random g - writeIORef s (M.insert id x m, g') + H.insert m id x + writeIORef s (m, g') pure x - Just x -> pure x) [go (i:id) | i <- [0..]] - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') - where - randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# NOINLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = - if r < p - then b - else Tree a $ zipWith3 (mutateTree p) rs bs ts + Just x -> pure x) (go (False : id), go (True : id)) newtype Prob a = Prob {runProb :: Tree -> a} |