module PPL.Distr where import PPL.Internal import qualified PPL.Internal as I -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ {-# INLINE probit #-} probit :: Double -> Double probit p | p < lower = let q = sqrt (-2 * log p) in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) | p < 1 - lower = let q = p - 0.5 r = q * q in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) | otherwise = -probit (1 - p) where a1 = -3.969683028665376e+01 a2 = 2.209460984245205e+02 a3 = -2.759285104469687e+02 a4 = 1.383577518672690e+02 a5 = -3.066479806614716e+01 a6 = 2.506628277459239e+00 b1 = -5.447609879822406e+01 b2 = 1.615858368580409e+02 b3 = -1.556989798598866e+02 b4 = 6.680131188771972e+01 b5 = -1.328068155288572e+01 c1 = -7.784894002430293e-03 c2 = -3.223964580411365e-01 c3 = -2.400758277161838e+00 c4 = -2.549732539343734e+00 c5 = 4.374664141464968e+00 c6 = 2.938163982698783e+00 d1 = 7.784695709041462e-03 d2 = 3.224671290700398e-01 d3 = 2.445134137142996e+00 d4 = 3.754408661907416e+00 lower = 0.02425 iid :: Prob a -> Prob [a] iid = sequence . repeat gauss = probit <$> uniform norm m s = (+ m) . (* s) <$> gauss -- Marsaglia's fast gamma rejection sampling gamma a = do x <- gauss u <- uniform if u < 1 - 0.03331 * x ** 4 then pure $ d * v x else gamma a where d = a - 1 / 3 v x = (1 + x / sqrt (9 * d)) ** 3 beta a b = do x <- gamma a y <- gamma b pure $ x / (x + y) bern p = (< p) <$> uniform binom n = fmap (length . filter id . take n) . iid . bern exponential lambda = negate . (/ lambda) . log <$> uniform geom :: Double -> Prob Int geom p = first 0 <$> iid (bern p) where first n (True : _) = n first n (_ : xs) = first (n + 1) xs bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform where search i [] _ = i search i (x : xs) r | x > r = i | otherwise = search (i + 1) xs r dirichletProcess p = go 1 where go rest = do x <- beta 1 p (x*rest:) <$> go (rest - x*rest)