From aeafc1692afa5952f3d06195916bb38463c1674c Mon Sep 17 00:00:00 2001 From: Justin Bedo Date: Sat, 1 Nov 2014 10:18:31 +1100 Subject: Benchmarking & optimisations --- Math/LinProg/LP.hs | 9 +++++---- Math/LinProg/LPSolve.hs | 27 +++++++++++++++------------ Math/LinProg/LPSolve/FFI.hs | 13 ++++++++++++- Math/LinProg/LPSolve/bindings.c | 12 ++++++++++++ Math/LinProg/Types.hs | 22 ++++++++++++---------- 5 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 Math/LinProg/LPSolve/bindings.c (limited to 'Math/LinProg') diff --git a/Math/LinProg/LP.hs b/Math/LinProg/LP.hs index d1a1cf1..513bdad 100644 --- a/Math/LinProg/LP.hs +++ b/Math/LinProg/LP.hs @@ -24,11 +24,12 @@ module Math.LinProg.LP ( ,bins ) where -import Data.List -import Math.LinProg.Types import Control.Lens -import Data.Maybe import Control.Monad.Free +import Data.Hashable +import Data.List +import Data.Maybe +import Math.LinProg.Types type Equation t v = (LinExpr t v, t) -- LHS and RHS @@ -88,7 +89,7 @@ instance (Show t, Num t, Ord t) => Show (CompilerS t String) where render x = (if x >= 0 then "+" else "") ++ show x -findBounds :: (Eq v, Num t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v]) +findBounds :: (Hashable v, Eq v, Num t, Ord t, Eq t) => [Equation t v] -> ([(t, v, t)], [Equation t v]) findBounds eqs = (mapMaybe bound singleTerms, eqs \\ filter (isBounded . head . vars . fst) singleTermEqs) where singleTermEqs = filter (\(ts, _) -> length (vars ts) == 1) eqs diff --git a/Math/LinProg/LPSolve.hs b/Math/LinProg/LPSolve.hs index 427c5d7..4e8385a 100644 --- a/Math/LinProg/LPSolve.hs +++ b/Math/LinProg/LPSolve.hs @@ -19,16 +19,18 @@ module Math.LinProg.LPSolve ( ) where import Control.Applicative -import Control.Monad -import Data.List +import Control.Arrow import Control.Lens -import Math.LinProg.LPSolve.FFI hiding (solve) -import qualified Math.LinProg.LPSolve.FFI as F +import Control.Monad +import Data.Hashable +import Data.List hiding (nub) import Math.LinProg.LP +import Math.LinProg.LPSolve.FFI hiding (solve) import Math.LinProg.Types +import Prelude hiding (EQ, nub) import qualified Data.HashMap.Strict as M -import Data.Hashable -import Prelude hiding (EQ) +import qualified Data.HashSet as S +import qualified Math.LinProg.LPSolve.FFI as F solve :: (Hashable v, Eq v, Ord v) => LinProg Double v () -> IO (Maybe ResultCode, [(v, Double)]) solve = solveWithTimeout 0 @@ -47,8 +49,7 @@ solveWithTimeout t (compile -> lp) = do let c = negate $ snd eq setConstrType m i EQ setRHS m i c - forM_ (varTerms (fst eq)) $ \(v, w) -> - setMat m i (varLUT M.! v) w + setRow' m i (fst eq) return () -- Leqs @@ -56,8 +57,7 @@ solveWithTimeout t (compile -> lp) = do let c = negate $ snd eq setConstrType m i LE setRHS m i c - forM_ (varTerms (fst eq)) $ \(v, w) -> - setMat m i (varLUT M.! v) w + setRow' m i (fst eq) return () -- Ints @@ -69,8 +69,7 @@ solveWithTimeout t (compile -> lp) = do setBin m (varLUT M.! v) -- Objective - forM_ (varTerms (lp ^. objective)) $ \(v, w) -> - void $ setMat m 0 (varLUT M.! v) w + setRow' m 0 (lp ^. objective) res <- F.solve m sol <- snd <$> getSol nvars m @@ -89,3 +88,7 @@ solveWithTimeout t (compile -> lp) = do r <- f m freeLP m return r + + nub = S.toList . S.fromList + + setRow' m i eq = setRow m i (map (first ((M.!) varLUT)) $ varTerms eq) diff --git a/Math/LinProg/LPSolve/FFI.hs b/Math/LinProg/LPSolve/FFI.hs index 64919b9..edf0a43 100644 --- a/Math/LinProg/LPSolve/FFI.hs +++ b/Math/LinProg/LPSolve/FFI.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE ForeignFunctionInterface, ViewPatterns #-} module Math.LinProg.LPSolve.FFI ( ResultCode(..) ,ConstraintType(..) @@ -9,6 +9,7 @@ module Math.LinProg.LPSolve.FFI ( ,setBin ,makeLP ,freeLP + ,setRow ,setMat ,setRHS ,solve @@ -58,6 +59,7 @@ foreign import ccall "set_timeout" c_set_timeout :: LPRec -> CLong -> IO () foreign import ccall "set_int" c_set_int :: LPRec -> CInt -> CChar -> IO CChar foreign import ccall "set_binary" c_set_binary :: LPRec -> CInt -> CChar -> IO CChar foreign import ccall "print_debugdump" c_print_debugdump :: LPRec -> CString -> IO () +foreign import ccall "hs_set_row" c_hs_set_row :: LPRec -> CInt -> CInt -> Ptr CInt -> Ptr CDouble -> IO CChar debugDump :: LPRec -> FilePath -> IO () debugDump lp path = withCString path $ \str -> c_print_debugdump lp str @@ -82,6 +84,15 @@ freeLP m = with m $ \m' -> c_free_lp m' setMat :: LPRec -> Int -> Int -> Double -> IO Word8 setMat a b c d = fromIntegral <$> c_set_mat a (fromIntegral b) (fromIntegral c) (realToFrac d) +setRow :: LPRec -> Int -> [(Int, Double)] -> IO Word8 +setRow m row (unzip -> (cols, ws)) = fmap fromIntegral $ withArray (map fromIntegral cols) $ \c -> + withArray (map realToFrac ws) $ \w -> + c_hs_set_row m + (fromIntegral row) + (fromIntegral (length cols)) + c + w + setRHS :: LPRec -> Int -> Double -> IO Word8 setRHS a b c = fromIntegral <$> c_set_rh a (fromIntegral b) (realToFrac c) diff --git a/Math/LinProg/LPSolve/bindings.c b/Math/LinProg/LPSolve/bindings.c new file mode 100644 index 0000000..fbaf09f --- /dev/null +++ b/Math/LinProg/LPSolve/bindings.c @@ -0,0 +1,12 @@ +#include + +char +hs_set_row(void *model, int row, int n, int *cols, double *ws) +{ + int i; + + for(i = 0; i < n; i++) + set_mat(model, row, cols[i], ws[i]); + + return 0; +} diff --git a/Math/LinProg/Types.hs b/Math/LinProg/Types.hs index 4819dd3..d67e642 100644 --- a/Math/LinProg/Types.hs +++ b/Math/LinProg/Types.hs @@ -30,13 +30,14 @@ module Math.LinProg.Types ( ,int ) where -import Data.Functor.Foldable +import Control.Applicative import Control.Monad.Free +import Data.Functor.Foldable +import Data.Hashable +import Data.List import qualified Data.HashMap.Strict as M +import qualified Data.HashSet as S import Test.QuickCheck -import Control.Applicative -import Data.List -import Data.Hashable -- | Base AST for expressions. Expressions have factors or type t and -- variables referenced by ids of type v. @@ -92,13 +93,14 @@ getVar id x = cata getVar' x - consts x where getVar' (Negate a) = negate a -- | Gets all variables used in an equation. -vars :: Eq v => LinExpr t v -> [v] -vars = nub . cata vars' where - vars' (Var x) = [x] - vars' (Add a b) = a ++ b - vars' (Mul a b) = a ++ b +vars :: (Hashable v, Eq v) => LinExpr t v -> [v] +vars = S.toList . cata vars' where + vars' (Wvar _ x) = S.fromList [x] + vars' (Var x) = S.fromList [x] + vars' (Add a b) = S.union a b + vars' (Mul a b) = S.union a b vars' (Negate a) = a - vars' _ = [] + vars' _ = S.empty -- | Expands terms to Wvars but does not collect like terms rewrite :: (Eq t, Num t) => LinExpr t v -> LinExpr t v -- cgit v1.2.3