aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-03-05 18:26:04 +1100
committerJustin Bedo <cu@cua0.org>2025-03-06 09:58:26 +1100
commitdab819d62c6bf139dbbb3e36fd6f835f7681b595 (patch)
treee934f5d61d2dd488e6c796d15ded47149586df03
parent51c32c174a168db6f97a7f93dcc58bcb7c351a65 (diff)
use hashed ids
-rw-r--r--src/PPL/Internal.hs38
-rw-r--r--src/PPL/Sampling.hs5
2 files changed, 34 insertions, 9 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 960b19a..462be20 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,7 +1,10 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
module PPL.Internal
( uniform,
@@ -29,11 +32,31 @@ import System.Random hiding (split, uniform)
import qualified System.Random as R
import Data.IORef
import System.IO.Unsafe
-import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Data.Vector.Hashtables as H
+import Data.Word
+import Data.Bits
-type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v
+type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v
+
+
+-- Simple hashing scheme into 64-bit space based on mzHash64
+data Hash = Hash Int Int Word8 Word64 deriving (Eq, Ord, Show)
+
+unhash :: Hash -> Word64
+unhash h@(Hash _ j w8 _) = let Hash _ _ _ s = addhash (fromIntegral j) $ addhash (fromIntegral w8) h in s
+
+addhash :: Word8 -> Hash -> Hash
+addhash x (Hash i j bits state) = Hash (i+1) j bits $ (0xB2DEEF07CB4ACD43 * (fromIntegral i + fromIntegral x)) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2)
+initHash = Hash 0 0 0 0xE297DA430DB2DF1A
+
+pushbit :: Bool -> Hash -> Hash
+pushbit b h@(Hash i j bits state)
+ | j == 8 = let Hash i _ _ state = addhash bits h in Hash i 1 (fromBool b) state
+ | otherwise = Hash i (j+1) (bits*2 + fromBool b) state
+ where
+ fromBool True = 1
+ fromBool False = 0
-- Reimplementation of the LazyPPL monads to avoid some dependencies
@@ -43,18 +66,19 @@ data Tree = Tree
}
{-# INLINE newTree #-}
-newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree
-newTree s = go []
+newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree
+newTree s = go initHash
where
+ go :: Hash -> Tree
go id = Tree (unsafePerformIO $ do
(m, g) <- readIORef s
- H.lookup m id >>= \case
+ H.lookup m (unhash id) >>= \case
Nothing -> do
let (x, g') = R.random g
- H.insert m id x
+ H.insert m (unhash id) x
writeIORef s (m, g')
pure x
- Just x -> pure x) (go (False : id), go (True : id))
+ Just x -> pure x) (go (pushbit False id), go (pushbit True id))
newtype Prob a = Prob {runProb :: Tree -> a}
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 5ad346d..59126ba 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -27,7 +27,8 @@ import qualified System.Random as R
import Data.IORef
import Control.Monad
import qualified Data.Vector.Hashtables as H
-import qualified Data.Vector as V
+import qualified Data.Vector.Unboxed as V
+import Data.Word
mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p m = do
@@ -50,7 +51,7 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) -> m (IORef (HashMap [Bool] Double, StdGen))
+ mutate :: MonadIO m => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen))
mutate g omega = liftIO $ do
(m, g0) <- readIORef omega
m' <- H.clone m