diff options
author | Justin Bedo <cu@cua0.org> | 2022-12-16 10:57:55 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2022-12-31 17:28:29 +1100 |
commit | e318b40cfc1a3079375a015a651b1b44f6279ad0 (patch) | |
tree | 40a72f1b807fd77a1dd53ebf3f0cc284e65f30ad /src |
init
Diffstat (limited to 'src')
-rw-r--r-- | src/PPL.hs | 5 | ||||
-rw-r--r-- | src/PPL/Distr.hs | 100 | ||||
-rw-r--r-- | src/PPL/Internal.hs | 81 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 52 |
4 files changed, 238 insertions, 0 deletions
diff --git a/src/PPL.hs b/src/PPL.hs new file mode 100644 index 0000000..21bb87c --- /dev/null +++ b/src/PPL.hs @@ -0,0 +1,5 @@ +module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where + +import PPL.Internal +import PPL.Sampling +import PPL.Distr diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs new file mode 100644 index 0000000..797379b --- /dev/null +++ b/src/PPL/Distr.hs @@ -0,0 +1,100 @@ +module PPL.Distr where + +import PPL.Internal +import qualified PPL.Internal as I + +-- Acklam's approximation +-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ +{-# INLINE probit #-} +probit :: Double -> Double +probit p + | p < lower = + let q = sqrt (-2 * log p) + in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) + / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) + | p < 1 - lower = + let q = p - 0.5 + r = q * q + in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q + / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) + | otherwise = -probit (1 - p) + where + a1 = -3.969683028665376e+01 + a2 = 2.209460984245205e+02 + a3 = -2.759285104469687e+02 + a4 = 1.383577518672690e+02 + a5 = -3.066479806614716e+01 + a6 = 2.506628277459239e+00 + + b1 = -5.447609879822406e+01 + b2 = 1.615858368580409e+02 + b3 = -1.556989798598866e+02 + b4 = 6.680131188771972e+01 + b5 = -1.328068155288572e+01 + + c1 = -7.784894002430293e-03 + c2 = -3.223964580411365e-01 + c3 = -2.400758277161838e+00 + c4 = -2.549732539343734e+00 + c5 = 4.374664141464968e+00 + c6 = 2.938163982698783e+00 + + d1 = 7.784695709041462e-03 + d2 = 3.224671290700398e-01 + d3 = 2.445134137142996e+00 + d4 = 3.754408661907416e+00 + + lower = 0.02425 + +iid :: Prob a -> Prob [a] +iid = sequence . repeat + +gauss = probit <$> uniform + +norm m s = (+ m) . (* s) <$> gauss + +-- Marsaglia's fast gamma rejection sampling +gamma a = do + x <- gauss + u <- uniform + if u < 1 - 0.03331 * x ** 4 + then pure $ d * v x + else gamma a + where + d = a - 1 / 3 + v x = (1 + x / sqrt (9 * d)) ** 3 + +beta a b = do + x <- gamma a + y <- gamma b + pure $ x / (x + y) + +bern p = (< p) <$> uniform + +binom n = fmap (length . filter id . take n) . iid . bern + +exponential lambda = negate . (/ lambda) . log <$> uniform + +geom :: Double -> Prob Int +geom p = first 0 <$> iid (bern p) + where + first n (True : _) = n + first n (_ : xs) = first (n + 1) xs + +bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform + +bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) + +cat :: [Double] -> Prob Int +cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform + where + search i [] _ = i + search i (x : xs) r + | x > r = i + | otherwise = search (i + 1) xs r + +dirichletProcess p = go 1 + where + go rest = do + x <- beta 1 p + (x*rest:) <$> go (rest - x*rest) diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs new file mode 100644 index 0000000..a729d3d --- /dev/null +++ b/src/PPL/Internal.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample, +randomTree, samples, mutateTree) where + +import Control.Monad +import Control.Monad.Trans.Class +import Control.Monad.Trans.Writer +import Data.Monoid +import qualified Language.Haskell.TH.Syntax as TH +import Numeric.Log +import System.Random hiding (uniform, split) +import qualified System.Random as R +import Data.Bifunctor +import Control.Monad.IO.Class + +-- Reimplementation of the LazyPPL monads to avoid some dependencies + +data Tree = Tree Double [Tree] + +split :: Tree -> (Tree, Tree) +split (Tree r (t : ts)) = (t, Tree r ts) + +{-# INLINE randomTree #-} +randomTree :: RandomGen g => g -> Tree +randomTree g = let (a, g') = random g in Tree a (randomTrees g') + +{-# INLINE randomTrees #-} +randomTrees :: RandomGen g => g -> [Tree] +randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 + +{-# INLINE mutateTree #-} +mutateTree :: RandomGen g => Double -> g -> Tree -> Tree +mutateTree p g (Tree a ts) = + let (r, g1) = random g + (b, g2) = random g1 + in Tree (if r < p then b else a) (mutateTrees p g2 ts) + +{-# INLINE mutateTrees #-} +mutateTrees :: RandomGen g => Double -> g -> [Tree] -> [Tree] +mutateTrees p g (t:ts) = + let (g1, g2) = R.split g + in mutateTree p g1 t : mutateTrees p g2 ts + +newtype Prob a = Prob {runProb :: Tree -> a} + +instance Monad Prob where + Prob f >>= g = Prob $ \t -> + let (t1, t2) = split t + (Prob g') = g (f t1) + in g' t2 + +instance Functor Prob where fmap = liftM + +instance Applicative Prob where pure = Prob . const; (<*>) = ap + +uniform = Prob $ \(Tree r _) -> r + +newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) + deriving (Functor, Applicative, Monad) + +{-# INLINE score #-} +score :: Double -> Meas () +score = scoreLog . Exp . log . max eps + where + eps = $(TH.lift (until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval + +{-# INLINE scoreLog #-} +scoreLog :: Log Double -> Meas () +scoreLog = Meas . tell . Product + +sample :: Prob a -> Meas a +sample = Meas . lift + +{-# INLINE samples #-} +samples :: forall a. Meas a -> Tree -> [(a, Log Double)] +samples (Meas m) t = map (second getProduct) $ runProb f t + where + f = runWriterT m >>= \x -> (x:) <$> f diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs new file mode 100644 index 0000000..a3f38db --- /dev/null +++ b/src/PPL/Sampling.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE ViewPatterns #-} + +module PPL.Sampling where + +import Control.Monad.IO.Class +import Control.Monad.Trans.State +import Data.Bifunctor +import Data.Monoid +import Numeric.Log +import PPL.Distr +import PPL.Internal hiding (split) +import System.Random (getStdGen, newStdGen, random, randoms, split) + +importance :: MonadIO m => Int -> Meas a -> m [a] +importance n m = do + newStdGen + g <- getStdGen + let ys = take n $ accumulate xs + max = snd $ last ys + xs = samples m $ randomTree g1 + (g1, g2) = split g + let rs = randoms g2 + pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max + where + cumsum = tail . scanl (+) 0 + accumulate = uncurry zip . second cumsum . unzip + +mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)] +mh p m = do + newStdGen + g <- getStdGen + let (g1, g2) = split g + (x, w) = head $ samples m t + t = randomTree g1 + pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2 + where + step (t, x, w) = do + g <- get + let (g1, g2) = split g + t' = mutateTree p g1 t + (x', w') = head $ samples m t' + ratio = w' / w + (Exp . log -> r, g3) = random g2 + put g3 + pure $ + if r < ratio + then (t', x', w') + else (t, x, w) + + iterateM f x = do + y <- f x + (y :) <$> iterateM f y |