aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Distr.hs
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2022-12-16 10:57:55 +1100
committerJustin Bedo <cu@cua0.org>2022-12-31 17:28:29 +1100
commite318b40cfc1a3079375a015a651b1b44f6279ad0 (patch)
tree40a72f1b807fd77a1dd53ebf3f0cc284e65f30ad /src/PPL/Distr.hs
init
Diffstat (limited to 'src/PPL/Distr.hs')
-rw-r--r--src/PPL/Distr.hs100
1 files changed, 100 insertions, 0 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
new file mode 100644
index 0000000..797379b
--- /dev/null
+++ b/src/PPL/Distr.hs
@@ -0,0 +1,100 @@
+module PPL.Distr where
+
+import PPL.Internal
+import qualified PPL.Internal as I
+
+-- Acklam's approximation
+-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
+{-# INLINE probit #-}
+probit :: Double -> Double
+probit p
+ | p < lower =
+ let q = sqrt (-2 * log p)
+ in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+ / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+ | p < 1 - lower =
+ let q = p - 0.5
+ r = q * q
+ in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
+ / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+ | otherwise = -probit (1 - p)
+ where
+ a1 = -3.969683028665376e+01
+ a2 = 2.209460984245205e+02
+ a3 = -2.759285104469687e+02
+ a4 = 1.383577518672690e+02
+ a5 = -3.066479806614716e+01
+ a6 = 2.506628277459239e+00
+
+ b1 = -5.447609879822406e+01
+ b2 = 1.615858368580409e+02
+ b3 = -1.556989798598866e+02
+ b4 = 6.680131188771972e+01
+ b5 = -1.328068155288572e+01
+
+ c1 = -7.784894002430293e-03
+ c2 = -3.223964580411365e-01
+ c3 = -2.400758277161838e+00
+ c4 = -2.549732539343734e+00
+ c5 = 4.374664141464968e+00
+ c6 = 2.938163982698783e+00
+
+ d1 = 7.784695709041462e-03
+ d2 = 3.224671290700398e-01
+ d3 = 2.445134137142996e+00
+ d4 = 3.754408661907416e+00
+
+ lower = 0.02425
+
+iid :: Prob a -> Prob [a]
+iid = sequence . repeat
+
+gauss = probit <$> uniform
+
+norm m s = (+ m) . (* s) <$> gauss
+
+-- Marsaglia's fast gamma rejection sampling
+gamma a = do
+ x <- gauss
+ u <- uniform
+ if u < 1 - 0.03331 * x ** 4
+ then pure $ d * v x
+ else gamma a
+ where
+ d = a - 1 / 3
+ v x = (1 + x / sqrt (9 * d)) ** 3
+
+beta a b = do
+ x <- gamma a
+ y <- gamma b
+ pure $ x / (x + y)
+
+bern p = (< p) <$> uniform
+
+binom n = fmap (length . filter id . take n) . iid . bern
+
+exponential lambda = negate . (/ lambda) . log <$> uniform
+
+geom :: Double -> Prob Int
+geom p = first 0 <$> iid (bern p)
+ where
+ first n (True : _) = n
+ first n (_ : xs) = first (n + 1) xs
+
+bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+
+bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+
+cat :: [Double] -> Prob Int
+cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
+ where
+ search i [] _ = i
+ search i (x : xs) r
+ | x > r = i
+ | otherwise = search (i + 1) xs r
+
+dirichletProcess p = go 1
+ where
+ go rest = do
+ x <- beta 1 p
+ (x*rest:) <$> go (rest - x*rest)