From d1a1ffa811b199f5d7d02d14cf76631a0ecd2699 Mon Sep 17 00:00:00 2001
From: Justin Bedo <cu@cua0.org>
Date: Wed, 26 Feb 2025 16:21:33 +1100
Subject: port prototype single site mh from lazyppl

---
 src/PPL/Sampling.hs | 108 +++++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 98 insertions(+), 10 deletions(-)

(limited to 'src/PPL/Sampling.hs')

diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 84e5ada..520f182 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,33 +1,121 @@
-{-# LANGUAGE ViewPatterns #-}
 {-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ViewPatterns #-}
 
-module PPL.Sampling where
+module PPL.Sampling
+  ( mh,
+    ssmh,
+  )
+where
 
+import Control.DeepSeq
+import Control.Exception (evaluate)
 import Control.Monad.IO.Class
 import Control.Monad.Trans.State
 import Data.Bifunctor
+import qualified Data.Map.Strict as M
 import Data.Monoid
+import GHC.Exts.Heap
 import Numeric.Log
 import PPL.Distr
 import PPL.Internal
-import System.Random (StdGen, random, randoms)
 import qualified Streaming as S
-import Streaming.Prelude (Stream, yield, Of)
+import Streaming.Prelude (Of, Stream, yield)
+import System.IO.Unsafe
+import System.Random (StdGen, random, randoms)
+import Unsafe.Coerce
 
-mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
 mh g p m = step t0 t x w
   where
-    (t0,t) = split $ randomTree g
+    (t0, t) = split $ randomTree g
     (x, w) = head $ samples m t
 
     step !t0 !t !x !w = do
-      let (t1:t2:t3:t4:_) = splitTrees t0
+      let (t1 : t2 : t3 : t4 : _) = splitTrees t0
           t' = mutateTree p t1 t2 t
           (x', w') = head $ samples m t'
           ratio = w' / w
           (Exp . log -> r) = draw t3
-          (t'', x'', w'') = if r < ratio
-            then (t', x', w')
-            else (t, x, w)
+          (t'', x'', w'') =
+            if r < ratio
+              then (t', x', w')
+              else (t, x, w)
       yield (x'', w'')
       step t4 t'' x'' w''
+
+-- Single site MH
+
+-- Truncated trees
+data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show)
+
+type Site = [Int]
+
+trunc :: Tree -> IO TTree
+trunc = truncTree . asBox
+  where
+    truncTree t =
+      getBoxedClosureData' t >>= \case
+        ConstrClosure _ [l, r] [] _ _ "Tree" ->
+          getBoxedClosureData' l >>= \case
+            ConstrClosure {dataArgs = [d], name = "D#"} -> do
+              TTree (Just $ unsafeCoerce d) <$> truncTrees r
+            x -> error $ "truncTree:ConstrClosure:" ++ show x
+        ConstrClosure _ [r] [d] _ _ "Tree" -> 
+          TTree (Just $ unsafeCoerce d) <$> truncTrees r
+        x -> error $ "truncTree:" ++ show x
+
+    getBoxedClosureData' x =
+      getBoxedClosureData x >>= \c -> case c of
+        BlackholeClosure _ t -> getBoxedClosureData' t
+        _ -> pure c
+
+    truncTrees b =
+      getBoxedClosureData' b >>= \case
+        ConstrClosure _ [l, r] [] _ _ ":" ->
+          getBoxedClosureData' l >>= \case
+            ConstrClosure {name = "Tree"} -> do
+              l' <- truncTree l
+              r' <- truncTrees r
+              pure $ Just l' : r'
+            _ -> pure []
+        _ -> pure []
+
+trunc' t x w = unsafePerformIO $ do
+  evaluate (rnf x)
+  evaluate (rnf w)
+  trunc t
+
+sites :: Site -> TTree -> [Site]
+sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
+sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
+
+mutate = M.foldrWithKey go
+  where
+    go [] d (Tree _ ts) = Tree d ts
+    go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts
+
+ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m ()
+ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w
+  where
+    (t0, t) = split $ randomTree g
+    (x, w) = head $ samples m t0
+
+    step !t !sub !tt !x !w = do
+      let ss = sites [] tt
+          (t1 : t2 : t3 : t4 : _) = splitTrees t
+          i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate
+          sub' = M.insert (reverse $ ss !! i) (draw t3) sub
+          t' = mutate t0 sub'
+          (x', w') = head $ samples m t'
+          tt' = trunc' t' x' w'
+          ratio = w' / w
+          (Exp . log -> r) = draw t4
+          (sub'', tt'', x'', w'') =
+            if r < ratio
+              then (sub', tt', x', w')
+              else (sub, tt, x, w)
+
+      yield (x'', w'')
+      step t1 sub'' tt'' x'' w''
-- 
cgit v1.2.3