aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Distr.hs')
-rw-r--r--src/PPL/Distr.hs88
1 files changed, 74 insertions, 14 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index 797379b..d265f05 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,7 +1,15 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
+
module PPL.Distr where
+import Debug.Trace
+import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.))
+import Data.Functor ((<&>))
+import Data.Word (Word64)
import PPL.Internal
-import qualified PPL.Internal as I
+import Unsafe.Coerce
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
@@ -9,14 +17,15 @@ import qualified PPL.Internal as I
probit :: Double -> Double
probit p
| p < lower =
- let q = sqrt (-2 * log p)
- in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
- / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+ let q = sqrt (-2 * log p)
+ in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+ / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
| p < 1 - lower =
- let q = p - 0.5
- r = q * q
- in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
- / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+ let q = p - 0.5
+ r = q * q
+ in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
+ * q
+ / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
| otherwise = -probit (1 - p)
where
a1 = -3.969683028665376e+01
@@ -49,11 +58,14 @@ probit p
iid :: Prob a -> Prob [a]
iid = sequence . repeat
+gauss :: Prob Double
gauss = probit <$> uniform
+norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss
-- Marsaglia's fast gamma rejection sampling
+gamma :: Double -> Prob Double
gamma a = do
x <- gauss
u <- uniform
@@ -64,15 +76,24 @@ gamma a = do
d = a - 1 / 3
v x = (1 + x / sqrt (9 * d)) ** 3
+beta :: Double -> Double -> Prob Double
beta a b = do
x <- gamma a
y <- gamma b
pure $ x / (x + y)
+beta' :: Double -> Double -> Prob Double
+beta' a b = do
+ p <- beta a b
+ pure $ p / (1 - p)
+
+bern :: Double -> Prob Bool
bern p = (< p) <$> uniform
+binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern
+exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform
geom :: Double -> Prob Int
@@ -80,10 +101,48 @@ geom p = first 0 <$> iid (bern p)
where
first n (True : _) = n
first n (_ : xs) = first (n + 1) xs
+ first _ _ = undefined
+
+uniform :: Prob Double
+uniform = do
+ z <- unbounded
+ let r = countTrailingZeros z - 12
+ e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r
+ pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52)
+
+-- Uses algorithm from 10.1145/3503512
+bounded :: Double -> Double -> Prob Double
+bounded lower upper = do
+ (k1, k2) <- split <$> bounded' 1 (hi - 2)
+ pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
+ where
+ hi = ceilint lower upper g
+ g = max (up lower - lower) (upper - down upper)
+
+ split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
+
+ up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
-bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+ down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
-bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+ ceilint :: Double -> Double -> Double -> Word64
+ ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
+ where
+ s = b / g - a / g
+ eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
+
+unbounded :: Prob Word64
+unbounded = Prob draw
+
+bounded' :: Integral a => a -> a -> Prob a
+bounded' lower upper = do
+ z <- (`mod` nextPow2 r) <$> unbounded
+ if z <= r
+ then pure $ fromIntegral $ fromIntegral lower + z
+ else bounded' lower upper
+ where
+ r :: Word64 = fromIntegral $ upper - lower
+ nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
@@ -93,8 +152,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
| x > r = i
| otherwise = search (i + 1) xs r
+dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
- where
- go rest = do
- x <- beta 1 p
- (x*rest:) <$> go (rest - x*rest)
+ where
+ go rest = do
+ x <- beta 1 p
+ (x * rest :) <$> go (rest - x * rest)