aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
blob: 093e1021063bd1cc86c52692da973a9703ddabcf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
module PPL.Distr where

import PPL.Internal
import qualified PPL.Internal as I

-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
{-# INLINE probit #-}
probit :: Double -> Double
probit p
  | p < lower =
    let q = sqrt (-2 * log p)
     in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
          / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
  | p < 1 - lower =
    let q = p - 0.5
        r = q * q
     in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
          / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
  | otherwise = -probit (1 - p)
  where
    a1 = -3.969683028665376e+01
    a2 = 2.209460984245205e+02
    a3 = -2.759285104469687e+02
    a4 = 1.383577518672690e+02
    a5 = -3.066479806614716e+01
    a6 = 2.506628277459239e+00

    b1 = -5.447609879822406e+01
    b2 = 1.615858368580409e+02
    b3 = -1.556989798598866e+02
    b4 = 6.680131188771972e+01
    b5 = -1.328068155288572e+01

    c1 = -7.784894002430293e-03
    c2 = -3.223964580411365e-01
    c3 = -2.400758277161838e+00
    c4 = -2.549732539343734e+00
    c5 = 4.374664141464968e+00
    c6 = 2.938163982698783e+00

    d1 = 7.784695709041462e-03
    d2 = 3.224671290700398e-01
    d3 = 2.445134137142996e+00
    d4 = 3.754408661907416e+00

    lower = 0.02425

iid :: Prob a -> Prob [a]
iid = sequence . repeat

gauss = probit <$> uniform

norm m s = (+ m) . (* s) <$> gauss

-- Marsaglia's fast gamma rejection sampling
gamma a = do
  x <- gauss
  u <- uniform
  if u < 1 - 0.03331 * x ** 4
    then pure $ d * v x
    else gamma a
  where
    d = a - 1 / 3
    v x = (1 + x / sqrt (9 * d)) ** 3

beta a b = do
  x <- gamma a
  y <- gamma b
  pure $ x / (x + y)

beta' a b = do
  p <- beta a b
  pure $ p / (1-p)

bern p = (< p) <$> uniform

binom n = fmap (length . filter id . take n) . iid . bern

exponential lambda = negate . (/ lambda) . log <$> uniform

geom :: Double -> Prob Int
geom p = first 0 <$> iid (bern p)
  where
    first n (True : _) = n
    first n (_ : xs) = first (n + 1) xs

bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform

bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)

cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
  where
    search i [] _ = i
    search i (x : xs) r
      | x > r = i
      | otherwise = search (i + 1) xs r

dirichletProcess p = go 1
 where
   go rest = do
     x <- beta 1 p
     (x*rest:) <$> go (rest - x*rest)