Abstract:We introduce MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating mathematical programming into machine learning workflows. MPAX implemented firstorder methods in JAX, providing native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Currently in beta version, MPAX supports linear programming and will be extended to solve more general mathematical programming problems and specialized modules for common machine learning tasks. The solver is available at https://github.com/MIT-Lu-Lab/MPAX.
Abstract:There is a recent interest on first-order methods for linear programming (LP). In this paper, we propose a stochastic algorithm using variance reduction and restarts for solving sharp primal-dual problems such as LP. We show that the proposed stochastic method exhibits a linear convergence rate for sharp instances with a high probability, which improves the complexity of the existing deterministic and stochastic algorithms. In addition, we propose an efficient coordinate-based stochastic oracle for unconstrained bilinear problems, which has $\mathcal O(1)$ per iteration cost and improves the total flop counts to reach a certain accuracy.