diff options
author | Justin Bedo <cu@cua0.org> | 2025-03-04 17:30:41 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-03-07 21:05:38 +1100 |
commit | be4be249004af1e7039a29d0228b506139983ea2 (patch) | |
tree | 92f9a768c82ab38841098188cdccbaa8df28a036 | |
parent | d621243bffa14214c722a9be181f4a98d41380bb (diff) |
Allows significant speed increase, memory reduction, and simpler
single site mode.
-rw-r--r-- | package.yaml | 4 | ||||
-rw-r--r-- | ppl.cabal | 4 | ||||
-rw-r--r-- | src/PPL.hs | 4 | ||||
-rw-r--r-- | src/PPL/Distr.hs | 38 | ||||
-rw-r--r-- | src/PPL/Internal.hs | 66 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 133 |
6 files changed, 111 insertions, 138 deletions
diff --git a/package.yaml b/package.yaml index 4e4d3bc..9efd300 100644 --- a/package.yaml +++ b/package.yaml @@ -12,8 +12,8 @@ dependencies: - template-haskell - containers - streaming - - ghc-heap - - deepseq + - vector + - vector-hashtables library: source-dirs: src @@ -24,11 +24,11 @@ library build-depends: base >=4.9 && <5 , containers - , deepseq - , ghc-heap , log-domain , random , streaming , template-haskell , transformers + , vector + , vector-hashtables default-language: Haskell2010 @@ -1,5 +1,5 @@ -module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where +module PPL (module PPL.Internal, module PPL.Sampling, module PPL.Distr) where +import PPL.Distr import PPL.Internal import PPL.Sampling -import PPL.Distr diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 093e102..e1fd5aa 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,7 +1,6 @@ module PPL.Distr where import PPL.Internal -import qualified PPL.Internal as I -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -9,14 +8,15 @@ import qualified PPL.Internal as I probit :: Double -> Double probit p | p < lower = - let q = sqrt (-2 * log p) - in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) - / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) + let q = sqrt (-2 * log p) + in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) + / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) | p < 1 - lower = - let q = p - 0.5 - r = q * q - in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q - / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) + let q = p - 0.5 + r = q * q + in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) + * q + / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) | otherwise = -probit (1 - p) where a1 = -3.969683028665376e+01 @@ -49,11 +49,14 @@ probit p iid :: Prob a -> Prob [a] iid = sequence . repeat +gauss :: Prob Double gauss = probit <$> uniform +norm :: Double -> Double -> Prob Double norm m s = (+ m) . (* s) <$> gauss -- Marsaglia's fast gamma rejection sampling +gamma :: Double -> Prob Double gamma a = do x <- gauss u <- uniform @@ -64,19 +67,24 @@ gamma a = do d = a - 1 / 3 v x = (1 + x / sqrt (9 * d)) ** 3 +beta :: Double -> Double -> Prob Double beta a b = do x <- gamma a y <- gamma b pure $ x / (x + y) +beta' :: Double -> Double -> Prob Double beta' a b = do p <- beta a b - pure $ p / (1-p) + pure $ p / (1 - p) +bern :: Double -> Prob Bool bern p = (< p) <$> uniform +binom :: Int -> Double -> Prob Int binom n = fmap (length . filter id . take n) . iid . bern +exponential :: Double -> Prob Double exponential lambda = negate . (/ lambda) . log <$> uniform geom :: Double -> Prob Int @@ -84,9 +92,12 @@ geom p = first 0 <$> iid (bern p) where first n (True : _) = n first n (_ : xs) = first (n + 1) xs + first _ _ = undefined +bounded :: Double -> Double -> Prob Double bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform +bounded' :: Integral a => a -> a -> Prob a bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) cat :: [Double] -> Prob Int @@ -97,8 +108,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform | x > r = i | otherwise = search (i + 1) xs r +dirichletProcess :: Double -> Prob [Double] dirichletProcess p = go 1 - where - go rest = do - x <- beta 1 p - (x*rest:) <$> go (rest - x*rest) + where + go rest = do + x <- beta 1 p + (x * rest :) <$> go (rest - x * rest) diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index b4283af..a81ba37 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,55 +1,80 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} module PPL.Internal ( uniform, - split, Prob (..), Meas, score, scoreLog, sample, - randomTree, samples, - mutateTree, - Tree(..), + newTree, + HashMap, + Tree (..), ) where import Control.Monad -import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor +import Data.Bits +import Data.IORef import Data.Monoid +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed.Mutable as UM +import Data.Word import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log +import System.IO.Unsafe import System.Random hiding (split, uniform) import qualified System.Random as R +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show) + +pushbit :: Bool -> Hash -> Hash +pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) + where + fromBool True = 1 + fromBool False = 0 + +initHash :: Hash +initHash = Hash 0xE297DA430DB2DF1A + -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, - splitTrees :: [Tree] + split :: (Tree, Tree) } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') +{-# INLINE newTree #-} +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash where - randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# INLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = - if r < p - then b - else Tree a $ zipWith3 (mutateTree p) rs bs ts + go :: Hash -> Tree + go id = + Tree + ( unsafePerformIO $ do + (m, g) <- readIORef s + H.lookup m (unhash id) >>= \case + Nothing -> do + let (x, g') = R.random g + H.insert m (unhash id) x + writeIORef s (m, g') + pure x + Just x -> pure x + ) + (go (pushbit False id), go (pushbit True id)) newtype Prob a = Prob {runProb :: Tree -> a} @@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap +uniform :: Prob Double uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 8e0ab8f..7c2cb54 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,117 +1,52 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} module PPL.Sampling ( mh, - ssmh, ) where -import Control.DeepSeq -import Control.Exception (evaluate) +import Control.Monad import Control.Monad.IO.Class -import Control.Monad.Trans.State -import Data.Bifunctor -import qualified Data.Map.Strict as M -import Data.Monoid -import GHC.Exts.Heap +import Data.IORef +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed as V +import Data.Word import Numeric.Log -import PPL.Distr import PPL.Internal -import qualified Streaming as S import Streaming.Prelude (Of, Stream, yield) -import System.IO.Unsafe -import System.Random (StdGen, random, randoms) - -mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () -mh g p m = step t0 t x w - where - (t0, t) = split $ randomTree g - (x, w) = head $ samples m t - - step !t0 !t !x !w = do - let (t1 : t2 : t3 : t4 : _) = splitTrees t0 - t' = mutateTree p t1 t2 t - (x', w') = head $ samples m t' - ratio = w' / w - (Exp . log -> r) = draw t3 - (t'', x'', w'') = - if r < ratio - then (t', x', w') - else (t, x, w) - yield (x'', w'') - step t4 t'' x'' w'' - --- Single site MH - --- Truncated trees -data TTree = TTree Bool [Maybe TTree] deriving (Show) - -type Site = [Int] - -trunc :: Tree -> IO TTree -trunc = truncTree . asBox +import System.Random (StdGen) +import qualified System.Random as R + +mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p m = do + let (g0, g1) = R.split g + hm <- liftIO $ H.initialize 0 + omega <- liftIO $ newIORef (hm, g0) + let (x, w) = head $ samples m $ newTree omega + step g1 omega x w where - truncTree t = - getBoxedClosureData' t >>= \case - ConstrClosure _ [l, r] [] _ _ "Tree" -> - getBoxedClosureData' l >>= \case - ConstrClosure {dataArgs = [d], name = "D#"} -> - TTree True <$> truncTrees r - x -> error $ "truncTree:ConstrClosure:" ++ show x - ConstrClosure _ [r] [d] _ _ "Tree" -> - TTree False <$> truncTrees r - x -> error $ "truncTree:" ++ show x - - getBoxedClosureData' x = - getBoxedClosureData x >>= \c -> case c of - BlackholeClosure _ t -> getBoxedClosureData' t - _ -> pure c - - truncTrees b = - getBoxedClosureData' b >>= \case - ConstrClosure _ [l, r] [] _ _ ":" -> - getBoxedClosureData' l >>= \case - ConstrClosure {name = "Tree"} -> - ((:) . Just) <$> truncTree l <*> truncTrees r - _ -> (Nothing :) <$> truncTrees r - _ -> pure [] - -trunc' t x w = unsafePerformIO $ do - evaluate (rnf x) - evaluate (rnf w) - trunc t - -sites :: Site -> TTree -> [Site] -sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] - -mutate = M.foldrWithKey go - where - go [] d (Tree _ ts) = Tree d ts - go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts - -ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () -ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w - where - (t0, t) = split $ randomTree g - (x, w) = head $ samples m t0 - - step !t !sub !tt !x !w = do - let ss = sites [] tt - (t1 : t2 : t3 : t4 : _) = splitTrees t - i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate - sub' = M.insert (reverse $ ss !! i) (draw t3) sub - t' = mutate t0 sub' - (x', w') = head $ samples m t' - tt' = trunc' t' x' w' + step !g0 !omega !x !w = do + let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 + omega' <- mutate g1 omega + let (!x', !w') = head $ samples m $ newTree omega' ratio = w' / w - (Exp . log -> r) = draw t4 - (sub'', tt'', x'', w'') = + (omega'', x'', w'') = if r < ratio - then (sub', tt', x', w') - else (sub, tt, x, w) - + then (omega', x', w') + else (omega, x, w) yield (x'', w'') - step t1 sub'' tt'' x'' w'' + step g2 omega'' x'' w'' + + mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen)) + mutate g omega = liftIO $ do + (m, g0) <- readIORef omega + m' <- H.clone m + ks <- H.keys m + let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g) + n = fromIntegral (V.length ks) + void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs + newIORef (m', g0) |