aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg/Types.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Math/LinProg/Types.hs')
-rw-r--r--Math/LinProg/Types.hs35
1 files changed, 23 insertions, 12 deletions
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs
index d67e642..ebf7156 100644
--- a/Math/LinProg/Types.hs
+++ b/Math/LinProg/Types.hs
@@ -38,6 +38,7 @@ import Data.List
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S
import Test.QuickCheck
+import Data.Ratio
-- | Base AST for expressions. Expressions have factors or type t and
-- variables referenced by ids of type v.
@@ -47,6 +48,7 @@ data LinExpr' t v a =
| Wvar !t !v
| Add !a !a
| Mul !a !a
+ | Div !a !a
| Negate !a
deriving (Show, Eq, Functor)
@@ -68,20 +70,21 @@ instance Num t => Num (LinExpr t v) where
-- | Linear expressions can also be instances of fractional.
instance Fractional t => Fractional (LinExpr t v) where
- a / b = Fix (Mul a (1/b))
+ a / b = Fix (Div a b)
fromRational a = Fix (Lit (fromRational a))
-- | Reduce a linear expression down to the constant factor.
-consts :: Num t => LinExpr t v -> t
+consts :: (Num t, Fractional t) => LinExpr t v -> t
consts = cata consts' where
consts' (Negate a) = negate a
consts' (Lit a) = a
consts' (Add a b) = a + b
consts' (Mul a b) = a * b
+ consts' (Div a b) = a / b
consts' _ = 0
-- | Gets the multiplier for a particular variable.
-getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
+getVar :: (Num t, Fractional t, Eq v) => v -> LinExpr t v -> t
getVar id x = cata getVar' x - consts x where
getVar' (Wvar w x) | x == id = w
| otherwise = 0
@@ -90,6 +93,7 @@ getVar id x = cata getVar' x - consts x where
getVar' (Lit a) = a
getVar' (Add a b) = a + b
getVar' (Mul a b) = a * b
+ getVar' (Div a b) = a / b
getVar' (Negate a) = negate a
-- | Gets all variables used in an equation.
@@ -103,26 +107,30 @@ vars = S.toList . cata vars' where
vars' _ = S.empty
-- | Expands terms to Wvars but does not collect like terms
-rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v
+rewrite :: (Eq t, Num t, Fractional t) => LinExpr t v -> LinExpr t v
rewrite = cata rewrite' where
+ rewrite' (Lit a) = Fix (Lit a)
rewrite' (Var a) = Fix (Wvar 1 a)
rewrite' (Add (Fix (Lit _)) a@(Fix (Wvar _ _))) = a
rewrite' (Add a@(Fix (Wvar _ _)) (Fix (Lit _))) = a
+ rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))
rewrite' (Mul (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a * b) c)
rewrite' (Mul (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a * b) c)
- rewrite' (Add (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a + b))
rewrite' (Mul (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a * b))
- rewrite' (Lit a) = Fix (Lit a)
rewrite' (Mul (Fix (Add a b)) c) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
rewrite' (Mul c (Fix (Add a b))) = rewrite' (Add (rewrite' (Mul a c)) (rewrite' (Mul b c)))
+ rewrite' (Div (Fix (Lit a)) (Fix (Wvar b c))) = Fix (Wvar (a / b) c)
+ rewrite' (Div (Fix (Wvar b c)) (Fix (Lit a))) = Fix (Wvar (a / b) c)
+ rewrite' (Div (Fix (Lit a)) (Fix (Lit b))) = Fix (Lit (a / b))
rewrite' (Negate (Fix (Wvar a b))) = Fix (Wvar (negate a) b)
rewrite' (Negate (Fix (Lit a))) = Fix (Lit (negate a))
rewrite' (Negate (Fix (Add a b))) = rewrite' (Add (rewrite' (Negate a)) (rewrite' (Negate b)))
- rewrite' (Negate (Fix (Mul a b))) = rewrite' (Add (rewrite' (Negate a)) b)
+ rewrite' (Negate (Fix (Mul a b))) = rewrite' (Mul (rewrite' (Negate a)) b)
+ rewrite' (Negate (Fix (Div a b))) = rewrite' (Div (rewrite' (Negate a)) b)
rewrite' a = Fix a
-- | Reduces an expression to the variable terms
-varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)]
+varTerms :: (Num t, Fractional t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)]
varTerms = M.toList . cata go . rewrite where
go (Wvar w a) = M.fromList [(a, w)]
go (Add a b) = M.unionWith (+) a b
@@ -130,13 +138,15 @@ varTerms = M.toList . cata go . rewrite where
go _ = M.empty
-- | Splits an expression into the variables and the constant term
-split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
-split eq = (eq - (Fix (Lit (consts eq))), consts eq)
+split :: (Num t, Fractional t, Eq v) => LinExpr t v -> (LinExpr t v, t)
+split eq = (eq - Fix (Lit (consts eq)), consts eq)
prettyPrint :: (Show t, Show v) => LinExpr t v -> String
prettyPrint = cata prettyPrint' where
prettyPrint' (Lit a) = show a
+ prettyPrint' (Negate a) = "-" ++ show a
prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
+ prettyPrint' (Div a b) = concat ["(", a, "/", b, ")"]
prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
prettyPrint' (Var x) = show x
prettyPrint' (Wvar w x) = show w ++ show x
@@ -186,13 +196,14 @@ instance (Arbitrary t, Arbitrary v) => Arbitrary (LinExpr t v) where
,((Fix .) . Mul) <$> arbitrary <*> arbitrary
,(Fix . Negate) <$> arbitrary]
-prop_rewrite :: LinExpr Int Int -> Property
+prop_rewrite :: LinExpr Rational Int -> Property
prop_rewrite eq = isLinear eq ==> sort (zip vs ws) == sort (varTerms eq)
where
vs = vars eq
- ws = map (flip getVar eq) vs
+ ws = map (`getVar` eq) vs
isLinear = (<= 1) . cata isLinear' where
isLinear' (Mul a b) = a + b
+ isLinear' (Div a b) = a + b
isLinear' (Add a b) = max a b
isLinear' (Var _) = 1
isLinear' (Wvar _ _) = 1