diff options
author | Justin Bedo <cu@cua0.org> | 2025-03-05 11:07:12 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-03-05 11:07:15 +1100 |
commit | 3e98e71b49c3e7dbba3af43d18dea106726fef41 (patch) | |
tree | 794517785bedc340bdb88591845786aef70f4d79 | |
parent | a4cbf2ff3f1839dc302f5956b9b27c6bb28b3f30 (diff) |
switch to hashmap
-rw-r--r-- | package.yaml | 1 | ||||
-rw-r--r-- | ppl.cabal | 1 | ||||
-rw-r--r-- | src/PPL/Internal.hs | 4 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 79 |
4 files changed, 8 insertions, 77 deletions
diff --git a/package.yaml b/package.yaml index 4e4d3bc..46fc800 100644 --- a/package.yaml +++ b/package.yaml @@ -14,6 +14,7 @@ dependencies: - streaming - ghc-heap - deepseq + - unordered-containers library: source-dirs: src @@ -31,4 +31,5 @@ library , streaming , template-haskell , transformers + , unordered-containers default-language: Haskell2010 diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 49737d0..7273d23 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -28,7 +28,7 @@ import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (split, uniform) import qualified System.Random as R -import qualified Data.Map.Strict as M +import qualified Data.HashMap.Strict as M import Data.IORef import System.IO.Unsafe @@ -43,7 +43,7 @@ split :: Tree -> (Tree, Tree) split (Tree r (t : ts)) = (t, Tree r ts) {-# INLINE newTree #-} -newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree +newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree newTree s = go [] where go id = Tree (unsafePerformIO $ do diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 6807439..9a749e9 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -6,7 +6,6 @@ module PPL.Sampling ( mh, - ssmh, ) where @@ -15,7 +14,7 @@ import Control.Exception (evaluate) import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor -import qualified Data.Map.Strict as M +import qualified Data.HashMap.Strict as M import Data.Monoid import GHC.Exts.Heap import Numeric.Log @@ -50,7 +49,7 @@ mh g p m = do yield (x'', w'') step g2 omega'' x'' w'' - mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) -> m (IORef (M.Map [Int] Double, StdGen)) + mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) -> m (IORef (M.HashMap [Int] Double, StdGen)) mutate g omega = do (m, g0) <- liftIO $ readIORef omega let (r:q:_) = R.randoms g @@ -70,76 +69,6 @@ mh g p m = do pure y else do put g1 - pure x + pure x + --- Single site MH - --- Truncated trees -data TTree = TTree Bool [Maybe TTree] deriving (Show) - -type Site = [Int] - -trunc :: Tree -> IO TTree -trunc = truncTree . asBox - where - truncTree t = - getBoxedClosureData' t >>= \case - ConstrClosure _ [l, r] [] _ _ "Tree" -> - getBoxedClosureData' l >>= \case - ConstrClosure {dataArgs = [d], name = "D#"} -> - TTree True <$> truncTrees r - x -> error $ "truncTree:ConstrClosure:" ++ show x - ConstrClosure _ [r] [d] _ _ "Tree" -> - TTree False <$> truncTrees r - x -> error $ "truncTree:" ++ show x - - getBoxedClosureData' x = - getBoxedClosureData x >>= \c -> case c of - BlackholeClosure _ t -> getBoxedClosureData' t - _ -> pure c - - truncTrees b = - getBoxedClosureData' b >>= \case - ConstrClosure _ [l, r] [] _ _ ":" -> - getBoxedClosureData' l >>= \case - ConstrClosure {name = "Tree"} -> - ((:) . Just) <$> truncTree l <*> truncTrees r - _ -> (Nothing :) <$> truncTrees r - _ -> pure [] - -trunc' t x w = unsafePerformIO $ do - evaluate (rnf x) - evaluate (rnf w) - trunc t - -sites :: Site -> TTree -> [Site] -sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] - -mutate = M.foldrWithKey go - where - go [] d (Tree _ ts) = Tree d ts - go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts - -ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () -ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w - where - (t0, t) = split $ randomTree g - (x, w) = head $ samples m t0 - - step !t !sub !tt !x !w = do - let ss = sites [] tt - (t1 : t2 : t3 : t4 : _) = splitTrees t - i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate - sub' = M.insert (reverse $ ss !! i) (draw t3) sub - t' = mutate t0 sub' - (x', w') = head $ samples m t' - tt' = trunc' t' x' w' - ratio = w' / w - (Exp . log -> r) = draw t4 - (sub'', tt'', x'', w'') = - if r < ratio - then (sub', tt', x', w') - else (sub, tt, x, w) - - yield (x'', w'') - step t1 sub'' tt'' x'' w'' |