From a4cbf2ff3f1839dc302f5956b9b27c6bb28b3f30 Mon Sep 17 00:00:00 2001
From: Justin Bedo <cu@cua0.org>
Date: Tue, 4 Mar 2025 17:30:41 +1100
Subject: disable inlinine ,c disable inlining

---
 src/PPL/Internal.hs | 19 +++++++++++++++++-
 src/PPL/Sampling.hs | 56 +++++++++++++++++++++++++++++++++++++++--------------
 2 files changed, 60 insertions(+), 15 deletions(-)

(limited to 'src')

diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index b4283af..49737d0 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -13,6 +13,7 @@ module PPL.Internal
     randomTree,
     samples,
     mutateTree,
+    newTree,
     Tree(..),
   )
 where
@@ -27,6 +28,9 @@ import qualified Language.Haskell.TH.Syntax as TH
 import Numeric.Log
 import System.Random hiding (split, uniform)
 import qualified System.Random as R
+import qualified Data.Map.Strict as M
+import Data.IORef
+import System.IO.Unsafe
 
 -- Reimplementation of the LazyPPL monads to avoid some dependencies
 
@@ -38,13 +42,26 @@ data Tree = Tree
 split :: Tree -> (Tree, Tree)
 split (Tree r (t : ts)) = (t, Tree r ts)
 
+{-# INLINE newTree #-}
+newTree :: IORef (M.Map [Int] Double, StdGen) -> Tree
+newTree s = go []
+  where
+    go id = Tree (unsafePerformIO $ do
+      (m, g) <- readIORef s
+      case M.lookup id m of
+        Nothing -> do
+          let (x, g') = R.random g
+          writeIORef s (M.insert id x m, g')
+          pure x
+        Just x -> pure x) [go (i:id) | i <- [0..]]
+
 {-# INLINE randomTree #-}
 randomTree :: RandomGen g => g -> Tree
 randomTree g = let (a, g') = random g in Tree a (randomTrees g')
   where
     randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
 
-{-# INLINE mutateTree #-}
+{-# NOINLINE mutateTree #-}
 mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
 mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
   if r < p
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 8e0ab8f..6807439 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -2,6 +2,7 @@
 {-# LANGUAGE BlockArguments #-}
 {-# LANGUAGE LambdaCase #-}
 {-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE TupleSections #-}
 
 module PPL.Sampling
   ( mh,
@@ -24,25 +25,52 @@ import qualified Streaming as S
 import Streaming.Prelude (Of, Stream, yield)
 import System.IO.Unsafe
 import System.Random (StdGen, random, randoms)
-
-mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p m = step t0 t x w
+import qualified System.Random as R
+import Data.IORef
+import Control.Monad
+import Debug.Trace
+
+mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = do
+    let (g0, g1) = R.split g
+    omega <- liftIO $ newIORef (mempty, g0)
+    let (x, w) = head $ samples m $ newTree omega
+    step g1 omega x w
   where
-    (t0, t) = split $ randomTree g
-    (x, w) = head $ samples m t
 
-    step !t0 !t !x !w = do
-      let (t1 : t2 : t3 : t4 : _) = splitTrees t0
-          t' = mutateTree p t1 t2 t
-          (x', w') = head $ samples m t'
+    step !g0 !omega !x !w = do 
+      let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0
+      omega' <- mutate g1 omega
+      let (!x', !w') = head $ samples m $ newTree omega'
           ratio = w' / w
-          (Exp . log -> r) = draw t3
-          (t'', x'', w'') =
+          (omega'', x'', w'') =
             if r < ratio
-              then (t', x', w')
-              else (t, x, w)
+              then (omega', x', w')
+              else (omega, x, w)
       yield (x'', w'')
-      step t4 t'' x'' w''
+      step g2 omega'' x'' w''
+    
+    mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) ->  m (IORef (M.Map [Int] Double, StdGen))
+    mutate g omega = do
+      (m, g0) <- liftIO $ readIORef omega
+      let (r:q:_) = R.randoms g
+          ks = M.keys m
+          k = ks !! floor (r * join traceShow (fromIntegral (length ks)))
+          m' = M.insert k q m
+      liftIO $ newIORef $ (m',g0)
+      
+     where
+      go x = do
+        g <- get
+        let (r, g1) = R.random g
+            (y, g2) = R.random g1
+        if r < p 
+          then do
+            put g2
+            pure y
+          else do
+            put g1
+            pure x
 
 -- Single site MH
 
-- 
cgit v1.2.3