diff options
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r-- | src/PPL/Internal.hs | 66 |
1 files changed, 46 insertions, 20 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index b4283af..a81ba37 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,55 +1,80 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} module PPL.Internal ( uniform, - split, Prob (..), Meas, score, scoreLog, sample, - randomTree, samples, - mutateTree, - Tree(..), + newTree, + HashMap, + Tree (..), ) where import Control.Monad -import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor +import Data.Bits +import Data.IORef import Data.Monoid +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed.Mutable as UM +import Data.Word import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log +import System.IO.Unsafe import System.Random hiding (split, uniform) import qualified System.Random as R +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show) + +pushbit :: Bool -> Hash -> Hash +pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) + where + fromBool True = 1 + fromBool False = 0 + +initHash :: Hash +initHash = Hash 0xE297DA430DB2DF1A + -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, - splitTrees :: [Tree] + split :: (Tree, Tree) } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') +{-# INLINE newTree #-} +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash where - randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# INLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = - if r < p - then b - else Tree a $ zipWith3 (mutateTree p) rs bs ts + go :: Hash -> Tree + go id = + Tree + ( unsafePerformIO $ do + (m, g) <- readIORef s + H.lookup m (unhash id) >>= \case + Nothing -> do + let (x, g') = R.random g + H.insert m (unhash id) x + writeIORef s (m, g') + pure x + Just x -> pure x + ) + (go (pushbit False id), go (pushbit True id)) newtype Prob a = Prob {runProb :: Tree -> a} @@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap +uniform :: Prob Double uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) |