aboutsummaryrefslogtreecommitdiff
path: root/Math/LinProg
diff options
context:
space:
mode:
Diffstat (limited to 'Math/LinProg')
-rw-r--r--Math/LinProg/LP.hs14
-rw-r--r--Math/LinProg/LPSolve.hs15
-rw-r--r--Math/LinProg/LPSolve/FFI.hs10
-rw-r--r--Math/LinProg/Types.hs17
4 files changed, 50 insertions, 6 deletions
diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs
index 18d2068..d1a1cf1 100644
--- a/Math/LinProg/LP.hs
+++ b/Math/LinProg/LP.hs
@@ -20,6 +20,8 @@ module Math.LinProg.LP (
,objective
,equals
,leqs
+ ,ints
+ ,bins
) where
import Data.List
@@ -35,6 +37,8 @@ data CompilerS t v = CompilerS {
_objective :: LinExpr t v
,_equals :: [Equation t v]
,_leqs :: [Equation t v]
+ ,_bins :: [v]
+ ,_ints :: [v]
} deriving (Eq)
$(makeLenses ''CompilerS)
@@ -45,12 +49,16 @@ compile ast = compile' ast initCompilerS where
compile' (Free (Objective a c)) state = compile' c $ state & objective +~ a
compile' (Free (EqConstraint a b c)) state = compile' c $ state & equals %~ (split (a-b):)
compile' (Free (LeqConstraint a b c)) state = compile' c $ state & leqs %~ (split (a-b):)
+ compile' (Free (Integer a b)) state = compile' b $ state & ints %~ (a:)
+ compile' (Free (Binary a b)) state = compile' b $ state & bins %~ (a:)
compile' _ state = state
initCompilerS = CompilerS
0
[]
[]
+ []
+ []
-- | Shows a compiled state as LP format. Requires variable ids are strings.
instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
@@ -62,6 +70,10 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
,if hasUnbounded then Just (intercalate "\n" $ map (\(a, b) -> showEq a ++ " <= " ++ show (negate b)) unbounded) else Nothing
,if hasBounded then Just "Bounds" else Nothing
,if hasBounded then Just (intercalate "\n" $ map (\(l, v, u) -> show l ++ " <= " ++ v ++ " <= " ++ show u) bounded) else Nothing
+ ,if hasInts then Just "General" else Nothing
+ ,if hasInts then Just (unwords $ s ^. ints) else Nothing
+ ,if hasBins then Just "Binary" else Nothing
+ ,if hasBins then Just (unwords $ s ^. bins) else Nothing
]
where
showEq = unwords . map (\(a, b) -> render b ++ " " ++ a) . varTerms
@@ -71,6 +83,8 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where
hasUnbounded = not (null unbounded)
hasEqs = not (null (s^.equals))
hasST = hasUnbounded || hasEqs
+ hasInts = not . null $ s ^. ints
+ hasBins = not . null $ s ^. bins
render x = (if x >= 0 then "+" else "") ++ show x
diff --git a/Math/LinProg/LPSolve.hs b/Math/LinProg/LPSolve.hs
index baa5d7e..5299d94 100644
--- a/Math/LinProg/LPSolve.hs
+++ b/Math/LinProg/LPSolve.hs
@@ -26,14 +26,15 @@ import Math.LinProg.LPSolve.FFI hiding (solve)
import qualified Math.LinProg.LPSolve.FFI as F
import Math.LinProg.LP
import Math.LinProg.Types
-import qualified Data.Map.Strict as M
+import qualified Data.HashMap.Strict as M
+import Data.Hashable
import Prelude hiding (EQ)
-solve :: (Eq v, Ord v) => LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)])
+solve :: (Hashable v, Eq v, Ord v) => LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)])
solve = solveWithTimeout 0
-- | Solves an LP using lp_solve.
-solveWithTimeout :: (Eq v, Ord v) => Integer -> LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)])
+solveWithTimeout :: (Hashable v, Eq v, Ord v) => Integer -> LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)])
solveWithTimeout t (compile -> lp) = do
model <- makeLP nconstr nvars
case model of
@@ -59,6 +60,14 @@ solveWithTimeout t (compile -> lp) = do
setRHS m i c
return ()
+ -- Ints
+ forM_ (lp ^. ints) $ \v -> do
+ setInt m (varLUT M.! v)
+
+ -- Bins
+ forM_ (lp ^. bins) $ \v -> do
+ setBin m (varLUT M.! v)
+
-- Objective
forM_ (varTerms (lp ^. objective)) $ \(v, w) -> do
void $ setMat m 0 (varLUT M.! v) w
diff --git a/Math/LinProg/LPSolve/FFI.hs b/Math/LinProg/LPSolve/FFI.hs
index ddc7798..ff0bc16 100644
--- a/Math/LinProg/LPSolve/FFI.hs
+++ b/Math/LinProg/LPSolve/FFI.hs
@@ -5,6 +5,8 @@ module Math.LinProg.LPSolve.FFI (
,LPRec
,setConstrType
,setTimeout
+ ,setInt
+ ,setBin
,makeLP
,freeLP
,setMat
@@ -52,6 +54,8 @@ foreign import ccall "solve" c_solve :: LPRec -> IO CInt
foreign import ccall "get_variables" c_get_variables :: LPRec -> Ptr CDouble -> IO CChar
foreign import ccall "set_constr_type" c_set_constr_type :: LPRec -> CInt -> CInt -> IO CChar
foreign import ccall "set_timeout" c_set_timeout :: LPRec -> CLong -> IO ()
+foreign import ccall "set_int" c_set_int :: LPRec -> CInt -> CChar -> IO CChar
+foreign import ccall "set_binary" c_set_binary :: LPRec -> CInt -> CChar -> IO CChar
setTimeout :: LPRec -> Integer -> IO ()
setTimeout lp x = c_set_timeout lp (fromIntegral x)
@@ -76,6 +80,12 @@ setMat a b c d = fromIntegral <$> c_set_mat a (fromIntegral b) (fromIntegral c)
setRHS :: LPRec -> Int -> Double -> IO Word8
setRHS a b c = fromIntegral <$> c_set_rh a (fromIntegral b) (realToFrac c)
+setInt :: LPRec -> Int -> IO Word8
+setInt m a = fromIntegral <$> c_set_int m (fromIntegral a) 1
+
+setBin :: LPRec -> Int -> IO Word8
+setBin m a = fromIntegral <$> c_set_binary m (fromIntegral a) 1
+
solve :: LPRec -> IO ResultCode
solve lp = (lut M.!) . fromIntegral <$> c_solve lp
where
diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs
index 2a81918..4819dd3 100644
--- a/Math/LinProg/Types.hs
+++ b/Math/LinProg/Types.hs
@@ -26,14 +26,17 @@ module Math.LinProg.Types (
,(<:)
,(=:)
,(>:)
+ ,bin
+ ,int
) where
import Data.Functor.Foldable
import Control.Monad.Free
-import qualified Data.Map.Strict as M
+import qualified Data.HashMap.Strict as M
import Test.QuickCheck
import Control.Applicative
import Data.List
+import Data.Hashable
-- | Base AST for expressions. Expressions have factors or type t and
-- variables referenced by ids of type v.
@@ -117,7 +120,7 @@ rewrite = cata rewrite' where
rewrite' a = Fix a
-- | Reduces an expression to the variable terms
-varTerms :: (Num t, Eq t, Ord v) => LinExpr t v -> [(v, t)]
+varTerms :: (Num t, Eq t, Hashable v, Eq v) => LinExpr t v -> [(v, t)]
varTerms = M.toList . cata go . rewrite where
go (Wvar w a) = M.fromList [(a, w)]
go (Add a b) = M.unionWith (+) a b
@@ -141,6 +144,8 @@ prettyPrint = cata prettyPrint' where
-- in the data type).
data LinProg' t v a =
Objective !(LinExpr t v) !a
+ | Integer !v !a
+ | Binary !v !a
| EqConstraint !(LinExpr t v) !(LinExpr t v) !a
| LeqConstraint !(LinExpr t v) !(LinExpr t v) !a
deriving (Show, Eq, Functor)
@@ -155,10 +160,16 @@ a =: b = liftF (EqConstraint a b ())
-- | Define an inequality (less than equal) contraint
a <: b = liftF (LeqConstraint a b ())
---
+
-- | Define an inequality (greater than equal) contraint
b >: a = liftF (LeqConstraint a b ())
+-- | Declare a variable to be binary
+bin (Fix (Var v)) = liftF (Binary v ())
+
+-- | Declare a variable to be integral
+int (Fix (Var v)) = liftF (Integer v ())
+
infix 4 =:
infix 4 <:
infix 4 >: