diff options
Diffstat (limited to 'src/PPL/Distr.hs')
| -rw-r--r-- | src/PPL/Distr.hs | 40 | 
1 files changed, 28 insertions, 12 deletions
| diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 797379b..e1fd5aa 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,7 +1,6 @@  module PPL.Distr where  import PPL.Internal -import qualified PPL.Internal as I  -- Acklam's approximation  -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -9,14 +8,15 @@ import qualified PPL.Internal as I  probit :: Double -> Double  probit p    | p < lower = -    let q = sqrt (-2 * log p) -     in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) -          / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) +      let q = sqrt (-2 * log p) +       in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) +            / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)    | p < 1 - lower = -    let q = p - 0.5 -        r = q * q -     in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q -          / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) +      let q = p - 0.5 +          r = q * q +       in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) +            * q +            / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)    | otherwise = -probit (1 - p)    where      a1 = -3.969683028665376e+01 @@ -49,11 +49,14 @@ probit p  iid :: Prob a -> Prob [a]  iid = sequence . repeat +gauss :: Prob Double  gauss = probit <$> uniform +norm :: Double -> Double -> Prob Double  norm m s = (+ m) . (* s) <$> gauss  -- Marsaglia's fast gamma rejection sampling +gamma :: Double -> Prob Double  gamma a = do    x <- gauss    u <- uniform @@ -64,15 +67,24 @@ gamma a = do      d = a - 1 / 3      v x = (1 + x / sqrt (9 * d)) ** 3 +beta :: Double -> Double -> Prob Double  beta a b = do    x <- gamma a    y <- gamma b    pure $ x / (x + y) +beta' :: Double -> Double -> Prob Double +beta' a b = do +  p <- beta a b +  pure $ p / (1 - p) + +bern :: Double -> Prob Bool  bern p = (< p) <$> uniform +binom :: Int -> Double -> Prob Int  binom n = fmap (length . filter id . take n) . iid . bern +exponential :: Double -> Prob Double  exponential lambda = negate . (/ lambda) . log <$> uniform  geom :: Double -> Prob Int @@ -80,9 +92,12 @@ geom p = first 0 <$> iid (bern p)    where      first n (True : _) = n      first n (_ : xs) = first (n + 1) xs +    first _ _ = undefined +bounded :: Double -> Double -> Prob Double  bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform +bounded' :: Integral a => a -> a -> Prob a  bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)  cat :: [Double] -> Prob Int @@ -93,8 +108,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform        | x > r = i        | otherwise = search (i + 1) xs r +dirichletProcess :: Double -> Prob [Double]  dirichletProcess p = go 1 - where -   go rest = do -     x <- beta 1 p -     (x*rest:) <$> go (rest - x*rest) +  where +    go rest = do +      x <- beta 1 p +      (x * rest :) <$> go (rest - x * rest) | 
