aboutsummaryrefslogtreecommitdiff
path: root/src/PPL
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2026-02-20 09:58:32 +1100
committerJustin Bedo <cu@cua0.org>2026-02-20 09:58:32 +1100
commit339dc051c0b37952ab9aa55a0983ec35ba03033f (patch)
tree9f8e104c0e50da6378eb78c041a4a64b8e15c76d /src/PPL
parent9665e05e1a643dc24cc73fa914144f9990605015 (diff)
improve floating point random generation and include tree for stdgen
Diffstat (limited to 'src/PPL')
-rw-r--r--src/PPL/Internal.hs46
-rw-r--r--src/PPL/Sampling.hs33
2 files changed, 52 insertions, 27 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index d09d8ff..3c54078 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -17,8 +17,10 @@ module PPL.Internal
sample,
memoize,
samples,
- newTree,
HashMap,
+ random,
+ randoms,
+ newTree,
Tree (..),
)
where
@@ -27,17 +29,20 @@ import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
+import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.Map qualified as Q
import Data.Monoid
import Data.Vector.Hashtables qualified as H
import Data.Vector.Mutable qualified as M
import Data.Vector.Unboxed.Mutable qualified as UM
+import Data.Word (Word64)
import Language.Haskell.TH.Syntax qualified as TH
import Numeric.Log
import System.IO.Unsafe
-import System.Random hiding (split, uniform)
+import System.Random hiding (random, randoms, split, uniform)
import System.Random qualified as R
+import Unsafe.Coerce
type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
@@ -48,24 +53,25 @@ data Tree = Tree
split :: (Tree, Tree)
}
-{-# INLINE newTree #-}
-newTree :: IORef (HashMap Integer Double, StdGen) -> Tree
-newTree s = go 1
+-- Let's draw doubles more uniformly than the standard generator
+-- based on https://www.corsix.org/content/higher-quality-random-floats
+random :: StdGen -> (Double, StdGen)
+random g0 =
+ let (r1 :: Word64, g1) = R.random g0
+ (r2 :: Word64, g2) = R.random g1
+ e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
+ dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+ in (unsafeCoerce dbl, g2)
+
+randoms :: StdGen -> [Double]
+randoms g =
+ let (x, g') = random g
+ in x : randoms g'
+
+newTree :: StdGen -> Tree
+newTree g0 = Tree x (newTree g1, newTree g2)
where
- go :: Integer -> Tree
- go id =
- Tree
- ( unsafePerformIO $ do
- (m, g) <- readIORef s
- H.lookup m id >>= \case
- Nothing -> do
- let (x, g') = R.random g
- H.insert m id x
- writeIORef s (m, g')
- pure x
- Just x -> pure x
- )
- (go (2 * id), go (2 * id + 1))
+ (x, R.split -> (g1, g2)) = random g0
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -106,7 +112,7 @@ samples (Meas m) = map (second getProduct) . runProb f
f = runWriterT m >>= \x -> (x :) <$> f
{-# NOINLINE memoize #-}
-memoize :: Ord a => (a -> Prob b) -> Prob (a -> b)
+memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b)
memoize f = unsafePerformIO $ do
ref <- newIORef mempty
pure $ Prob $ \t -> \x -> unsafePerformIO $ do
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 574025a..71a61fd 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -18,11 +18,30 @@ import Data.IORef
import Data.List (foldl')
import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
-import Numeric.Log
-import PPL.Internal
+import Numeric.Log hiding (sum)
+import PPL.Internal hiding (newTree, split)
import Streaming.Prelude (Of, Stream, yield)
-import System.Random (StdGen)
-import System.Random qualified as R
+import System.IO.Unsafe
+import System.Random (StdGen, split)
+
+{-# INLINE newTree #-}
+newTree :: IORef (HashMap Integer Double, StdGen) -> Tree
+newTree s = go 1
+ where
+ go :: Integer -> Tree
+ go id =
+ Tree
+ ( unsafePerformIO $ do
+ (m, g) <- readIORef s
+ H.lookup m id >>= \case
+ Nothing -> do
+ let (x, g') = random g
+ H.insert m id x
+ writeIORef s (m, g')
+ pure x
+ Just x -> pure x
+ )
+ (go (2 * id), go (2 * id + 1))
data MinHeap a k = Node a k (MinHeap a k) (MinHeap a k) | Empty
deriving (Show)
@@ -41,14 +60,14 @@ toList (Node k v l r) = v : toList (merge l r)
mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p m = do
- let (g0, g1) = R.split g
+ let (g0, g1) = split g
hm <- liftIO $ H.initialize 0
omega <- liftIO $ newIORef (hm, g0)
let (x, w) = head $ samples m $ newTree omega
step g1 omega x w
where
step !g0 !omega !x !w = do
- let (Exp . log -> r, R.split -> (g1, g2)) = R.random g0
+ let (Exp . log -> r, split -> (g1, g2)) = random g0
omega' <- mutate g1 omega
let (!x', !w') = head $ samples m $ newTree omega'
ratio = w' / w
@@ -64,7 +83,7 @@ mh g p m = do
(m, g0) <- readIORef omega
m' <- H.clone m
ks <- H.keys m
- let (rs :: [Double], qs :: [Double]) = (R.randoms *** R.randoms) (R.split g)
+ let (rs :: [Double], qs :: [Double]) = (randoms *** randoms) (split g)
ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks
n = fromIntegral (V.length ks)
void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs