aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
blob: fae825bcf49b0d47f019533ca23147c0e94db6ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
{-# LANGUAGE TypeApplications #-}

module PPL.Distr where

import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.))
import Data.Functor ((<&>))
import Data.Word (Word64)
import PPL.Internal
import Unsafe.Coerce

-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
{-# INLINE probit #-}
probit :: Double -> Double
probit p
  | p < lower =
      let q = sqrt (-2 * log p)
       in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
            / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
  | p < 1 - lower =
      let q = p - 0.5
          r = q * q
       in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
            * q
            / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
  | otherwise = -probit (1 - p)
  where
    a1 = -3.969683028665376e+01
    a2 = 2.209460984245205e+02
    a3 = -2.759285104469687e+02
    a4 = 1.383577518672690e+02
    a5 = -3.066479806614716e+01
    a6 = 2.506628277459239e+00

    b1 = -5.447609879822406e+01
    b2 = 1.615858368580409e+02
    b3 = -1.556989798598866e+02
    b4 = 6.680131188771972e+01
    b5 = -1.328068155288572e+01

    c1 = -7.784894002430293e-03
    c2 = -3.223964580411365e-01
    c3 = -2.400758277161838e+00
    c4 = -2.549732539343734e+00
    c5 = 4.374664141464968e+00
    c6 = 2.938163982698783e+00

    d1 = 7.784695709041462e-03
    d2 = 3.224671290700398e-01
    d3 = 2.445134137142996e+00
    d4 = 3.754408661907416e+00

    lower = 0.02425

iid :: Prob a -> Prob [a]
iid = sequence . repeat

gauss :: Prob Double
gauss = probit <$> uniform

norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss

-- Marsaglia's fast gamma rejection sampling
gamma :: Double -> Prob Double
gamma a = do
  x <- gauss
  u <- uniform
  if u < 1 - 0.03331 * x ** 4
    then pure $ d * v x
    else gamma a
  where
    d = a - 1 / 3
    v x = (1 + x / sqrt (9 * d)) ** 3

beta :: Double -> Double -> Prob Double
beta a b = do
  x <- gamma a
  y <- gamma b
  pure $ x / (x + y)

beta' :: Double -> Double -> Prob Double
beta' a b = do
  p <- beta a b
  pure $ p / (1 - p)

bern :: Double -> Prob Bool
bern p = (< p) <$> uniform

binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern

exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform

geom :: Double -> Prob Int
geom p = first 0 <$> iid (bern p)
  where
    first n (True : _) = n
    first n (_ : xs) = first (n + 1) xs
    first _ _ = undefined

-- Uses idea from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
bounded lower upper
  | hi - 2 <= 0xfffffffffffff = do
      (k1, k2) <- split <$> boundedWord52 1 (hi - 2)
      pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
  | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x)
  where
    hi = ceilint lower upper g
    g = max (up lower - lower) (upper - down upper)

    split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)

    up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
    down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)

    ceilint :: Double -> Double -> Double -> Word64
    ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
      where
        s = b / g - a / g
        eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)

-- Random mantissa part of a uniform double
unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform

boundedWord52 l r = do
  x <- (`mod` nextPow2 r) <$> unboundedWord52
  if x <= r then pure (x + l) else boundedWord52 l r
  where
    nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)

bounded' :: Integral a => a -> a -> Prob a
bounded' lower upper
  | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r)
  | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper)
  where
    r = upper - lower

cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
  where
    search i [] _ = i
    search i (x : xs) r
      | x > r = i
      | otherwise = search (i + 1) xs r

dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
  where
    go rest = do
      x <- beta 1 p
      (x * rest :) <$> go (rest - x * rest)