diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL/Internal.hs | 38 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 5 | 
2 files changed, 34 insertions, 9 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 960b19a..462be20 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,7 +1,10 @@  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-}  module PPL.Internal    ( uniform, @@ -29,11 +32,31 @@ import System.Random hiding (split, uniform)  import qualified System.Random as R  import Data.IORef  import System.IO.Unsafe -import qualified Data.Vector.Mutable as VM  import qualified Data.Vector.Unboxed.Mutable as UM  import qualified Data.Vector.Hashtables as H +import Data.Word +import Data.Bits -type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash Int Int Word8 Word64 deriving (Eq, Ord, Show) + +unhash :: Hash -> Word64 +unhash h@(Hash _ j w8 _) = let Hash _ _ _ s = addhash (fromIntegral j) $ addhash (fromIntegral w8) h in s + +addhash :: Word8 -> Hash -> Hash +addhash x (Hash i j bits state) = Hash (i+1) j bits $ (0xB2DEEF07CB4ACD43 * (fromIntegral i + fromIntegral x)) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) +initHash = Hash 0 0 0 0xE297DA430DB2DF1A + +pushbit :: Bool -> Hash -> Hash +pushbit b h@(Hash i j bits state) +  | j == 8 = let Hash i _ _ state = addhash bits h in Hash i 1 (fromBool b) state +  | otherwise = Hash i (j+1) (bits*2 + fromBool b) state +  where +    fromBool True = 1 +    fromBool False = 0  -- Reimplementation of the LazyPPL monads to avoid some dependencies @@ -43,18 +66,19 @@ data Tree = Tree    }  {-# INLINE newTree #-} -newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree -newTree s = go [] +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash    where +    go :: Hash -> Tree      go id = Tree (unsafePerformIO $ do        (m, g) <- readIORef s -      H.lookup m id >>= \case +      H.lookup m (unhash id) >>= \case          Nothing -> do            let (x, g') = R.random g -          H.insert m id x +          H.insert m (unhash id) x            writeIORef s (m, g')            pure x -        Just x -> pure x) (go (False : id), go (True : id)) +        Just x -> pure x) (go (pushbit False id), go (pushbit True id))  newtype Prob a = Prob {runProb :: Tree -> a} diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 5ad346d..59126ba 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -27,7 +27,8 @@ import qualified System.Random as R  import Data.IORef  import Control.Monad  import qualified Data.Vector.Hashtables as H -import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as V +import Data.Word  mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()  mh g p m = do @@ -50,7 +51,7 @@ mh g p m = do        yield (x'', w'')        step g2 omega'' x'' w'' -    mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) ->  m (IORef (HashMap [Bool] Double, StdGen)) +    mutate :: MonadIO m => StdGen -> IORef (HashMap Word64 Double, StdGen) ->  m (IORef (HashMap Word64 Double, StdGen))      mutate g omega = liftIO $ do        (m, g0) <- readIORef omega        m' <- H.clone m | 
