aboutsummaryrefslogtreecommitdiff
path: root/Math
diff options
context:
space:
mode:
authorJustin Bedo <cu@cua0.org>2014-10-06 16:59:31 +1100
committerJustin Bedo <cu@cua0.org>2014-10-06 16:59:31 +1100
commit10dd5dc7fef792b271d8fbbdf8222a12a910fe92 (patch)
tree3320d0ff63004c14283915d76f02863d749463c7 /Math
Initial implementation of EDSL and LP output formatter
Diffstat (limited to 'Math')
-rw-r--r--Math/LinProg/LP.hs88
-rw-r--r--Math/LinProg/Setup.hs2
-rw-r--r--Math/LinProg/Types.hs106
3 files changed, 196 insertions, 0 deletions
diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs
new file mode 100644
index 0000000..8670cdf
--- /dev/null
+++ b/Math/LinProg/LP.hs
@@ -0,0 +1,88 @@
+{-# LANGUAGE TemplateHaskell, FlexibleInstances, ScopedTypeVariables #-}
+
+module Math.LinProg.LP (
+ compile
+) where
+
+import Data.List
+import Math.LinProg.Types
+import Control.Lens
+import Control.Monad.State
+import Control.Monad.Free
+import Data.Maybe
+
+type Equation t v = (LinExpr t v, t) -- LHS and RHS
+
+data CompilerS t v = CompilerS {
+ _objective :: LinExpr t v
+ ,_equals :: [Equation t v]
+ ,_leqs :: [Equation t v]
+} deriving (Eq)
+
+$(makeLenses ''CompilerS)
+
+instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
+ show s = unlines $ catMaybes [
+ Just "Minimize"
+ ,Just (showEq $ varTerms (s ^. objective))
+ ,if hasST then Just "Subject to" else Nothing
+ ,if hasEqs then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " = " ++ show (negate b)) $ s ^. equals) else Nothing
+ ,if hasUnbounded then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " <= " ++ show (negate b)) unbounded) else Nothing
+ ,if hasBounded then Just "Bounds" else Nothing
+ ,if hasBounded then Just (intercalate "\n" $ map (\(l, v, u) -> show l ++ " <= " ++ v ++ " <= " ++ show u) bounded) else Nothing
+ ]
+ where
+ getVars eq = zip vs ws
+ where
+ vs = vars eq
+ ws = map (`getVar` eq) vs
+
+ showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . getVars
+
+ (bounded, unbounded) = findBounds $ s ^. leqs
+ hasBounded = not (null bounded)
+ hasUnbounded = not (null unbounded)
+ hasEqs = not (null (s^.equals))
+ hasST = hasUnbounded || hasEqs
+
+ render x = (if x >= 0 then "+" else "") ++ show x
+
+findBounds :: (Eq v, Num t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v])
+findBounds eqs = (mapMaybe bound singleTerms, eqs \\ filter (isBounded . head . vars . fst) singleTermEqs)
+ where
+ singleTermEqs = filter (\(ts, _) -> length (vars ts) == 1) eqs
+ singleTerms = nub $ concatMap (vars . fst) singleTermEqs
+
+ upperBound x = mapMaybe (\(a, c) -> let w = getVar x a in if w == 1 then Nothing else Just c) singleTermEqs
+ lowerBound x = mapMaybe (\(a, c) -> let w = getVar x a in if w == -1 then Nothing else Just c) singleTermEqs
+
+ bound v = bound' (lowerBound v) (upperBound v) where
+ bound' [] _ = Nothing
+ bound' _ [] = Nothing
+ bound' ls us | l <= u = Just (l, v, u)
+ | otherwise = Nothing where
+ l = maximum ls
+ u = minimum us
+
+ isBounded v = isJust (bound v)
+
+compile :: (Num t, Show t, Ord t) => LinProg t String () -> String
+compile ast = show $ compile' ast initCompilerS where
+ compile' (Free (Objective a c)) state = compile' c $ state & objective +~ a
+ compile' (Free (EqConstraint a b c)) state = compile' c $ state & equals %~ (split (b-a):)
+ compile' (Free (LeqConstraint a b c)) state = compile' c $ state & leqs %~ (split (b-a):)
+ compile' _ state = state
+
+ initCompilerS = CompilerS
+ 0
+ []
+ []
+
+test :: LinProg Double String ()
+test = do
+ let [x, y] = map var ["x", "y"]
+ obj $ 1 + 5 * y + x
+ y =: (1 + x)
+ y >: (-5)
+ x <: 10
+ x >: 0
diff --git a/Math/LinProg/Setup.hs b/Math/LinProg/Setup.hs
new file mode 100644
index 0000000..9a994af
--- /dev/null
+++ b/Math/LinProg/Setup.hs
@@ -0,0 +1,2 @@
+import Distribution.Simple
+main = defaultMain
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs
new file mode 100644
index 0000000..d96fcb5
--- /dev/null
+++ b/Math/LinProg/Types.hs
@@ -0,0 +1,106 @@
+{-# LANGUAGE DeriveFunctor, FlexibleInstances, FlexibleContexts, UndecidableInstances #-}
+
+module Math.LinProg.Types (
+ LinExpr
+ ,var
+ ,vars
+ ,varTerms
+ ,getVar
+ ,split
+ ,LinProg
+ ,LinProg'(..)
+ ,obj
+ ,(<:)
+ ,(=:)
+ ,(>:)
+ ,eq
+ ,leq
+ ,geq
+) where
+
+import Data.Functor.Foldable
+import Control.Monad.Free
+import qualified Data.Map as M
+
+data LinExpr' t v a =
+ Lit t
+ | Var v
+ | Add a a
+ | Mul a a
+ | Negate a
+ deriving (Show, Eq, Functor)
+
+type LinExpr t v = Fix (LinExpr' t v)
+
+var = Fix . Var
+
+instance Num t => Num (LinExpr t v) where
+ a * b = Fix (Mul a b)
+ a + b = Fix (Add a b)
+ negate a = Fix (Negate a)
+ fromInteger a = Fix (Lit (fromInteger a))
+ abs = undefined
+ signum = undefined
+
+consts :: Num t => LinExpr t v -> t
+consts = cata consts' where
+ consts' (Negate a) = negate a
+ consts' (Lit a) = a
+ consts' (Var _) = 0
+ consts' (Add a b) = a + b
+ consts' (Mul a b) = a * b
+
+getVar :: (Num t, Eq v) => v -> LinExpr t v -> t
+getVar id x = cata getVar' x - consts x where
+ getVar' (Var x) | x == id = 1
+ | otherwise = 0
+ getVar' (Lit a) = a
+ getVar' (Add a b) = a + b
+ getVar' (Mul a b) = a * b
+ getVar' (Negate a) = negate a
+
+vars :: LinExpr t v -> [v]
+vars = cata vars' where
+ vars' (Var x) = [x]
+ vars' (Add a b) = a ++ b
+ vars' (Mul a b) = a ++ b
+ vars' (Negate a) = a
+ vars' _ = []
+
+varTerms eq = go eq' where
+ go [t] = t
+ go (t:ts) = Fix (Add t (go ts))
+ go [] = Fix (Lit 0)
+
+ eq' = zipWith (\v w -> Fix (Mul (Fix (Lit w)) (Fix (Var v)))) vs ws
+ vs = vars eq
+ ws = map (`getVar` eq) vs
+
+split :: (Num t, Eq v) => LinExpr t v -> (LinExpr t v, t)
+split eq = (varTerms eq, consts eq)
+
+prettyPrint :: (Show t, Show v) => LinExpr t v -> String
+prettyPrint = cata prettyPrint' where
+ prettyPrint' (Lit a) = show a
+ prettyPrint' (Mul a b) = concat ["(", a, "×", b, ")"]
+ prettyPrint' (Add a b) = concat ["(", a, "+", b, ")"]
+ prettyPrint' (Var x) = show x
+
+-- Monad for linear programs
+
+data LinProg' t v a =
+ Objective (LinExpr t v) a
+ | EqConstraint (LinExpr t v) (LinExpr t v) a
+ | LeqConstraint (LinExpr t v) (LinExpr t v) a
+ deriving (Show, Eq, Functor)
+
+type LinProg t v = Free (LinProg' t v)
+
+obj a = liftF (Objective a ())
+eq a b = liftF (EqConstraint a b ())
+leq a b = liftF (LeqConstraint a b ())
+geq b a = liftF (LeqConstraint a b ())
+
+a =: b = eq a b
+a <: b = leq a b
+a >: b = geq a b