diff options
Diffstat (limited to 'Math/LinProg')
| -rw-r--r-- | Math/LinProg/LP.hs | 6 | ||||
| -rw-r--r-- | Math/LinProg/Types.hs | 35 | 
2 files changed, 26 insertions, 15 deletions
| diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs index 513bdad..e4984f7 100644 --- a/Math/LinProg/LP.hs +++ b/Math/LinProg/LP.hs @@ -45,7 +45,7 @@ data CompilerS t v = CompilerS {  $(makeLenses ''CompilerS)  -- | Compiles a linear programming monad to intermediate form which is easier to process -compile :: (Num t, Show t, Ord t, Eq v) => LinProg t v () -> CompilerS t v +compile :: (Num t, Fractional t, Show t, Ord t, Eq v) => LinProg t v () -> CompilerS t v  compile ast = compile' ast initCompilerS where    compile' (Free (Objective a c)) state = compile' c $ state & objective +~ a    compile' (Free (EqConstraint a b c)) state = compile' c $ state & equals %~ (split (a-b):) @@ -62,7 +62,7 @@ compile ast = compile' ast initCompilerS where      []  -- | Shows a compiled state as LP format.  Requires variable ids are strings. -instance (Show t, Num t, Ord t) => Show (CompilerS t String) where +instance (Show t, Num t, Fractional t, Ord t) => Show (CompilerS t String) where    show s = unlines $ catMaybes [        Just "Minimize"        ,Just (showEq (s ^. objective)) @@ -89,7 +89,7 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where        render x = (if x >= 0 then "+" else "") ++ show x -findBounds :: (Hashable v, Eq v, Num t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v]) +findBounds :: (Hashable v, Eq v, Num t, Fractional t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v])  findBounds eqs = (mapMaybe bound singleTerms, eqs \\ filter (isBounded . head . vars . fst) singleTermEqs)    where      singleTermEqs = filter (\(ts, _) -> length (vars ts) == 1) eqs diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index d67e642..ebf7156 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -38,6 +38,7 @@ import Data.List  import qualified Data.HashMap.Strict as M  import qualified Data.HashSet as S  import Test.QuickCheck +import Data.Ratio  -- | Base AST for expressions.  Expressions have factors or type t and  -- variables referenced by ids of type v. @@ -47,6 +48,7 @@ data LinExpr' t v a =    | Wvar !t !v    | Add !a !a    | Mul !a !a +  | Div !a !a    | Negate !a    deriving (Show, Eq, Functor) @@ -68,20 +70,21 @@ instance Num t => Num (LinExpr t v) where  -- | Linear expressions can also be instances of fractional.  instance Fractional t => Fractional (LinExpr t v) where -  a / b = Fix (Mul a (1/b)) +  a / b = Fix (Div a b)    fromRational a = Fix (Lit (fromRational a))  -- | Reduce a linear expression down to the constant factor. -consts :: Num t => LinExpr t v -> t +consts :: (Num t, Fractional t) => LinExpr t v -> t  consts = cata consts' where    consts' (Negate a) = negate a    consts' (Lit a) = a    consts' (Add a b) = a + b    consts' (Mul a b) = a * b +  consts' (Div a b) = a / b    consts' _ = 0  -- | Gets the multiplier for a particular variable. -getVar :: (Num t, Eq v) => v -> LinExpr t v -> t +getVar :: (Num t, Fractional t, Eq v) => v -> LinExpr t v -> t  getVar id x = cata getVar' x - consts x where    getVar' (Wvar w x) | x == id = w                       | otherwise = 0 @@ -90,6 +93,7 @@ getVar id x = cata getVar' x - consts x where    getVar' (Lit a) = a    getVar' (Add a b) = a + b    getVar' (Mul a b) = a * b +  getVar' (Div a b) = a / b    getVar' (Negate a) = negate a  -- | Gets all variables used in an equation. @@ -103,26 +107,30 @@ vars = S.toList . cata vars' where    vars' _ = S.empty  -- | Expands terms to Wvars but does not collect like terms -rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v +rewrite :: (Eq t, Num t, Fractional t) => LinExpr t v -> LinExpr t v  rewrite = cata rewrite' where +  rewrite' (Lit a) = Fix (Lit a)    rewrite' (Var a) = Fix (Wvar 1 a)    rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a    rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a +  rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))    rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c)    rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c) -  rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))    rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b)) -  rewrite' (Lit a) = Fix (Lit a)    rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))    rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) +  rewrite' (Div (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a / b) c) +  rewrite' (Div (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a / b) c) +  rewrite' (Div (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a / b))    rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b)    rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a))    rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b))) -  rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b) +  rewrite' (Negate (Fix (Mul a b))) = rewrite' (Mul (rewrite' (Negate a)) b) +  rewrite' (Negate (Fix (Div a b))) = rewrite' (Div (rewrite' (Negate a)) b)    rewrite' a = Fix a  -- | Reduces an expression to the variable terms -varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)] +varTerms :: (Num t, Fractional t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)]  varTerms = M.toList . cata go . rewrite where    go (Wvar w a) = M.fromList [(a, w)]    go (Add a b) = M.unionWith (+) a b @@ -130,13 +138,15 @@ varTerms = M.toList . cata go . rewrite where    go _ = M.empty  -- | Splits an expression into the variables and the constant term -split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t) -split eq = (eq - (Fix (Lit (consts eq))), consts eq) +split :: (Num t, Fractional t, Eq v) => LinExpr t v -> (LinExpr t v, t) +split eq = (eq - Fix (Lit (consts eq)), consts eq)  prettyPrint :: (Show t, Show v) => LinExpr t v -> String  prettyPrint = cata prettyPrint' where    prettyPrint' (Lit a) = show a +  prettyPrint' (Negate a) = "-" ++ show a    prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"] +  prettyPrint' (Div a b) = concat ["(", a, "/", b, ")"]    prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]    prettyPrint' (Var x) = show x    prettyPrint' (Wvar w x) = show w ++ show x @@ -186,13 +196,14 @@ instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where      ,((Fix .) . Mul) <$> arbitrary <*> arbitrary      ,(Fix . Negate) <$> arbitrary] -prop_rewrite :: LinExpr Int Int -> Property +prop_rewrite :: LinExpr Rational Int -> Property  prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq)    where      vs = vars eq -    ws = map (flip getVar eq) vs +    ws = map (`getVar` eq) vs      isLinear = (<= 1) . cata isLinear' where        isLinear' (Mul a b) = a + b +      isLinear' (Div a b) = a + b        isLinear' (Add a b) = max a b        isLinear' (Var _) = 1        isLinear' (Wvar _ _) = 1 | 
