aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-02-26 16:21:33 +1100
committerJustin Bedo <cu@cua0.org>2025-02-26 19:48:52 +1100
commitd1a1ffa811b199f5d7d02d14cf76631a0ecd2699 (patch)
tree719e8ae2fa90e16af95941f8b70309931639681a
parent906ae5ad3a960e4c7d7c32528ced79ba63ee2238 (diff)
port prototype single site mh from lazyppl
-rw-r--r--package.yaml2
-rw-r--r--ppl.cabal4
-rw-r--r--src/PPL/Internal.hs3
-rw-r--r--src/PPL/Sampling.hs108
4 files changed, 104 insertions, 13 deletions
diff --git a/package.yaml b/package.yaml
index 703c5a2..4e4d3bc 100644
--- a/package.yaml
+++ b/package.yaml
@@ -12,6 +12,8 @@ dependencies:
- template-haskell
- containers
- streaming
+ - ghc-heap
+ - deepseq
library:
source-dirs: src
diff --git a/ppl.cabal b/ppl.cabal
index 5d8db40..b2d83f5 100644
--- a/ppl.cabal
+++ b/ppl.cabal
@@ -1,6 +1,6 @@
cabal-version: 1.12
--- This file has been generated from package.yaml by hpack version 0.34.7.
+-- This file has been generated from package.yaml by hpack version 0.36.1.
--
-- see: https://github.com/sol/hpack
@@ -24,6 +24,8 @@ library
build-depends:
base >=4.9 && <5
, containers
+ , deepseq
+ , ghc-heap
, log-domain
, random
, streaming
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 0ba6ae5..b4283af 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -13,8 +13,7 @@ module PPL.Internal
randomTree,
samples,
mutateTree,
- splitTrees,
- draw,
+ Tree(..),
)
where
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 84e5ada..520f182 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,33 +1,121 @@
-{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ViewPatterns #-}
-module PPL.Sampling where
+module PPL.Sampling
+ ( mh,
+ ssmh,
+ )
+where
+import Control.DeepSeq
+import Control.Exception (evaluate)
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Bifunctor
+import qualified Data.Map.Strict as M
import Data.Monoid
+import GHC.Exts.Heap
import Numeric.Log
import PPL.Distr
import PPL.Internal
-import System.Random (StdGen, random, randoms)
import qualified Streaming as S
-import Streaming.Prelude (Stream, yield, Of)
+import Streaming.Prelude (Of, Stream, yield)
+import System.IO.Unsafe
+import System.Random (StdGen, random, randoms)
+import Unsafe.Coerce
-mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p m = step t0 t x w
where
- (t0,t) = split $ randomTree g
+ (t0, t) = split $ randomTree g
(x, w) = head $ samples m t
step !t0 !t !x !w = do
- let (t1:t2:t3:t4:_) = splitTrees t0
+ let (t1 : t2 : t3 : t4 : _) = splitTrees t0
t' = mutateTree p t1 t2 t
(x', w') = head $ samples m t'
ratio = w' / w
(Exp . log -> r) = draw t3
- (t'', x'', w'') = if r < ratio
- then (t', x', w')
- else (t, x, w)
+ (t'', x'', w'') =
+ if r < ratio
+ then (t', x', w')
+ else (t, x, w)
yield (x'', w'')
step t4 t'' x'' w''
+
+-- Single site MH
+
+-- Truncated trees
+data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show)
+
+type Site = [Int]
+
+trunc :: Tree -> IO TTree
+trunc = truncTree . asBox
+ where
+ truncTree t =
+ getBoxedClosureData' t >>= \case
+ ConstrClosure _ [l, r] [] _ _ "Tree" ->
+ getBoxedClosureData' l >>= \case
+ ConstrClosure {dataArgs = [d], name = "D#"} -> do
+ TTree (Just $ unsafeCoerce d) <$> truncTrees r
+ x -> error $ "truncTree:ConstrClosure:" ++ show x
+ ConstrClosure _ [r] [d] _ _ "Tree" ->
+ TTree (Just $ unsafeCoerce d) <$> truncTrees r
+ x -> error $ "truncTree:" ++ show x
+
+ getBoxedClosureData' x =
+ getBoxedClosureData x >>= \c -> case c of
+ BlackholeClosure _ t -> getBoxedClosureData' t
+ _ -> pure c
+
+ truncTrees b =
+ getBoxedClosureData' b >>= \case
+ ConstrClosure _ [l, r] [] _ _ ":" ->
+ getBoxedClosureData' l >>= \case
+ ConstrClosure {name = "Tree"} -> do
+ l' <- truncTree l
+ r' <- truncTrees r
+ pure $ Just l' : r'
+ _ -> pure []
+ _ -> pure []
+
+trunc' t x w = unsafePerformIO $ do
+ evaluate (rnf x)
+ evaluate (rnf w)
+ trunc t
+
+sites :: Site -> TTree -> [Site]
+sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
+sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts]
+
+mutate = M.foldrWithKey go
+ where
+ go [] d (Tree _ ts) = Tree d ts
+ go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts
+
+ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m ()
+ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w
+ where
+ (t0, t) = split $ randomTree g
+ (x, w) = head $ samples m t0
+
+ step !t !sub !tt !x !w = do
+ let ss = sites [] tt
+ (t1 : t2 : t3 : t4 : _) = splitTrees t
+ i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate
+ sub' = M.insert (reverse $ ss !! i) (draw t3) sub
+ t' = mutate t0 sub'
+ (x', w') = head $ samples m t'
+ tt' = trunc' t' x' w'
+ ratio = w' / w
+ (Exp . log -> r) = draw t4
+ (sub'', tt'', x'', w'') =
+ if r < ratio
+ then (sub', tt', x', w')
+ else (sub, tt, x, w)
+
+ yield (x'', w'')
+ step t1 sub'' tt'' x'' w''