{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} module PPL.Internal ( uniform, Prob (..), Meas, score, scoreLog, sample, samples, newTree, HashMap, Tree(..), ) where import Control.Monad import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Writer import Data.Bifunctor import Data.Monoid import qualified Language.Haskell.TH.Syntax as TH import Numeric.Log import System.Random hiding (split, uniform) import qualified System.Random as R import Data.IORef import System.IO.Unsafe import qualified Data.Vector.Unboxed.Mutable as UM import qualified Data.Vector.Hashtables as H import Data.Word import Data.Bits type HashMap k v = H.Dictionary (H.PrimState IO) UM.MVector k UM.MVector v -- Simple hashing scheme into 64-bit space based on mzHash64 data Hash = Hash Int Int Word8 Word64 deriving (Eq, Ord, Show) unhash :: Hash -> Word64 unhash h@(Hash _ j w8 _) = let Hash _ _ _ s = addhash (fromIntegral j) $ addhash (fromIntegral w8) h in s addhash :: Word8 -> Hash -> Hash addhash x (Hash i j bits state) = Hash (i+1) j bits $ (0xB2DEEF07CB4ACD43 * (fromIntegral i + fromIntegral x)) `xor` (state `shiftL` 2) `xor` (state `shiftR` 2) initHash = Hash 0 0 0 0xE297DA430DB2DF1A pushbit :: Bool -> Hash -> Hash pushbit b h@(Hash i j bits state) | j == 8 = let Hash i _ _ state = addhash bits h in Hash i 1 (fromBool b) state | otherwise = Hash i (j+1) (bits*2 + fromBool b) state where fromBool True = 1 fromBool False = 0 -- Reimplementation of the LazyPPL monads to avoid some dependencies data Tree = Tree { draw :: !Double, split :: (Tree, Tree) } {-# INLINE newTree #-} newTree :: IORef (HashMap Word64 Double, StdGen) -> Tree newTree s = go initHash where go :: Hash -> Tree go id = Tree (unsafePerformIO $ do (m, g) <- readIORef s H.lookup m (unhash id) >>= \case Nothing -> do let (x, g') = R.random g H.insert m (unhash id) x writeIORef s (m, g') pure x Just x -> pure x) (go (pushbit False id), go (pushbit True id)) newtype Prob a = Prob {runProb :: Tree -> a} instance Monad Prob where Prob f >>= g = Prob $ \t -> let (t1, t2) = split t (Prob g') = g (f t1) in g' t2 instance Functor Prob where fmap = liftM instance Applicative Prob where pure = Prob . const; (<*>) = ap uniform = Prob $ \(Tree r _) -> r newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a) deriving (Functor, Applicative, Monad) {-# INLINE score #-} score :: Double -> Meas () score = scoreLog . Exp . log . max eps where eps = $(TH.lift (2 * until ((== 1) . (1 +)) (/ 2) (1 :: Double))) -- machine epsilon, force compile time eval {-# INLINE scoreLog #-} scoreLog :: Log Double -> Meas () scoreLog = Meas . tell . Product {-# INLINE sample #-} sample :: Prob a -> Meas a sample = Meas . lift {-# INLINE samples #-} samples :: forall a. Meas a -> Tree -> [(a, Log Double)] samples (Meas m) = map (second getProduct) . runProb f where f = runWriterT m >>= \x -> (x :) <$> f