aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg/Types.hs
blob: 6fbad93a08434b55a4765d09c1cbeae21b841fc5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
{-# LANGUAGE DeriveFunctor, FlexibleInstances, FlexibleContexts, UndecidableInstances #-}
{-|
Module      : Math.LinProg.Types
Description : Base types for equations and optimisation monad
Copyright   : (c) Justin Bedő, 2014
License     : BSD
Maintainer  : cu@cua0.org
Stability   : experimental

This module defines the base types for representing equations and linear
programs.  The linear program is created as a free monad, and equations as an
AST.  Note that expressions are assumed to be linear expressions and hence
there is no explicit checking for higher order terms.
-}
module Math.LinProg.Types (
  LinExpr
  ,LinExpr'(..)
  ,var
  ,vars
  ,varTerms
  ,getVar
  ,split
  ,LinProg
  ,LinProg'(..)
  ,obj
  ,(<:)
  ,(=:)
  ,(>:)
) where

import Data.Functor.Foldable
import Control.Monad.Free

-- | Base AST for expressions.  Expressions have factors or type t and
-- variables referenced by ids of type v.
data LinExpr' t v a =
  Lit t
  | Var v
  | Add a a
  | Mul a a
  | Negate a
  deriving (Show, Eq, Functor)

type LinExpr t v = Fix (LinExpr' t v)

-- | Creates a new variable for reference in equations
var = Fix . Var

-- | For convient notation, expressions are declared as instances of num.
-- However, linear expressions cannot implement absolute value or sign
-- functions, hence these two remain undefined.
instance Num t => Num (LinExpr t v) where
  a * b = Fix (Mul a b)
  a + b = Fix (Add a b)
  negate a = Fix (Negate a)
  fromInteger a = Fix (Lit (fromInteger a))
  abs = undefined
  signum = undefined

-- | Linear expressions can also be instances of fractional.
instance Fractional t => Fractional (LinExpr t v) where
  a / b = Fix (Mul a (1/b))
  fromRational a = Fix (Lit (fromRational a))

-- | Reduce a linear expression down to the constant factor.
consts :: Num t => LinExpr t v -> t
consts = cata consts' where
  consts' (Negate a) = negate a
  consts' (Lit a) = a
  consts' (Var _) = 0
  consts' (Add a b) = a + b
  consts' (Mul a b) = a * b

-- | Gets the multiplier for a particular variable.
getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
  getVar' (Var x) | x == id = 1
                  | otherwise = 0
  getVar' (Lit a) = a
  getVar' (Add a b) = a + b
  getVar' (Mul a b) = a * b
  getVar' (Negate a) = negate a

-- | Gets all variables used in an equation.
vars :: LinExpr t v -> [v]
vars = cata vars' where
  vars' (Var x) = [x]
  vars' (Add a b) = a ++ b
  vars' (Mul a b) = a ++ b
  vars' (Negate a) = a
  vars' _ = []

-- | Reduces an expression to the variable terms
varTerms eq = go eq' where
  go [t] = t
  go (t:ts) = Fix (Add t (go ts))
  go [] = Fix (Lit 0)

  eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws
  vs = vars eq
  ws = map (`getVar` eq) vs

-- | Splits an expression into the variables and the constant term
split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
split eq = (varTerms eq, consts eq)

prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
  prettyPrint' (Lit a) = show a
  prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
  prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
  prettyPrint' (Var x) = show x

-- | Free monad for linear programs.  The monad allows definition of the
-- objective function, equality constraints, and inequality constraints (≤ only
-- in the data type).
data LinProg' t v a =
  Objective (LinExpr t v) a
  | EqConstraint (LinExpr t v) (LinExpr t v) a
  | LeqConstraint (LinExpr t v) (LinExpr t v) a
  deriving (Show, Eq, Functor)

type LinProg t v = Free (LinProg' t v)

-- | Define a term in the objective function
obj a = liftF (Objective a ())

-- | Define an equality constraint
a =: b = liftF (EqConstraint a b ())

-- | Define an inequality (less than equal) contraint
a <: b = liftF (LeqConstraint a b ())
--
-- | Define an inequality (greater than equal) contraint
b >: a = liftF (LeqConstraint a b ())

infix 4 =:
infix 4 <:
infix 4 >: