diff options
Diffstat (limited to 'Math/LinProg/Types.hs')
-rw-r--r-- | Math/LinProg/Types.hs | 88 |
1 files changed, 68 insertions, 20 deletions
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index 6fbad93..2a81918 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -30,15 +30,20 @@ module Math.LinProg.Types ( import Data.Functor.Foldable import Control.Monad.Free +import qualified Data.Map.Strict as M +import Test.QuickCheck +import Control.Applicative +import Data.List -- | Base AST for expressions. Expressions have factors or type t and -- variables referenced by ids of type v. data LinExpr' t v a = - Lit t - | Var v - | Add a a - | Mul a a - | Negate a + Lit !t + | Var !v + | Wvar !t !v + | Add !a !a + | Mul !a !a + | Negate !a deriving (Show, Eq, Functor) type LinExpr t v = Fix (LinExpr' t v) @@ -67,13 +72,15 @@ consts :: Num t => LinExpr t v -> t consts = cata consts' where consts' (Negate a) = negate a consts' (Lit a) = a - consts' (Var _) = 0 consts' (Add a b) = a + b consts' (Mul a b) = a * b + consts' _ = 0 -- | Gets the multiplier for a particular variable. getVar :: (Num t, Eq v) => v -> LinExpr t v -> t getVar id x = cata getVar' x - consts x where + getVar' (Wvar w x) | x == id = w + | otherwise = 0 getVar' (Var x) | x == id = 1 | otherwise = 0 getVar' (Lit a) = a @@ -82,27 +89,44 @@ getVar id x = cata getVar' x - consts x where getVar' (Negate a) = negate a -- | Gets all variables used in an equation. -vars :: LinExpr t v -> [v] -vars = cata vars' where +vars :: Eq v => LinExpr t v -> [v] +vars = nub . cata vars' where vars' (Var x) = [x] vars' (Add a b) = a ++ b vars' (Mul a b) = a ++ b vars' (Negate a) = a vars' _ = [] --- | Reduces an expression to the variable terms -varTerms eq = go eq' where - go [t] = t - go (t:ts) = Fix (Add t (go ts)) - go [] = Fix (Lit 0) +-- | Expands terms to Wvars but does not collect like terms +rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v +rewrite = cata rewrite' where + rewrite' (Var a) = Fix (Wvar 1 a) + rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a + rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a + rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c) + rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c) + rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) + rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b)) + rewrite' (Lit a) = Fix (Lit a) + rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b) + rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a)) + rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b))) + rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b) + rewrite' a = Fix a - eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws - vs = vars eq - ws = map (`getVar` eq) vs +-- | Reduces an expression to the variable terms +varTerms :: (Num t, Eq t, Ord v) => LinExpr t v -> [(v, t)] +varTerms = M.toList . cata go . rewrite where + go (Wvar w a) = M.fromList [(a, w)] + go (Add a b) = M.unionWith (+) a b + go (Mul _ _) = error "Only linear terms supported" + go _ = M.empty -- | Splits an expression into the variables and the constant term split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t) -split eq = (varTerms eq, consts eq) +split eq = (eq - (Fix (Lit (consts eq))), consts eq) prettyPrint :: (Show t, Show v) => LinExpr t v -> String prettyPrint = cata prettyPrint' where @@ -110,14 +134,15 @@ prettyPrint = cata prettyPrint' where prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"] prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"] prettyPrint' (Var x) = show x + prettyPrint' (Wvar w x) = show w ++ show x -- | Free monad for linear programs. The monad allows definition of the -- objective function, equality constraints, and inequality constraints (≤ only -- in the data type). data LinProg' t v a = - Objective (LinExpr t v) a - | EqConstraint (LinExpr t v) (LinExpr t v) a - | LeqConstraint (LinExpr t v) (LinExpr t v) a + Objective !(LinExpr t v) !a + | EqConstraint !(LinExpr t v) !(LinExpr t v) !a + | LeqConstraint !(LinExpr t v) !(LinExpr t v) !a deriving (Show, Eq, Functor) type LinProg t v = Free (LinProg' t v) @@ -137,3 +162,26 @@ b >: a = liftF (LeqConstraint a b ()) infix 4 =: infix 4 <: infix 4 >: + +-- Quickcheck properties + +instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where + arbitrary = oneof [ + (Fix . Var) <$> arbitrary + ,(Fix . Lit) <$> arbitrary + ,((Fix .) . Add) <$> arbitrary <*> arbitrary + ,((Fix .) . Mul) <$> arbitrary <*> arbitrary + ,(Fix . Negate) <$> arbitrary] + +prop_rewrite :: LinExpr Int Int -> Property +prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq) + where + vs = vars eq + ws = map (flip getVar eq) vs + isLinear = (<= 1) . cata isLinear' where + isLinear' (Mul a b) = a + b + isLinear' (Add a b) = max a b + isLinear' (Var _) = 1 + isLinear' (Wvar _ _) = 1 + isLinear' (Lit _) = 0 + isLinear' (Negate a) = a |