diff options
Diffstat (limited to 'src/PPL/Distr.hs')
| -rw-r--r-- | src/PPL/Distr.hs | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index fae825b..54db358 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} module PPL.Distr where -import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.)) +import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.)) import Data.Functor ((<&>)) import Data.Word (Word64) import PPL.Internal @@ -100,13 +102,17 @@ geom p = first 0 <$> iid (bern p) first n (_ : xs) = first (n + 1) xs first _ _ = undefined --- Uses idea from 10.1145/3503512 +uniform = do + z <- unbounded + let r = countTrailingZeros z - 11 + e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r + pure . unsafeCoerce $ (1 + (z `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + +-- Uses algorithm from 10.1145/3503512 bounded :: Double -> Double -> Prob Double -bounded lower upper - | hi - 2 <= 0xfffffffffffff = do - (k1, k2) <- split <$> boundedWord52 1 (hi - 2) - pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g - | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x) +bounded lower upper = do + (k1, k2) <- split <$> bounded' 1 (hi - 2) + pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g where hi = ceilint lower upper g g = max (up lower - lower) (upper - down upper) @@ -114,6 +120,7 @@ bounded lower upper split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) + down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) ceilint :: Double -> Double -> Double -> Word64 @@ -122,21 +129,17 @@ bounded lower upper s = b / g - a / g eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) --- Random mantissa part of a uniform double -unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform - -boundedWord52 l r = do - x <- (`mod` nextPow2 r) <$> unboundedWord52 - if x <= r then pure (x + l) else boundedWord52 l r - where - nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) +unbounded = Prob draw bounded' :: Integral a => a -> a -> Prob a -bounded' lower upper - | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r) - | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper) +bounded' lower upper = do + z <- (`mod` nextPow2 r) <$> unbounded + if z <= r + then pure $ fromIntegral $ fromIntegral lower + z + else bounded' lower upper where - r = upper - lower + r :: Word64 = fromIntegral $ upper - lower + nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform |
