diff options
-rw-r--r-- | bin/cluster.hs | 14 | ||||
-rw-r--r-- | bin/draw.hs | 15 |
2 files changed, 13 insertions, 16 deletions
diff --git a/bin/cluster.hs b/bin/cluster.hs index 6c8ea52..d0328a6 100644 --- a/bin/cluster.hs +++ b/bin/cluster.hs @@ -1,6 +1,5 @@ -{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ViewPatterns #-} module Main where @@ -13,12 +12,12 @@ import GHC.Compact import Numeric.Log hiding (sum) import Options.Applicative import PPL hiding (binom) -import System.Random (mkStdGen, setStdGen, StdGen, random, split) import qualified Streaming as S +import Streaming.Prelude (Of, Stream, each, fold, yield) import qualified Streaming.Prelude as S -import Streaming.Prelude (each, fold, Of, Stream, yield) -import System.IO ( hSetBuffering, BufferMode(NoBuffering), stdout) +import System.IO (BufferMode (NoBuffering), hSetBuffering, stdout) import System.ProgressBar +import System.Random (StdGen, mkStdGen, random, setStdGen, split) cumsum = scanl1 (+) @@ -123,7 +122,6 @@ main = run =<< execParser opts [(r, "")] -> if r <= 1 && r > 0 then Right r else Left "mhfrac not a valid probability" _ -> Left "mhfrac not a valid probability" - takeWithProgress :: S.MonadIO m => Int -> Stream (Of a) m r -> Stream (Of a) m () takeWithProgress n str = do pb <- S.liftIO $ newProgressBar defStyle 10 (Progress 0 n ()) @@ -139,7 +137,7 @@ run opts = do g = mkStdGen $ seed opts parsed <- compact $ map (map dbl . tail . words) lines hSetBuffering stdout NoBuffering - ((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]],[]), -1/0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) (model $ getCompact parsed) + ((ps, cl), _) <- S.fold_ (\l r -> if mml l < mml r then l else r) (([[]], []), -1 / 0) id . takeWithProgress (nsamples opts) $ mh g (mhfrac opts) (model $ getCompact parsed) writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps writeFile (clusterPath opts) . unlines $ map show cl where @@ -147,5 +145,3 @@ run opts = do where tab = tabulate cl sum' f = sum . map f - - diff --git a/bin/draw.hs b/bin/draw.hs index 4e19992..89658f6 100644 --- a/bin/draw.hs +++ b/bin/draw.hs @@ -1,18 +1,19 @@ {-# LANGUAGE ViewPatterns #-} + module Main where -import Options.Applicative +import Data.List import Data.List.Split import qualified Data.Map as M -import Data.List +import Options.Applicative type Props = [[Double]] + type Nodes = [Int] -- Command line args data Options = Options - { - nthr :: Int, + { nthr :: Int, propsPath :: FilePath, clusterPath :: FilePath, dotPath :: FilePath @@ -25,10 +26,10 @@ data Tree = Tree [Double] Int [Tree] children (Tree _ _ cs) = cs addChild :: Tree -> Tree -> Tree -(Tree a b cs) `addChild` child = Tree a b (child:cs) +(Tree a b cs) `addChild` child = Tree a b (child : cs) buildTree :: Props -> Nodes -> Tree -buildTree (transpose -> ps) cl = merge n [Tree (ps !! i) (tab !! i) [] | i <- [0..n]] +buildTree (transpose -> ps) cl = merge n [Tree (ps !! i) (tab !! i) [] | i <- [0 .. n]] where tab = tabulate cl n = length tab - 1 @@ -38,7 +39,7 @@ buildTree (transpose -> ps) cl = merge n [Tree (ps !! i) (tab !! i) [] | i <- [0 merge 0 ts = head ts merge i ts = let j = parent i - in merge (i-1) $ take j ts <> [(ts !! j) `addChild` (ts !! i)] <> drop (j+1) ts + in merge (i - 1) $ take j ts <> [(ts !! j) `addChild` (ts !! i)] <> drop (j + 1) ts -- Prune a tree according to a node threshold prune th = fix prune' |