aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flake.nix13
-rw-r--r--package.yaml2
-rw-r--r--ppl.cabal4
-rw-r--r--src/PPL.hs4
-rw-r--r--src/PPL/Distr.hs88
-rw-r--r--src/PPL/Internal.hs75
-rw-r--r--src/PPL/Sampling.hs123
7 files changed, 234 insertions, 75 deletions
diff --git a/flake.nix b/flake.nix
index 76c754c..432d3a0 100644
--- a/flake.nix
+++ b/flake.nix
@@ -3,7 +3,7 @@
extra-trusted-public-keys = ["hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ="];
extra-substituters = ["https://cache.iog.io"];
};
- description = "Bayesian phylogentic playground";
+ description = "Probabilistic Programming Language";
inputs.haskellNix.url = "github:input-output-hk/haskell.nix";
inputs.nixpkgs.follows = "haskellNix/nixpkgs-unstable";
inputs.flake-utils.url = "github:numtide/flake-utils";
@@ -17,14 +17,19 @@
overlays = [
haskellNix.overlay
(final: prev: {
- phylogenies = final.haskell-nix.project' {
+ ppl = final.haskell-nix.project' {
src = ./.;
compiler-nix-name = "ghc925";
shell.tools = {
hlint = {};
ormolu = {};
+ haskell-language-server = {};
};
- shell.buildInputs = with pkgs; [
+ shell.buildInputs = [
+ (pkgs.writeScriptBin "haskell-language-server-wrapper" ''
+ #!${pkgs.stdenv.shell}
+ exec haskell-language-server "$@"
+ '')
];
};
})
@@ -33,7 +38,7 @@
inherit system overlays;
inherit (haskellNix) config;
};
- flake = pkgs.phylogenies.flake {};
+ flake = pkgs.ppl.flake {};
in
flake // { packages.default = flake.packages."ppl:lib:ppl";});
}
diff --git a/package.yaml b/package.yaml
index 703c5a2..9efd300 100644
--- a/package.yaml
+++ b/package.yaml
@@ -12,6 +12,8 @@ dependencies:
- template-haskell
- containers
- streaming
+ - vector
+ - vector-hashtables
library:
source-dirs: src
diff --git a/ppl.cabal b/ppl.cabal
index 5d8db40..75dd0bc 100644
--- a/ppl.cabal
+++ b/ppl.cabal
@@ -1,6 +1,6 @@
cabal-version: 1.12
--- This file has been generated from package.yaml by hpack version 0.34.7.
+-- This file has been generated from package.yaml by hpack version 0.36.1.
--
-- see: https://github.com/sol/hpack
@@ -29,4 +29,6 @@ library
, streaming
, template-haskell
, transformers
+ , vector
+ , vector-hashtables
default-language: Haskell2010
diff --git a/src/PPL.hs b/src/PPL.hs
index 21bb87c..d1dc17a 100644
--- a/src/PPL.hs
+++ b/src/PPL.hs
@@ -1,5 +1,5 @@
-module PPL(module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+module PPL (module PPL.Internal, module PPL.Sampling, module PPL.Distr) where
+import PPL.Distr
import PPL.Internal
import PPL.Sampling
-import PPL.Distr
diff --git a/src/PPL/Distr.hs b/src/PPL/Distr.hs
index 797379b..d265f05 100644
--- a/src/PPL/Distr.hs
+++ b/src/PPL/Distr.hs
@@ -1,7 +1,15 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
+
module PPL.Distr where
+import Debug.Trace
+import Data.Bits (countLeadingZeros, countTrailingZeros, shiftL, shiftR, (.&.))
+import Data.Functor ((<&>))
+import Data.Word (Word64)
import PPL.Internal
-import qualified PPL.Internal as I
+import Unsafe.Coerce
-- Acklam's approximation
-- https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/
@@ -9,14 +17,15 @@ import qualified PPL.Internal as I
probit :: Double -> Double
probit p
| p < lower =
- let q = sqrt (-2 * log p)
- in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
- / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
+ let q = sqrt (-2 * log p)
+ in (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
+ / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1)
| p < 1 - lower =
- let q = p - 0.5
- r = q * q
- in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
- / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
+ let q = p - 0.5
+ r = q * q
+ in (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6)
+ * q
+ / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1)
| otherwise = -probit (1 - p)
where
a1 = -3.969683028665376e+01
@@ -49,11 +58,14 @@ probit p
iid :: Prob a -> Prob [a]
iid = sequence . repeat
+gauss :: Prob Double
gauss = probit <$> uniform
+norm :: Double -> Double -> Prob Double
norm m s = (+ m) . (* s) <$> gauss
-- Marsaglia's fast gamma rejection sampling
+gamma :: Double -> Prob Double
gamma a = do
x <- gauss
u <- uniform
@@ -64,15 +76,24 @@ gamma a = do
d = a - 1 / 3
v x = (1 + x / sqrt (9 * d)) ** 3
+beta :: Double -> Double -> Prob Double
beta a b = do
x <- gamma a
y <- gamma b
pure $ x / (x + y)
+beta' :: Double -> Double -> Prob Double
+beta' a b = do
+ p <- beta a b
+ pure $ p / (1 - p)
+
+bern :: Double -> Prob Bool
bern p = (< p) <$> uniform
+binom :: Int -> Double -> Prob Int
binom n = fmap (length . filter id . take n) . iid . bern
+exponential :: Double -> Prob Double
exponential lambda = negate . (/ lambda) . log <$> uniform
geom :: Double -> Prob Int
@@ -80,10 +101,48 @@ geom p = first 0 <$> iid (bern p)
where
first n (True : _) = n
first n (_ : xs) = first (n + 1) xs
+ first _ _ = undefined
+
+uniform :: Prob Double
+uniform = do
+ z <- unbounded
+ let r = countTrailingZeros z - 12
+ e <- if r >= 0 then countTrailingZeros <$> unbounded else pure r
+ pure . unsafeCoerce $ z `shiftR` 12 - ((unsafeCoerce e - 1010) `shiftL` 52)
+
+-- Uses algorithm from 10.1145/3503512
+bounded :: Double -> Double -> Prob Double
+bounded lower upper = do
+ (k1, k2) <- split <$> bounded' 1 (hi - 2)
+ pure $ if abs lower <= abs upper then 4 * (upper / 4 - k1 * g) - k2 * g else 4 * (lower / 4 + k1 * g) + k2 * g
+ where
+ hi = ceilint lower upper g
+ g = max (up lower - lower) (upper - down upper)
+
+ split x = (fromIntegral $ x `shiftR` 2, fromIntegral $ x .&. 0x3)
+
+ up x = unsafeCoerce @Word64 @Double (unsafeCoerce x + 1)
-bounded lower upper = (+ lower) . (* (upper - lower)) <$> uniform
+ down x = unsafeCoerce @Word64 @Double (unsafeCoerce x - 1)
-bounded' lower upper = round <$> bounded (fromIntegral lower) (fromIntegral upper)
+ ceilint :: Double -> Double -> Double -> Word64
+ ceilint a b g = ceiling s + if s == fromIntegral (ceiling s) && eps > 0 then 1 else 0
+ where
+ s = b / g - a / g
+ eps = if abs a <= abs b then negate a / g - (s - b / g) else b / g - (s + a / g)
+
+unbounded :: Prob Word64
+unbounded = Prob draw
+
+bounded' :: Integral a => a -> a -> Prob a
+bounded' lower upper = do
+ z <- (`mod` nextPow2 r) <$> unbounded
+ if z <= r
+ then pure $ fromIntegral $ fromIntegral lower + z
+ else bounded' lower upper
+ where
+ r :: Word64 = fromIntegral $ upper - lower
+ nextPow2 x = 1 `shiftL` (64 - countLeadingZeros x)
cat :: [Double] -> Prob Int
cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
@@ -93,8 +152,9 @@ cat xs = search 0 (tail $ scanl (+) 0 xs) <$> uniform
| x > r = i
| otherwise = search (i + 1) xs r
+dirichletProcess :: Double -> Prob [Double]
dirichletProcess p = go 1
- where
- go rest = do
- x <- beta 1 p
- (x*rest:) <$> go (rest - x*rest)
+ where
+ go rest = do
+ x <- beta 1 p
+ (x * rest :) <$> go (rest - x * rest)
diff --git a/src/PPL/Internal.hs b/src/PPL/Internal.hs
index 0ba6ae5..ec9b0c8 100644
--- a/src/PPL/Internal.hs
+++ b/src/PPL/Internal.hs
@@ -1,56 +1,57 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE ViewPatterns #-}
module PPL.Internal
- ( uniform,
- split,
- Prob (..),
+ ( Prob (..),
Meas,
score,
scoreLog,
sample,
- randomTree,
+ memoize,
samples,
- mutateTree,
- splitTrees,
- draw,
+ HashMap,
+ newTree,
+ Tree (..),
)
where
import Control.Monad
-import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
+import Data.IORef
+import Data.Map qualified as Q
import Data.Monoid
-import qualified Language.Haskell.TH.Syntax as TH
+import Data.Vector.Hashtables qualified as H
+import Data.Vector.Mutable qualified as M
+import Data.Vector.Unboxed.Mutable qualified as UM
+import Data.Word (Word64)
+import Language.Haskell.TH.Syntax qualified as TH
import Numeric.Log
-import System.Random hiding (split, uniform)
-import qualified System.Random as R
+import System.IO.Unsafe
+import System.Random hiding (random, randoms, split, uniform)
+import System.Random qualified as R
+
+type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
- { draw :: !Double,
- splitTrees :: [Tree]
+ { draw :: Word64,
+ split :: (Tree, Tree)
}
-split :: Tree -> (Tree, Tree)
-split (Tree r (t : ts)) = (t, Tree r ts)
-
-{-# INLINE randomTree #-}
-randomTree :: RandomGen g => g -> Tree
-randomTree g = let (a, g') = random g in Tree a (randomTrees g')
+newTree :: StdGen -> Tree
+newTree g0 = Tree x (newTree g1, newTree g2)
where
- randomTrees g = let (g1, g2) = R.split g in randomTree g1 : randomTrees g2
-
-{-# INLINE mutateTree #-}
-mutateTree :: Double -> Tree -> Tree -> Tree -> Tree
-mutateTree p (Tree r rs) b@(Tree _ bs) (Tree a ts) =
- if r < p
- then b
- else Tree a $ zipWith3 (mutateTree p) rs bs ts
+ (x, R.split -> (g1, g2)) = R.random g0
newtype Prob a = Prob {runProb :: Tree -> a}
@@ -64,8 +65,6 @@ instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
-uniform = Prob $ \(Tree r _) -> r
-
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
deriving (Functor, Applicative, Monad)
@@ -88,3 +87,21 @@ samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
samples (Meas m) = map (second getProduct) . runProb f
where
f = runWriterT m >>= \x -> (x :) <$> f
+
+{-# NOINLINE memoize #-}
+memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b)
+memoize f = unsafePerformIO $ do
+ ref <- newIORef mempty
+ pure $ Prob $ \t -> \x -> unsafePerformIO $ do
+ m <- readIORef ref
+ case Q.lookup x m of
+ Just z -> pure z
+ _ -> do
+ let Prob g = f x
+ z = g (getNode (Q.size m) t)
+ m' = Q.insert x z m
+ writeIORef ref m'
+ pure z
+ where
+ getNode 0 t = t
+ getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1)
diff --git a/src/PPL/Sampling.hs b/src/PPL/Sampling.hs
index 84e5ada..4deaf36 100644
--- a/src/PPL/Sampling.hs
+++ b/src/PPL/Sampling.hs
@@ -1,33 +1,106 @@
-{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE ViewPatterns #-}
-module PPL.Sampling where
+module PPL.Sampling
+ ( mh,
+ )
+where
+import Control.Arrow
+import Control.Monad
import Control.Monad.IO.Class
-import Control.Monad.Trans.State
-import Data.Bifunctor
-import Data.Monoid
-import Numeric.Log
-import PPL.Distr
-import PPL.Internal
-import System.Random (StdGen, random, randoms)
-import qualified Streaming as S
-import Streaming.Prelude (Stream, yield, Of)
-
-mh :: Monad m => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
-mh g p m = step t0 t x w
+import Data.Bits (countTrailingZeros, shiftL, shiftR)
+import Data.IORef
+import Data.List (foldl')
+import Data.Vector qualified as V
+import Data.Vector.Hashtables qualified as H
+import Data.Word (Word64)
+import Numeric.Log hiding (sum)
+import PPL.Internal hiding (newTree, split)
+import Streaming.Prelude (Of, Stream, yield)
+import System.IO.Unsafe
+import System.Random (StdGen, split)
+import System.Random qualified as R
+import Unsafe.Coerce
+
+type MemoTree = IORef (HashMap Integer Word64, StdGen)
+
+-- Let's draw doubles more uniformly than the standard generator
+-- based on https://www.corsix.org/content/higher-quality-random-floats
+random :: StdGen -> (Double, StdGen)
+random g0 =
+ let (r1 :: Word64, g1) = R.random g0
+ (r2 :: Word64, g2) = R.random g1
+ e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
+ dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
+ in (unsafeCoerce dbl, g2)
+
+{-# INLINE newMemoTree #-}
+newMemoTree :: MemoTree -> Tree
+newMemoTree s = go 1
where
- (t0,t) = split $ randomTree g
- (x, w) = head $ samples m t
+ go :: Integer -> Tree
+ go id =
+ Tree
+ ( unsafePerformIO $ do
+ (m, g) <- readIORef s
+ H.lookup m id >>= \case
+ Nothing -> do
+ let (x, g') = R.random g
+ H.insert m id x
+ writeIORef s (m, g')
+ pure x
+ Just x -> pure x
+ )
+ (go (2 * id), go (2 * id + 1))
+
+data MinHeap a k = Node a k (MinHeap a k) (MinHeap a k) | Empty
+ deriving (Show)
- step !t0 !t !x !w = do
- let (t1:t2:t3:t4:_) = splitTrees t0
- t' = mutateTree p t1 t2 t
- (x', w') = head $ samples m t'
+insert m k v = merge m (Node k v Empty Empty)
+
+merge Empty n = n
+merge n Empty = n
+merge n0@(Node k0 v0 l0 r0) n1@(Node k1 v1 l1 r1) =
+ if k0 < k1
+ then Node k0 v0 (merge r0 n1) l0
+ else Node k1 v1 r1 (merge l1 n0)
+
+toList Empty = []
+toList (Node k v l r) = v : toList (merge l r)
+
+mh :: (MonadIO m) => StdGen -> Double -> Meas a -> Stream (Of (a, Log Double)) m ()
+mh g p m = do
+ let (g0, g1) = split g
+ hm <- liftIO $ H.initialize 0
+ omega <- liftIO $ newIORef (hm, g0)
+ let (x, w) = head $ samples m $ newMemoTree omega
+ step g1 omega x w
+ where
+ step !g0 !omega !x !w = do
+ let (Exp . log -> r, split -> (g1, g2)) = random g0
+ omega' <- mutate g1 omega
+ let (!x', !w') = head $ samples m $ newMemoTree omega'
ratio = w' / w
- (Exp . log -> r) = draw t3
- (t'', x'', w'') = if r < ratio
- then (t', x', w')
- else (t, x, w)
+ (omega'', x'', w'') =
+ if r < ratio
+ then (omega', x', w')
+ else (omega, x, w)
yield (x'', w'')
- step t4 t'' x'' w''
+ step g2 omega'' x'' w''
+
+ mutate :: (MonadIO m) => StdGen -> MemoTree -> m MemoTree
+ mutate g omega = liftIO $ do
+ (m, g0) <- readIORef omega
+ m' <- H.clone m
+ ks <- H.keys m
+ let (rs :: [Word64], qs :: [Word64]) = (R.randoms *** R.randoms) (split g)
+ ks' = toList $ foldl' (\m -> uncurry (insert m)) Empty $ zip rs $ V.toList ks
+ n = fromIntegral (V.length ks)
+ void $ zipWithM_ (\k q -> H.insert m' k q) (take (1 + floor (p * fromIntegral n)) ks') qs
+ newIORef (m', g0)