aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2025-03-05 13:25:10 +1100
committerJustin Bedo <cu@cua0.org>2025-03-05 14:49:58 +1100
commit51c32c174a168db6f97a7f93dcc58bcb7c351a65 (patch)
tree300cf52fbf2e6abfa6856b0788b06d8dedc1c8ca
parentc7524e19c481a50a9095b1d567e3d23316915f82 (diff)
use vector-hashtable
-rw-r--r--package.yaml3
-rw-r--r--ppl.cabal3
-rw-r--r--src/PPL/Internal.hs38
-rw-r--r--src/PPL/Sampling.hs41
4 files changed, 31 insertions, 54 deletions
diff --git a/package.yaml b/package.yaml
index 46fc800..24f14da 100644
--- a/package.yaml
+++ b/package.yaml
@@ -14,7 +14,8 @@ dependencies:
- streaming
- ghc-heap
- deepseq
- - unordered-containers
+ - vector
+ - vector-hashtables
library:
source-dirs: src
diff --git a/ppl.cabal b/ppl.cabal
index 6474f11..ccaf4f0 100644
--- a/ppl.cabal
+++ b/ppl.cabal
@@ -31,5 +31,6 @@ library
, streaming
, template-haskell
, transformers
- , unordered-containers
+ , vector
+ , vector-hashtables
default-language: Haskell2010
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 7273d23..960b19a 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,19 +1,18 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE LambdaCase #-}
module PPL.Internal
( uniform,
- split,
Prob (..),
Meas,
score,
scoreLog,
sample,
- randomTree,
samples,
- mutateTree,
newTree,
+ HashMap,
Tree(..),
)
where
@@ -28,45 +27,34 @@ import qualified Language.Haskell.TH.Syntax as TH
import Numeric.Log
import System.Random hiding (split, uniform)
import qualified System.Random as R
-import qualified Data.HashMap.Strict as M
import Data.IORef
import System.IO.Unsafe
+import qualified Data.Vector.Mutable as VM
+import qualified Data.Vector.Unboxed.Mutable as UM
+import qualified Data.Vector.Hashtables as H
+
+type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
{ draw :: !Double,
- splitTrees :: [Tree]
+ split :: (Tree, Tree)
}
-split :: Tree -> (Tree, Tree)
-split (Tree r (t : ts)) = (t, Tree r ts)
-
{-# INLINE newTree #-}
-newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree
+newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree
newTree s = go []
where
go id = Tree (unsafePerformIO $ do
(m, g) <- readIORef s
- case M.lookup id m of
+ H.lookup m id >>= \case
Nothing -> do
let (x, g') = R.random g
- writeIORef s (M.insert id x m, g')
+ H.insert m id x
+ writeIORef s (m, g')
pure x
- Just x -> pure x) [go (i:id) | i <- [0..]]
-
-{-# INLINE randomTree #-}
-randomTree :: RandomGen g => g -> Tree
-randomTree g = let (a, g') = random g in Tree a (randomTrees g')
- where
- randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-
-{-# NOINLINE mutateTree #-}
-mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
-mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
- if r < p
- then b
- else Tree a $ zipWith3 (mutateTree p) rs bs ts
+ Just x -> pure x) (go (False : id), go (True : id))
newtype Prob a = Prob {runProb :: Tree -> a}
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 2d0c380..5ad346d 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -14,7 +14,6 @@ import Control.Exception (evaluate)
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Bifunctor
-import qualified Data.HashMap.Strict as M
import Data.Monoid
import GHC.Exts.Heap
import Numeric.Log
@@ -27,12 +26,14 @@ import System.Random (StdGen, random, randoms)
import qualified System.Random as R
import Data.IORef
import Control.Monad
-import Debug.Trace
+import qualified Data.Vector.Hashtables as H
+import qualified Data.Vector as V
mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
mh g p m = do
let (g0, g1) = R.split g
- omega <- liftIO $ newIORef (mempty, g0)
+ hm <- liftIO $ H.initialize 0
+ omega <- liftIO $ newIORef (hm, g0)
let (x, w) = head $ samples m $ newTree omega
step g1 omega x w
where
@@ -49,27 +50,13 @@ mh g p m = do
yield (x'', w'')
step g2 omega'' x'' w''
- mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) -> m (IORef (M.HashMap [Int] Double, StdGen))
- mutate g omega = do
- (m, g0) <- liftIO $ readIORef omega
- {-let (r:q:_) = R.randoms g
- ks = M.keys m
- k = ks !! floor (r * join traceShow (fromIntegral (length ks)))
- m' = M.insert k q m
- liftIO $ newIORef $ (m',g0)-}
- liftIO $ newIORef $ (,g0) $ flip evalState g $ mapM go m
-
- where
- go x = do
- g <- get
- let (r, g1) = R.random g
- (y, g2) = R.random g1
- if r < p
- then do
- put g2
- pure y
- else do
- put g1
- pure x
-
-
+ mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) -> m (IORef (HashMap [Bool] Double, StdGen))
+ mutate g omega = liftIO $ do
+ (m, g0) <- readIORef omega
+ m' <- H.clone m
+ ks <- H.keys m
+ let (rs, qs) = splitAt (1 + floor (p * (n-1))) (R.randoms g)
+ n = fromIntegral (V.length ks)
+ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs
+ newIORef (m',g0)
+