diff options
| author | Justin Bedo <cu@cua0.org> | 2026-01-13 14:12:54 +1100 |
|---|---|---|
| committer | Justin Bedo <cu@cua0.org> | 2026-01-13 14:12:54 +1100 |
| commit | 6d55adfed12df75a09dfd9aaf275b36151d77104 (patch) | |
| tree | 29d8e29f6d6fcb81528f228584bbb52d68a0689e /src/PPL/Internal.hs | |
| parent | 5cdb9b15563786195cc98cd87e7eb64151a519fc (diff) | |
increase lazyness and add probilistic memoisation
Diffstat (limited to 'src/PPL/Internal.hs')
| -rw-r--r-- | src/PPL/Internal.hs | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index ea37490..d09d8ff 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -6,6 +6,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} module PPL.Internal ( uniform, @@ -14,6 +15,7 @@ module PPL.Internal score, scoreLog, sample, + memoize, samples, newTree, HashMap, @@ -25,13 +27,12 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor -import Data.Bits import Data.IORef +import Data.Map qualified as Q import Data.Monoid import Data.Vector.Hashtables qualified as H import Data.Vector.Mutable qualified as M import Data.Vector.Unboxed.Mutable qualified as UM -import Data.Word import Language.Haskell.TH.Syntax qualified as TH import Numeric.Log import System.IO.Unsafe @@ -43,7 +44,7 @@ type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree - { draw :: !Double, + { draw :: Double, split :: (Tree, Tree) } @@ -103,3 +104,21 @@ samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f + +{-# NOINLINE memoize #-} +memoize :: Ord a => (a -> Prob b) -> Prob (a -> b) +memoize f = unsafePerformIO $ do + ref <- newIORef mempty + pure $ Prob $ \t -> \x -> unsafePerformIO $ do + m <- readIORef ref + case Q.lookup x m of + Just z -> pure z + _ -> do + let Prob g = f x + z = g (getNode (Q.size m) t) + m' = Q.insert x z m + writeIORef ref m' + pure z + where + getNode 0 t = t + getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1) |
