aboutsummaryrefslogtreecommitdiff
path: root/bin/cluster.hs
blob: 6bcac297e7b53f4aabe55fa91bb6f8f4f80ab817 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
{-# LANGUAGE ViewPatterns #-}

module Main where

import Data.Fixed (mod')
import Data.Foldable (toList)
import Data.List
import qualified Data.Map as M
import Numeric.Log hiding (sum)
import Options.Applicative
import PPL
import System.Random (mkStdGen, setStdGen)

cumsum = scanl1 (+)

first f xs = snd . head . filter (f . fst) $ zip xs [0 ..]

stirling n = n * log n - n

pois lambda (fromIntegral -> k) = lambda' ** k' * Exp (negate lambda) / Exp (stirling k)
  where
    lambda' = Exp $ log lambda
    k' = Exp $ log k

-- (infinite) binary trees
data Tree a = Empty | Tree a (Tree a) (Tree a)
  deriving (Show)

instance Foldable Tree where
  foldMap f t = bftrav [t]
    where
      bftrav [] = mempty
      bftrav (Empty : ts) = bftrav ts
      bftrav ((Tree a l r) : ts) = f a <> bftrav (ts <> [l, r])

-- Infinite trees from infinite lists
-- NB: it's harder to partition a list so that it folds back to
-- equivalence. It doesn't really matter here since we're only
-- unfolding random uniforms anyway.
treeFromList (x : xs) = Tree x (treeFromList lpart) (treeFromList rpart)
  where
    (lpart, rpart) = unzip $ partition xs
    partition (a : b : xs) = (a, b) : partition xs

-- Constrain trees so leaves sum to node value
normTree :: Tree Double -> Tree Double
normTree (Tree x l r) = go $ Tree x l r
  where
    go (Tree x (Tree u l r) (Tree v l' r')) =
      let s = x / (u + v)
       in Tree x (go $ Tree (s * u) l r) (go $ Tree (s * v) l' r')

drawTreeProbs = toList . normTree . treeFromList <$> iid uniform

model :: [[Int]] -> Meas ([[Double]], [Int])
model xs = do
  (ps, params, clusters) <- sample $ do
    -- Sample hyperparameters
    a <- bounded 1 10

    -- CRP style
    dir <- cumsum <$> dirichletProcess a
    rs <- iid uniform
    ps <- iid drawTreeProbs
    let clusters = map (\r -> first (>= r) dir) rs
        params = map (transpose ps !!) clusters
    pure (ps, params, clusters)
  mapM_ scoreLog $ zipWith likelihood params xs
  let cls = take (length xs) clusters
      k = maximum cls + 1
      n = length $ head xs
  pure (map (take k) $ take n ps, cls)
  where
    likelihood ps cnts = product $ zipWith go ps (pairs cnts)
      where
        go p (c, d) = max (pois (fromIntegral d * p) c) (pois (fromIntegral d * p / 2) c)
        pairs (a : b : rs) = (a, b) : pairs rs
        pairs [] = []
        pairs _ = error "unexpected number of columns, expecting count/depth pairs"

-- Command line args
data Options = Options
  { seed :: Int,
    nsamples :: Int,
    mhfrac :: Double,
    input :: FilePath,
    propsPath :: FilePath,
    clusterPath :: FilePath
  }

main = run =<< execParser opts
  where
    opts =
      info
        (parser <**> helper)
        (fullDesc <> progDesc "Infer a phylogeny from SNV calls in multiple samples" <> header "phylogey - Bayesian phylogeny inference")
    parser =
      Options <$> option auto (long "seed" <> short 's' <> help "random seed" <> showDefault <> value 42 <> metavar "INT")
        <*> option auto (long "nsamples" <> short 'n' <> help "number of samples from posterior" <> value 100000 <> metavar "INT" <> showDefault)
        <*> option probability (long "mhfrac" <> short 'm' <> help "Metropolis-Hastings mutation probability" <> value 0.3 <> metavar "(0,1]" <> showDefault)
        <*> argument str (metavar "INPUT")
        <*> argument str (metavar "PROPS")
        <*> argument str (metavar "TREE")

    probability = eitherReader $ \arg -> case reads arg of
      [(r, "")] -> if r <= 1 && r > 0 then Right r else Left "mhfrac not a valid probability"
      _ -> Left "mhfrac not a valid probability"

run opts = do
  setStdGen . mkStdGen $ seed opts
  (hdr : lines) <- lines <$> readFile (input opts)
  let parsed = map (map dbl . tail . words) lines
      dbl = round . read :: String -> Int
  ((ps, cl), _) <- foldl1' (\a c -> if mml a < mml c then a else c) . take (nsamples opts) <$> mh (mhfrac opts) (model parsed)
  writeFile (propsPath opts) . unlines $ map (intercalate "," . map show) ps
  writeFile (clusterPath opts) . unlines $ map show cl
  where
    mml ((ps, cl), lik) = sum' (sum' (log . (+ 1))) ps + sum' (log . (+ 1)) tab - sum' stirling tab - ln lik
      where
        tab = M.elems $ M.fromListWith (+) [(c, 1) | c <- cl]
        sum' f = sum . map f