diff options
| author | Justin Bedo <cu@cua0.org> | 2025-03-04 17:30:41 +1100 | 
|---|---|---|
| committer | Justin Bedo <cu@cua0.org> | 2025-03-07 21:05:38 +1100 | 
| commit | be4be249004af1e7039a29d0228b506139983ea2 (patch) | |
| tree | 92f9a768c82ab38841098188cdccbaa8df28a036 /src/PPL | |
| parent | d621243bffa14214c722a9be181f4a98d41380bb (diff) | |
switch to hashmap based table
Allows significant speed increase, memory reduction, and simpler
single site mode.
Diffstat (limited to 'src/PPL')
| -rw-r--r-- | src/PPL/Distr.hs | 38 | ||||
| -rw-r--r-- | src/PPL/Internal.hs | 66 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 133 | 
3 files changed, 105 insertions, 132 deletions
| diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 093e102..e1fd5aa 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,7 +1,6 @@  module PPL.Distr where  import PPL.Internal -import qualified PPL.Internal as I  -- Acklam's approximation  -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -9,14 +8,15 @@ import qualified PPL.Internal as I  probit :: Double -> Double  probit p    | p < lower = -    let q = sqrt (-2 * log p) -     in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) -          / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) +      let q = sqrt (-2 * log p) +       in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) +            / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)    | p < 1 - lower = -    let q = p - 0.5 -        r = q * q -     in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q -          / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) +      let q = p - 0.5 +          r = q * q +       in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) +            * q +            / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)    | otherwise = -probit (1 - p)    where      a1 = -3.969683028665376e+01 @@ -49,11 +49,14 @@ probit p  iid :: Prob a -> Prob [a]  iid = sequence . repeat +gauss :: Prob Double  gauss = probit <$> uniform +norm :: Double -> Double -> Prob Double  norm m s = (+ m) . (* s) <$> gauss  -- Marsaglia's fast gamma rejection sampling +gamma :: Double -> Prob Double  gamma a = do    x <- gauss    u <- uniform @@ -64,19 +67,24 @@ gamma a = do      d = a - 1 / 3      v x = (1 + x / sqrt (9 * d)) ** 3 +beta :: Double -> Double -> Prob Double  beta a b = do    x <- gamma a    y <- gamma b    pure $ x / (x + y) +beta' :: Double -> Double -> Prob Double  beta' a b = do    p <- beta a b -  pure $ p / (1-p) +  pure $ p / (1 - p) +bern :: Double -> Prob Bool  bern p = (< p) <$> uniform +binom :: Int -> Double -> Prob Int  binom n = fmap (length . filter id . take n) . iid . bern +exponential :: Double -> Prob Double  exponential lambda = negate . (/ lambda) . log <$> uniform  geom :: Double -> Prob Int @@ -84,9 +92,12 @@ geom p = first 0 <$> iid (bern p)    where      first n (True : _) = n      first n (_ : xs) = first (n + 1) xs +    first _ _ = undefined +bounded :: Double -> Double -> Prob Double  bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform +bounded' :: Integral a => a -> a -> Prob a  bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)  cat :: [Double] -> Prob Int @@ -97,8 +108,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform        | x > r = i        | otherwise = search (i + 1) xs r +dirichletProcess :: Double -> Prob [Double]  dirichletProcess p = go 1 - where -   go rest = do -     x <- beta 1 p -     (x*rest:) <$> go (rest - x*rest) +  where +    go rest = do +      x <- beta 1 p +      (x * rest :) <$> go (rest - x * rest) diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index b4283af..a81ba37 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,55 +1,80 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-}  module PPL.Internal    ( uniform, -    split,      Prob (..),      Meas,      score,      scoreLog,      sample, -    randomTree,      samples, -    mutateTree, -    Tree(..), +    newTree, +    HashMap, +    Tree (..),    )  where  import Control.Monad -import Control.Monad.IO.Class  import Control.Monad.Trans.Class  import Control.Monad.Trans.Writer  import Data.Bifunctor +import Data.Bits +import Data.IORef  import Data.Monoid +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed.Mutable as UM +import Data.Word  import qualified Language.Haskell.TH.Syntax as TH  import Numeric.Log +import System.IO.Unsafe  import System.Random hiding (split, uniform)  import qualified System.Random as R +type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v + +-- Simple hashing scheme into 64-bit space based on mzHash64 +data Hash = Hash {unhash :: Word64} deriving (Eq, Ord, Show) + +pushbit :: Bool -> Hash -> Hash +pushbit b (Hash state) = Hash $ (0xB2DEEF07CB4ACD43 * fromBool b) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) +  where +    fromBool True = 1 +    fromBool False = 0 + +initHash :: Hash +initHash = Hash 0xE297DA430DB2DF1A +  -- Reimplementation of the LazyPPL monads to avoid some dependencies  data Tree = Tree    { draw :: !Double, -    splitTrees :: [Tree] +    split :: (Tree, Tree)    } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') +{-# INLINE newTree #-} +newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree +newTree s = go initHash    where -    randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# INLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = -  if r < p -    then b -    else Tree a $ zipWith3 (mutateTree p) rs bs ts +    go :: Hash -> Tree +    go id = +      Tree +        ( unsafePerformIO $ do +            (m, g) <- readIORef s +            H.lookup m (unhash id) >>= \case +              Nothing -> do +                let (x, g') = R.random g +                H.insert m (unhash id) x +                writeIORef s (m, g') +                pure x +              Just x -> pure x +        ) +        (go (pushbit False id), go (pushbit True id))  newtype Prob a = Prob {runProb :: Tree -> a} @@ -63,6 +88,7 @@ instance Functor Prob where fmap = liftM  instance Applicative Prob where pure = Prob . const; (<*>) = ap +uniform :: Prob Double  uniform = Prob $ \(Tree r _) -> r  newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 8e0ab8f..7c2cb54 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,117 +1,52 @@  {-# LANGUAGE BangPatterns #-}  {-# LANGUAGE BlockArguments #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-}  {-# LANGUAGE ViewPatterns #-}  module PPL.Sampling    ( mh, -    ssmh,    )  where -import Control.DeepSeq -import Control.Exception (evaluate) +import Control.Monad  import Control.Monad.IO.Class -import Control.Monad.Trans.State -import Data.Bifunctor -import qualified Data.Map.Strict as M -import Data.Monoid -import GHC.Exts.Heap +import Data.IORef +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed as V +import Data.Word  import Numeric.Log -import PPL.Distr  import PPL.Internal -import qualified Streaming as S  import Streaming.Prelude (Of, Stream, yield) -import System.IO.Unsafe -import System.Random (StdGen, random, randoms) - -mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () -mh g p m = step t0 t x w -  where -    (t0, t) = split $ randomTree g -    (x, w) = head $ samples m t - -    step !t0 !t !x !w = do -      let (t1 : t2 : t3 : t4 : _) = splitTrees t0 -          t' = mutateTree p t1 t2 t -          (x', w') = head $ samples m t' -          ratio = w' / w -          (Exp . log -> r) = draw t3 -          (t'', x'', w'') = -            if r < ratio -              then (t', x', w') -              else (t, x, w) -      yield (x'', w'') -      step t4 t'' x'' w'' - --- Single site MH - --- Truncated trees -data TTree = TTree Bool [Maybe TTree] deriving (Show) - -type Site = [Int] - -trunc :: Tree -> IO TTree -trunc = truncTree . asBox +import System.Random (StdGen) +import qualified System.Random as R + +mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p m = do +  let (g0, g1) = R.split g +  hm <- liftIO $ H.initialize 0 +  omega <- liftIO $ newIORef (hm, g0) +  let (x, w) = head $ samples m $ newTree omega +  step g1 omega x w    where -    truncTree t = -      getBoxedClosureData' t >>= \case -        ConstrClosure _ [l, r] [] _ _ "Tree" -> -          getBoxedClosureData' l >>= \case -            ConstrClosure {dataArgs = [d], name = "D#"} -> -              TTree True <$> truncTrees r -            x -> error $ "truncTree:ConstrClosure:" ++ show x -        ConstrClosure _ [r] [d] _ _ "Tree" ->  -          TTree False <$> truncTrees r -        x -> error $ "truncTree:" ++ show x - -    getBoxedClosureData' x = -      getBoxedClosureData x >>= \c -> case c of -        BlackholeClosure _ t -> getBoxedClosureData' t -        _ -> pure c - -    truncTrees b = -      getBoxedClosureData' b >>= \case -        ConstrClosure _ [l, r] [] _ _ ":" -> -          getBoxedClosureData' l >>= \case -            ConstrClosure {name = "Tree"} -> -              ((:) . Just) <$> truncTree l <*>  truncTrees r -            _ -> (Nothing :) <$> truncTrees r -        _ -> pure [] - -trunc' t x w = unsafePerformIO $ do -  evaluate (rnf x) -  evaluate (rnf w) -  trunc t - -sites :: Site -> TTree -> [Site] -sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] - -mutate = M.foldrWithKey go -  where -    go [] d (Tree _ ts) = Tree d ts -    go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts - -ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () -ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w -  where -    (t0, t) = split $ randomTree g -    (x, w) = head $ samples m t0 - -    step !t !sub !tt !x !w = do -      let ss = sites [] tt -          (t1 : t2 : t3 : t4 : _) = splitTrees t -          i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate -          sub' = M.insert (reverse $ ss !! i) (draw t3) sub -          t' = mutate t0 sub' -          (x', w') = head $ samples m t' -          tt' = trunc' t' x' w' +    step !g0 !omega !x !w = do +      let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 +      omega' <- mutate g1 omega +      let (!x', !w') = head $ samples m $ newTree omega'            ratio = w' / w -          (Exp . log -> r) = draw t4 -          (sub'', tt'', x'', w'') = +          (omega'', x'', w'') =              if r < ratio -              then (sub', tt', x', w') -              else (sub, tt, x, w) - +              then (omega', x', w') +              else (omega, x, w)        yield (x'', w'') -      step t1 sub'' tt'' x'' w'' +      step g2 omega'' x'' w'' + +    mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen)) +    mutate g omega = liftIO $ do +      (m, g0) <- readIORef omega +      m' <- H.clone m +      ks <- H.keys m +      let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g) +          n = fromIntegral (V.length ks) +      void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs +      newIORef (m', g0) | 
