diff options
author | Justin Bedo <cu@cua0.org> | 2025-03-05 18:26:04 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-03-06 09:58:26 +1100 |
commit | dab819d62c6bf139dbbb3e36fd6f835f7681b595 (patch) | |
tree | e934f5d61d2dd488e6c796d15ded47149586df03 | |
parent | 51c32c174a168db6f97a7f93dcc58bcb7c351a65 (diff) |
use hashed ids
-rw-r--r-- | src/PPL/Internal.hs | 38 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 5 |
2 files changed, 34 insertions, 9 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 960b19a..462be20 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,7 +1,10 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} module PPL.Internal ( uniform, @@ -29,11 +32,31 @@ import System.Random hiding (split, uniform) import qualified System.Random as R import Data.IORef import System.IO.Unsafe -import qualified Data.Vector.Mutable as VM import qualified Data.Vector.Unboxed.Mutable as UM import qualified Data.Vector.Hashtables as H +import Data.Word +import Data.Bits -type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash Int Int Word8 Word64 deriving (Eq, Ord, Show) + +unhash :: Hash -> Word64 +unhash h@(Hash _ j w8 _) = let Hash _ _ _ s = addhash (fromIntegral j) $ addhash (fromIntegral w8) h in s + +addhash :: Word8 -> Hash -> Hash +addhash x (Hash i j bits state) = Hash (i+1) j bits $ (0xB2DEEF07CB4ACD43 * (fromIntegral i + fromIntegral x)) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) +initHash = Hash 0 0 0 0xE297DA430DB2DF1A + +pushbit :: Bool -> Hash -> Hash +pushbit b h@(Hash i j bits state) + | j == 8 = let Hash i _ _ state = addhash bits h in Hash i 1 (fromBool b) state + | otherwise = Hash i (j+1) (bits*2 + fromBool b) state + where + fromBool True = 1 + fromBool False = 0 -- Reimplementation of the LazyPPL monads to avoid some dependencies @@ -43,18 +66,19 @@ data Tree = Tree } {-# INLINE newTree #-} -newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree -newTree s = go [] +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash where + go :: Hash -> Tree go id = Tree (unsafePerformIO $ do (m, g) <- readIORef s - H.lookup m id >>= \case + H.lookup m (unhash id) >>= \case Nothing -> do let (x, g') = R.random g - H.insert m id x + H.insert m (unhash id) x writeIORef s (m, g') pure x - Just x -> pure x) (go (False : id), go (True : id)) + Just x -> pure x) (go (pushbit False id), go (pushbit True id)) newtype Prob a = Prob {runProb :: Tree -> a} diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 5ad346d..59126ba 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -27,7 +27,8 @@ import qualified System.Random as R import Data.IORef import Control.Monad import qualified Data.Vector.Hashtables as H -import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as V +import Data.Word mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = do @@ -50,7 +51,7 @@ mh g p m = do yield (x'', w'') step g2 omega'' x'' w'' - mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) -> m (IORef (HashMap [Bool] Double, StdGen)) + mutate :: MonadIO m => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen)) mutate g omega = liftIO $ do (m, g0) <- readIORef omega m' <- H.clone m |