{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TupleSections #-} module PPL.Sampling ( mh, ssmh, ) where import Control.DeepSeq import Control.Exception (evaluate) import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor import qualified Data.Map.Strict as M import Data.Monoid import GHC.Exts.Heap import Numeric.Log import PPL.Distr import PPL.Internal import qualified Streaming as S import Streaming.Prelude (Of, Stream, yield) import System.IO.Unsafe import System.Random (StdGen, random, randoms) import qualified System.Random as R import Data.IORef import Control.Monad import Debug.Trace mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = do let (g0, g1) = R.split g omega <- liftIO $ newIORef (mempty, g0) let (x, w) = head $ samples m $ newTree omega step g1 omega x w where step !g0 !omega !x !w = do let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 omega' <- mutate g1 omega let (!x', !w') = head $ samples m $ newTree omega' ratio = w' / w (omega'', x'', w'') = if r < ratio then (omega', x', w') else (omega, x, w) yield (x'', w'') step g2 omega'' x'' w'' mutate :: MonadIO m => StdGen -> IORef (M.Map [Int] Double, StdGen) -> m (IORef (M.Map [Int] Double, StdGen)) mutate g omega = do (m, g0) <- liftIO $ readIORef omega let (r:q:_) = R.randoms g ks = M.keys m k = ks !! floor (r * join traceShow (fromIntegral (length ks))) m' = M.insert k q m liftIO $ newIORef $ (m',g0) where go x = do g <- get let (r, g1) = R.random g (y, g2) = R.random g1 if r < p then do put g2 pure y else do put g1 pure x -- Single site MH -- Truncated trees data TTree = TTree Bool [Maybe TTree] deriving (Show) type Site = [Int] trunc :: Tree -> IO TTree trunc = truncTree . asBox where truncTree t = getBoxedClosureData' t >>= \case ConstrClosure _ [l, r] [] _ _ "Tree" -> getBoxedClosureData' l >>= \case ConstrClosure {dataArgs = [d], name = "D#"} -> TTree True <$> truncTrees r x -> error $ "truncTree:ConstrClosure:" ++ show x ConstrClosure _ [r] [d] _ _ "Tree" -> TTree False <$> truncTrees r x -> error $ "truncTree:" ++ show x getBoxedClosureData' x = getBoxedClosureData x >>= \c -> case c of BlackholeClosure _ t -> getBoxedClosureData' t _ -> pure c truncTrees b = getBoxedClosureData' b >>= \case ConstrClosure _ [l, r] [] _ _ ":" -> getBoxedClosureData' l >>= \case ConstrClosure {name = "Tree"} -> ((:) . Just) <$> truncTree l <*> truncTrees r _ -> (Nothing :) <$> truncTrees r _ -> pure [] trunc' t x w = unsafePerformIO $ do evaluate (rnf x) evaluate (rnf w) trunc t sites :: Site -> TTree -> [Site] sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] mutate = M.foldrWithKey go where go [] d (Tree _ ts) = Tree d ts go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w where (t0, t) = split $ randomTree g (x, w) = head $ samples m t0 step !t !sub !tt !x !w = do let ss = sites [] tt (t1 : t2 : t3 : t4 : _) = splitTrees t i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate sub' = M.insert (reverse $ ss !! i) (draw t3) sub t' = mutate t0 sub' (x', w') = head $ samples m t' tt' = trunc' t' x' w' ratio = w' / w (Exp . log -> r) = draw t4 (sub'', tt'', x'', w'') = if r < ratio then (sub', tt', x', w') else (sub, tt, x, w) yield (x'', w'') step t1 sub'' tt'' x'' w''