{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} module PPL.Internal ( uniform, Prob (..), Meas, score, scoreLog, sample, memoize, samples, newTree, HashMap, Tree (..), ) where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor import Data.IORef import Data.Map qualified as Q import Data.Monoid import Data.Vector.Hashtables qualified as H import Data.Vector.Mutable qualified as M import Data.Vector.Unboxed.Mutable qualified as UM import Language.Haskell.TH.Syntax qualified as TH import Numeric.Log import System.IO.Unsafe import System.Random hiding (split, uniform) import System.Random qualified as R type HashMap k v = H.Dictionary (H.PrimState IO) M.MVector k UM.MVector v -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: Double, split :: (Tree, Tree) } {-# INLINE newTree #-} newTree :: IORef (HashMap Integer Double, StdGen) -> Tree newTree s = go 1 where go :: Integer -> Tree go id = Tree ( unsafePerformIO $ do (m, g) <- readIORef s H.lookup m id >>= \case Nothing -> do let (x, g') = R.random g H.insert m id x writeIORef s (m, g') pure x Just x -> pure x ) (go (2 * id), go (2 * id + 1)) newtype Prob a = Prob {runProb :: Tree -> a} instance Monad Prob where Prob f >>= g = Prob $ \t -> let (t1, t2) = split t (Prob g') = g (f t1) in g' t2 instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap uniform :: Prob Double uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) {-# INLINE score #-} score :: Double -> Meas () score = scoreLog . Exp . log . max eps where eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval {-# INLINE scoreLog #-} scoreLog :: Log Double -> Meas () scoreLog = Meas . tell . Product {-# INLINE sample #-} sample :: Prob a -> Meas a sample = Meas . lift {-# INLINE samples #-} samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f {-# NOINLINE memoize #-} memoize :: Ord a => (a -> Prob b) -> Prob (a -> b) memoize f = unsafePerformIO $ do ref <- newIORef mempty pure $ Prob $ \t -> \x -> unsafePerformIO $ do m <- readIORef ref case Q.lookup x m of Just z -> pure z _ -> do let Prob g = f x z = g (getNode (Q.size m) t) m' = Q.insert x z m writeIORef ref m' pure z where getNode 0 t = t getNode i (split -> (t0, t1)) = getNode (i `div` 2) (if i `mod` 2 == 1 then t0 else t1)