aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-03-04 17:30:41 +1100
committerJustin Bedo <cu@cua0.org>2025-03-07 21:05:38 +1100
commitbe4be249004af1e7039a29d0228b506139983ea2 (patch)
tree92f9a768c82ab38841098188cdccbaa8df28a036
parentd621243bffa14214c722a9be181f4a98d41380bb (diff)
switch to hashmap based tableHEADmaster
Allows significant speed increase, memory reduction, and simpler single site mode.
-rw-r--r--package.yaml4
-rw-r--r--ppl.cabal4
-rw-r--r--src/PPL.hs4
-rw-r--r--src/PPL/Distr.hs38
-rw-r--r--src/PPL/Internal.hs66
-rw-r--r--src/PPL/Sampling.hs133
6 files changed, 111 insertions, 138 deletions
diff --git a/package.yaml b/package.yaml
index 4e4d3bc..9efd300 100644
--- a/package.yaml
+++ b/package.yaml
@@ -12,8 +12,8 @@ dependencies:
- template-haskell
- containers
- streaming
- - ghc-heap
- - deepseq
+ - vector
+ - vector-hashtables
library:
source-dirs: src
diff --git a/ppl.cabal b/ppl.cabal
index b2d83f5..75dd0bc 100644
--- a/ppl.cabal
+++ b/ppl.cabal
@@ -24,11 +24,11 @@ library
build-depends:
base >=4.9 && <5
, containers
- , deepseq
- , ghc-heap
, log-domain
, random
, streaming
, template-haskell
, transformers
+ , vector
+ , vector-hashtables
default-language: Haskell2010
diff --git a/src/PPL.hs b/src/PPL.hs
index 21bb87c..d1dc17a 100644
--- a/src/PPL.hs
+++ b/src/PPL.hs
@@ -1,5 +1,5 @@
-module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+module PPL (module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+import PPL.Distr
import PPL.Internal
import PPL.Sampling
-import PPL.Distr
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index 093e102..e1fd5aa 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,7 +1,6 @@
module PPL.Distr where
import PPL.Internal
-import qualified PPL.Internal as I
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
@@ -9,14 +8,15 @@ import qualified PPL.Internal as I
probit :: Double -> Double
probit p
| p < lower =
- let q = sqrt (-2 * log p)
- in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
- / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+ let q = sqrt (-2 * log p)
+ in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+ / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
| p < 1 - lower =
- let q = p - 0.5
- r = q * q
- in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
- / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+ let q = p - 0.5
+ r = q * q
+ in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
+ * q
+ / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
| otherwise = -probit (1 - p)
where
a1 = -3.969683028665376e+01
@@ -49,11 +49,14 @@ probit p
iid :: Prob a -> Prob [a]
iid = sequence . repeat
+gauss :: Prob Double
gauss = probit <$> uniform
+norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss
-- Marsaglia's fast gamma rejection sampling
+gamma :: Double -> Prob Double
gamma a = do
x <- gauss
u <- uniform
@@ -64,19 +67,24 @@ gamma a = do
d = a - 1 / 3
v x = (1 + x / sqrt (9 * d)) ** 3
+beta :: Double -> Double -> Prob Double
beta a b = do
x <- gamma a
y <- gamma b
pure $ x / (x + y)
+beta' :: Double -> Double -> Prob Double
beta' a b = do
p <- beta a b
- pure $ p / (1-p)
+ pure $ p / (1 - p)
+bern :: Double -> Prob Bool
bern p = (< p) <$> uniform
+binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern
+exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform
geom :: Double -> Prob Int
@@ -84,9 +92,12 @@ geom p = first 0 <$> iid (bern p)
where
first n (True : _) = n
first n (_ : xs) = first (n + 1) xs
+ first _ _ = undefined
+bounded :: Double -> Double -> Prob Double
bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+bounded' :: Integral a => a -> a -> Prob a
bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
cat :: [Double] -> Prob Int
@@ -97,8 +108,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
| x > r = i
| otherwise = search (i + 1) xs r
+dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
- where
- go rest = do
- x <- beta 1 p
- (x*rest:) <$> go (rest - x*rest)
+ where
+ go rest = do
+ x <- beta 1 p
+ (x * rest :) <$> go (rest - x * rest)
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index b4283af..a81ba37 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,55 +1,80 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module PPL.Internal
( uniform,
- split,
Prob (..),
Meas,
score,
scoreLog,
sample,
- randomTree,
samples,
- mutateTree,
- Tree(..),
+ newTree,
+ HashMap,
+ Tree (..),
)
where
import Control.Monad
-import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
+import Data.Bits
+import Data.IORef
import Data.Monoid
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed.Mutable as UM
+import Data.Word
import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
+import System.IO.Unsafe
import System.Random hiding (split, uniform)
import qualified System.Random as R
+type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v
+
+-- Simple hashing scheme into 64-bit space based on mzHash64
+data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show)
+
+pushbit :: Bool -> Hash -> Hash
+pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2)
+ where
+ fromBool True = 1
+ fromBool False = 0
+
+initHash :: Hash
+initHash = Hash 0xE297DA430DB2DF1A
+
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
{ draw :: !Double,
- splitTrees :: [Tree]
+ split :: (Tree, Tree)
}
-split :: Tree -> (Tree, Tree)
-split (Tree r (t : ts)) = (t, Tree r ts)
-
-{-# INLINE randomTree #-}
-randomTree :: RandomGen g => g -> Tree
-randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+{-# INLINE newTree #-}
+newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree
+newTree s = go initHash
where
- randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-
-{-# INLINE mutateTree #-}
-mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
-mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
- if r < p
- then b
- else Tree a $ zipWith3 (mutateTree p) rs bs ts
+ go :: Hash -> Tree
+ go id =
+ Tree
+ ( unsafePerformIO $ do
+ (m, g) <- readIORef s
+ H.lookup m (unhash id) >>= \case
+ Nothing -> do
+ let (x, g') = R.random g
+ H.insert m (unhash id) x
+ writeIORef s (m, g')
+ pure x
+ Just x -> pure x
+ )
+ (go (pushbit False id), go (pushbit True id))
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
+uniform :: Prob Double
uniform = Prob $ \(Tree r _) -> r
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 8e0ab8f..7c2cb54 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,117 +1,52 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module PPL.Sampling
( mh,
- ssmh,
)
where
-import Control.DeepSeq
-import Control.Exception (evaluate)
+import Control.Monad
import Control.Monad.IO.Class
-import Control.Monad.Trans.State
-import Data.Bifunctor
-import qualified Data.Map.Strict as M
-import Data.Monoid
-import GHC.Exts.Heap
+import Data.IORef
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed as V
+import Data.Word
import Numeric.Log
-import PPL.Distr
import PPL.Internal
-import qualified Streaming as S
import Streaming.Prelude (Of, Stream, yield)
-import System.IO.Unsafe
-import System.Random (StdGen, random, randoms)
-
-mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p m = step t0 t x w
- where
- (t0, t) = split $ randomTree g
- (x, w) = head $ samples m t
-
- step !t0 !t !x !w = do
- let (t1 : t2 : t3 : t4 : _) = splitTrees t0
- t' = mutateTree p t1 t2 t
- (x', w') = head $ samples m t'
- ratio = w' / w
- (Exp . log -> r) = draw t3
- (t'', x'', w'') =
- if r < ratio
- then (t', x', w')
- else (t, x, w)
- yield (x'', w'')
- step t4 t'' x'' w''
-
--- Single site MH
-
--- Truncated trees
-data TTree = TTree Bool [Maybe TTree] deriving (Show)
-
-type Site = [Int]
-
-trunc :: Tree -> IO TTree
-trunc = truncTree . asBox
+import System.Random (StdGen)
+import qualified System.Random as R
+
+mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = do
+ let (g0, g1) = R.split g
+ hm <- liftIO $ H.initialize 0
+ omega <- liftIO $ newIORef (hm, g0)
+ let (x, w) = head $ samples m $ newTree omega
+ step g1 omega x w
where
- truncTree t =
- getBoxedClosureData' t >>= \case
- ConstrClosure _ [l, r] [] _ _ "Tree" ->
- getBoxedClosureData' l >>= \case
- ConstrClosure {dataArgs = [d], name = "D#"} ->
- TTree True <$> truncTrees r
- x -> error $ "truncTree:ConstrClosure:" ++ show x
- ConstrClosure _ [r] [d] _ _ "Tree" ->
- TTree False <$> truncTrees r
- x -> error $ "truncTree:" ++ show x
-
- getBoxedClosureData' x =
- getBoxedClosureData x >>= \c -> case c of
- BlackholeClosure _ t -> getBoxedClosureData' t
- _ -> pure c
-
- truncTrees b =
- getBoxedClosureData' b >>= \case
- ConstrClosure _ [l, r] [] _ _ ":" ->
- getBoxedClosureData' l >>= \case
- ConstrClosure {name = "Tree"} ->
- ((:) . Just) <$> truncTree l <*> truncTrees r
- _ -> (Nothing :) <$> truncTrees r
- _ -> pure []
-
-trunc' t x w = unsafePerformIO $ do
- evaluate (rnf x)
- evaluate (rnf w)
- trunc t
-
-sites :: Site -> TTree -> [Site]
-sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
-
-mutate = M.foldrWithKey go
- where
- go [] d (Tree _ ts) = Tree d ts
- go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts
-
-ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m ()
-ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w
- where
- (t0, t) = split $ randomTree g
- (x, w) = head $ samples m t0
-
- step !t !sub !tt !x !w = do
- let ss = sites [] tt
- (t1 : t2 : t3 : t4 : _) = splitTrees t
- i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate
- sub' = M.insert (reverse $ ss !! i) (draw t3) sub
- t' = mutate t0 sub'
- (x', w') = head $ samples m t'
- tt' = trunc' t' x' w'
+ step !g0 !omega !x !w = do
+ let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0
+ omega' <- mutate g1 omega
+ let (!x', !w') = head $ samples m $ newTree omega'
ratio = w' / w
- (Exp . log -> r) = draw t4
- (sub'', tt'', x'', w'') =
+ (omega'', x'', w'') =
if r < ratio
- then (sub', tt', x', w')
- else (sub, tt, x, w)
-
+ then (omega', x', w')
+ else (omega, x, w)
yield (x'', w'')
- step t1 sub'' tt'' x'' w''
+ step g2 omega'' x'' w''
+
+ mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen))
+ mutate g omega = liftIO $ do
+ (m, g0) <- readIORef omega
+ m' <- H.clone m
+ ks <- H.keys m
+ let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g)
+ n = fromIntegral (V.length ks)
+ void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs
+ newIORef (m', g0)