aboutsummaryrefslogtreecommitdiff
path: root/Math
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2014-10-27 10:17:23 +1100
committerJustin Bedo <cu@cua0.org>2014-10-27 10:18:31 +1100
commit544eef53181f52423f513227e2bd98c20815b243 (patch)
tree5af4b35386c603b0fbbe5c7705a325d5f9b0d69b /Math
parente60c072870ee428720dce1f22890f1ce075325e4 (diff)
Improve computational complexity of varTerms by rewriting equations.
Diffstat (limited to 'Math')
-rw-r--r--Math/LinProg/LP.hs9
-rw-r--r--Math/LinProg/LPSolve.hs36
-rw-r--r--Math/LinProg/Types.hs88
3 files changed, 85 insertions, 48 deletions
diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs
index 1fc59e0..18d2068 100644
--- a/Math/LinProg/LP.hs
+++ b/Math/LinProg/LP.hs
@@ -56,7 +56,7 @@ compile ast = compile' ast initCompilerS where
instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
show s = unlines $ catMaybes [
Just "Minimize"
- ,Just (showEq $ varTerms (s ^. objective))
+ ,Just (showEq (s ^. objective))
,if hasST then Just "Subject to" else Nothing
,if hasEqs then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " = " ++ show (negate b)) $ s ^. equals) else Nothing
,if hasUnbounded then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " <= " ++ show (negate b)) unbounded) else Nothing
@@ -64,12 +64,7 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
,if hasBounded then Just (intercalate "\n" $ map (\(l, v, u) -> show l ++ " <= " ++ v ++ " <= " ++ show u) bounded) else Nothing
]
where
- getVars eq = zip vs ws
- where
- vs = vars eq
- ws = map (`getVar` eq) vs
-
- showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . getVars
+ showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . varTerms
(bounded, unbounded) = findBounds $ s ^. leqs
hasBounded = not (null bounded)
diff --git a/Math/LinProg/LPSolve.hs b/Math/LinProg/LPSolve.hs
index 2031e0a..baa5d7e 100644
--- a/Math/LinProg/LPSolve.hs
+++ b/Math/LinProg/LPSolve.hs
@@ -26,10 +26,9 @@ import Math.LinProg.LPSolve.FFI hiding (solve)
import qualified Math.LinProg.LPSolve.FFI as F
import Math.LinProg.LP
import Math.LinProg.Types
-import qualified Data.Map as M
+import qualified Data.Map.Strict as M
import Prelude hiding (EQ)
--- | Solves an LP using lp_solve.
solve :: (Eq v, Ord v) => LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)])
solve = solveWithTimeout 0
@@ -44,30 +43,25 @@ solveWithTimeout t (compile -> lp) = do
-- Eqs
forM_ (zip [1..] $ lp ^. equals) $ \(i, eq) ->
- forM_ (M.keys varLUT) $ \v -> do
- let w = getVar v $ fst eq
- c = negate $ snd eq
- when (w /= 0) $ do
- setMat m i (varLUT M.! v) w
- setConstrType m i EQ
- setRHS m i c
- return ()
+ forM_ (varTerms (fst eq)) $ \(v, w) -> do
+ let c = negate $ snd eq
+ setMat m i (varLUT M.! v) w
+ setConstrType m i EQ
+ setRHS m i c
+ return ()
-- Leqs
forM_ (zip [1+nequals..] $ lp ^. leqs) $ \(i, eq) ->
- forM_ (M.keys varLUT) $ \v -> do
- let w = getVar v $ fst eq
- c = negate $ snd eq
- when (w /= 0) $ do
- setMat m i (varLUT M.! v) w
- setConstrType m i LE
- setRHS m i c
- return ()
+ forM_ (varTerms (fst eq)) $ \(v, w) -> do
+ let c = negate $ snd eq
+ setMat m i (varLUT M.! v) w
+ setConstrType m i LE
+ setRHS m i c
+ return ()
-- Objective
- forM_ (M.keys varLUT) $ \v -> do
- let w = getVar v $ lp ^. objective
- when (w /= 0) $ void $ setMat m 0 (varLUT M.! v) w
+ forM_ (varTerms (lp ^. objective)) $ \(v, w) -> do
+ void $ setMat m 0 (varLUT M.! v) w
res <- F.solve m
sol <- snd <$> getSol nvars m
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs
index 6fbad93..2a81918 100644
--- a/Math/LinProg/Types.hs
+++ b/Math/LinProg/Types.hs
@@ -30,15 +30,20 @@ module Math.LinProg.Types (
import Data.Functor.Foldable
import Control.Monad.Free
+import qualified Data.Map.Strict as M
+import Test.QuickCheck
+import Control.Applicative
+import Data.List
-- | Base AST for expressions. Expressions have factors or type t and
-- variables referenced by ids of type v.
data LinExpr' t v a =
- Lit t
- | Var v
- | Add a a
- | Mul a a
- | Negate a
+ Lit !t
+ | Var !v
+ | Wvar !t !v
+ | Add !a !a
+ | Mul !a !a
+ | Negate !a
deriving (Show, Eq, Functor)
type LinExpr t v = Fix (LinExpr' t v)
@@ -67,13 +72,15 @@ consts :: Num t => LinExpr t v -> t
consts = cata consts' where
consts' (Negate a) = negate a
consts' (Lit a) = a
- consts' (Var _) = 0
consts' (Add a b) = a + b
consts' (Mul a b) = a * b
+ consts' _ = 0
-- | Gets the multiplier for a particular variable.
getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
+ getVar' (Wvar w x) | x == id = w
+ | otherwise = 0
getVar' (Var x) | x == id = 1
| otherwise = 0
getVar' (Lit a) = a
@@ -82,27 +89,44 @@ getVar id x = cata getVar' x - consts x where
getVar' (Negate a) = negate a
-- | Gets all variables used in an equation.
-vars :: LinExpr t v -> [v]
-vars = cata vars' where
+vars :: Eq v => LinExpr t v -> [v]
+vars = nub . cata vars' where
vars' (Var x) = [x]
vars' (Add a b) = a ++ b
vars' (Mul a b) = a ++ b
vars' (Negate a) = a
vars' _ = []
--- | Reduces an expression to the variable terms
-varTerms eq = go eq' where
- go [t] = t
- go (t:ts) = Fix (Add t (go ts))
- go [] = Fix (Lit 0)
+-- | Expands terms to Wvars but does not collect like terms
+rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v
+rewrite = cata rewrite' where
+ rewrite' (Var a) = Fix (Wvar 1 a)
+ rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a
+ rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a
+ rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c)
+ rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c)
+ rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))
+ rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b))
+ rewrite' (Lit a) = Fix (Lit a)
+ rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
+ rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
+ rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b)
+ rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a))
+ rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b)))
+ rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b)
+ rewrite' a = Fix a
- eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws
- vs = vars eq
- ws = map (`getVar` eq) vs
+-- | Reduces an expression to the variable terms
+varTerms :: (Num t, Eq t, Ord v) => LinExpr t v -> [(v, t)]
+varTerms = M.toList . cata go . rewrite where
+ go (Wvar w a) = M.fromList [(a, w)]
+ go (Add a b) = M.unionWith (+) a b
+ go (Mul _ _) = error "Only linear terms supported"
+ go _ = M.empty
-- | Splits an expression into the variables and the constant term
split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
-split eq = (varTerms eq, consts eq)
+split eq = (eq - (Fix (Lit (consts eq))), consts eq)
prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
@@ -110,14 +134,15 @@ prettyPrint = cata prettyPrint' where
prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
prettyPrint' (Var x) = show x
+ prettyPrint' (Wvar w x) = show w ++ show x
-- | Free monad for linear programs. The monad allows definition of the
-- objective function, equality constraints, and inequality constraints (≤ only
-- in the data type).
data LinProg' t v a =
- Objective (LinExpr t v) a
- | EqConstraint (LinExpr t v) (LinExpr t v) a
- | LeqConstraint (LinExpr t v) (LinExpr t v) a
+ Objective !(LinExpr t v) !a
+ | EqConstraint !(LinExpr t v) !(LinExpr t v) !a
+ | LeqConstraint !(LinExpr t v) !(LinExpr t v) !a
deriving (Show, Eq, Functor)
type LinProg t v = Free (LinProg' t v)
@@ -137,3 +162,26 @@ b >: a = liftF (LeqConstraint a b ())
infix 4 =:
infix 4 <:
infix 4 >:
+
+-- Quickcheck properties
+
+instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where
+ arbitrary = oneof [
+ (Fix . Var) <$> arbitrary
+ ,(Fix . Lit) <$> arbitrary
+ ,((Fix .) . Add) <$> arbitrary <*> arbitrary
+ ,((Fix .) . Mul) <$> arbitrary <*> arbitrary
+ ,(Fix . Negate) <$> arbitrary]
+
+prop_rewrite :: LinExpr Int Int -> Property
+prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq)
+ where
+ vs = vars eq
+ ws = map (flip getVar eq) vs
+ isLinear = (<= 1) . cata isLinear' where
+ isLinear' (Mul a b) = a + b
+ isLinear' (Add a b) = max a b
+ isLinear' (Var _) = 1
+ isLinear' (Wvar _ _) = 1
+ isLinear' (Lit _) = 0
+ isLinear' (Negate a) = a