diff options
| author | Justin Bedo <cu@cua0.org> | 2026-02-20 17:05:41 +1100 |
|---|---|---|
| committer | Justin Bedo <cu@cua0.org> | 2026-02-20 22:55:09 +1100 |
| commit | 8b96f69225cba41c4b4217290d2a0feef810b0f2 (patch) | |
| tree | d64e7ea39ee18262ca002e5fe57b81f8812484a9 /src/PPL/Distr.hs | |
| parent | 0e87d9352381c865ba91cffda353e394d0cf418f (diff) | |
improve bounded sampling
Diffstat (limited to 'src/PPL/Distr.hs')
| -rw-r--r-- | src/PPL/Distr.hs | 44 |
1 files changed, 40 insertions, 4 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 8ffbd0f..fae825b 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,6 +1,12 @@ +{-# LANGUAGE TypeApplications #-} + module PPL.Distr where +import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.)) +import Data.Functor ((<&>)) +import Data.Word (Word64) import PPL.Internal +import Unsafe.Coerce -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -94,13 +100,43 @@ geom p = first 0 <$> iid (bern p) first n (_ : xs) = first (n + 1) xs first _ _ = undefined +-- Uses idea from 10.1145/3503512 bounded :: Double -> Double -> Prob Double -bounded lower upper = do - z <- uniform - pure $ 2 * (lower / 2 + (upper / 2 - lower / 2) * z) +bounded lower upper + | hi - 2 <= 0xfffffffffffff = do + (k1, k2) <- split <$> boundedWord52 1 (hi - 2) + pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g + | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x) + where + hi = ceilint lower upper g + g = max (up lower - lower) (upper - down upper) + + split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) + + up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) + down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) + + ceilint :: Double -> Double -> Double -> Word64 + ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0 + where + s = b / g - a / g + eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) + +-- Random mantissa part of a uniform double +unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform + +boundedWord52 l r = do + x <- (`mod` nextPow2 r) <$> unboundedWord52 + if x <= r then pure (x + l) else boundedWord52 l r + where + nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) bounded' :: Integral a => a -> a -> Prob a -bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) +bounded' lower upper + | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r) + | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper) + where + r = upper - lower cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform |
