aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
blob: d265f056b7026d759e860c950e02a31fecef8e92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module PPL.Distr where
import Debug.Trace

import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.))
import Data.Functor ((<&>))
import Data.Word (Word64)
import PPL.Internal
import Unsafe.Coerce

-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
{-# INLINE probit #-}
probit :: Double -> Double
probit p
  | p < lower =
      let q = sqrt (-2 * log p)
       in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
            / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
  | p < 1 - lower =
      let q = p - 0.5
          r = q * q
       in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
            * q
            / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
  | otherwise = -probit (1 - p)
  where
    a1 = -3.969683028665376e+01
    a2 = 2.209460984245205e+02
    a3 = -2.759285104469687e+02
    a4 = 1.383577518672690e+02
    a5 = -3.066479806614716e+01
    a6 = 2.506628277459239e+00

    b1 = -5.447609879822406e+01
    b2 = 1.615858368580409e+02
    b3 = -1.556989798598866e+02
    b4 = 6.680131188771972e+01
    b5 = -1.328068155288572e+01

    c1 = -7.784894002430293e-03
    c2 = -3.223964580411365e-01
    c3 = -2.400758277161838e+00
    c4 = -2.549732539343734e+00
    c5 = 4.374664141464968e+00
    c6 = 2.938163982698783e+00

    d1 = 7.784695709041462e-03
    d2 = 3.224671290700398e-01
    d3 = 2.445134137142996e+00
    d4 = 3.754408661907416e+00

    lower = 0.02425

iid :: Prob a -> Prob [a]
iid = sequence . repeat

gauss :: Prob Double
gauss = probit <$> uniform

norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss

-- Marsaglia's fast gamma rejection sampling
gamma :: Double -> Prob Double
gamma a = do
  x <- gauss
  u <- uniform
  if u < 1 - 0.03331 * x ** 4
    then pure $ d * v x
    else gamma a
  where
    d = a - 1 / 3
    v x = (1 + x / sqrt (9 * d)) ** 3

beta :: Double -> Double -> Prob Double
beta a b = do
  x <- gamma a
  y <- gamma b
  pure $ x / (x + y)

beta' :: Double -> Double -> Prob Double
beta' a b = do
  p <- beta a b
  pure $ p / (1 - p)

bern :: Double -> Prob Bool
bern p = (< p) <$> uniform

binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern

exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform

geom :: Double -> Prob Int
geom p = first 0 <$> iid (bern p)
  where
    first n (True : _) = n
    first n (_ : xs) = first (n + 1) xs
    first _ _ = undefined

uniform :: Prob Double
uniform = do
  z <- unbounded
  let r = countTrailingZeros z - 12
  e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r
  pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52)

-- Uses algorithm from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
bounded lower upper = do
  (k1, k2) <- split <$> bounded' 1 (hi - 2)
  pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
  where
    hi = ceilint lower upper g
    g = max (up lower - lower) (upper - down upper)

    split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)

    up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)

    down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)

    ceilint :: Double -> Double -> Double -> Word64
    ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
      where
        s = b / g - a / g
        eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)

unbounded :: Prob Word64
unbounded = Prob draw

bounded' :: Integral a => a -> a -> Prob a
bounded' lower upper = do
  z <- (`mod` nextPow2 r) <$> unbounded
  if z <= r
    then pure $ fromIntegral $ fromIntegral lower + z
    else bounded' lower upper
  where
    r :: Word64 = fromIntegral $ upper - lower
    nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)

cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
  where
    search i [] _ = i
    search i (x : xs) r
      | x > r = i
      | otherwise = search (i + 1) xs r

dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
  where
    go rest = do
      x <- beta 1 p
      (x * rest :) <$> go (rest - x * rest)