diff options
Diffstat (limited to 'src/PPL')
| -rw-r--r-- | src/PPL/Internal.hs | 46 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 33 |
2 files changed, 52 insertions, 27 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index d09d8ff..3c54078 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -17,8 +17,10 @@ module PPL.Internal sample, memoize, samples, - newTree, HashMap, + random, + randoms, + newTree, Tree (..), ) where @@ -27,17 +29,20 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor +import Data.Bits (countTrailingZeros, shiftL, shiftR) import Data.IORef import Data.Map qualified as Q import Data.Monoid import Data.Vector.Hashtables qualified as H import Data.Vector.Mutable qualified as M import Data.Vector.Unboxed.Mutable qualified as UM +import Data.Word (Word64) import Language.Haskell.TH.Syntax qualified as TH import Numeric.Log import System.IO.Unsafe -import System.Random hiding (split, uniform) +import System.Random hiding (random, randoms, split, uniform) import System.Random qualified as R +import Unsafe.Coerce type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v @@ -48,24 +53,25 @@ data Tree = Tree split :: (Tree, Tree) } -{-# INLINE newTree #-} -newTree :: IORef (HashMap Integer Double, StdGen) -> Tree -newTree s = go 1 +-- Let's draw doubles more uniformly than the standard generator +-- based on https://www.corsix.org/content/higher-quality-random-floats +random :: StdGen -> (Double, StdGen) +random g0 = + let (r1 :: Word64, g1) = R.random g0 + (r2 :: Word64, g2) = R.random g1 + e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r + dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + in (unsafeCoerce dbl, g2) + +randoms :: StdGen -> [Double] +randoms g = + let (x, g') = random g + in x : randoms g' + +newTree :: StdGen -> Tree +newTree g0 = Tree x (newTree g1, newTree g2) where - go :: Integer -> Tree - go id = - Tree - ( unsafePerformIO $ do - (m, g) <- readIORef s - H.lookup m id >>= \case - Nothing -> do - let (x, g') = R.random g - H.insert m id x - writeIORef s (m, g') - pure x - Just x -> pure x - ) - (go (2 * id), go (2 * id + 1)) + (x, R.split -> (g1, g2)) = random g0 newtype Prob a = Prob {runProb :: Tree -> a} @@ -106,7 +112,7 @@ samples (Meas m) = map (second getProduct) . runProb f f = runWriterT m >>= \x -> (x :) <$> f {-# NOINLINE memoize #-} -memoize :: Ord a => (a -> Prob b) -> Prob (a -> b) +memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b) memoize f = unsafePerformIO $ do ref <- newIORef mempty pure $ Prob $ \t -> \x -> unsafePerformIO $ do diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 574025a..71a61fd 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -18,11 +18,30 @@ import Data.IORef import Data.List (foldl') import Data.Vector qualified as V import Data.Vector.Hashtables qualified as H -import Numeric.Log -import PPL.Internal +import Numeric.Log hiding (sum) +import PPL.Internal hiding (newTree, split) import Streaming.Prelude (Of, Stream, yield) -import System.Random (StdGen) -import System.Random qualified as R +import System.IO.Unsafe +import System.Random (StdGen, split) + +{-# INLINE newTree #-} +newTree :: IORef (HashMap Integer Double, StdGen) -> Tree +newTree s = go 1 + where + go :: Integer -> Tree + go id = + Tree + ( unsafePerformIO $ do + (m, g) <- readIORef s + H.lookup m id >>= \case + Nothing -> do + let (x, g') = random g + H.insert m id x + writeIORef s (m, g') + pure x + Just x -> pure x + ) + (go (2 * id), go (2 * id + 1)) data MinHeap a k = Node a k (MinHeap a k) (MinHeap a k) | Empty deriving (Show) @@ -41,14 +60,14 @@ toList (Node k v l r) = v : toList (merge l r) mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = do - let (g0, g1) = R.split g + let (g0, g1) = split g hm <- liftIO $ H.initialize 0 omega <- liftIO $ newIORef (hm, g0) let (x, w) = head $ samples m $ newTree omega step g1 omega x w where step !g0 !omega !x !w = do - let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0 + let (Exp . log -> r, split -> (g1, g2)) = random g0 omega' <- mutate g1 omega let (!x', !w') = head $ samples m $ newTree omega' ratio = w' / w @@ -64,7 +83,7 @@ mh g p m = do (m, g0) <- readIORef omega m' <- H.clone m ks <- H.keys m - let (rs :: [Double], qs :: [Double]) = (R.randoms *** R.randoms) (R.split g) + let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g) ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks n = fromIntegral (V.length ks) void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs |
