diff options
author | Justin Bedo <cu@cua0.org> | 2014-11-11 13:37:48 +1100 |
---|---|---|
committer | Justin Bedo <cu@cua0.org> | 2014-11-11 13:37:48 +1100 |
commit | 19a7c7e4225ad6650d8a60d673edb7b23294f9d5 (patch) | |
tree | 55149c8792ea14be5ed05147bd77ad4f8589cdb5 /Math/LinProg/Types.hs | |
parent | aeafc1692afa5952f3d06195916bb38463c1674c (diff) |
Diffstat (limited to 'Math/LinProg/Types.hs')
-rw-r--r-- | Math/LinProg/Types.hs | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index d67e642..ebf7156 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -38,6 +38,7 @@ import Data.List import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S import Test.QuickCheck +import Data.Ratio -- | Base AST for expressions. Expressions have factors or type t and -- variables referenced by ids of type v. @@ -47,6 +48,7 @@ data LinExpr' t v a = | Wvar !t !v | Add !a !a | Mul !a !a + | Div !a !a | Negate !a deriving (Show, Eq, Functor) @@ -68,20 +70,21 @@ instance Num t => Num (LinExpr t v) where -- | Linear expressions can also be instances of fractional. instance Fractional t => Fractional (LinExpr t v) where - a / b = Fix (Mul a (1/b)) + a / b = Fix (Div a b) fromRational a = Fix (Lit (fromRational a)) -- | Reduce a linear expression down to the constant factor. -consts :: Num t => LinExpr t v -> t +consts :: (Num t, Fractional t) => LinExpr t v -> t consts = cata consts' where consts' (Negate a) = negate a consts' (Lit a) = a consts' (Add a b) = a + b consts' (Mul a b) = a * b + consts' (Div a b) = a / b consts' _ = 0 -- | Gets the multiplier for a particular variable. -getVar :: (Num t, Eq v) => v -> LinExpr t v -> t +getVar :: (Num t, Fractional t, Eq v) => v -> LinExpr t v -> t getVar id x = cata getVar' x - consts x where getVar' (Wvar w x) | x == id = w | otherwise = 0 @@ -90,6 +93,7 @@ getVar id x = cata getVar' x - consts x where getVar' (Lit a) = a getVar' (Add a b) = a + b getVar' (Mul a b) = a * b + getVar' (Div a b) = a / b getVar' (Negate a) = negate a -- | Gets all variables used in an equation. @@ -103,26 +107,30 @@ vars = S.toList . cata vars' where vars' _ = S.empty -- | Expands terms to Wvars but does not collect like terms -rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v +rewrite :: (Eq t, Num t, Fractional t) => LinExpr t v -> LinExpr t v rewrite = cata rewrite' where + rewrite' (Lit a) = Fix (Lit a) rewrite' (Var a) = Fix (Wvar 1 a) rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a + rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c) rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c) - rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b)) rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b)) - rewrite' (Lit a) = Fix (Lit a) rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c))) + rewrite' (Div (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a / b) c) + rewrite' (Div (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a / b) c) + rewrite' (Div (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a / b)) rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b) rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a)) rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b))) - rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b) + rewrite' (Negate (Fix (Mul a b))) = rewrite' (Mul (rewrite' (Negate a)) b) + rewrite' (Negate (Fix (Div a b))) = rewrite' (Div (rewrite' (Negate a)) b) rewrite' a = Fix a -- | Reduces an expression to the variable terms -varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)] +varTerms :: (Num t, Fractional t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)] varTerms = M.toList . cata go . rewrite where go (Wvar w a) = M.fromList [(a, w)] go (Add a b) = M.unionWith (+) a b @@ -130,13 +138,15 @@ varTerms = M.toList . cata go . rewrite where go _ = M.empty -- | Splits an expression into the variables and the constant term -split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t) -split eq = (eq - (Fix (Lit (consts eq))), consts eq) +split :: (Num t, Fractional t, Eq v) => LinExpr t v -> (LinExpr t v, t) +split eq = (eq - Fix (Lit (consts eq)), consts eq) prettyPrint :: (Show t, Show v) => LinExpr t v -> String prettyPrint = cata prettyPrint' where prettyPrint' (Lit a) = show a + prettyPrint' (Negate a) = "-" ++ show a prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"] + prettyPrint' (Div a b) = concat ["(", a, "/", b, ")"] prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"] prettyPrint' (Var x) = show x prettyPrint' (Wvar w x) = show w ++ show x @@ -186,13 +196,14 @@ instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where ,((Fix .) . Mul) <$> arbitrary <*> arbitrary ,(Fix . Negate) <$> arbitrary] -prop_rewrite :: LinExpr Int Int -> Property +prop_rewrite :: LinExpr Rational Int -> Property prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq) where vs = vars eq - ws = map (flip getVar eq) vs + ws = map (`getVar` eq) vs isLinear = (<= 1) . cata isLinear' where isLinear' (Mul a b) = a + b + isLinear' (Div a b) = a + b isLinear' (Add a b) = max a b isLinear' (Var _) = 1 isLinear' (Wvar _ _) = 1 |