From 544eef53181f52423f513227e2bd98c20815b243 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Mon, 27 Oct 2014 10:17:23 +1100 Subject: Improve computational complexity of varTerms by rewriting equations. --- Math/LinProg/LP.hs | 9 ++--- Math/LinProg/LPSolve.hs | 36 +++++++++----------- Math/LinProg/Types.hs | 88 ++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 85 insertions(+), 48 deletions(-) (limited to 'Math/LinProg') diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs index 1fc59e0..18d2068 100644 --- a/Math/LinProg/LP.hs +++ b/Math/LinProg/LP.hs @@ -56,7 +56,7 @@ compile ast = compile' ast initCompilerS where instance (Show t, Num t, Ord t) => Show (CompilerS t String) where show s = unlines $ catMaybes [ Just "Minimize" - ,Just (showEq $ varTerms (s ^. objective)) + ,Just (showEq (s ^. objective)) ,if hasST then Just "Subject to" else Nothing ,if hasEqs then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " = " ++ show (negate b)) $ s ^. equals) else Nothing ,if hasUnbounded then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " <= " ++ show (negate b)) unbounded) else Nothing @@ -64,12 +64,7 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where ,if hasBounded then Just (intercalate "\n" $ map (\(l, v, u) -> show l ++ " <= " ++ v ++ " <= " ++ show u) bounded) else Nothing ] where - getVars eq = zip vs ws - where - vs = vars eq - ws = map (`getVar` eq) vs - - showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . getVars + showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . varTerms (bounded, unbounded) = findBounds $ s ^. leqs hasBounded = not (null bounded) diff --git a/Math/LinProg/LPSolve.hs b/Math/LinProg/LPSolve.hs index 2031e0a..baa5d7e 100644 --- a/Math/LinProg/LPSolve.hs +++ b/Math/LinProg/LPSolve.hs @@ -26,10 +26,9 @@ import Math.LinProg.LPSolve.FFI hiding (solve) import qualified Math.LinProg.LPSolve.FFI as F import Math.LinProg.LP import Math.LinProg.Types -import qualified Data.Map as M +import qualified Data.Map.Strict as M import Prelude hiding (EQ) --- | Solves an LP using lp_solve. solve :: (Eq v, Ord v) => LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)]) solve = solveWithTimeout 0 @@ -44,30 +43,25 @@ solveWithTimeout t (compile -> lp) = do -- Eqs forM_ (zip [1..] $ lp ^. equals) $ \(i, eq) -> - forM_ (M.keys varLUT) $ \v -> do - let w = getVar v $ fst eq - c = negate $ snd eq - when (w /= 0) $ do - setMat m i (varLUT M.! v) w - setConstrType m i EQ - setRHS m i c - return () + forM_ (varTerms (fst eq)) $ \(v, w) -> do + let c = negate $ snd eq + setMat m i (varLUT M.! v) w + setConstrType m i EQ + setRHS m i c + return () -- Leqs forM_ (zip [1+nequals..] $ lp ^. leqs) $ \(i, eq) -> - forM_ (M.keys varLUT) $ \v -> do - let w = getVar v $ fst eq - c = negate $ snd eq - when (w /= 0) $ do - setMat m i (varLUT M.! v) w - setConstrType m i LE - setRHS m i c - return () + forM_ (varTerms (fst eq)) $ \(v, w) -> do + let c = negate $ snd eq + setMat m i (varLUT M.! v) w + setConstrType m i LE + setRHS m i c + return () -- Objective - forM_ (M.keys varLUT) $ \v -> do - let w = getVar v $ lp ^. objective - when (w /= 0) $ void $ setMat m 0 (varLUT M.! v) w + forM_ (varTerms (lp ^. objective)) $ \(v, w) -> do + void $ setMat m 0 (varLUT M.! v) w res <- F.solve m sol <- snd <$> getSol nvars m diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index 6fbad93..2a81918 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -30,15 +30,20 @@ module Math.LinProg.Types ( import Data.Functor.Foldable import Control.Monad.Free +import qualified Data.Map.Strict as M +import Test.QuickCheck +import Control.Applicative +import Data.List -- | Base AST for expressions. Expressions have factors or type t and -- variables referenced by ids of type v. data LinExpr' t v a = - Lit t - | Var v - | Add a a - | Mul a a - | Negate a + Lit !t + | Var !v + | Wvar !t !v + | Add !a !a + | Mul !a !a + | Negate !a deriving (Show, Eq, Functor) type LinExpr t v = Fix (LinExpr' t v) @@ -67,13 +72,15 @@ consts :: Num t => LinExpr t v -> t consts = cata consts' where consts' (Negate a) = negate a consts' (Lit a) = a - consts' (Var _) = 0 consts' (Add a b) = a + b consts' (Mul a b) = a * b + consts' _ = 0 -- | Gets the multiplier for a particular variable. getVar :: (Num t, Eq v) => v -> LinExpr t v -> t getVar id x = cata getVar' x - consts x where + getVar' (Wvar w x) | x == id = w + | otherwise = 0 getVar' (Var x) | x == id = 1 | otherwise = 0 getVar' (Lit a) = a @@ -82,27 +89,44 @@ getVar id x = cata getVar' x - consts x where getVar' (Negate a) = negate a -- | Gets all variables used in an equation. -vars :: LinExpr t v -> [v] -vars = cata vars' where +vars :: Eq v => LinExpr t v -> [v] +vars = nub . cata vars' where vars' (Var x) = [x] vars' (Add a b) = a ++ b vars' (Mul a b) = a ++ b vars' (Negate a) = a vars' _ = [] --- | Reduces an expression to the variable terms -varTerms eq = go eq' where - go [t] = t - go (t:ts) = Fix (Add t (go ts)) - go [] = Fix (Lit 0) +-- | Expands terms to Wvars but does not collect like terms +rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v +rewrite = cata rewrite' where + rewrite' (Var a) = Fix (Wvar 1 a) + rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a + rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a + rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c) + rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c) + rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) + rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b)) + rewrite' (Lit a) = Fix (Lit a) + rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b) + rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a)) + rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b))) + rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b) + rewrite' a = Fix a - eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws - vs = vars eq - ws = map (`getVar` eq) vs +-- | Reduces an expression to the variable terms +varTerms :: (Num t, Eq t, Ord v) => LinExpr t v -> [(v, t)] +varTerms = M.toList . cata go . rewrite where + go (Wvar w a) = M.fromList [(a, w)] + go (Add a b) = M.unionWith (+) a b + go (Mul _ _) = error "Only linear terms supported" + go _ = M.empty -- | Splits an expression into the variables and the constant term split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t) -split eq = (varTerms eq, consts eq) +split eq = (eq - (Fix (Lit (consts eq))), consts eq) prettyPrint :: (Show t, Show v) => LinExpr t v -> String prettyPrint = cata prettyPrint' where @@ -110,14 +134,15 @@ prettyPrint = cata prettyPrint' where prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"] prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"] prettyPrint' (Var x) = show x + prettyPrint' (Wvar w x) = show w ++ show x -- | Free monad for linear programs. The monad allows definition of the -- objective function, equality constraints, and inequality constraints (≤ only -- in the data type). data LinProg' t v a = - Objective (LinExpr t v) a - | EqConstraint (LinExpr t v) (LinExpr t v) a - | LeqConstraint (LinExpr t v) (LinExpr t v) a + Objective !(LinExpr t v) !a + | EqConstraint !(LinExpr t v) !(LinExpr t v) !a + | LeqConstraint !(LinExpr t v) !(LinExpr t v) !a deriving (Show, Eq, Functor) type LinProg t v = Free (LinProg' t v) @@ -137,3 +162,26 @@ b >: a = liftF (LeqConstraint a b ()) infix 4 =: infix 4 <: infix 4 >: + +-- Quickcheck properties + +instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where + arbitrary = oneof [ + (Fix . Var) <$> arbitrary + ,(Fix . Lit) <$> arbitrary + ,((Fix .) . Add) <$> arbitrary <*> arbitrary + ,((Fix .) . Mul) <$> arbitrary <*> arbitrary + ,(Fix . Negate) <$> arbitrary] + +prop_rewrite :: LinExpr Int Int -> Property +prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq) + where + vs = vars eq + ws = map (flip getVar eq) vs + isLinear = (<= 1) . cata isLinear' where + isLinear' (Mul a b) = a + b + isLinear' (Add a b) = max a b + isLinear' (Var _) = 1 + isLinear' (Wvar _ _) = 1 + isLinear' (Lit _) = 0 + isLinear' (Negate a) = a -- cgit v1.2.3