aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-03-05 11:07:12 +1100
committerJustin Bedo <cu@cua0.org>2025-03-05 11:07:15 +1100
commit3e98e71b49c3e7dbba3af43d18dea106726fef41 (patch)
tree794517785bedc340bdb88591845786aef70f4d79
parenta4cbf2ff3f1839dc302f5956b9b27c6bb28b3f30 (diff)
switch to hashmap
-rw-r--r--package.yaml1
-rw-r--r--ppl.cabal1
-rw-r--r--src/PPL/Internal.hs4
-rw-r--r--src/PPL/Sampling.hs79
4 files changed, 8 insertions, 77 deletions
diff --git a/package.yaml b/package.yaml
index 4e4d3bc..46fc800 100644
--- a/package.yaml
+++ b/package.yaml
@@ -14,6 +14,7 @@ dependencies:
- streaming
- ghc-heap
- deepseq
+ - unordered-containers
library:
source-dirs: src
diff --git a/ppl.cabal b/ppl.cabal
index b2d83f5..6474f11 100644
--- a/ppl.cabal
+++ b/ppl.cabal
@@ -31,4 +31,5 @@ library
, streaming
, template-haskell
, transformers
+ , unordered-containers
default-language: Haskell2010
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 49737d0..7273d23 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -28,7 +28,7 @@ import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
import System.Random hiding (split, uniform)
import qualified System.Random as R
-import qualified Data.Map.Strict as M
+import qualified Data.HashMap.Strict as M
import Data.IORef
import System.IO.Unsafe
@@ -43,7 +43,7 @@ split :: Tree -> (Tree, Tree)
split (Tree r (t : ts)) = (t, Tree r ts)
{-# INLINE newTree #-}
-newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree
+newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree
newTree s = go []
where
go id = Tree (unsafePerformIO $ do
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 6807439..9a749e9 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -6,7 +6,6 @@
module PPL.Sampling
( mh,
- ssmh,
)
where
@@ -15,7 +14,7 @@ import Control.Exception (evaluate)
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Bifunctor
-import qualified Data.Map.Strict as M
+import qualified Data.HashMap.Strict as M
import Data.Monoid
import GHC.Exts.Heap
import Numeric.Log
@@ -50,7 +49,7 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) -> m (IORef (M.Map [Int] Double, StdGen))
+ mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) -> m (IORef (M.HashMap [Int] Double, StdGen))
mutate g omega = do
(m, g0) <- liftIO $ readIORef omega
let (r:q:_) = R.randoms g
@@ -70,76 +69,6 @@ mh g p m = do
pure y
else do
put g1
- pure x
+ pure x
+
--- Single site MH
-
--- Truncated trees
-data TTree = TTree Bool [Maybe TTree] deriving (Show)
-
-type Site = [Int]
-
-trunc :: Tree -> IO TTree
-trunc = truncTree . asBox
- where
- truncTree t =
- getBoxedClosureData' t >>= \case
- ConstrClosure _ [l, r] [] _ _ "Tree" ->
- getBoxedClosureData' l >>= \case
- ConstrClosure {dataArgs = [d], name = "D#"} ->
- TTree True <$> truncTrees r
- x -> error $ "truncTree:ConstrClosure:" ++ show x
- ConstrClosure _ [r] [d] _ _ "Tree" ->
- TTree False <$> truncTrees r
- x -> error $ "truncTree:" ++ show x
-
- getBoxedClosureData' x =
- getBoxedClosureData x >>= \c -> case c of
- BlackholeClosure _ t -> getBoxedClosureData' t
- _ -> pure c
-
- truncTrees b =
- getBoxedClosureData' b >>= \case
- ConstrClosure _ [l, r] [] _ _ ":" ->
- getBoxedClosureData' l >>= \case
- ConstrClosure {name = "Tree"} ->
- ((:) . Just) <$> truncTree l <*> truncTrees r
- _ -> (Nothing :) <$> truncTrees r
- _ -> pure []
-
-trunc' t x w = unsafePerformIO $ do
- evaluate (rnf x)
- evaluate (rnf w)
- trunc t
-
-sites :: Site -> TTree -> [Site]
-sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
-
-mutate = M.foldrWithKey go
- where
- go [] d (Tree _ ts) = Tree d ts
- go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts
-
-ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m ()
-ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w
- where
- (t0, t) = split $ randomTree g
- (x, w) = head $ samples m t0
-
- step !t !sub !tt !x !w = do
- let ss = sites [] tt
- (t1 : t2 : t3 : t4 : _) = splitTrees t
- i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate
- sub' = M.insert (reverse $ ss !! i) (draw t3) sub
- t' = mutate t0 sub'
- (x', w') = head $ samples m t'
- tt' = trunc' t' x' w'
- ratio = w' / w
- (Exp . log -> r) = draw t4
- (sub'', tt'', x'', w'') =
- if r < ratio
- then (sub', tt', x', w')
- else (sub, tt, x, w)
-
- yield (x'', w'')
- step t1 sub'' tt'' x'' w''