diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/PPL.hs | 4 | ||||
| -rw-r--r-- | src/PPL/Distr.hs | 88 | ||||
| -rw-r--r-- | src/PPL/Internal.hs | 75 | ||||
| -rw-r--r-- | src/PPL/Sampling.hs | 123 |
4 files changed, 220 insertions, 70 deletions
@@ -1,5 +1,5 @@ -module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where +module PPL (module PPL.Internal, module PPL.Sampling, module PPL.Distr) where +import PPL.Distr import PPL.Internal import PPL.Sampling -import PPL.Distr diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs index 797379b..d265f05 100644 --- a/src/PPL/Distr.hs +++ b/src/PPL/Distr.hs @@ -1,7 +1,15 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} + module PPL.Distr where +import Debug.Trace +import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.)) +import Data.Functor ((<&>)) +import Data.Word (Word64) import PPL.Internal -import qualified PPL.Internal as I +import Unsafe.Coerce -- Acklam's approximation -- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/ @@ -9,14 +17,15 @@ import qualified PPL.Internal as I probit :: Double -> Double probit p | p < lower = - let q = sqrt (-2 * log p) - in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) - / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) + let q = sqrt (-2 * log p) + in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) + / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1) | p < 1 - lower = - let q = p - 0.5 - r = q * q - in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q - / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) + let q = p - 0.5 + r = q * q + in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) + * q + / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1) | otherwise = -probit (1 - p) where a1 = -3.969683028665376e+01 @@ -49,11 +58,14 @@ probit p iid :: Prob a -> Prob [a] iid = sequence . repeat +gauss :: Prob Double gauss = probit <$> uniform +norm :: Double -> Double -> Prob Double norm m s = (+ m) . (* s) <$> gauss -- Marsaglia's fast gamma rejection sampling +gamma :: Double -> Prob Double gamma a = do x <- gauss u <- uniform @@ -64,15 +76,24 @@ gamma a = do d = a - 1 / 3 v x = (1 + x / sqrt (9 * d)) ** 3 +beta :: Double -> Double -> Prob Double beta a b = do x <- gamma a y <- gamma b pure $ x / (x + y) +beta' :: Double -> Double -> Prob Double +beta' a b = do + p <- beta a b + pure $ p / (1 - p) + +bern :: Double -> Prob Bool bern p = (< p) <$> uniform +binom :: Int -> Double -> Prob Int binom n = fmap (length . filter id . take n) . iid . bern +exponential :: Double -> Prob Double exponential lambda = negate . (/ lambda) . log <$> uniform geom :: Double -> Prob Int @@ -80,10 +101,48 @@ geom p = first 0 <$> iid (bern p) where first n (True : _) = n first n (_ : xs) = first (n + 1) xs + first _ _ = undefined + +uniform :: Prob Double +uniform = do + z <- unbounded + let r = countTrailingZeros z - 12 + e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r + pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52) + +-- Uses algorithm from 10.1145/3503512 +bounded :: Double -> Double -> Prob Double +bounded lower upper = do + (k1, k2) <- split <$> bounded' 1 (hi - 2) + pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g + where + hi = ceilint lower upper g + g = max (up lower - lower) (upper - down upper) + + split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3) + + up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1) -bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform + down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1) -bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper) + ceilint :: Double -> Double -> Double -> Word64 + ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0 + where + s = b / g - a / g + eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g) + +unbounded :: Prob Word64 +unbounded = Prob draw + +bounded' :: Integral a => a -> a -> Prob a +bounded' lower upper = do + z <- (`mod` nextPow2 r) <$> unbounded + if z <= r + then pure $ fromIntegral $ fromIntegral lower + z + else bounded' lower upper + where + r :: Word64 = fromIntegral $ upper - lower + nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x) cat :: [Double] -> Prob Int cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform @@ -93,8 +152,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform | x > r = i | otherwise = search (i + 1) xs r +dirichletProcess :: Double -> Prob [Double] dirichletProcess p = go 1 - where - go rest = do - x <- beta 1 p - (x*rest:) <$> go (rest - x*rest) + where + go rest = do + x <- beta 1 p + (x * rest :) <$> go (rest - x * rest) diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs index 0ba6ae5..ec9b0c8 100644 --- a/src/PPL/Internal.hs +++ b/src/PPL/Internal.hs @@ -1,56 +1,57 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} module PPL.Internal - ( uniform, - split, - Prob (..), + ( Prob (..), Meas, score, scoreLog, sample, - randomTree, + memoize, samples, - mutateTree, - splitTrees, - draw, + HashMap, + newTree, + Tree (..), ) where import Control.Monad -import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor +import Data.IORef +import Data.Map qualified as Q import Data.Monoid -import qualified Language.Haskell.TH.Syntax as TH +import Data.Vector.Hashtables qualified as H +import Data.Vector.Mutable qualified as M +import Data.Vector.Unboxed.Mutable qualified as UM +import Data.Word (Word64) +import Language.Haskell.TH.Syntax qualified as TH import Numeric.Log -import System.Random hiding (split, uniform) -import qualified System.Random as R +import System.IO.Unsafe +import System.Random hiding (random, randoms, split, uniform) +import System.Random qualified as R + +type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree - { draw :: !Double, - splitTrees :: [Tree] + { draw :: Word64, + split :: (Tree, Tree) } -split :: Tree -> (Tree, Tree) -split (Tree r (t : ts)) = (t, Tree r ts) - -{-# INLINE randomTree #-} -randomTree :: RandomGen g => g -> Tree -randomTree g = let (a, g') = random g in Tree a (randomTrees g') +newTree :: StdGen -> Tree +newTree g0 = Tree x (newTree g1, newTree g2) where - randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2 - -{-# INLINE mutateTree #-} -mutateTree :: Double -> Tree -> Tree -> Tree -> Tree -mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) = - if r < p - then b - else Tree a $ zipWith3 (mutateTree p) rs bs ts + (x, R.split -> (g1, g2)) = R.random g0 newtype Prob a = Prob {runProb :: Tree -> a} @@ -64,8 +65,6 @@ instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap -uniform = Prob $ \(Tree r _) -> r - newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) @@ -88,3 +87,21 @@ samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f + +{-# NOINLINE memoize #-} +memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b) +memoize f = unsafePerformIO $ do + ref <- newIORef mempty + pure $ Prob $ \t -> \x -> unsafePerformIO $ do + m <- readIORef ref + case Q.lookup x m of + Just z -> pure z + _ -> do + let Prob g = f x + z = g (getNode (Q.size m) t) + m' = Q.insert x z m + writeIORef ref m' + pure z + where + getNode 0 t = t + getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1) diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs index 84e5ada..4deaf36 100644 --- a/src/PPL/Sampling.hs +++ b/src/PPL/Sampling.hs @@ -1,33 +1,106 @@ -{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} -module PPL.Sampling where +module PPL.Sampling + ( mh, + ) +where +import Control.Arrow +import Control.Monad import Control.Monad.IO.Class -import Control.Monad.Trans.State -import Data.Bifunctor -import Data.Monoid -import Numeric.Log -import PPL.Distr -import PPL.Internal -import System.Random (StdGen, random, randoms) -import qualified Streaming as S -import Streaming.Prelude (Stream, yield, Of) - -mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () -mh g p m = step t0 t x w +import Data.Bits (countTrailingZeros, shiftL, shiftR) +import Data.IORef +import Data.List (foldl') +import Data.Vector qualified as V +import Data.Vector.Hashtables qualified as H +import Data.Word (Word64) +import Numeric.Log hiding (sum) +import PPL.Internal hiding (newTree, split) +import Streaming.Prelude (Of, Stream, yield) +import System.IO.Unsafe +import System.Random (StdGen, split) +import System.Random qualified as R +import Unsafe.Coerce + +type MemoTree = IORef (HashMap Integer Word64, StdGen) + +-- Let's draw doubles more uniformly than the standard generator +-- based on https://www.corsix.org/content/higher-quality-random-floats +random :: StdGen -> (Double, StdGen) +random g0 = + let (r1 :: Word64, g1) = R.random g0 + (r2 :: Word64, g2) = R.random g1 + e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r + dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52) + in (unsafeCoerce dbl, g2) + +{-# INLINE newMemoTree #-} +newMemoTree :: MemoTree -> Tree +newMemoTree s = go 1 where - (t0,t) = split $ randomTree g - (x, w) = head $ samples m t + go :: Integer -> Tree + go id = + Tree + ( unsafePerformIO $ do + (m, g) <- readIORef s + H.lookup m id >>= \case + Nothing -> do + let (x, g') = R.random g + H.insert m id x + writeIORef s (m, g') + pure x + Just x -> pure x + ) + (go (2 * id), go (2 * id + 1)) + +data MinHeap a k = Node a k (MinHeap a k) (MinHeap a k) | Empty + deriving (Show) - step !t0 !t !x !w = do - let (t1:t2:t3:t4:_) = splitTrees t0 - t' = mutateTree p t1 t2 t - (x', w') = head $ samples m t' +insert m k v = merge m (Node k v Empty Empty) + +merge Empty n = n +merge n Empty = n +merge n0@(Node k0 v0 l0 r0) n1@(Node k1 v1 l1 r1) = + if k0 < k1 + then Node k0 v0 (merge r0 n1) l0 + else Node k1 v1 r1 (merge l1 n0) + +toList Empty = [] +toList (Node k v l r) = v : toList (merge l r) + +mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m () +mh g p m = do + let (g0, g1) = split g + hm <- liftIO $ H.initialize 0 + omega <- liftIO $ newIORef (hm, g0) + let (x, w) = head $ samples m $ newMemoTree omega + step g1 omega x w + where + step !g0 !omega !x !w = do + let (Exp . log -> r, split -> (g1, g2)) = random g0 + omega' <- mutate g1 omega + let (!x', !w') = head $ samples m $ newMemoTree omega' ratio = w' / w - (Exp . log -> r) = draw t3 - (t'', x'', w'') = if r < ratio - then (t', x', w') - else (t, x, w) + (omega'', x'', w'') = + if r < ratio + then (omega', x', w') + else (omega, x, w) yield (x'', w'') - step t4 t'' x'' w'' + step g2 omega'' x'' w'' + + mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree + mutate g omega = liftIO $ do + (m, g0) <- readIORef omega + m' <- H.clone m + ks <- H.keys m + let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g) + ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks + n = fromIntegral (V.length ks) + void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs + newIORef (m', g0) |
