aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r--src/PPL/Internal.hs66
1 files changed, 46 insertions, 20 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index b4283af..a81ba37 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,55 +1,80 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module PPL.Internal
( uniform,
- split,
Prob (..),
Meas,
score,
scoreLog,
sample,
- randomTree,
samples,
- mutateTree,
- Tree(..),
+ newTree,
+ HashMap,
+ Tree (..),
)
where
import Control.Monad
-import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
+import Data.Bits
+import Data.IORef
import Data.Monoid
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed.Mutable as UM
+import Data.Word
import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
+import System.IO.Unsafe
import System.Random hiding (split, uniform)
import qualified System.Random as R
+type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v
+
+-- Simple hashing scheme into 64-bit space based on mzHash64
+data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show)
+
+pushbit :: Bool -> Hash -> Hash
+pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2)
+ where
+ fromBool True = 1
+ fromBool False = 0
+
+initHash :: Hash
+initHash = Hash 0xE297DA430DB2DF1A
+
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
{ draw :: !Double,
- splitTrees :: [Tree]
+ split :: (Tree, Tree)
}
-split :: Tree -> (Tree, Tree)
-split (Tree r (t : ts)) = (t, Tree r ts)
-
-{-# INLINE randomTree #-}
-randomTree :: RandomGen g => g -> Tree
-randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+{-# INLINE newTree #-}
+newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree
+newTree s = go initHash
where
- randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-
-{-# INLINE mutateTree #-}
-mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
-mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
- if r < p
- then b
- else Tree a $ zipWith3 (mutateTree p) rs bs ts
+ go :: Hash -> Tree
+ go id =
+ Tree
+ ( unsafePerformIO $ do
+ (m, g) <- readIORef s
+ H.lookup m (unhash id) >>= \case
+ Nothing -> do
+ let (x, g') = R.random g
+ H.insert m (unhash id) x
+ writeIORef s (m, g')
+ pure x
+ Just x -> pure x
+ )
+ (go (pushbit False id), go (pushbit True id))
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
+uniform :: Prob Double
uniform = Prob $ \(Tree r _) -> r
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)