aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg/Types.hs
blob: d96fcb5148fe75926202a0a29a81cabde8668e8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
{-# LANGUAGE DeriveFunctor, FlexibleInstances, FlexibleContexts, UndecidableInstances #-}

module Math.LinProg.Types (
  LinExpr
  ,var
  ,vars
  ,varTerms
  ,getVar
  ,split
  ,LinProg
  ,LinProg'(..)
  ,obj
  ,(<:)
  ,(=:)
  ,(>:)
  ,eq
  ,leq
  ,geq
) where

import Data.Functor.Foldable
import Control.Monad.Free
import qualified Data.Map as M

data LinExpr' t v a =
  Lit t
  | Var v
  | Add a a
  | Mul a a
  | Negate a
  deriving (Show, Eq, Functor)

type LinExpr t v = Fix (LinExpr' t v)

var = Fix . Var

instance Num t => Num (LinExpr t v) where
  a * b = Fix (Mul a b)
  a + b = Fix (Add a b)
  negate a = Fix (Negate a)
  fromInteger a = Fix (Lit (fromInteger a))
  abs = undefined
  signum = undefined

consts :: Num t => LinExpr t v -> t
consts = cata consts' where
  consts' (Negate a) = negate a
  consts' (Lit a) = a
  consts' (Var _) = 0
  consts' (Add a b) = a + b
  consts' (Mul a b) = a * b

getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
  getVar' (Var x) | x == id = 1
                  | otherwise = 0
  getVar' (Lit a) = a
  getVar' (Add a b) = a + b
  getVar' (Mul a b) = a * b
  getVar' (Negate a) = negate a

vars :: LinExpr t v -> [v]
vars = cata vars' where
  vars' (Var x) = [x]
  vars' (Add a b) = a ++ b
  vars' (Mul a b) = a ++ b
  vars' (Negate a) = a
  vars' _ = []

varTerms eq = go eq' where
  go [t] = t
  go (t:ts) = Fix (Add t (go ts))
  go [] = Fix (Lit 0)

  eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws
  vs = vars eq
  ws = map (`getVar` eq) vs

split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
split eq = (varTerms eq, consts eq)

prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
  prettyPrint' (Lit a) = show a
  prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
  prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
  prettyPrint' (Var x) = show x

-- Monad for linear programs

data LinProg' t v a =
  Objective (LinExpr t v) a
  | EqConstraint (LinExpr t v) (LinExpr t v) a
  | LeqConstraint (LinExpr t v) (LinExpr t v) a
  deriving (Show, Eq, Functor)

type LinProg t v = Free (LinProg' t v)

obj a = liftF (Objective a ())
eq a b = liftF (EqConstraint a b ())
leq a b = liftF (LeqConstraint a b ())
geq b a = liftF (LeqConstraint a b ())

a =: b = eq a b
a <: b = leq a b
a >: b = geq a b