aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg/Types.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Math/LinProg/Types.hs')
-rw-r--r--Math/LinProg/Types.hs88
1 files changed, 68 insertions, 20 deletions
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs
index 6fbad93..2a81918 100644
--- a/Math/LinProg/Types.hs
+++ b/Math/LinProg/Types.hs
@@ -30,15 +30,20 @@ module Math.LinProg.Types (
import Data.Functor.Foldable
import Control.Monad.Free
+import qualified Data.Map.Strict as M
+import Test.QuickCheck
+import Control.Applicative
+import Data.List
-- | Base AST for expressions. Expressions have factors or type t and
-- variables referenced by ids of type v.
data LinExpr' t v a =
- Lit t
- | Var v
- | Add a a
- | Mul a a
- | Negate a
+ Lit !t
+ | Var !v
+ | Wvar !t !v
+ | Add !a !a
+ | Mul !a !a
+ | Negate !a
deriving (Show, Eq, Functor)
type LinExpr t v = Fix (LinExpr' t v)
@@ -67,13 +72,15 @@ consts :: Num t => LinExpr t v -> t
consts = cata consts' where
consts' (Negate a) = negate a
consts' (Lit a) = a
- consts' (Var _) = 0
consts' (Add a b) = a + b
consts' (Mul a b) = a * b
+ consts' _ = 0
-- | Gets the multiplier for a particular variable.
getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
+ getVar' (Wvar w x) | x == id = w
+ | otherwise = 0
getVar' (Var x) | x == id = 1
| otherwise = 0
getVar' (Lit a) = a
@@ -82,27 +89,44 @@ getVar id x = cata getVar' x - consts x where
getVar' (Negate a) = negate a
-- | Gets all variables used in an equation.
-vars :: LinExpr t v -> [v]
-vars = cata vars' where
+vars :: Eq v => LinExpr t v -> [v]
+vars = nub . cata vars' where
vars' (Var x) = [x]
vars' (Add a b) = a ++ b
vars' (Mul a b) = a ++ b
vars' (Negate a) = a
vars' _ = []
--- | Reduces an expression to the variable terms
-varTerms eq = go eq' where
- go [t] = t
- go (t:ts) = Fix (Add t (go ts))
- go [] = Fix (Lit 0)
+-- | Expands terms to Wvars but does not collect like terms
+rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v
+rewrite = cata rewrite' where
+ rewrite' (Var a) = Fix (Wvar 1 a)
+ rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a
+ rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a
+ rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c)
+ rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c)
+ rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))
+ rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b))
+ rewrite' (Lit a) = Fix (Lit a)
+ rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
+ rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
+ rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b)
+ rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a))
+ rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b)))
+ rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b)
+ rewrite' a = Fix a
- eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws
- vs = vars eq
- ws = map (`getVar` eq) vs
+-- | Reduces an expression to the variable terms
+varTerms :: (Num t, Eq t, Ord v) => LinExpr t v -> [(v, t)]
+varTerms = M.toList . cata go . rewrite where
+ go (Wvar w a) = M.fromList [(a, w)]
+ go (Add a b) = M.unionWith (+) a b
+ go (Mul _ _) = error "Only linear terms supported"
+ go _ = M.empty
-- | Splits an expression into the variables and the constant term
split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
-split eq = (varTerms eq, consts eq)
+split eq = (eq - (Fix (Lit (consts eq))), consts eq)
prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
@@ -110,14 +134,15 @@ prettyPrint = cata prettyPrint' where
prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
prettyPrint' (Var x) = show x
+ prettyPrint' (Wvar w x) = show w ++ show x
-- | Free monad for linear programs. The monad allows definition of the
-- objective function, equality constraints, and inequality constraints (≤ only
-- in the data type).
data LinProg' t v a =
- Objective (LinExpr t v) a
- | EqConstraint (LinExpr t v) (LinExpr t v) a
- | LeqConstraint (LinExpr t v) (LinExpr t v) a
+ Objective !(LinExpr t v) !a
+ | EqConstraint !(LinExpr t v) !(LinExpr t v) !a
+ | LeqConstraint !(LinExpr t v) !(LinExpr t v) !a
deriving (Show, Eq, Functor)
type LinProg t v = Free (LinProg' t v)
@@ -137,3 +162,26 @@ b >: a = liftF (LeqConstraint a b ())
infix 4 =:
infix 4 <:
infix 4 >:
+
+-- Quickcheck properties
+
+instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where
+ arbitrary = oneof [
+ (Fix . Var) <$> arbitrary
+ ,(Fix . Lit) <$> arbitrary
+ ,((Fix .) . Add) <$> arbitrary <*> arbitrary
+ ,((Fix .) . Mul) <$> arbitrary <*> arbitrary
+ ,(Fix . Negate) <$> arbitrary]
+
+prop_rewrite :: LinExpr Int Int -> Property
+prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq)
+ where
+ vs = vars eq
+ ws = map (flip getVar eq) vs
+ isLinear = (<= 1) . cata isLinear' where
+ isLinear' (Mul a b) = a + b
+ isLinear' (Add a b) = max a b
+ isLinear' (Var _) = 1
+ isLinear' (Wvar _ _) = 1
+ isLinear' (Lit _) = 0
+ isLinear' (Negate a) = a