diff options
Diffstat (limited to 'src/PPL/Distr.hs')
| -rw-r--r-- | src/PPL/Distr.hs | 88 |
1 files changed, 74 insertions, 14 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 797379b..d265f05 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,7 +1,15 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} + module PPL.Distr where +import Debug.Trace +import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.)) +import Data.Functor ((<&>)) +import Data.Word (Word64) import PPL.Internal -import qualified PPL.Internal as I +import Unsafe.Coerce -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -9,14 +17,15 @@ import qualified PPL.Internal as I probit :: Double -> Double probit p | p < lower = - let q = sqrt (-2 * log p) - in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) - / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) + let q = sqrt (-2 * log p) + in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) + / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) | p < 1 - lower = - let q = p - 0.5 - r = q * q - in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q - / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) + let q = p - 0.5 + r = q * q + in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) + * q + / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) | otherwise = -probit (1 - p) where a1 = -3.969683028665376e+01 @@ -49,11 +58,14 @@ probit p iid :: Prob a -> Prob [a] iid = sequence . repeat +gauss :: Prob Double gauss = probit <$> uniform +norm :: Double -> Double -> Prob Double norm m s = (+ m) . (* s) <$> gauss -- Marsaglia's fast gamma rejection sampling +gamma :: Double -> Prob Double gamma a = do x <- gauss u <- uniform @@ -64,15 +76,24 @@ gamma a = do d = a - 1 / 3 v x = (1 + x / sqrt (9 * d)) ** 3 +beta :: Double -> Double -> Prob Double beta a b = do x <- gamma a y <- gamma b pure $ x / (x + y) +beta' :: Double -> Double -> Prob Double +beta' a b = do + p <- beta a b + pure $ p / (1 - p) + +bern :: Double -> Prob Bool bern p = (< p) <$> uniform +binom :: Int -> Double -> Prob Int binom n = fmap (length . filter id . take n) . iid . bern +exponential :: Double -> Prob Double exponential lambda = negate . (/ lambda) . log <$> uniform geom :: Double -> Prob Int @@ -80,10 +101,48 @@ geom p = first 0 <$> iid (bern p) where first n (True : _) = n first n (_ : xs) = first (n + 1) xs + first _ _ = undefined + +uniform :: Prob Double +uniform = do + z <- unbounded + let r = countTrailingZeros z - 12 + e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r + pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52) + +-- Uses algorithm from 10.1145/3503512 +bounded :: Double -> Double -> Prob Double +bounded lower upper = do + (k1, k2) <- split <$> bounded' 1 (hi - 2) + pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g + where + hi = ceilint lower upper g + g = max (up lower - lower) (upper - down upper) + + split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) + + up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) -bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform + down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) -bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) + ceilint :: Double -> Double -> Double -> Word64 + ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0 + where + s = b / g - a / g + eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) + +unbounded :: Prob Word64 +unbounded = Prob draw + +bounded' :: Integral a => a -> a -> Prob a +bounded' lower upper = do + z <- (`mod` nextPow2 r) <$> unbounded + if z <= r + then pure $ fromIntegral $ fromIntegral lower + z + else bounded' lower upper + where + r :: Word64 = fromIntegral $ upper - lower + nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform @@ -93,8 +152,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform | x > r = i | otherwise = search (i + 1) xs r +dirichletProcess :: Double -> Prob [Double] dirichletProcess p = go 1 - where - go rest = do - x <- beta 1 p - (x*rest:) <$> go (rest - x*rest) + where + go rest = do + x <- beta 1 p + (x * rest :) <$> go (rest - x * rest) |
