{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE LambdaCase #-} module Main where import Control.Monad import Data.Fixed (mod') import Data.Foldable (toList) import Data.List import qualified Data.Map as M import GHC.Compact import Numeric.Log hiding (sum) import Options.Applicative import PPL hiding (binom) import System.Random (mkStdGen, setStdGen, StdGen, random, split) import qualified Streaming as S import qualified Streaming.Prelude as S import Streaming.Prelude (each, fold, Of, Stream, yield) import System.IO ( hSetBuffering, BufferMode(NoBuffering), stdout) import System.ProgressBar cumsum = scanl1 (+) first f xs = snd . head . filter (f . fst) $ zip xs [0 ..] stirling 0 = 1 stirling 1 = 1 stirling n = Exp $ n * (log n - 1) + log (sqrt (2 * pi)) + log n / 2 logFromInt = logFrom . fromIntegral logFrom = Exp . log binom :: Int -> Int -> Double -> Log Double binom n@(logFromInt -> n') k@(logFromInt -> k') p = n `choose` k * logFrom p ** k' * logFrom (1 - p) ** (n' - k') where choose (fromIntegral -> n) (fromIntegral -> k) = stirling n / stirling k / stirling (n - k) -- (infinite) binary trees data Tree a = Tree a (Tree a) (Tree a) deriving (Show) instance Foldable Tree where foldMap f t = bftrav [t] where bftrav [] = mempty bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r]) {-# INLINE pairs #-} pairs (a : b : rs) = (a, b) : pairs rs pairs [] = [] pairs _ = error "unexpected number of columns, expecting pairs" -- Infinite trees from infinite lists -- NB: it's harder to partition a list so that it folds back to -- equivalence. It doesn't really matter here since we're only -- unfolding random uniforms anyway. treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart) where (lpart, rpart) = unzip $ pairs xs -- Constrain trees so leaves sum to node value normTree :: Tree Double -> Tree Double normTree (Tree x (Tree u l r) (Tree v l' r')) = let s = x / (u + v) in Tree x (normTree $ Tree (s * u) l r) (normTree $ Tree (s * v) l' r') drawTreeProbs = toList . normTree . treeFromList <$> iid uniform model :: [[Int]] -> Meas ([[Double]], [Int]) model xs = do (ps, params, clusters) <- sample $ do -- Sample hyperparameters a <- bounded 1 10 -- CRP style dir <- cumsum <$> dirichletProcess a rs <- iid uniform ps <- iid drawTreeProbs let clusters = map (\r -> first (>= r) dir) rs params = map (transpose ps !!) clusters pure (ps, params, clusters) mapM_ scoreLog $ zipWith likelihood params xs let cls = take (length xs) clusters k = maximum cls + 1 n = length (head xs) `div` 2 pure (map (take k) $ take n ps, cls) where likelihood ps cnts = product $ zipWith go ps (pairs cnts) where go p (c, d) = binom d c p -- Tabulate list tabulate xs = M.elems $ M.fromListWith (+) [(c, 1) | c <- xs] -- Command line args data Options = Options { seed :: Int, nsamples :: Int, mhfrac :: Double, input :: FilePath, propsPath :: FilePath, clusterPath :: FilePath } main = run =<< execParser opts where opts = info (parser <**> helper) (fullDesc <> progDesc "Infer a phylogeny from SNV calls in multiple samples" <> header "phylogey - Bayesian phylogeny inference") parser = Options <$> option auto (long "seed" <> short 's' <> help "random seed" <> showDefault <> value 42 <> metavar "INT") <*> option auto (long "nsamples" <> short 'n' <> help "number of samples from posterior" <> value 100000 <> metavar "INT" <> showDefault) <*> option probability (long "mhfrac" <> short 'm' <> help "Metropolis-Hastings mutation probability" <> value 0.3 <> metavar "(0,1]" <> showDefault) <*> argument str (metavar "INPUT") <*> argument str (metavar "PROPS") <*> argument str (metavar "TREE") probability = eitherReader $ \arg -> case reads arg of [(r, "")] -> if r <= 1 && r > 0 then Right r else Left "mhfrac not a valid probability" _ -> Left "mhfrac not a valid probability" takeWithProgress :: S.MonadIO m => Int -> Stream (Of a) m r -> Stream (Of a) m () takeWithProgress n str = do pb <- S.liftIO $ newProgressBar defStyle 10 (Progress 0 n ()) S.mapM (update pb) $ S.take n str where update pb x = do S.liftIO $ incProgress pb 1 pure x run opts = do (hdr : lines) <- lines <$> readFile (input opts) let dbl = round . read :: String -> Int g = mkStdGen $ seed opts parsed <- compact $ map (map dbl . tail . words) lines hSetBuffering stdout NoBuffering ((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]],[]), -1/0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) 0.5 (model $ getCompact parsed) writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps writeFile (clusterPath opts) . unlines $ map show cl where mml ((ps, cl), lik) = sum' (sum' (log . (+ 1))) ps + sum' (log . (+ 1)) tab - sum' (ln . stirling) tab - ln lik where tab = tabulate cl sum' f = sum . map f