diff options
Diffstat (limited to 'src/PPL/Sampling.hs')
-rw-r--r-- | src/PPL/Sampling.hs | 133 |
1 files changed, 34 insertions, 99 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 8e0ab8f..7c2cb54 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,117 +1,52 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} module PPL.Sampling ( mh, - ssmh, ) where -import Control.DeepSeq -import Control.Exception (evaluate) +import Control.Monad import Control.Monad.IO.Class -import Control.Monad.Trans.State -import Data.Bifunctor -import qualified Data.Map.Strict as M -import Data.Monoid -import GHC.Exts.Heap +import Data.IORef +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector.Unboxed as V +import Data.Word import Numeric.Log -import PPL.Distr import PPL.Internal -import qualified Streaming as S import Streaming.Prelude (Of, Stream, yield) -import System.IO.Unsafe -import System.Random (StdGen, random, randoms) - -mh :: (Monad m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () -mh g p m = step t0 t x w - where - (t0, t) = split $ randomTree g - (x, w) = head $ samples m t - - step !t0 !t !x !w = do - let (t1 : t2 : t3 : t4 : _) = splitTrees t0 - t' = mutateTree p t1 t2 t - (x', w') = head $ samples m t' - ratio = w' / w - (Exp . log -> r) = draw t3 - (t'', x'', w'') = - if r < ratio - then (t', x', w') - else (t, x, w) - yield (x'', w'') - step t4 t'' x'' w'' - --- Single site MH - --- Truncated trees -data TTree = TTree Bool [Maybe TTree] deriving (Show) - -type Site = [Int] - -trunc :: Tree -> IO TTree -trunc = truncTree . asBox +import System.Random (StdGen) +import qualified System.Random as R + +mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p m = do + let (g0, g1) = R.split g + hm <- liftIO $ H.initialize 0 + omega <- liftIO $ newIORef (hm, g0) + let (x, w) = head $ samples m $ newTree omega + step g1 omega x w where - truncTree t = - getBoxedClosureData' t >>= \case - ConstrClosure _ [l, r] [] _ _ "Tree" -> - getBoxedClosureData' l >>= \case - ConstrClosure {dataArgs = [d], name = "D#"} -> - TTree True <$> truncTrees r - x -> error $ "truncTree:ConstrClosure:" ++ show x - ConstrClosure _ [r] [d] _ _ "Tree" -> - TTree False <$> truncTrees r - x -> error $ "truncTree:" ++ show x - - getBoxedClosureData' x = - getBoxedClosureData x >>= \c -> case c of - BlackholeClosure _ t -> getBoxedClosureData' t - _ -> pure c - - truncTrees b = - getBoxedClosureData' b >>= \case - ConstrClosure _ [l, r] [] _ _ ":" -> - getBoxedClosureData' l >>= \case - ConstrClosure {name = "Tree"} -> - ((:) . Just) <$> truncTree l <*> truncTrees r - _ -> (Nothing :) <$> truncTrees r - _ -> pure [] - -trunc' t x w = unsafePerformIO $ do - evaluate (rnf x) - evaluate (rnf w) - trunc t - -sites :: Site -> TTree -> [Site] -sites acc (TTree eval ts) = (if eval then acc else mempty) : concat [sites (x : acc) t | (x, Just t) <- zip [0 ..] ts] - -mutate = M.foldrWithKey go - where - go [] d (Tree _ ts) = Tree d ts - go (n : ns) d (Tree v ts) = Tree v $ take n ts ++ go ns d (ts !! n) : drop (n + 1) ts - -ssmh :: (Show a, NFData a, Monad m) => StdGen -> Meas a -> Stream (Of (a, Log Double)) m () -ssmh g m = step t (mempty :: M.Map Site Double) (trunc' t0 x w) x w - where - (t0, t) = split $ randomTree g - (x, w) = head $ samples m t0 - - step !t !sub !tt !x !w = do - let ss = sites [] tt - (t1 : t2 : t3 : t4 : _) = splitTrees t - i = floor $ draw t2 * (fromIntegral $ length ss) -- site to mutate - sub' = M.insert (reverse $ ss !! i) (draw t3) sub - t' = mutate t0 sub' - (x', w') = head $ samples m t' - tt' = trunc' t' x' w' + step !g0 !omega !x !w = do + let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 + omega' <- mutate g1 omega + let (!x', !w') = head $ samples m $ newTree omega' ratio = w' / w - (Exp . log -> r) = draw t4 - (sub'', tt'', x'', w'') = + (omega'', x'', w'') = if r < ratio - then (sub', tt', x', w') - else (sub, tt, x, w) - + then (omega', x', w') + else (omega, x, w) yield (x'', w'') - step t1 sub'' tt'' x'' w'' + step g2 omega'' x'' w'' + + mutate :: (MonadIO m) => StdGen -> IORef (HashMap Word64 Double, StdGen) -> m (IORef (HashMap Word64 Double, StdGen)) + mutate g omega = liftIO $ do + (m, g0) <- readIORef omega + m' <- H.clone m + ks <- H.keys m + let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g) + n = fromIntegral (V.length ks) + void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs + newIORef (m', g0) |