{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TupleSections #-} module PPL.Sampling ( mh, ) where import Control.DeepSeq import Control.Exception (evaluate) import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor import qualified Data.HashMap.Strict as M import Data.Monoid import GHC.Exts.Heap import Numeric.Log import PPL.Distr import PPL.Internal import qualified Streaming as S import Streaming.Prelude (Of, Stream, yield) import System.IO.Unsafe import System.Random (StdGen, random, randoms) import qualified System.Random as R import Data.IORef import Control.Monad import Debug.Trace mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = do let (g0, g1) = R.split g omega <- liftIO $ newIORef (mempty, g0) let (x, w) = head $ samples m $ newTree omega step g1 omega x w where step !g0 !omega !x !w = do let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 omega' <- mutate g1 omega let (!x', !w') = head $ samples m $ newTree omega' ratio = w' / w (omega'', x'', w'') = if r < ratio then (omega', x', w') else (omega, x, w) yield (x'', w'') step g2 omega'' x'' w'' mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) -> m (IORef (M.HashMap [Int] Double, StdGen)) mutate g omega = do (m, g0) <- liftIO $ readIORef omega {-let (r:q:_) = R.randoms g ks = M.keys m k = ks !! floor (r * join traceShow (fromIntegral (length ks))) m' = M.insert k q m liftIO $ newIORef $ (m',g0)-} liftIO $ newIORef $ (,g0) $ flip evalState g $ mapM go m where go x = do g <- get let (r, g1) = R.random g (y, g2) = R.random g1 if r < p then do put g2 pure y else do put g1 pure x