diff options
author | Justin Bedo <cu@cua0.org> | 2025-02-26 16:21:33 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-02-26 19:48:52 +1100 |
commit | d1a1ffa811b199f5d7d02d14cf76631a0ecd2699 (patch) | |
tree | 719e8ae2fa90e16af95941f8b70309931639681a | |
parent | 906ae5ad3a960e4c7d7c32528ced79ba63ee2238 (diff) |
port prototype single site mh from lazyppl
-rw-r--r-- | package.yaml | 2 | ||||
-rw-r--r-- | ppl.cabal | 4 | ||||
-rw-r--r-- | src/PPL/Internal.hs | 3 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 108 |
4 files changed, 104 insertions, 13 deletions
diff --git a/package.yaml b/package.yaml index 703c5a2..4e4d3bc 100644 --- a/package.yaml +++ b/package.yaml @@ -12,6 +12,8 @@ dependencies: - template-haskell - containers - streaming + - ghc-heap + - deepseq library: source-dirs: src @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.34.7. +-- This file has been generated from package.yaml by hpack version 0.36.1. -- -- see: https://github.com/sol/hpack @@ -24,6 +24,8 @@ library build-depends: base >=4.9 && <5 , containers + , deepseq + , ghc-heap , log-domain , random , streaming diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 0ba6ae5..b4283af 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -13,8 +13,7 @@ module PPL.Internal randomTree, samples, mutateTree, - splitTrees, - draw, + Tree(..), ) where diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 84e5ada..520f182 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,33 +1,121 @@ -{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ViewPatterns #-} -module PPL.Sampling where +module PPL.Sampling + ( mh, + ssmh, + ) +where +import Control.DeepSeq +import Control.Exception (evaluate) import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor +import qualified Data.Map.Strict as M import Data.Monoid +import GHC.Exts.Heap import Numeric.Log import PPL.Distr import PPL.Internal -import System.Random (StdGen, random, randoms) import qualified Streaming as S -import Streaming.Prelude (Stream, yield, Of) +import Streaming.Prelude (Of, Stream, yield) +import System.IO.Unsafe +import System.Random (StdGen, random, randoms) +import Unsafe.Coerce -mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = step t0 t x w where - (t0,t) = split $ randomTree g + (t0, t) = split $ randomTree g (x, w) = head $ samples m t step !t0 !t !x !w = do - let (t1:t2:t3:t4:_) = splitTrees t0 + let (t1 : t2 : t3 : t4 : _) = splitTrees t0 t' = mutateTree p t1 t2 t (x', w') = head $ samples m t' ratio = w' / w (Exp . log -> r) = draw t3 - (t'', x'', w'') = if r < ratio - then (t', x', w') - else (t, x, w) + (t'', x'', w'') = + if r < ratio + then (t', x', w') + else (t, x, w) yield (x'', w'') step t4 t'' x'' w'' + +-- Single site MH + +-- Truncated trees +data TTree = TTree (Maybe Double) [Maybe TTree] deriving (Show) + +type Site = [Int] + +trunc :: Tree -> IO TTree +trunc = truncTree . asBox + where + truncTree t = + getBoxedClosureData' t >>= \case + ConstrClosure _ [l, r] [] _ _ "Tree" -> + getBoxedClosureData' l >>= \case + ConstrClosure {dataArgs = [d], name = "D#"} -> do + TTree (Just $ unsafeCoerce d) <$> truncTrees r + x -> error $ "truncTree:ConstrClosure:" ++ show x + ConstrClosure _ [r] [d] _ _ "Tree" -> + TTree (Just $ unsafeCoerce d) <$> truncTrees r + x -> error $ "truncTree:" ++ show x + + getBoxedClosureData' x = + getBoxedClosureData x >>= \c -> case c of + BlackholeClosure _ t -> getBoxedClosureData' t + _ -> pure c + + truncTrees b = + getBoxedClosureData' b >>= \case + ConstrClosure _ [l, r] [] _ _ ":" -> + getBoxedClosureData' l >>= \case + ConstrClosure {name = "Tree"} -> do + l' <- truncTree l + r' <- truncTrees r + pure $ Just l' : r' + _ -> pure [] + _ -> pure [] + +trunc' t x w = unsafePerformIO $ do + evaluate (rnf x) + evaluate (rnf w) + trunc t + +sites :: Site -> TTree -> [Site] +sites acc (TTree (Just v) ts) = acc : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] +sites acc (TTree Nothing ts) = concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] + +mutate = M.foldrWithKey go + where + go [] d (Tree _ ts) = Tree d ts + go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts + +ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () +ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w + where + (t0, t) = split $ randomTree g + (x, w) = head $ samples m t0 + + step !t !sub !tt !x !w = do + let ss = sites [] tt + (t1 : t2 : t3 : t4 : _) = splitTrees t + i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate + sub' = M.insert (reverse $ ss !! i) (draw t3) sub + t' = mutate t0 sub' + (x', w') = head $ samples m t' + tt' = trunc' t' x' w' + ratio = w' / w + (Exp . log -> r) = draw t4 + (sub'', tt'', x'', w'') = + if r < ratio + then (sub', tt', x', w') + else (sub, tt, x, w) + + yield (x'', w'') + step t1 sub'' tt'' x'' w'' |