aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/PPL/Distr.hs41
-rw-r--r--src/PPL/Internal.hs29
-rw-r--r--src/PPL/Sampling.hs32
3 files changed, 49 insertions, 53 deletions
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index fae825b..54db358 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,8 +1,10 @@
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
module PPL.Distr where
-import Data.Bits (countLeadingZeros, shiftL, shiftR, (.&.))
+import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.))
import Data.Functor ((<&>))
import Data.Word (Word64)
import PPL.Internal
@@ -100,13 +102,17 @@ geom p = first 0 <$> iid (bern p)
first n (_ : xs) = first (n + 1) xs
first _ _ = undefined
--- Uses idea from 10.1145/3503512
+uniform = do
+ z <- unbounded
+ let r = countTrailingZeros z - 11
+ e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r
+ pure . unsafeCoerce $ (1 + (z `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+
+-- Uses algorithm from 10.1145/3503512
bounded :: Double -> Double -> Prob Double
-bounded lower upper
- | hi - 2 <= 0xfffffffffffff = do
- (k1, k2) <- split <$> boundedWord52 1 (hi - 2)
- pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
- | otherwise = uniform <&> \x -> 2 * (lower / 2 + (upper / 2 - lower / 2) * x)
+bounded lower upper = do
+ (k1, k2) <- split <$> bounded' 1 (hi - 2)
+ pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
where
hi = ceilint lower upper g
g = max (up lower - lower) (upper - down upper)
@@ -114,6 +120,7 @@ bounded lower upper
split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
+
down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
ceilint :: Double -> Double -> Double -> Word64
@@ -122,21 +129,17 @@ bounded lower upper
s = b / g - a / g
eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
--- Random mantissa part of a uniform double
-unboundedWord52 = (.&. 0xfffffffffffff) . (unsafeCoerce @Double @Word64) <$> uniform
-
-boundedWord52 l r = do
- x <- (`mod` nextPow2 r) <$> unboundedWord52
- if x <= r then pure (x + l) else boundedWord52 l r
- where
- nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
+unbounded = Prob draw
bounded' :: Integral a => a -> a -> Prob a
-bounded' lower upper
- | r <= 0xfffffffffffff = fromIntegral <$> boundedWord52 (fromIntegral lower) (fromIntegral r)
- | otherwise = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+bounded' lower upper = do
+ z <- (`mod` nextPow2 r) <$> unbounded
+ if z <= r
+ then pure $ fromIntegral $ fromIntegral lower + z
+ else bounded' lower upper
where
- r = upper - lower
+ r :: Word64 = fromIntegral $ upper - lower
+ nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 3c54078..ec9b0c8 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -9,8 +9,7 @@
{-# LANGUAGE ViewPatterns #-}
module PPL.Internal
- ( uniform,
- Prob (..),
+ ( Prob (..),
Meas,
score,
scoreLog,
@@ -18,8 +17,6 @@ module PPL.Internal
memoize,
samples,
HashMap,
- random,
- randoms,
newTree,
Tree (..),
)
@@ -29,7 +26,6 @@ import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
-import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.Map qualified as Q
import Data.Monoid
@@ -42,36 +38,20 @@ import Numeric.Log
import System.IO.Unsafe
import System.Random hiding (random, randoms, split, uniform)
import System.Random qualified as R
-import Unsafe.Coerce
type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
- { draw :: Double,
+ { draw :: Word64,
split :: (Tree, Tree)
}
--- Let's draw doubles more uniformly than the standard generator
--- based on https://www.corsix.org/content/higher-quality-random-floats
-random :: StdGen -> (Double, StdGen)
-random g0 =
- let (r1 :: Word64, g1) = R.random g0
- (r2 :: Word64, g2) = R.random g1
- e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
- dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
- in (unsafeCoerce dbl, g2)
-
-randoms :: StdGen -> [Double]
-randoms g =
- let (x, g') = random g
- in x : randoms g'
-
newTree :: StdGen -> Tree
newTree g0 = Tree x (newTree g1, newTree g2)
where
- (x, R.split -> (g1, g2)) = random g0
+ (x, R.split -> (g1, g2)) = R.random g0
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -85,9 +65,6 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
-uniform :: Prob Double
-uniform = Prob $ \(Tree r _) -> r
-
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
deriving (Functor, Applicative, Monad)
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 71a61fd..4deaf36 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -14,19 +14,35 @@ where
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
+import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.List (foldl')
import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
+import Data.Word (Word64)
import Numeric.Log hiding (sum)
import PPL.Internal hiding (newTree, split)
import Streaming.Prelude (Of, Stream, yield)
import System.IO.Unsafe
import System.Random (StdGen, split)
+import System.Random qualified as R
+import Unsafe.Coerce
-{-# INLINE newTree #-}
-newTree :: IORef (HashMap Integer Double, StdGen) -> Tree
-newTree s = go 1
+type MemoTree = IORef (HashMap Integer Word64, StdGen)
+
+-- Let's draw doubles more uniformly than the standard generator
+-- based on https://www.corsix.org/content/higher-quality-random-floats
+random :: StdGen -> (Double, StdGen)
+random g0 =
+ let (r1 :: Word64, g1) = R.random g0
+ (r2 :: Word64, g2) = R.random g1
+ e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
+ dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+ in (unsafeCoerce dbl, g2)
+
+{-# INLINE newMemoTree #-}
+newMemoTree :: MemoTree -> Tree
+newMemoTree s = go 1
where
go :: Integer -> Tree
go id =
@@ -35,7 +51,7 @@ newTree s = go 1
(m, g) <- readIORef s
H.lookup m id >>= \case
Nothing -> do
- let (x, g') = random g
+ let (x, g') = R.random g
H.insert m id x
writeIORef s (m, g')
pure x
@@ -63,13 +79,13 @@ mh g p m = do
let (g0, g1) = split g
hm <- liftIO $ H.initialize 0
omega <- liftIO $ newIORef (hm, g0)
- let (x, w) = head $ samples m $ newTree omega
+ let (x, w) = head $ samples m $ newMemoTree omega
step g1 omega x w
where
step !g0 !omega !x !w = do
let (Exp . log -> r, split -> (g1, g2)) = random g0
omega' <- mutate g1 omega
- let (!x', !w') = head $ samples m $ newTree omega'
+ let (!x', !w') = head $ samples m $ newMemoTree omega'
ratio = w' / w
(omega'', x'', w'') =
if r < ratio
@@ -78,12 +94,12 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: (MonadIO m) => StdGen -> IORef (HashMap Integer Double, StdGen) -> m (IORef (HashMap Integer Double, StdGen))
+ mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree
mutate g omega = liftIO $ do
(m, g0) <- readIORef omega
m' <- H.clone m
ks <- H.keys m
- let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g)
+ let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g)
ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks
n = fromIntegral (V.length ks)
void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs