From 19a7c7e4225ad6650d8a60d673edb7b23294f9d5 Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Tue, 11 Nov 2014 13:37:48 +1100 Subject: Correct bug with fractional handling --- Math/LinProg/LP.hs | 6 +++--- Math/LinProg/Types.hs | 35 +++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 15 deletions(-) (limited to 'Math/LinProg') diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs index 513bdad..e4984f7 100644 --- a/Math/LinProg/LP.hs +++ b/Math/LinProg/LP.hs @@ -45,7 +45,7 @@ data CompilerS t v = CompilerS { $(makeLenses ''CompilerS) -- | Compiles a linear programming monad to intermediate form which is easier to process -compile :: (Num t, Show t, Ord t, Eq v) => LinProg t v () -> CompilerS t v +compile :: (Num t, Fractional t, Show t, Ord t, Eq v) => LinProg t v () -> CompilerS t v compile ast = compile' ast initCompilerS where compile' (Free (Objective a c)) state = compile' c $ state & objective +~ a compile' (Free (EqConstraint a b c)) state = compile' c $ state & equals %~ (split (a-b):) @@ -62,7 +62,7 @@ compile ast = compile' ast initCompilerS where [] -- | Shows a compiled state as LP format. Requires variable ids are strings. -instance (Show t, Num t, Ord t) => Show (CompilerS t String) where +instance (Show t, Num t, Fractional t, Ord t) => Show (CompilerS t String) where show s = unlines $ catMaybes [ Just "Minimize" ,Just (showEq (s ^. objective)) @@ -89,7 +89,7 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where render x = (if x >= 0 then "+" else "") ++ show x -findBounds :: (Hashable v, Eq v, Num t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v]) +findBounds :: (Hashable v, Eq v, Num t, Fractional t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v]) findBounds eqs = (mapMaybe bound singleTerms, eqs \\ filter (isBounded . head . vars . fst) singleTermEqs) where singleTermEqs = filter (\(ts, _) -> length (vars ts) == 1) eqs diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index d67e642..ebf7156 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -38,6 +38,7 @@ import Data.List import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S import Test.QuickCheck +import Data.Ratio -- | Base AST for expressions. Expressions have factors or type t and -- variables referenced by ids of type v. @@ -47,6 +48,7 @@ data LinExpr' t v a = | Wvar !t !v | Add !a !a | Mul !a !a + | Div !a !a | Negate !a deriving (Show, Eq, Functor) @@ -68,20 +70,21 @@ instance Num t => Num (LinExpr t v) where -- | Linear expressions can also be instances of fractional. instance Fractional t => Fractional (LinExpr t v) where - a / b = Fix (Mul a (1/b)) + a / b = Fix (Div a b) fromRational a = Fix (Lit (fromRational a)) -- | Reduce a linear expression down to the constant factor. -consts :: Num t => LinExpr t v -> t +consts :: (Num t, Fractional t) => LinExpr t v -> t consts = cata consts' where consts' (Negate a) = negate a consts' (Lit a) = a consts' (Add a b) = a + b consts' (Mul a b) = a * b + consts' (Div a b) = a / b consts' _ = 0 -- | Gets the multiplier for a particular variable. -getVar :: (Num t, Eq v) => v -> LinExpr t v -> t +getVar :: (Num t, Fractional t, Eq v) => v -> LinExpr t v -> t getVar id x = cata getVar' x - consts x where getVar' (Wvar w x) | x == id = w | otherwise = 0 @@ -90,6 +93,7 @@ getVar id x = cata getVar' x - consts x where getVar' (Lit a) = a getVar' (Add a b) = a + b getVar' (Mul a b) = a * b + getVar' (Div a b) = a / b getVar' (Negate a) = negate a -- | Gets all variables used in an equation. @@ -103,26 +107,30 @@ vars = S.toList . cata vars' where vars' _ = S.empty -- | Expands terms to Wvars but does not collect like terms -rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v +rewrite :: (Eq t, Num t, Fractional t) => LinExpr t v -> LinExpr t v rewrite = cata rewrite' where + rewrite' (Lit a) = Fix (Lit a) rewrite' (Var a) = Fix (Wvar 1 a) rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a + rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c) rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c) - rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b)) - rewrite' (Lit a) = Fix (Lit a) rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Div (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a / b) c) + rewrite' (Div (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a / b) c) + rewrite' (Div (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a / b)) rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b) rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a)) rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b))) - rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b) + rewrite' (Negate (Fix (Mul a b))) = rewrite' (Mul (rewrite' (Negate a)) b) + rewrite' (Negate (Fix (Div a b))) = rewrite' (Div (rewrite' (Negate a)) b) rewrite' a = Fix a -- | Reduces an expression to the variable terms -varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)] +varTerms :: (Num t, Fractional t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)] varTerms = M.toList . cata go . rewrite where go (Wvar w a) = M.fromList [(a, w)] go (Add a b) = M.unionWith (+) a b @@ -130,13 +138,15 @@ varTerms = M.toList . cata go . rewrite where go _ = M.empty -- | Splits an expression into the variables and the constant term -split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t) -split eq = (eq - (Fix (Lit (consts eq))), consts eq) +split :: (Num t, Fractional t, Eq v) => LinExpr t v -> (LinExpr t v, t) +split eq = (eq - Fix (Lit (consts eq)), consts eq) prettyPrint :: (Show t, Show v) => LinExpr t v -> String prettyPrint = cata prettyPrint' where prettyPrint' (Lit a) = show a + prettyPrint' (Negate a) = "-" ++ show a prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"] + prettyPrint' (Div a b) = concat ["(", a, "/", b, ")"] prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"] prettyPrint' (Var x) = show x prettyPrint' (Wvar w x) = show w ++ show x @@ -186,13 +196,14 @@ instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where ,((Fix .) . Mul) <$> arbitrary <*> arbitrary ,(Fix . Negate) <$> arbitrary] -prop_rewrite :: LinExpr Int Int -> Property +prop_rewrite :: LinExpr Rational Int -> Property prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq) where vs = vars eq - ws = map (flip getVar eq) vs + ws = map (`getVar` eq) vs isLinear = (<= 1) . cata isLinear' where isLinear' (Mul a b) = a + b + isLinear' (Div a b) = a + b isLinear' (Add a b) = max a b isLinear' (Var _) = 1 isLinear' (Wvar _ _) = 1 -- cgit v1.2.3