diff options
Diffstat (limited to 'bin/cluster.hs')
-rw-r--r-- | bin/cluster.hs | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs index 6b0d917..4c182a1 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -1,4 +1,6 @@ {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE LambdaCase #-} module Main where @@ -11,7 +13,12 @@ import GHC.Compact import Numeric.Log hiding (sum) import Options.Applicative import PPL hiding (binom) -import System.Random (mkStdGen, setStdGen) +import System.Random (mkStdGen, setStdGen, StdGen, random, split) +import qualified Streaming as S +import qualified Streaming.Prelude as S +import Streaming.Prelude (each, fold, Of, Stream, yield) +import System.IO ( hSetBuffering, BufferMode(NoBuffering), stdout) +import System.ProgressBar cumsum = scanl1 (+) @@ -116,12 +123,23 @@ main = run =<< execParser opts [(r, "")] -> if r <= 1 && r > 0 then Right r else Left "mhfrac not a valid probability" _ -> Left "mhfrac not a valid probability" + +takeWithProgress :: S.MonadIO m => Int -> Stream (Of a) m r -> Stream (Of a) m () +takeWithProgress n str = do + pb <- S.liftIO $ newProgressBar defStyle 10 (Progress 0 n ()) + S.mapM (update pb) $ S.take n str + where + update pb x = do + S.liftIO $ incProgress pb 1 + pure x + run opts = do - setStdGen . mkStdGen $ seed opts (hdr : lines) <- lines <$> readFile (input opts) let dbl = round . read :: String -> Int + g = mkStdGen $ seed opts parsed <- compact $ map (map dbl . tail . words) lines - ((ps, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) 0.5 (model $ getCompact parsed) + hSetBuffering stdout NoBuffering + ((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]],[]), -1/0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) 0.5 (model $ getCompact parsed) writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps writeFile (clusterPath opts) . unlines $ map show cl where @@ -129,3 +147,5 @@ run opts = do where tab = tabulate cl sum' f = sum . map f + + |