diff options
Diffstat (limited to 'src/PPL/Internal.hs')
| -rw-r--r-- | src/PPL/Internal.hs | 66 | 
1 files changed, 46 insertions, 20 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index b4283af..a81ba37 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,55 +1,80 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-}  module PPL.Internal    ( uniform, -    split,      Prob (..),      Meas,      score,      scoreLog,      sample, -    randomTree,      samples, -    mutateTree, -    Tree(..), +    newTree, +    HashMap, +    Tree (..),    )  where  import Control.Monad -import Control.Monad.IO.Class  import Control.Monad.Trans.Class  import Control.Monad.Trans.Writer  import Data.Bifunctor +import Data.Bits +import Data.IORef  import Data.Monoid +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed.Mutable as UM +import Data.Word  import qualified Language.Haskell.TH.Syntax as TH  import Numeric.Log +import System.IO.Unsafe  import System.Random hiding (split, uniform)  import qualified System.Random as R +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show) + +pushbit :: Bool -> Hash -> Hash +pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) +  where +    fromBool True = 1 +    fromBool False = 0 + +initHash :: Hash +initHash = Hash 0xE297DA430DB2DF1A +  -- Reimplementation of the LazyPPL monads to avoid some dependencies  data Tree = Tree    { draw :: !Double, -    splitTrees :: [Tree] +    split :: (Tree, Tree)    } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') +{-# INLINE newTree #-} +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash    where -    randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# INLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = -  if r < p -    then b -    else Tree a $ zipWith3 (mutateTree p) rs bs ts +    go :: Hash -> Tree +    go id = +      Tree +        ( unsafePerformIO $ do +            (m, g) <- readIORef s +            H.lookup m (unhash id) >>= \case +              Nothing -> do +                let (x, g') = R.random g +                H.insert m (unhash id) x +                writeIORef s (m, g') +                pure x +              Just x -> pure x +        ) +        (go (pushbit False id), go (pushbit True id))  newtype Prob a = Prob {runProb :: Tree -> a} @@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM  instance Applicative Prob where pure = Prob . const; (<*>) = ap +uniform :: Prob Double  uniform = Prob $ \(Tree r _) -> r  newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) | 
