From 6d55adfed12df75a09dfd9aaf275b36151d77104 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Tue, 13 Jan 2026 14:12:54 +1100 Subject: increase lazyness and add probilistic memoisation --- src/PPL/Internal.hs | 25 ++++++++++++++++++++++--- src/PPL/Sampling.hs | 3 +-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index ea37490..d09d8ff 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -6,6 +6,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} module PPL.Internal ( uniform, @@ -14,6 +15,7 @@ module PPL.Internal score, scoreLog, sample, + memoize, samples, newTree, HashMap, @@ -25,13 +27,12 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor -import Data.Bits import Data.IORef +import Data.Map qualified as Q import Data.Monoid import Data.Vector.Hashtables qualified as H import Data.Vector.Mutable qualified as M import Data.Vector.Unboxed.Mutable qualified as UM -import Data.Word import Language.Haskell.TH.Syntax qualified as TH import Numeric.Log import System.IO.Unsafe @@ -43,7 +44,7 @@ type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree - { draw :: !Double, + { draw :: Double, split :: (Tree, Tree) } @@ -103,3 +104,21 @@ samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f + +{-# NOINLINE memoize #-} +memoize :: Ord a => (a -> Prob b) -> Prob (a -> b) +memoize f = unsafePerformIO $ do + ref <- newIORef mempty + pure $ Prob $ \t -> \x -> unsafePerformIO $ do + m <- readIORef ref + case Q.lookup x m of + Just z -> pure z + _ -> do + let Prob g = f x + z = g (getNode (Q.size m) t) + m' = Q.insert x z m + writeIORef ref m' + pure z + where + getNode 0 t = t + getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1) diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 256e67b..8a74854 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -15,7 +15,6 @@ import Control.Monad.IO.Class import Data.IORef import Data.Vector qualified as V import Data.Vector.Hashtables qualified as H -import Data.Word import Numeric.Log import PPL.Internal import Streaming.Prelude (Of, Stream, yield) @@ -49,5 +48,5 @@ mh g p m = do ks <- H.keys m let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g) n = fromIntegral (V.length ks) - void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs + when (n > 0) $ void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs newIORef (m', g0) -- cgit v1.2.3