aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Sampling.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Sampling.hs')
-rw-r--r--src/PPL/Sampling.hs32
1 files changed, 24 insertions, 8 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 71a61fd..4deaf36 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -14,19 +14,35 @@ where
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
+import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.List (foldl')
import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
+import Data.Word (Word64)
import Numeric.Log hiding (sum)
import PPL.Internal hiding (newTree, split)
import Streaming.Prelude (Of, Stream, yield)
import System.IO.Unsafe
import System.Random (StdGen, split)
+import System.Random qualified as R
+import Unsafe.Coerce
-{-# INLINE newTree #-}
-newTree :: IORef (HashMap Integer Double, StdGen) -> Tree
-newTree s = go 1
+type MemoTree = IORef (HashMap Integer Word64, StdGen)
+
+-- Let's draw doubles more uniformly than the standard generator
+-- based on https://www.corsix.org/content/higher-quality-random-floats
+random :: StdGen -> (Double, StdGen)
+random g0 =
+ let (r1 :: Word64, g1) = R.random g0
+ (r2 :: Word64, g2) = R.random g1
+ e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
+ dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+ in (unsafeCoerce dbl, g2)
+
+{-# INLINE newMemoTree #-}
+newMemoTree :: MemoTree -> Tree
+newMemoTree s = go 1
where
go :: Integer -> Tree
go id =
@@ -35,7 +51,7 @@ newTree s = go 1
(m, g) <- readIORef s
H.lookup m id >>= \case
Nothing -> do
- let (x, g') = random g
+ let (x, g') = R.random g
H.insert m id x
writeIORef s (m, g')
pure x
@@ -63,13 +79,13 @@ mh g p m = do
let (g0, g1) = split g
hm <- liftIO $ H.initialize 0
omega <- liftIO $ newIORef (hm, g0)
- let (x, w) = head $ samples m $ newTree omega
+ let (x, w) = head $ samples m $ newMemoTree omega
step g1 omega x w
where
step !g0 !omega !x !w = do
let (Exp . log -> r, split -> (g1, g2)) = random g0
omega' <- mutate g1 omega
- let (!x', !w') = head $ samples m $ newTree omega'
+ let (!x', !w') = head $ samples m $ newMemoTree omega'
ratio = w' / w
(omega'', x'', w'') =
if r < ratio
@@ -78,12 +94,12 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: (MonadIO m) => StdGen -> IORef (HashMap Integer Double, StdGen) -> m (IORef (HashMap Integer Double, StdGen))
+ mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree
mutate g omega = liftIO $ do
(m, g0) <- readIORef omega
m' <- H.clone m
ks <- H.keys m
- let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g)
+ let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g)
ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks
n = fromIntegral (V.length ks)
void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs