aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2026-02-21 08:48:25 +1100
committerJustin Bedo <cu@cua0.org>2026-02-21 08:48:25 +1100
commit6767617cb5bcf4b32fd02c12eb235d66e3dda7d5 (patch)
treea581934ad4b9c2c51177aa71565499de5ad13d86 /src/PPL/Distr.hs
parent8b96f69225cba41c4b4217290d2a0feef810b0f2 (diff)
refactor bounded sampling
Diffstat (limited to 'src/PPL/Distr.hs')
-rw-r--r--src/PPL/Distr.hs41
1 files changed, 22 insertions, 19 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index fae825b..54db358 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,8 +1,10 @@
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
module PPL.Distr where
-import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.))
+import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.))
import Data.Functor ((<&>))
import Data.Word (Word64)
import PPL.Internal
@@ -100,13 +102,17 @@ geom p = first 0 <$> iid (bern p)
first n (_ : xs) = first (n + 1) xs
first _ _ = undefined
--- Uses idea from 10.1145/3503512
+uniform = do
+ z <- unbounded
+ let r = countTrailingZeros z - 11
+ e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r
+ pure . unsafeCoerce $ (1 + (z `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+
+-- Uses algorithm from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
-bounded lower upper
- | hi - 2 <= 0xfffffffffffff = do
- (k1, k2) <- split <$> boundedWord52 1 (hi - 2)
- pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
- | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x)
+bounded lower upper = do
+ (k1, k2) <- split <$> bounded' 1 (hi - 2)
+ pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
where
hi = ceilint lower upper g
g = max (up lower - lower) (upper - down upper)
@@ -114,6 +120,7 @@ bounded lower upper
split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
+
down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
ceilint :: Double -> Double -> Double -> Word64
@@ -122,21 +129,17 @@ bounded lower upper
s = b / g - a / g
eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
--- Random mantissa part of a uniform double
-unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform
-
-boundedWord52 l r = do
- x <- (`mod` nextPow2 r) <$> unboundedWord52
- if x <= r then pure (x + l) else boundedWord52 l r
- where
- nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
+unbounded = Prob draw
bounded' :: Integral a => a -> a -> Prob a
-bounded' lower upper
- | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r)
- | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+bounded' lower upper = do
+ z <- (`mod` nextPow2 r) <$> unbounded
+ if z <= r
+ then pure $ fromIntegral $ fromIntegral lower + z
+ else bounded' lower upper
where
- r = upper - lower
+ r :: Word64 = fromIntegral $ upper - lower
+ nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform