From 6767617cb5bcf4b32fd02c12eb235d66e3dda7d5 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Sat, 21 Feb 2026 08:48:25 +1100 Subject: refactor bounded sampling --- src/PPL/Sampling.hs | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'src/PPL/Sampling.hs') diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 71a61fd..4deaf36 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -14,19 +14,35 @@ where import Control.Arrow import Control.Monad import Control.Monad.IO.Class +import Data.Bits (countTrailingZeros, shiftL, shiftR) import Data.IORef import Data.List (foldl') import Data.Vector qualified as V import Data.Vector.Hashtables qualified as H +import Data.Word (Word64) import Numeric.Log hiding (sum) import PPL.Internal hiding (newTree, split) import Streaming.Prelude (Of, Stream, yield) import System.IO.Unsafe import System.Random (StdGen, split) +import System.Random qualified as R +import Unsafe.Coerce -{-# INLINE newTree #-} -newTree :: IORef (HashMap Integer Double, StdGen) -> Tree -newTree s = go 1 +type MemoTree = IORef (HashMap Integer Word64, StdGen) + +-- Let's draw doubles more uniformly than the standard generator +-- based on https://www.corsix.org/content/higher-quality-random-floats +random :: StdGen -> (Double, StdGen) +random g0 = + let (r1 :: Word64, g1) = R.random g0 + (r2 :: Word64, g2) = R.random g1 + e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r + dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + in (unsafeCoerce dbl, g2) + +{-# INLINE newMemoTree #-} +newMemoTree :: MemoTree -> Tree +newMemoTree s = go 1 where go :: Integer -> Tree go id = @@ -35,7 +51,7 @@ newTree s = go 1 (m, g) <- readIORef s H.lookup m id >>= \case Nothing -> do - let (x, g') = random g + let (x, g') = R.random g H.insert m id x writeIORef s (m, g') pure x @@ -63,13 +79,13 @@ mh g p m = do let (g0, g1) = split g hm <- liftIO $ H.initialize 0 omega <- liftIO $ newIORef (hm, g0) - let (x, w) = head $ samples m $ newTree omega + let (x, w) = head $ samples m $ newMemoTree omega step g1 omega x w where step !g0 !omega !x !w = do let (Exp . log -> r, split -> (g1, g2)) = random g0 omega' <- mutate g1 omega - let (!x', !w') = head $ samples m $ newTree omega' + let (!x', !w') = head $ samples m $ newMemoTree omega' ratio = w' / w (omega'', x'', w'') = if r < ratio @@ -78,12 +94,12 @@ mh g p m = do yield (x'', w'') step g2 omega'' x'' w'' - mutate :: (MonadIO m) => StdGen -> IORef (HashMap Integer Double, StdGen) -> m (IORef (HashMap Integer Double, StdGen)) + mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree mutate g omega = liftIO $ do (m, g0) <- readIORef omega m' <- H.clone m ks <- H.keys m - let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g) + let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g) ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks n = fromIntegral (V.length ks) void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs -- cgit v1.2.3