diff options
| author | Justin Bedo <cu@cua0.org> | 2025-03-05 13:25:10 +1100 | 
|---|---|---|
| committer | Justin Bedo <cu@cua0.org> | 2025-03-05 14:49:58 +1100 | 
| commit | 51c32c174a168db6f97a7f93dcc58bcb7c351a65 (patch) | |
| tree | 300cf52fbf2e6abfa6856b0788b06d8dedc1c8ca /src/PPL | |
| parent | c7524e19c481a50a9095b1d567e3d23316915f82 (diff) | |
use vector-hashtable
Diffstat (limited to 'src/PPL')
| -rw-r--r-- | src/PPL/Internal.hs | 38 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 41 | 
2 files changed, 27 insertions, 52 deletions
| diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 7273d23..960b19a 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,19 +1,18 @@  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE LambdaCase #-}  module PPL.Internal    ( uniform, -    split,      Prob (..),      Meas,      score,      scoreLog,      sample, -    randomTree,      samples, -    mutateTree,      newTree, +    HashMap,      Tree(..),    )  where @@ -28,45 +27,34 @@ import qualified Language.Haskell.TH.Syntax as TH  import Numeric.Log  import System.Random hiding (split, uniform)  import qualified System.Random as R -import qualified Data.HashMap.Strict as M  import Data.IORef  import System.IO.Unsafe +import qualified Data.Vector.Mutable as VM +import qualified Data.Vector.Unboxed.Mutable as UM +import qualified Data.Vector.Hashtables as H + +type HashMap k v = H.Dictionary (H.PrimState IO) VM.MVector k UM.MVector v  -- Reimplementation of the LazyPPL monads to avoid some dependencies  data Tree = Tree    { draw :: !Double, -    splitTrees :: [Tree] +    split :: (Tree, Tree)    } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) -  {-# INLINE newTree #-} -newTree :: IORef (M.HashMap [Int] Double, StdGen) -> Tree +newTree :: IORef (HashMap [Bool] Double, StdGen) -> Tree  newTree s = go []    where      go id = Tree (unsafePerformIO $ do        (m, g) <- readIORef s -      case M.lookup id m of +      H.lookup m id >>= \case          Nothing -> do            let (x, g') = R.random g -          writeIORef s (M.insert id x m, g') +          H.insert m id x +          writeIORef s (m, g')            pure x -        Just x -> pure x) [go (i:id) | i <- [0..]] - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') -  where -    randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# NOINLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = -  if r < p -    then b -    else Tree a $ zipWith3 (mutateTree p) rs bs ts +        Just x -> pure x) (go (False : id), go (True : id))  newtype Prob a = Prob {runProb :: Tree -> a} diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 2d0c380..5ad346d 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -14,7 +14,6 @@ import Control.Exception (evaluate)  import Control.Monad.IO.Class  import Control.Monad.Trans.State  import Data.Bifunctor -import qualified Data.HashMap.Strict as M  import Data.Monoid  import GHC.Exts.Heap  import Numeric.Log @@ -27,12 +26,14 @@ import System.Random (StdGen, random, randoms)  import qualified System.Random as R  import Data.IORef  import Control.Monad -import Debug.Trace +import qualified Data.Vector.Hashtables as H +import qualified Data.Vector as V  mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()  mh g p m = do      let (g0, g1) = R.split g -    omega <- liftIO $ newIORef (mempty, g0) +    hm <- liftIO $ H.initialize 0 +    omega <- liftIO $ newIORef (hm, g0)      let (x, w) = head $ samples m $ newTree omega      step g1 omega x w    where @@ -49,27 +50,13 @@ mh g p m = do        yield (x'', w'')        step g2 omega'' x'' w'' -    mutate :: MonadIO m => StdGen -> IORef (M.HashMap [Int] Double, StdGen) ->  m (IORef (M.HashMap [Int] Double, StdGen)) -    mutate g omega = do -      (m, g0) <- liftIO $ readIORef omega -      {-let (r:q:_) = R.randoms g -          ks = M.keys m -          k = ks !! floor (r * join traceShow (fromIntegral (length ks))) -          m' = M.insert k q m -      liftIO $ newIORef $ (m',g0)-} -      liftIO $ newIORef $ (,g0) $ flip evalState g $ mapM go m -       -     where -      go x = do -        g <- get -        let (r, g1) = R.random g -            (y, g2) = R.random g1 -        if r < p  -          then do -            put g2 -            pure y -          else do -            put g1 -            pure x   -             - +    mutate :: MonadIO m => StdGen -> IORef (HashMap [Bool] Double, StdGen) ->  m (IORef (HashMap [Bool] Double, StdGen)) +    mutate g omega = liftIO $ do +      (m, g0) <- readIORef omega +      m' <- H.clone m +      ks <- H.keys m +      let (rs, qs) = splitAt (1 + floor (p * (n-1))) (R.randoms g) +          n = fromIntegral (V.length ks) +      zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs +      newIORef (m',g0) +      | 
