aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/PPL/Sampling.hs8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 8a74854..e674b2b 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
@@ -10,9 +11,11 @@ module PPL.Sampling
)
where
+import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Data.IORef
+import Data.List (sort)
import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
import Numeric.Log
@@ -46,7 +49,8 @@ mh g p m = do
(m, g0) <- readIORef omega
m' <- H.clone m
ks <- H.keys m
- let (rs, qs) = splitAt (1 + floor (p * (n - 1))) (R.randoms g)
+ let (rs :: [Double], qs :: [Double]) = (R.randoms *** R.randoms) (R.split g)
+ ks' = map snd $ sort $ zip rs $ V.toList ks
n = fromIntegral (V.length ks)
- when (n > 0) $ void $ zipWithM (\r q -> H.insert m' (ks V.! floor (r * n)) q) rs qs
+ void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs
newIORef (m', g0)