diff options
author | Justin Bedo <cu@cua0.org> | 2025-03-05 13:25:10 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2025-03-05 14:49:58 +1100 |
commit | 51c32c174a168db6f97a7f93dcc58bcb7c351a65 (patch) | |
tree | 300cf52fbf2e6abfa6856b0788b06d8dedc1c8ca | |
parent | c7524e19c481a50a9095b1d567e3d23316915f82 (diff) |
use vector-hashtable
-rw-r--r-- | package.yaml | 3 | ||||
-rw-r--r-- | ppl.cabal | 3 | ||||
-rw-r--r-- | src/PPL/Internal.hs | 38 | ||||
-rw-r--r-- | src/PPL/Sampling.hs | 41 |
4 files changed, 31 insertions, 54 deletions
diff --git a/package.yaml b/package.yaml index 46fc800..24f14da 100644 --- a/package.yaml +++ b/package.yaml @@ -14,7 +14,8 @@ dependencies: - streaming - ghc-heap - deepseq - - unordered-containers + - vector + - vector-hashtables library: source-dirs: src @@ -31,5 +31,6 @@ library , streaming , template-haskell , transformers - , unordered-containers + , vector + , vector-hashtables default-language: Haskell2010 diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 7273d23..960b19a 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,19 +1,18 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE LambdaCase #-} module PPL.Internal ( uniform, - split, Prob (..), Meas, score, scoreLog, sample, - randomTree, samples, - mutateTree, newTree, + HashMap, Tree(..), ) where @@ -28,45 +27,34 @@ import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (split, uniform) import qualified System.Random as R -import qualified Data.HashMap.Strict as M import Data.IORef import System.IO.Unsafe +import qualified Data.Vector.Mutable as VM +import qualified Data.Vector.Unboxed.Mutable as UM +import qualified Data.Vector.Hashtables as H + +type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, - splitTrees :: [Tree] + split :: (Tree, Tree) } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - {-# INLINE newTree #-} -newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree +newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree newTree s = go [] where go id = Tree (unsafePerformIO $ do (m, g) <- readIORef s - case M.lookup id m of + H.lookup m id >>= \case Nothing -> do let (x, g') = R.random g - writeIORef s (M.insert id x m, g') + H.insert m id x + writeIORef s (m, g') pure x - Just x -> pure x) [go (i:id) | i <- [0..]] - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') - where - randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# NOINLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = - if r < p - then b - else Tree a $ zipWith3 (mutateTree p) rs bs ts + Just x -> pure x) (go (False : id), go (True : id)) newtype Prob a = Prob {runProb :: Tree -> a} diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 2d0c380..5ad346d 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -14,7 +14,6 @@ import Control.Exception (evaluate) import Control.Monad.IO.Class import Control.Monad.Trans.State import Data.Bifunctor -import qualified Data.HashMap.Strict as M import Data.Monoid import GHC.Exts.Heap import Numeric.Log @@ -27,12 +26,14 @@ import System.Random (StdGen, random, randoms) import qualified System.Random as R import Data.IORef import Control.Monad -import Debug.Trace +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector as V mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () mh g p m = do let (g0, g1) = R.split g - omega <- liftIO $ newIORef (mempty, g0) + hm <- liftIO $ H.initialize 0 + omega <- liftIO $ newIORef (hm, g0) let (x, w) = head $ samples m $ newTree omega step g1 omega x w where @@ -49,27 +50,13 @@ mh g p m = do yield (x'', w'') step g2 omega'' x'' w'' - mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) -> m (IORef (M.HashMap [Int] Double, StdGen)) - mutate g omega = do - (m, g0) <- liftIO $ readIORef omega - {-let (r:q:_) = R.randoms g - ks = M.keys m - k = ks !! floor (r * join traceShow (fromIntegral (length ks))) - m' = M.insert k q m - liftIO $ newIORef $ (m',g0)-} - liftIO $ newIORef $ (,g0) $ flip evalState g $ mapM go m - - where - go x = do - g <- get - let (r, g1) = R.random g - (y, g2) = R.random g1 - if r < p - then do - put g2 - pure y - else do - put g1 - pure x - - + mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) -> m (IORef (HashMap [Bool] Double, StdGen)) + mutate g omega = liftIO $ do + (m, g0) <- readIORef omega + m' <- H.clone m + ks <- H.keys m + let (rs, qs) = splitAt (1 + floor (p * (n-1))) (R.randoms g) + n = fromIntegral (V.length ks) + zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs + newIORef (m',g0) + |