aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/PPL.hs4
-rw-r--r--src/PPL/Distr.hs40
-rw-r--r--src/PPL/Internal.hs67
-rw-r--r--src/PPL/Sampling.hs65
4 files changed, 118 insertions, 58 deletions
diff --git a/src/PPL.hs b/src/PPL.hs
index 21bb87c..d1dc17a 100644
--- a/src/PPL.hs
+++ b/src/PPL.hs
@@ -1,5 +1,5 @@
-module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+module PPL (module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+import PPL.Distr
import PPL.Internal
import PPL.Sampling
-import PPL.Distr
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index 797379b..e1fd5aa 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,7 +1,6 @@
module PPL.Distr where
import PPL.Internal
-import qualified PPL.Internal as I
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
@@ -9,14 +8,15 @@ import qualified PPL.Internal as I
probit :: Double -> Double
probit p
| p < lower =
- let q = sqrt (-2 * log p)
- in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
- / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+ let q = sqrt (-2 * log p)
+ in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+ / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
| p < 1 - lower =
- let q = p - 0.5
- r = q * q
- in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
- / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+ let q = p - 0.5
+ r = q * q
+ in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
+ * q
+ / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
| otherwise = -probit (1 - p)
where
a1 = -3.969683028665376e+01
@@ -49,11 +49,14 @@ probit p
iid :: Prob a -> Prob [a]
iid = sequence . repeat
+gauss :: Prob Double
gauss = probit <$> uniform
+norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss
-- Marsaglia's fast gamma rejection sampling
+gamma :: Double -> Prob Double
gamma a = do
x <- gauss
u <- uniform
@@ -64,15 +67,24 @@ gamma a = do
d = a - 1 / 3
v x = (1 + x / sqrt (9 * d)) ** 3
+beta :: Double -> Double -> Prob Double
beta a b = do
x <- gamma a
y <- gamma b
pure $ x / (x + y)
+beta' :: Double -> Double -> Prob Double
+beta' a b = do
+ p <- beta a b
+ pure $ p / (1 - p)
+
+bern :: Double -> Prob Bool
bern p = (< p) <$> uniform
+binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern
+exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform
geom :: Double -> Prob Int
@@ -80,9 +92,12 @@ geom p = first 0 <$> iid (bern p)
where
first n (True : _) = n
first n (_ : xs) = first (n + 1) xs
+ first _ _ = undefined
+bounded :: Double -> Double -> Prob Double
bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+bounded' :: Integral a => a -> a -> Prob a
bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
cat :: [Double] -> Prob Int
@@ -93,8 +108,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
| x > r = i
| otherwise = search (i + 1) xs r
+dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
- where
- go rest = do
- x <- beta 1 p
- (x*rest:) <$> go (rest - x*rest)
+ where
+ go rest = do
+ x <- beta 1 p
+ (x * rest :) <$> go (rest - x * rest)
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 0ba6ae5..a5b3758 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,56 +1,80 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module PPL.Internal
( uniform,
- split,
Prob (..),
Meas,
score,
scoreLog,
sample,
- randomTree,
samples,
- mutateTree,
- splitTrees,
- draw,
+ newTree,
+ HashMap,
+ Tree (..),
)
where
import Control.Monad
-import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
+import Data.Bits
+import Data.IORef
import Data.Monoid
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed.Mutable as UM
+import Data.Word
import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
+import System.IO.Unsafe
import System.Random hiding (split, uniform)
import qualified System.Random as R
+type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v
+
+-- Simple hashing scheme into 64-bit space based on mzHash64
+data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show)
+
+pushbit :: Bool -> Hash -> Hash
+pushbit b (Hash state) = Hash $ (16738720027710993212 * fromBool b) `xor` (state `shiftL` 3) `xor` (state `shiftR` 1)
+ where
+ fromBool True = 1
+ fromBool False = 0
+
+initHash :: Hash
+initHash = Hash 6223102867371013753
+
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
{ draw :: !Double,
- splitTrees :: [Tree]
+ split :: (Tree, Tree)
}
-split :: Tree -> (Tree, Tree)
-split (Tree r (t : ts)) = (t, Tree r ts)
-
-{-# INLINE randomTree #-}
-randomTree :: RandomGen g => g -> Tree
-randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+{-# INLINE newTree #-}
+newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree
+newTree s = go initHash
where
- randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-
-{-# INLINE mutateTree #-}
-mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
-mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
- if r < p
- then b
- else Tree a $ zipWith3 (mutateTree p) rs bs ts
+ go :: Hash -> Tree
+ go id =
+ Tree
+ ( unsafePerformIO $ do
+ (m, g) <- readIORef s
+ H.lookup m (unhash id) >>= \case
+ Nothing -> do
+ let (x, g') = R.random g
+ H.insert m (unhash id) x
+ writeIORef s (m, g')
+ pure x
+ Just x -> pure x
+ )
+ (go (pushbit False id), go (pushbit True id))
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -64,6 +88,7 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
+uniform :: Prob Double
uniform = Prob $ \(Tree r _) -> r
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 84e5ada..7c2cb54 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,33 +1,52 @@
-{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE ViewPatterns #-}
-module PPL.Sampling where
+module PPL.Sampling
+ ( mh,
+ )
+where
+import Control.Monad
import Control.Monad.IO.Class
-import Control.Monad.Trans.State
-import Data.Bifunctor
-import Data.Monoid
+import Data.IORef
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed as V
+import Data.Word
import Numeric.Log
-import PPL.Distr
import PPL.Internal
-import System.Random (StdGen, random, randoms)
-import qualified Streaming as S
-import Streaming.Prelude (Stream, yield, Of)
+import Streaming.Prelude (Of, Stream, yield)
+import System.Random (StdGen)
+import qualified System.Random as R
-mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p m = step t0 t x w
+mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = do
+ let (g0, g1) = R.split g
+ hm <- liftIO $ H.initialize 0
+ omega <- liftIO $ newIORef (hm, g0)
+ let (x, w) = head $ samples m $ newTree omega
+ step g1 omega x w
where
- (t0,t) = split $ randomTree g
- (x, w) = head $ samples m t
-
- step !t0 !t !x !w = do
- let (t1:t2:t3:t4:_) = splitTrees t0
- t' = mutateTree p t1 t2 t
- (x', w') = head $ samples m t'
+ step !g0 !omega !x !w = do
+ let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0
+ omega' <- mutate g1 omega
+ let (!x', !w') = head $ samples m $ newTree omega'
ratio = w' / w
- (Exp . log -> r) = draw t3
- (t'', x'', w'') = if r < ratio
- then (t', x', w')
- else (t, x, w)
+ (omega'', x'', w'') =
+ if r < ratio
+ then (omega', x', w')
+ else (omega, x, w)
yield (x'', w'')
- step t4 t'' x'' w''
+ step g2 omega'' x'' w''
+
+ mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen))
+ mutate g omega = liftIO $ do
+ (m, g0) <- readIORef omega
+ m' <- H.clone m
+ ks <- H.keys m
+ let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g)
+ n = fromIntegral (V.length ks)
+ void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs
+ newIORef (m', g0)