aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/PPL/Distr.hs44
1 files changed, 40 insertions, 4 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index 8ffbd0f..fae825b 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,6 +1,12 @@
+{-# LANGUAGE TypeApplications #-}
+
module PPL.Distr where
+import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.))
+import Data.Functor ((<&>))
+import Data.Word (Word64)
import PPL.Internal
+import Unsafe.Coerce
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
@@ -94,13 +100,43 @@ geom p = first 0 <$> iid (bern p)
first n (_ : xs) = first (n + 1) xs
first _ _ = undefined
+-- Uses idea from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
-bounded lower upper = do
- z <- uniform
- pure $ 2 * (lower / 2 + (upper / 2 - lower / 2) * z)
+bounded lower upper
+ | hi - 2 <= 0xfffffffffffff = do
+ (k1, k2) <- split <$> boundedWord52 1 (hi - 2)
+ pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
+ | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x)
+ where
+ hi = ceilint lower upper g
+ g = max (up lower - lower) (upper - down upper)
+
+ split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
+
+ up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
+ down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
+
+ ceilint :: Double -> Double -> Double -> Word64
+ ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
+ where
+ s = b / g - a / g
+ eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
+
+-- Random mantissa part of a uniform double
+unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform
+
+boundedWord52 l r = do
+ x <- (`mod` nextPow2 r) <$> unboundedWord52
+ if x <= r then pure (x + l) else boundedWord52 l r
+ where
+ nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
bounded' :: Integral a => a -> a -> Prob a
-bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+bounded' lower upper
+ | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r)
+ | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+ where
+ r = upper - lower
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform