From e318b40cfc1a3079375a015a651b1b44f6279ad0 Mon Sep 17 00:00:00 2001
From: Justin Bedo <cu@cua0.org>
Date: Fri, 16 Dec 2022 10:57:55 +1100
Subject: init

---
 src/PPL/Distr.hs    | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 src/PPL/Internal.hs |  81 ++++++++++++++++++++++++++++++++++++++++++
 src/PPL/Sampling.hs |  52 +++++++++++++++++++++++++++
 3 files changed, 233 insertions(+)
 create mode 100644 src/PPL/Distr.hs
 create mode 100644 src/PPL/Internal.hs
 create mode 100644 src/PPL/Sampling.hs

(limited to 'src/PPL')

diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
new file mode 100644
index 0000000..797379b
--- /dev/null
+++ b/src/PPL/Distr.hs
@@ -0,0 +1,100 @@
+module PPL.Distr where
+
+import PPL.Internal
+import qualified PPL.Internal as I
+
+-- Acklam's approximation
+-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
+{-# INLINE probit #-}
+probit :: Double -> Double
+probit p
+  | p < lower =
+    let q = sqrt (-2 * log p)
+     in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+          / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+  | p < 1 - lower =
+    let q = p - 0.5
+        r = q * q
+     in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
+          / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+  | otherwise = -probit (1 - p)
+  where
+    a1 = -3.969683028665376e+01
+    a2 = 2.209460984245205e+02
+    a3 = -2.759285104469687e+02
+    a4 = 1.383577518672690e+02
+    a5 = -3.066479806614716e+01
+    a6 = 2.506628277459239e+00
+
+    b1 = -5.447609879822406e+01
+    b2 = 1.615858368580409e+02
+    b3 = -1.556989798598866e+02
+    b4 = 6.680131188771972e+01
+    b5 = -1.328068155288572e+01
+
+    c1 = -7.784894002430293e-03
+    c2 = -3.223964580411365e-01
+    c3 = -2.400758277161838e+00
+    c4 = -2.549732539343734e+00
+    c5 = 4.374664141464968e+00
+    c6 = 2.938163982698783e+00
+
+    d1 = 7.784695709041462e-03
+    d2 = 3.224671290700398e-01
+    d3 = 2.445134137142996e+00
+    d4 = 3.754408661907416e+00
+
+    lower = 0.02425
+
+iid :: Prob a -> Prob [a]
+iid = sequence . repeat
+
+gauss = probit <$> uniform
+
+norm m s = (+ m) . (* s) <$> gauss
+
+-- Marsaglia's fast gamma rejection sampling
+gamma a = do
+  x <- gauss
+  u <- uniform
+  if u < 1 - 0.03331 * x ** 4
+    then pure $ d * v x
+    else gamma a
+  where
+    d = a - 1 / 3
+    v x = (1 + x / sqrt (9 * d)) ** 3
+
+beta a b = do
+  x <- gamma a
+  y <- gamma b
+  pure $ x / (x + y)
+
+bern p = (< p) <$> uniform
+
+binom n = fmap (length . filter id . take n) . iid . bern
+
+exponential lambda = negate . (/ lambda) . log <$> uniform
+
+geom :: Double -> Prob Int
+geom p = first 0 <$> iid (bern p)
+  where
+    first n (True : _) = n
+    first n (_ : xs) = first (n + 1) xs
+
+bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+
+bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+
+cat :: [Double] -> Prob Int
+cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
+  where
+    search i [] _ = i
+    search i (x : xs) r
+      | x > r = i
+      | otherwise = search (i + 1) xs r
+
+dirichletProcess p = go 1
+ where
+   go rest = do
+     x <- beta 1 p
+     (x*rest:) <$> go (rest - x*rest)
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
new file mode 100644
index 0000000..a729d3d
--- /dev/null
+++ b/src/PPL/Internal.hs
@@ -0,0 +1,81 @@
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+
+module PPL.Internal (uniform, split, Prob(..), Meas, score, scoreLog, sample,
+randomTree, samples, mutateTree) where
+
+import Control.Monad
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Writer
+import Data.Monoid
+import qualified Language.Haskell.TH.Syntax as TH
+import Numeric.Log
+import System.Random hiding (uniform, split)
+import qualified System.Random as R
+import Data.Bifunctor
+import Control.Monad.IO.Class
+
+-- Reimplementation of the LazyPPL monads to avoid some dependencies
+
+data Tree = Tree Double [Tree]
+
+split :: Tree -> (Tree, Tree)
+split (Tree r (t : ts)) = (t, Tree r ts)
+
+{-# INLINE randomTree #-}
+randomTree :: RandomGen g => g -> Tree
+randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+
+{-# INLINE randomTrees #-}
+randomTrees :: RandomGen g => g -> [Tree]
+randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
+
+{-# INLINE mutateTree #-}
+mutateTree :: RandomGen g => Double -> g -> Tree -> Tree
+mutateTree p g (Tree a ts) =
+  let (r, g1) = random g
+      (b, g2) = random g1
+  in Tree (if r < p then b else a) (mutateTrees p g2 ts)
+
+{-# INLINE mutateTrees #-}
+mutateTrees :: RandomGen g => Double -> g -> [Tree] -> [Tree]
+mutateTrees p g (t:ts) =
+  let (g1, g2) = R.split g
+  in mutateTree p g1 t : mutateTrees p g2 ts
+
+newtype Prob a = Prob {runProb :: Tree -> a}
+
+instance Monad Prob where
+  Prob f >>= g = Prob $ \t ->
+    let (t1, t2) = split t
+        (Prob g') = g (f t1)
+     in g' t2
+
+instance Functor Prob where fmap = liftM
+
+instance Applicative Prob where pure = Prob . const; (<*>) = ap
+
+uniform = Prob $ \(Tree r _) -> r
+
+newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
+  deriving (Functor, Applicative, Monad)
+
+{-# INLINE score #-}
+score :: Double -> Meas ()
+score = scoreLog . Exp . log . max eps
+  where
+    eps = $(TH.lift (until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval
+
+{-# INLINE scoreLog #-}
+scoreLog :: Log Double -> Meas ()
+scoreLog = Meas . tell . Product
+
+sample :: Prob a -> Meas a
+sample = Meas . lift
+
+{-# INLINE samples #-}
+samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
+samples (Meas m) t = map (second getProduct) $ runProb f t
+  where
+    f = runWriterT m >>= \x -> (x:) <$> f
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
new file mode 100644
index 0000000..a3f38db
--- /dev/null
+++ b/src/PPL/Sampling.hs
@@ -0,0 +1,52 @@
+{-# LANGUAGE ViewPatterns #-}
+
+module PPL.Sampling where
+
+import Control.Monad.IO.Class
+import Control.Monad.Trans.State
+import Data.Bifunctor
+import Data.Monoid
+import Numeric.Log
+import PPL.Distr
+import PPL.Internal hiding (split)
+import System.Random (getStdGen, newStdGen, random, randoms, split)
+
+importance :: MonadIO m => Int -> Meas a -> m [a]
+importance n m = do
+  newStdGen
+  g <- getStdGen
+  let ys = take n $ accumulate xs
+      max = snd $ last ys
+      xs = samples m $ randomTree g1
+      (g1, g2) = split g
+  let rs = randoms g2
+  pure $ flip map rs $ \r -> fst . head $ flip filter ys $ \(x, w) -> w >= Exp (log r) * max
+  where
+    cumsum = tail . scanl (+) 0
+    accumulate = uncurry zip . second cumsum . unzip
+
+mh :: MonadIO m => Double -> Meas a -> m [(a, Log Double)]
+mh p m = do
+  newStdGen
+  g <- getStdGen
+  let (g1, g2) = split g
+      (x, w) = head $ samples m t
+      t = randomTree g1
+  pure $ map (\(_, x, w) -> (x, w)) $ evalState (iterateM step (t, x, w)) g2
+  where
+    step (t, x, w) = do
+      g <- get
+      let (g1, g2) = split g
+          t' = mutateTree p g1 t
+          (x', w') = head $ samples m t'
+          ratio = w' / w
+          (Exp . log -> r, g3) = random g2
+      put g3
+      pure $
+        if r < ratio
+          then (t', x', w')
+          else (t, x, w)
+
+    iterateM f x = do
+      y <- f x
+      (y :) <$> iterateM f y
-- 
cgit v1.2.3