blob: fae825bcf49b0d47f019533ca23147c0e94db6ae (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
{-# LANGUAGE TypeApplications #-}
module PPL.Distr where
import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.))
import Data.Functor ((<&>))
import Data.Word (Word64)
import PPL.Internal
import Unsafe.Coerce
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
{-# INLINE probit #-}
probit :: Double -> Double
probit p
| p < lower =
let q = sqrt (-2 * log p)
in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
/ ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
| p < 1 - lower =
let q = p - 0.5
r = q * q
in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
* q
/ (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
| otherwise = -probit (1 - p)
where
a1 = -3.969683028665376e+01
a2 = 2.209460984245205e+02
a3 = -2.759285104469687e+02
a4 = 1.383577518672690e+02
a5 = -3.066479806614716e+01
a6 = 2.506628277459239e+00
b1 = -5.447609879822406e+01
b2 = 1.615858368580409e+02
b3 = -1.556989798598866e+02
b4 = 6.680131188771972e+01
b5 = -1.328068155288572e+01
c1 = -7.784894002430293e-03
c2 = -3.223964580411365e-01
c3 = -2.400758277161838e+00
c4 = -2.549732539343734e+00
c5 = 4.374664141464968e+00
c6 = 2.938163982698783e+00
d1 = 7.784695709041462e-03
d2 = 3.224671290700398e-01
d3 = 2.445134137142996e+00
d4 = 3.754408661907416e+00
lower = 0.02425
iid :: Prob a -> Prob [a]
iid = sequence . repeat
gauss :: Prob Double
gauss = probit <$> uniform
norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss
-- Marsaglia's fast gamma rejection sampling
gamma :: Double -> Prob Double
gamma a = do
x <- gauss
u <- uniform
if u < 1 - 0.03331 * x ** 4
then pure $ d * v x
else gamma a
where
d = a - 1 / 3
v x = (1 + x / sqrt (9 * d)) ** 3
beta :: Double -> Double -> Prob Double
beta a b = do
x <- gamma a
y <- gamma b
pure $ x / (x + y)
beta' :: Double -> Double -> Prob Double
beta' a b = do
p <- beta a b
pure $ p / (1 - p)
bern :: Double -> Prob Bool
bern p = (< p) <$> uniform
binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern
exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform
geom :: Double -> Prob Int
geom p = first 0 <$> iid (bern p)
where
first n (True : _) = n
first n (_ : xs) = first (n + 1) xs
first _ _ = undefined
-- Uses idea from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
bounded lower upper
| hi - 2 <= 0xfffffffffffff = do
(k1, k2) <- split <$> boundedWord52 1 (hi - 2)
pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
| otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x)
where
hi = ceilint lower upper g
g = max (up lower - lower) (upper - down upper)
split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
ceilint :: Double -> Double -> Double -> Word64
ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
where
s = b / g - a / g
eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
-- Random mantissa part of a uniform double
unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform
boundedWord52 l r = do
x <- (`mod` nextPow2 r) <$> unboundedWord52
if x <= r then pure (x + l) else boundedWord52 l r
where
nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
bounded' :: Integral a => a -> a -> Prob a
bounded' lower upper
| r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r)
| otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper)
where
r = upper - lower
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
where
search i [] _ = i
search i (x : xs) r
| x > r = i
| otherwise = search (i + 1) xs r
dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
where
go rest = do
x <- beta 1 p
(x * rest :) <$> go (rest - x * rest)
|