From 6767617cb5bcf4b32fd02c12eb235d66e3dda7d5 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Sat, 21 Feb 2026 08:48:25 +1100 Subject: refactor bounded sampling --- src/PPL/Distr.hs | 41 ++++++++++++++++++++++------------------- src/PPL/Internal.hs | 29 +++-------------------------- src/PPL/Sampling.hs | 32 ++++++++++++++++++++++++-------- 3 files changed, 49 insertions(+), 53 deletions(-) (limited to 'src/PPL') diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index fae825b..54db358 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} module PPL.Distr where -import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.)) +import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.)) import Data.Functor ((<&>)) import Data.Word (Word64) import PPL.Internal @@ -100,13 +102,17 @@ geom p = first 0 <$> iid (bern p) first n (_ : xs) = first (n + 1) xs first _ _ = undefined --- Uses idea from 10.1145/3503512 +uniform = do + z <- unbounded + let r = countTrailingZeros z - 11 + e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r + pure . unsafeCoerce $ (1 + (z `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + +-- Uses algorithm from 10.1145/3503512 bounded :: Double -> Double -> Prob Double -bounded lower upper - | hi - 2 <= 0xfffffffffffff = do - (k1, k2) <- split <$> boundedWord52 1 (hi - 2) - pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g - | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x) +bounded lower upper = do + (k1, k2) <- split <$> bounded' 1 (hi - 2) + pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g where hi = ceilint lower upper g g = max (up lower - lower) (upper - down upper) @@ -114,6 +120,7 @@ bounded lower upper split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) + down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) ceilint :: Double -> Double -> Double -> Word64 @@ -122,21 +129,17 @@ bounded lower upper s = b / g - a / g eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) --- Random mantissa part of a uniform double -unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform - -boundedWord52 l r = do - x <- (`mod` nextPow2 r) <$> unboundedWord52 - if x <= r then pure (x + l) else boundedWord52 l r - where - nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) +unbounded = Prob draw bounded' :: Integral a => a -> a -> Prob a -bounded' lower upper - | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r) - | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper) +bounded' lower upper = do + z <- (`mod` nextPow2 r) <$> unbounded + if z <= r + then pure $ fromIntegral $ fromIntegral lower + z + else bounded' lower upper where - r = upper - lower + r :: Word64 = fromIntegral $ upper - lower + nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 3c54078..ec9b0c8 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -9,8 +9,7 @@ {-# LANGUAGE ViewPatterns #-} module PPL.Internal - ( uniform, - Prob (..), + ( Prob (..), Meas, score, scoreLog, @@ -18,8 +17,6 @@ module PPL.Internal memoize, samples, HashMap, - random, - randoms, newTree, Tree (..), ) @@ -29,7 +26,6 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor -import Data.Bits (countTrailingZeros, shiftL, shiftR) import Data.IORef import Data.Map qualified as Q import Data.Monoid @@ -42,36 +38,20 @@ import Numeric.Log import System.IO.Unsafe import System.Random hiding (random, randoms, split, uniform) import System.Random qualified as R -import Unsafe.Coerce type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree - { draw :: Double, + { draw :: Word64, split :: (Tree, Tree) } --- Let's draw doubles more uniformly than the standard generator --- based on https://www.corsix.org/content/higher-quality-random-floats -random :: StdGen -> (Double, StdGen) -random g0 = - let (r1 :: Word64, g1) = R.random g0 - (r2 :: Word64, g2) = R.random g1 - e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r - dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) - in (unsafeCoerce dbl, g2) - -randoms :: StdGen -> [Double] -randoms g = - let (x, g') = random g - in x : randoms g' - newTree :: StdGen -> Tree newTree g0 = Tree x (newTree g1, newTree g2) where - (x, R.split -> (g1, g2)) = random g0 + (x, R.split -> (g1, g2)) = R.random g0 newtype Prob a = Prob {runProb :: Tree -> a} @@ -85,9 +65,6 @@ instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap -uniform :: Prob Double -uniform = Prob $ \(Tree r _) -> r - newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 71a61fd..4deaf36 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -14,19 +14,35 @@ where import Control.Arrow import Control.Monad import Control.Monad.IO.Class +import Data.Bits (countTrailingZeros, shiftL, shiftR) import Data.IORef import Data.List (foldl') import Data.Vector qualified as V import Data.Vector.Hashtables qualified as H +import Data.Word (Word64) import Numeric.Log hiding (sum) import PPL.Internal hiding (newTree, split) import Streaming.Prelude (Of, Stream, yield) import System.IO.Unsafe import System.Random (StdGen, split) +import System.Random qualified as R +import Unsafe.Coerce -{-# INLINE newTree #-} -newTree :: IORef (HashMap Integer Double, StdGen) -> Tree -newTree s = go 1 +type MemoTree = IORef (HashMap Integer Word64, StdGen) + +-- Let's draw doubles more uniformly than the standard generator +-- based on https://www.corsix.org/content/higher-quality-random-floats +random :: StdGen -> (Double, StdGen) +random g0 = + let (r1 :: Word64, g1) = R.random g0 + (r2 :: Word64, g2) = R.random g1 + e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r + dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + in (unsafeCoerce dbl, g2) + +{-# INLINE newMemoTree #-} +newMemoTree :: MemoTree -> Tree +newMemoTree s = go 1 where go :: Integer -> Tree go id = @@ -35,7 +51,7 @@ newTree s = go 1 (m, g) <- readIORef s H.lookup m id >>= \case Nothing -> do - let (x, g') = random g + let (x, g') = R.random g H.insert m id x writeIORef s (m, g') pure x @@ -63,13 +79,13 @@ mh g p m = do let (g0, g1) = split g hm <- liftIO $ H.initialize 0 omega <- liftIO $ newIORef (hm, g0) - let (x, w) = head $ samples m $ newTree omega + let (x, w) = head $ samples m $ newMemoTree omega step g1 omega x w where step !g0 !omega !x !w = do let (Exp . log -> r, split -> (g1, g2)) = random g0 omega' <- mutate g1 omega - let (!x', !w') = head $ samples m $ newTree omega' + let (!x', !w') = head $ samples m $ newMemoTree omega' ratio = w' / w (omega'', x'', w'') = if r < ratio @@ -78,12 +94,12 @@ mh g p m = do yield (x'', w'') step g2 omega'' x'' w'' - mutate :: (MonadIO m) => StdGen -> IORef (HashMap Integer Double, StdGen) -> m (IORef (HashMap Integer Double, StdGen)) + mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree mutate g omega = liftIO $ do (m, g0) <- readIORef omega m' <- H.clone m ks <- H.keys m - let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g) + let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g) ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks n = fromIntegral (V.length ks) void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs -- cgit v1.2.3