From be4be249004af1e7039a29d0228b506139983ea2 Mon Sep 17 00:00:00 2001
From: Justin Bedo <cu@cua0.org>
Date: Tue, 4 Mar 2025 17:30:41 +1100
Subject: switch to hashmap based table

Allows significant speed increase, memory reduction, and simpler
single site mode.
---
 src/PPL/Sampling.hs | 133 ++++++++++++++--------------------------------------
 1 file changed, 34 insertions(+), 99 deletions(-)

(limited to 'src/PPL/Sampling.hs')

diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 8e0ab8f..7c2cb54 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,117 +1,52 @@
 {-# LANGUAGE BangPatterns #-}
 {-# LANGUAGE BlockArguments #-}
 {-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
 {-# LANGUAGE ViewPatterns #-}
 
 module PPL.Sampling
   ( mh,
-    ssmh,
   )
 where
 
-import Control.DeepSeq
-import Control.Exception (evaluate)
+import Control.Monad
 import Control.Monad.IO.Class
-import Control.Monad.Trans.State
-import Data.Bifunctor
-import qualified Data.Map.Strict as M
-import Data.Monoid
-import GHC.Exts.Heap
+import Data.IORef
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector.Unboxed as V
+import Data.Word
 import Numeric.Log
-import PPL.Distr
 import PPL.Internal
-import qualified Streaming as S
 import Streaming.Prelude (Of, Stream, yield)
-import System.IO.Unsafe
-import System.Random (StdGen, random, randoms)
-
-mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p m = step t0 t x w
-  where
-    (t0, t) = split $ randomTree g
-    (x, w) = head $ samples m t
-
-    step !t0 !t !x !w = do
-      let (t1 : t2 : t3 : t4 : _) = splitTrees t0
-          t' = mutateTree p t1 t2 t
-          (x', w') = head $ samples m t'
-          ratio = w' / w
-          (Exp . log -> r) = draw t3
-          (t'', x'', w'') =
-            if r < ratio
-              then (t', x', w')
-              else (t, x, w)
-      yield (x'', w'')
-      step t4 t'' x'' w''
-
--- Single site MH
-
--- Truncated trees
-data TTree = TTree Bool [Maybe TTree] deriving (Show)
-
-type Site = [Int]
-
-trunc :: Tree -> IO TTree
-trunc = truncTree . asBox
+import System.Random (StdGen)
+import qualified System.Random as R
+
+mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = do
+  let (g0, g1) = R.split g
+  hm <- liftIO $ H.initialize 0
+  omega <- liftIO $ newIORef (hm, g0)
+  let (x, w) = head $ samples m $ newTree omega
+  step g1 omega x w
   where
-    truncTree t =
-      getBoxedClosureData' t >>= \case
-        ConstrClosure _ [l, r] [] _ _ "Tree" ->
-          getBoxedClosureData' l >>= \case
-            ConstrClosure {dataArgs = [d], name = "D#"} ->
-              TTree True <$> truncTrees r
-            x -> error $ "truncTree:ConstrClosure:" ++ show x
-        ConstrClosure _ [r] [d] _ _ "Tree" -> 
-          TTree False <$> truncTrees r
-        x -> error $ "truncTree:" ++ show x
-
-    getBoxedClosureData' x =
-      getBoxedClosureData x >>= \c -> case c of
-        BlackholeClosure _ t -> getBoxedClosureData' t
-        _ -> pure c
-
-    truncTrees b =
-      getBoxedClosureData' b >>= \case
-        ConstrClosure _ [l, r] [] _ _ ":" ->
-          getBoxedClosureData' l >>= \case
-            ConstrClosure {name = "Tree"} ->
-              ((:) . Just) <$> truncTree l <*>  truncTrees r
-            _ -> (Nothing :) <$> truncTrees r
-        _ -> pure []
-
-trunc' t x w = unsafePerformIO $ do
-  evaluate (rnf x)
-  evaluate (rnf w)
-  trunc t
-
-sites :: Site -> TTree -> [Site]
-sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
-
-mutate = M.foldrWithKey go
-  where
-    go [] d (Tree _ ts) = Tree d ts
-    go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts
-
-ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m ()
-ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w
-  where
-    (t0, t) = split $ randomTree g
-    (x, w) = head $ samples m t0
-
-    step !t !sub !tt !x !w = do
-      let ss = sites [] tt
-          (t1 : t2 : t3 : t4 : _) = splitTrees t
-          i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate
-          sub' = M.insert (reverse $ ss !! i) (draw t3) sub
-          t' = mutate t0 sub'
-          (x', w') = head $ samples m t'
-          tt' = trunc' t' x' w'
+    step !g0 !omega !x !w = do
+      let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0
+      omega' <- mutate g1 omega
+      let (!x', !w') = head $ samples m $ newTree omega'
           ratio = w' / w
-          (Exp . log -> r) = draw t4
-          (sub'', tt'', x'', w'') =
+          (omega'', x'', w'') =
             if r < ratio
-              then (sub', tt', x', w')
-              else (sub, tt, x, w)
-
+              then (omega', x', w')
+              else (omega, x, w)
       yield (x'', w'')
-      step t1 sub'' tt'' x'' w''
+      step g2 omega'' x'' w''
+
+    mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen))
+    mutate g omega = liftIO $ do
+      (m, g0) <- readIORef omega
+      m' <- H.clone m
+      ks <- H.keys m
+      let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g)
+          n = fromIntegral (V.length ks)
+      void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs
+      newIORef (m', g0)
-- 
cgit v1.2.3