1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module PPL.Internal
( uniform,
Prob (..),
Meas,
score,
scoreLog,
sample,
memoize,
samples,
HashMap,
random,
randoms,
newTree,
Tree (..),
)
where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.Map qualified as Q
import Data.Monoid
import Data.Vector.Hashtables qualified as H
import Data.Vector.Mutable qualified as M
import Data.Vector.Unboxed.Mutable qualified as UM
import Data.Word (Word64)
import Language.Haskell.TH.Syntax qualified as TH
import Numeric.Log
import System.IO.Unsafe
import System.Random hiding (random, randoms, split, uniform)
import System.Random qualified as R
import Unsafe.Coerce
type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v
-- Reimplementation of the LazyPPL monads to avoid some dependencies
data Tree = Tree
{ draw :: Double,
split :: (Tree, Tree)
}
-- Let's draw doubles more uniformly than the standard generator
-- based on https://www.corsix.org/content/higher-quality-random-floats
random :: StdGen -> (Double, StdGen)
random g0 =
let (r1 :: Word64, g1) = R.random g0
(r2 :: Word64, g2) = R.random g1
e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
in (unsafeCoerce dbl, g2)
randoms :: StdGen -> [Double]
randoms g =
let (x, g') = random g
in x : randoms g'
newTree :: StdGen -> Tree
newTree g0 = Tree x (newTree g1, newTree g2)
where
(x, R.split -> (g1, g2)) = random g0
newtype Prob a = Prob {runProb :: Tree -> a}
instance Monad Prob where
Prob f >>= g = Prob $ \t ->
let (t1, t2) = split t
(Prob g') = g (f t1)
in g' t2
instance Functor Prob where fmap = liftM
instance Applicative Prob where pure = Prob . const; (<*>) = ap
uniform :: Prob Double
uniform = Prob $ \(Tree r _) -> r
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
deriving (Functor, Applicative, Monad)
{-# INLINE score #-}
score :: Double -> Meas ()
score = scoreLog . Exp . log . max eps
where
eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval
{-# INLINE scoreLog #-}
scoreLog :: Log Double -> Meas ()
scoreLog = Meas . tell . Product
{-# INLINE sample #-}
sample :: Prob a -> Meas a
sample = Meas . lift
{-# INLINE samples #-}
samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
samples (Meas m) = map (second getProduct) . runProb f
where
f = runWriterT m >>= \x -> (x :) <$> f
{-# NOINLINE memoize #-}
memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b)
memoize f = unsafePerformIO $ do
ref <- newIORef mempty
pure $ Prob $ \t -> \x -> unsafePerformIO $ do
m <- readIORef ref
case Q.lookup x m of
Just z -> pure z
_ -> do
let Prob g = f x
z = g (getNode (Q.size m) t)
m' = Q.insert x z m
writeIORef ref m'
pure z
where
getNode 0 t = t
getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1)
|