mushi.optimization.ThreeOpProxGrad

class ThreeOpProxGrad(g, grad, h1, prox1, h2, prox2, verbose=False, **line_search_kwargs)[source]

Bases: mushi.optimization.AccProxGrad

Three operator splitting proximal gradient method with backtracking line search 2.

The optimization problem solved is:

\[\arg\min_x g(x) + h_1(x) + h_2(x)\]

where \(g\) is differentiable, and the proximal operators for \(h_1\) and \(h_2\) are available.

Parameters

References

2

Pedregosa, Gidel, Adaptive Three Operator Splitting in Proceedings of the 35th International Conference on Machine Learning, Proceedings of Machine Learning Research., J. Dy, A. Krause, Eds. (PMLR, 2018), pp. 4085–4094.

Examples

Usage is very similar to mushi.optimization.AccProxGrad(), except that two non-smooth terms (and their associated proximal operators) may be specified.

>>> import mushi.optimization as opt
>>> import numpy as np

We’ll use a squared loss term, a Lasso term and a box constraint for this example.

Define \(g(x)\) and \(\boldsymbol\nabla g(x)\):

>>> def g(x):
...     return 0.5 * np.sum(x ** 2)
>>> def grad(x):
...     return x

Define \(h_1(x)\) and \(\mathrm{prox}_{h_1}(u)\). We will use a Lasso term and the corresponding soft thresholding operator:

>>> def h1(x):
...     return np.linalg.norm(x, 1)
>>> def prox1(u, s):
...     return np.sign(u) * np.clip(np.abs(u) - s, 0, None)

Define \(h_2(x)\) and \(\mathrm{prox}_{h_2}(u)\). We use a simple box constraint on one dimension, although note that this is quite artificial, since such constraints don’t require operator splitting.

>>> def h2(x):
...     if x[0] < 1:
...         return np.inf
...     return 0
>>> def prox2(u, s):
...     return np.clip(u, np.array([1, -np.inf]), None)

Initialize optimizer and define initial point

>>> threeop = opt.ThreeOpProxGrad(g, grad, h1, prox1, h2, prox2)
>>> x = np.zeros(2)

Run optimization

>>> threeop.run(x)
array([1., 0.])

Evaluate cost at the solution point

>>> threeop.f()
1.5

Methods

f

Evaluate cost function at current solution point.

run

Optimize until convergence criteria are met.

f()[source]

Evaluate cost function at current solution point.

run(x, tol=1e-06, max_iter=100)

Optimize until convergence criteria are met.

Parameters
  • x (ndarray) – initial point

  • tol (float64) – relative tolerance in objective function

  • max_iter (int) – maximum number of iterations

Returns

solution point

Return type

x