aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg/Types.hs
blob: 4819dd3ea83a00d47bb23dfe0f9c0b4212e10fbe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
{-# LANGUAGE DeriveFunctor, FlexibleInstances, FlexibleContexts, UndecidableInstances #-}
{-|
Module      : Math.LinProg.Types
Description : Base types for equations and optimisation monad
Copyright   : (c) Justin Bedő, 2014
License     : BSD
Maintainer  : cu@cua0.org
Stability   : experimental

This module defines the base types for representing equations and linear
programs.  The linear program is created as a free monad, and equations as an
AST.  Note that expressions are assumed to be linear expressions and hence
there is no explicit checking for higher order terms.
-}
module Math.LinProg.Types (
  LinExpr
  ,LinExpr'(..)
  ,var
  ,vars
  ,varTerms
  ,getVar
  ,split
  ,LinProg
  ,LinProg'(..)
  ,obj
  ,(<:)
  ,(=:)
  ,(>:)
  ,bin
  ,int
) where

import Data.Functor.Foldable
import Control.Monad.Free
import qualified Data.HashMap.Strict as M
import Test.QuickCheck
import Control.Applicative
import Data.List
import Data.Hashable

-- | Base AST for expressions.  Expressions have factors or type t and
-- variables referenced by ids of type v.
data LinExpr' t v a =
  Lit !t
  | Var !v
  | Wvar !t !v
  | Add !a !a
  | Mul !a !a
  | Negate !a
  deriving (Show, Eq, Functor)

type LinExpr t v = Fix (LinExpr' t v)

-- | Creates a new variable for reference in equations
var = Fix . Var

-- | For convient notation, expressions are declared as instances of num.
-- However, linear expressions cannot implement absolute value or sign
-- functions, hence these two remain undefined.
instance Num t => Num (LinExpr t v) where
  a * b = Fix (Mul a b)
  a + b = Fix (Add a b)
  negate a = Fix (Negate a)
  fromInteger a = Fix (Lit (fromInteger a))
  abs = undefined
  signum = undefined

-- | Linear expressions can also be instances of fractional.
instance Fractional t => Fractional (LinExpr t v) where
  a / b = Fix (Mul a (1/b))
  fromRational a = Fix (Lit (fromRational a))

-- | Reduce a linear expression down to the constant factor.
consts :: Num t => LinExpr t v -> t
consts = cata consts' where
  consts' (Negate a) = negate a
  consts' (Lit a) = a
  consts' (Add a b) = a + b
  consts' (Mul a b) = a * b
  consts' _ = 0

-- | Gets the multiplier for a particular variable.
getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
  getVar' (Wvar w x) | x == id = w
                     | otherwise = 0
  getVar' (Var x) | x == id = 1
                  | otherwise = 0
  getVar' (Lit a) = a
  getVar' (Add a b) = a + b
  getVar' (Mul a b) = a * b
  getVar' (Negate a) = negate a

-- | Gets all variables used in an equation.
vars :: Eq v => LinExpr t v -> [v]
vars = nub . cata vars' where
  vars' (Var x) = [x]
  vars' (Add a b) = a ++ b
  vars' (Mul a b) = a ++ b
  vars' (Negate a) = a
  vars' _ = []

-- | Expands terms to Wvars but does not collect like terms
rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v
rewrite = cata rewrite' where
  rewrite' (Var a) = Fix (Wvar 1 a)
  rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a
  rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a
  rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c)
  rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c)
  rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))
  rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b))
  rewrite' (Lit a) = Fix (Lit a)
  rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
  rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
  rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b)
  rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a))
  rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b)))
  rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b)
  rewrite' a = Fix a

-- | Reduces an expression to the variable terms
varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)]
varTerms = M.toList . cata go . rewrite where
  go (Wvar w a) = M.fromList [(a, w)]
  go (Add a b) = M.unionWith (+) a b
  go (Mul _ _) = error "Only linear terms supported"
  go _ = M.empty

-- | Splits an expression into the variables and the constant term
split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
split eq = (eq - (Fix (Lit (consts eq))), consts eq)

prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
  prettyPrint' (Lit a) = show a
  prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
  prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
  prettyPrint' (Var x) = show x
  prettyPrint' (Wvar w x) = show w ++ show x

-- | Free monad for linear programs.  The monad allows definition of the
-- objective function, equality constraints, and inequality constraints (≤ only
-- in the data type).
data LinProg' t v a =
  Objective !(LinExpr t v) !a
  | Integer !v !a
  | Binary !v !a
  | EqConstraint !(LinExpr t v) !(LinExpr t v) !a
  | LeqConstraint !(LinExpr t v) !(LinExpr t v) !a
  deriving (Show, Eq, Functor)

type LinProg t v = Free (LinProg' t v)

-- | Define a term in the objective function
obj a = liftF (Objective a ())

-- | Define an equality constraint
a =: b = liftF (EqConstraint a b ())

-- | Define an inequality (less than equal) contraint
a <: b = liftF (LeqConstraint a b ())

-- | Define an inequality (greater than equal) contraint
b >: a = liftF (LeqConstraint a b ())

-- | Declare a variable to be binary
bin (Fix (Var v)) = liftF (Binary v ())

-- | Declare a variable to be integral
int (Fix (Var v)) = liftF (Integer v ())

infix 4 =:
infix 4 <:
infix 4 >:

-- Quickcheck properties

instance (Arbitrary t, Arbitrary v) => Arbitrary (Lin