From 480b6e000afeec19aa65857232f1261d2b21de76 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Wed, 26 Feb 2025 07:37:19 +1100 Subject: output MML score --- bin/cluster.hs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'bin/cluster.hs') diff --git a/bin/cluster.hs b/bin/cluster.hs index d0328a6..46a1fa9 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -100,8 +100,7 @@ data Options = Options nsamples :: Int, mhfrac :: Double, input :: FilePath, - propsPath :: FilePath, - clusterPath :: FilePath + outputPath :: FilePath } main = run =<< execParser opts @@ -115,8 +114,7 @@ main = run =<< execParser opts <*> option auto (long "nsamples" <> short 'n' <> help "number of samples from posterior" <> value 100000 <> metavar "INT" <> showDefault) <*> option probability (long "mhfrac" <> short 'm' <> help "Metropolis-Hastings mutation probability" <> value 0.01 <> metavar "(0,1]" <> showDefault) <*> argument str (metavar "INPUT") - <*> argument str (metavar "PROPS") - <*> argument str (metavar "TREE") + <*> argument str (metavar "OUTPUTDIR") probability = eitherReader $ \arg -> case reads arg of [(r, "")] -> if r <= 1 && r > 0 then Right r else Left "mhfrac not a valid probability" @@ -137,9 +135,10 @@ run opts = do g = mkStdGen $ seed opts parsed <- compact $ map (map dbl . tail . words) lines hSetBuffering stdout NoBuffering - ((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]], []), -1 / 0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) (model $ getCompact parsed) - writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps - writeFile (clusterPath opts) . unlines $ map show cl + m@((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]], []), -1 / 0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) (model $ getCompact parsed) + writeFile (outputPath opts <> "/props") . unlines $ map (intercalate "," . map show) ps + writeFile (outputPath opts <> "/tree") . unlines $ map show cl + writeFile (outputPath opts <> "/mml") $ show (mml m) where mml ((ps, cl), lik) = sum' (sum' (log . (+ 1))) ps + sum' (log . (+ 1)) tab - sum' (ln . stirling) tab - ln lik where -- cgit v1.2.3