diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL/Sampling.hs | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 8a74854..e674b2b 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -2,6 +2,7 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} @@ -10,9 +11,11 @@ module PPL.Sampling ) where +import Control.Arrow import Control.Monad import Control.Monad.IO.Class import Data.IORef +import Data.List (sort) import Data.Vector qualified as V import Data.Vector.Hashtables qualified as H import Numeric.Log @@ -46,7 +49,8 @@ mh g p m = do (m, g0) <- readIORef omega m' <- H.clone m ks <- H.keys m - let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g) + let (rs :: [Double], qs :: [Double]) = (R.randoms *** R.randoms) (R.split g) + ks' = map snd $ sort $ zip rs $ V.toList ks n = fromIntegral (V.length ks) - when (n > 0) $ void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs + void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs newIORef (m', g0) |
