{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} module PPL.Distr where import Debug.Trace import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.)) import Data.Functor ((<&>)) import Data.Word (Word64) import PPL.Internal import Unsafe.Coerce -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ {-# INLINE probit #-} probit :: Double -> Double probit p | p < lower = let q = sqrt (-2 * log p) in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) | p < 1 - lower = let q = p - 0.5 r = q * q in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) | otherwise = -probit (1 - p) where a1 = -3.969683028665376e+01 a2 = 2.209460984245205e+02 a3 = -2.759285104469687e+02 a4 = 1.383577518672690e+02 a5 = -3.066479806614716e+01 a6 = 2.506628277459239e+00 b1 = -5.447609879822406e+01 b2 = 1.615858368580409e+02 b3 = -1.556989798598866e+02 b4 = 6.680131188771972e+01 b5 = -1.328068155288572e+01 c1 = -7.784894002430293e-03 c2 = -3.223964580411365e-01 c3 = -2.400758277161838e+00 c4 = -2.549732539343734e+00 c5 = 4.374664141464968e+00 c6 = 2.938163982698783e+00 d1 = 7.784695709041462e-03 d2 = 3.224671290700398e-01 d3 = 2.445134137142996e+00 d4 = 3.754408661907416e+00 lower = 0.02425 iid :: Prob a -> Prob [a] iid = sequence . repeat gauss :: Prob Double gauss = probit <$> uniform norm :: Double -> Double -> Prob Double norm m s = (+ m) . (* s) <$> gauss -- Marsaglia's fast gamma rejection sampling gamma :: Double -> Prob Double gamma a = do x <- gauss u <- uniform if u < 1 - 0.03331 * x ** 4 then pure $ d * v x else gamma a where d = a - 1 / 3 v x = (1 + x / sqrt (9 * d)) ** 3 beta :: Double -> Double -> Prob Double beta a b = do x <- gamma a y <- gamma b pure $ x / (x + y) beta' :: Double -> Double -> Prob Double beta' a b = do p <- beta a b pure $ p / (1 - p) bern :: Double -> Prob Bool bern p = (< p) <$> uniform binom :: Int -> Double -> Prob Int binom n = fmap (length . filter id . take n) . iid . bern exponential :: Double -> Prob Double exponential lambda = negate . (/ lambda) . log <$> uniform geom :: Double -> Prob Int geom p = first 0 <$> iid (bern p) where first n (True : _) = n first n (_ : xs) = first (n + 1) xs first _ _ = undefined uniform :: Prob Double uniform = do z <- unbounded let r = countTrailingZeros z - 12 e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52) -- Uses algorithm from 10.1145/3503512 bounded :: Double -> Double -> Prob Double bounded lower upper = do (k1, k2) <- split <$> bounded' 1 (hi - 2) pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g where hi = ceilint lower upper g g = max (up lower - lower) (upper - down upper) split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) ceilint :: Double -> Double -> Double -> Word64 ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0 where s = b / g - a / g eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) unbounded :: Prob Word64 unbounded = Prob draw bounded' :: Integral a => a -> a -> Prob a bounded' lower upper = do z <- (`mod` nextPow2 r) <$> unbounded if z <= r then pure $ fromIntegral $ fromIntegral lower + z else bounded' lower upper where r :: Word64 = fromIntegral $ upper - lower nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform where search i [] _ = i search i (x : xs) r | x > r = i | otherwise = search (i + 1) xs r dirichletProcess :: Double -> Prob [Double] dirichletProcess p = go 1 where go rest = do x <- beta 1 p (x * rest :) <$> go (rest - x * rest)