aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/PPL/Internal.hs27
-rw-r--r--src/PPL/Sampling.hs4
2 files changed, 10 insertions, 21 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 7ac884b..ea37490 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -29,6 +29,7 @@ import Data.Bits
import Data.IORef
import Data.Monoid
import Data.Vector.Hashtables qualified as H
+import Data.Vector.Mutable qualified as M
import Data.Vector.Unboxed.Mutable qualified as UM
import Data.Word
import Language.Haskell.TH.Syntax qualified as TH
@@ -37,19 +38,7 @@ import System.IO.Unsafe
import System.Random hiding (split, uniform)
import System.Random qualified as R
-type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v
-
--- Simple hashing scheme into 64-bit space based on mzHash64
-data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show)
-
-pushbit :: Bool -> Hash -> Hash
-pushbit b (Hash state) = Hash $ (16738720027710993212 * fromBool b) `xor` (state `shiftL` 3) `xor` (state `shiftR` 1)
- where
- fromBool True = 1
- fromBool False = 0
-
-initHash :: Hash
-initHash = Hash 6223102867371013753
+type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
@@ -59,23 +48,23 @@ data Tree = Tree
}
{-# INLINE newTree #-}
-newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree
-newTree s = go initHash
+newTree :: IORef (HashMap Integer Double, StdGen) -> Tree
+newTree s = go 1
where
- go :: Hash -> Tree
+ go :: Integer -> Tree
go id =
Tree
( unsafePerformIO $ do
(m, g) <- readIORef s
- H.lookup m (unhash id) >>= \case
+ H.lookup m id >>= \case
Nothing -> do
let (x, g') = R.random g
- H.insert m (unhash id) x
+ H.insert m id x
writeIORef s (m, g')
pure x
Just x -> pure x
)
- (go (pushbit False id), go (pushbit True id))
+ (go (2 * id), go (2 * id + 1))
newtype Prob a = Prob {runProb :: Tree -> a}
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index d397ca8..256e67b 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -13,8 +13,8 @@ where
import Control.Monad
import Control.Monad.IO.Class
import Data.IORef
+import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
-import Data.Vector.Unboxed qualified as V
import Data.Word
import Numeric.Log
import PPL.Internal
@@ -42,7 +42,7 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen))
+ mutate :: (MonadIO m) => StdGen -> IORef (HashMap Integer Double, StdGen) -> m (IORef (HashMap Integer Double, StdGen))
mutate g omega = liftIO $ do
(m, g0) <- readIORef omega
m' <- H.clone m