aboutsummaryrefslogtreecommitdiff
path: root/src/PPL/Internal.hs
blob: 3c54078ac3866c3c14240018167a3b098b2671f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}

module PPL.Internal
  ( uniform,
    Prob (..),
    Meas,
    score,
    scoreLog,
    sample,
    memoize,
    samples,
    HashMap,
    random,
    randoms,
    newTree,
    Tree (..),
  )
where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Bifunctor
import Data.Bits (countTrailingZeros, shiftL, shiftR)
import Data.IORef
import Data.Map qualified as Q
import Data.Monoid
import Data.Vector.Hashtables qualified as H
import Data.Vector.Mutable qualified as M
import Data.Vector.Unboxed.Mutable qualified as UM
import Data.Word (Word64)
import Language.Haskell.TH.Syntax qualified as TH
import Numeric.Log
import System.IO.Unsafe
import System.Random hiding (random, randoms, split, uniform)
import System.Random qualified as R
import Unsafe.Coerce

type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v

-- Reimplementation of the LazyPPL monads to avoid some dependencies

data Tree = Tree
  { draw :: Double,
    split :: (Tree, Tree)
  }

-- Let's draw doubles more uniformly than the standard generator
-- based on https://www.corsix.org/content/higher-quality-random-floats
random :: StdGen -> (Double, StdGen)
random g0 =
  let (r1 :: Word64, g1) = R.random g0
      (r2 :: Word64, g2) = R.random g1
      e = let r = countTrailingZeros r1 - 11 in if r >= 0 then countTrailingZeros r2 else r
      dbl = (1 + (r1 `shiftR` 11)) `shiftR` 1 - ((unsafeCoerce e - 1011) `shiftL` 52)
   in (unsafeCoerce dbl, g2)

randoms :: StdGen -> [Double]
randoms g =
  let (x, g') = random g
   in x : randoms g'

newTree :: StdGen -> Tree
newTree g0 = Tree x (newTree g1, newTree g2)
  where
    (x, R.split -> (g1, g2)) = random g0

newtype Prob a = Prob {runProb :: Tree -> a}

instance Monad Prob where
  Prob f >>= g = Prob $ \t ->
    let (t1, t2) = split t
        (Prob g') = g (f t1)
     in g' t2

instance Functor Prob where fmap = liftM

instance Applicative Prob where pure = Prob . const; (<*>) = ap

uniform :: Prob Double
uniform = Prob $ \(Tree r _) -> r

newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
  deriving (Functor, Applicative, Monad)

{-# INLINE score #-}
score :: Double -> Meas ()
score = scoreLog . Exp . log . max eps
  where
    eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval

{-# INLINE scoreLog #-}
scoreLog :: Log Double -> Meas ()
scoreLog = Meas . tell . Product

{-# INLINE sample #-}
sample :: Prob a -> Meas a
sample = Meas . lift

{-# INLINE samples #-}
samples :: forall a. Meas a -> Tree -> [(a, Log Double)]
samples (Meas m) = map (second getProduct) . runProb f
  where
    f = runWriterT m >>= \x -> (x :) <$> f

{-# NOINLINE memoize #-}
memoize :: (Ord a) => (a -> Prob b) -> Prob (a -> b)
memoize f = unsafePerformIO $ do
  ref <- newIORef mempty
  pure $ Prob $ \t -> \x -> unsafePerformIO $ do
    m <- readIORef ref
    case Q.lookup x m of
      Just z -> pure z
      _ -> do
        let Prob g = f x
            z = g (getNode (Q.size m) t)
            m' = Q.insert x z m
        writeIORef ref m'
        pure z
  where
    getNode 0 t = t
    getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1)