aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/PPL/Internal.hs')
-rw-r--r--src/PPL/Internal.hs29
1 files changed, 3 insertions, 26 deletions
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 3c54078..ec9b0c8 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -9,8 +9,7 @@
{-# LANGUAGE ViewPatterns #-}
module PPL.Internal
- ( uniform,
- Prob (..),
+ ( Prob (..),
Meas,
score,
scoreLog,
@@ -18,8 +17,6 @@ module PPL.Internal
memoize,
samples,
HashMap,
- random,
- randoms,
newTree,
Tree (..),
)
@@ -29,7 +26,6 @@ import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
-import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.Map qualified as Q
import Data.Monoid
@@ -42,36 +38,20 @@ import Numeric.Log
import System.IO.Unsafe
import System.Random hiding (random, randoms, split, uniform)
import System.Random qualified as R
-import Unsafe.Coerce
type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
- { draw :: Double,
+ { draw :: Word64,
split :: (Tree, Tree)
}
--- Let's draw doubles more uniformly than the standard generator
--- based on https://www.corsix.org/content/higher-quality-random-floats
-random :: StdGen -> (Double, StdGen)
-random g0 =
- let (r1 :: Word64, g1) = R.random g0
- (r2 :: Word64, g2) = R.random g1
- e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
- dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
- in (unsafeCoerce dbl, g2)
-
-randoms :: StdGen -> [Double]
-randoms g =
- let (x, g') = random g
- in x : randoms g'
-
newTree :: StdGen -> Tree
newTree g0 = Tree x (newTree g1, newTree g2)
where
- (x, R.split -> (g1, g2)) = random g0
+ (x, R.split -> (g1, g2)) = R.random g0
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -85,9 +65,6 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
-uniform :: Prob Double
-uniform = Prob $ \(Tree r _) -> r
-
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
deriving (Functor, Applicative, Monad)