Rieoptax Riemannian Optimization in JAX

2025-04-15 0 0 804.44KB 16 页 10玖币
侵权投诉
Rieoptax: Riemannian Optimization in JAX
Saiteja UtpalaAndi HanPratik JawanpuriaBamdev Mishra
Abstract
We present Rieoptax, an open source Python library for Riemannian optimization in JAX.
We show that many differential geometric primitives, such as Riemannian exponential and
logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU
and GPU. We support various range of basic and advanced stochastic optimization solvers like
Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods.
A distinguishing feature of the proposed toolbox is that we also support differentially private
optimization on Riemannian manifolds.
1 Introduction
Riemannian geometry is a generalization of the Euclidean geometry [
74
,
55
] to general Riemannian
manifolds. It includes several nonlinear spaces such as the set of positive definite matrices [
19
,
105
],
Grassmann manifold of subspaces [
36
,
16
,
5
], Stiefel manifold of orthogonal matrices [
36
,
5
,
27
],
Kendall shape spaces [
66
,
67
,
80
], hyperbolic spaces [
108
,
107
], and special Euclidean and orthogonal
group [98, 42, 101], to name a few.
Optimization with manifold based constraints has become increasingly popular and has been
employed in various applications such as low rank matrix completion [
22
], learning taxonomy
embeddings [
85
,
86
], neural networks [
60
,
61
,
62
,
43
,
84
,
90
], density estimation [
57
,
52
], optimal
transport [
29
,
9
,
99
,
82
,
51
], shape analysis [
103
,
59
], and topological dimension reduction [
63
],
among others.
In addition, privacy preserving machine learning [
34
,
33
,
32
,
28
,
102
,
2
,
81
] has become crucial
in real applications, which has been generalized to manifold-constrained problems very recently
[
92
,
109
,
49
]. Nevertheless, such a feature is absent in existing Riemannian optimization libraries
[23, 17, 78, 70, 100, 106, 79].
In this work, we introduce Rieoptax (
Rie
mannian
Opt
imization in J
ax
), an open source Python
library for Riemannian optimization in JAX [
41
,
24
]. The proposed library is mainly driven by the
needs of efficient implementation of manifold-valued operations and optimization solvers, readily
compatible with GPU and even TPU processors as well as the needs of privacy-supported Riemannian
optimization. To the best of our knowledge, Rieoptax is the first library to provide privacy guarantees
within the Riemannian optimization framework.
Independent (saitejautpala@gmail.com).
University of Sydney (andi.han@sydney.edu.au).
Microsoft India (pratik.jawanpuria@microsoft.com,bamdevm@microsoft.com).
1
arXiv:2210.04840v1 [math.OC] 10 Oct 2022
1.1 Background on Riemannian optimization, privacy, and JAX
Riemannian optimization. Riemannian optimization [5, 21] considers the following problem
min
w∈M f(w),(1)
where
f
:
M → R
, and
M
denotes a Riemannian manifold. Instead of considering
(1)
as a
constrained problem, Riemannian optimization [
5
,
21
] views it as an unconstrained problem on the
manifold space. Riemannian (stochastic) gradient descent [
112
,
20
] generalizes the Euclidean gradient
descent with intrinsic updates on manifold, i.e.,
wt+1
=
Expwt
(
ηtgradf
(
wt
)), where
gradf
(
wt
)is
the Riemannian (stochastic) gradient,
Expw
(
·
)is the Riemannian exponential map at
w
and
ηt
is
the step size. Recent years have witnessed significant advancements for Riemannian optimization
where more advanced solvers are generalized from the Euclidean space to Riemannian manifolds.
These include variance reduction methods [
111
,
96
,
65
,
114
,
48
,
47
], adaptive gradient methods
[
15
,
64
], accelerated gradient methods [
50
,
76
,
7
,
113
,
8
], quasi-Newton methods [
58
,
89
], zeroth-order
methods [
75
] and second order methods, such as trust region methods [
4
] and cubic regularized
Newton’s methods [6].
Differential privacy on Riemannian manifolds.
Differential privacy (DP) provides a rigorous
treatment for data privacy by precisely quantifying the deviation in the model’s output distribution
under modification of a small number of data points [
34
,
33
,
32
,
35
]. Provable guarantees of DP
coupled with properties like immunity to arbitrary post-processing and graceful composability have
made it a de-facto standard of privacy with steadfast adoption in the real applications [
37
,
10
,
31
,
83
,
3
].
Further, it has been shown empirically that DP models resist various kinds of leakage attacks that
can cause privacy violations [91, 26, 95, 115, 13].
Recently, there is a surge of interest on differential privacy over Riemannian manifolds, which
has been explored in the context of Fréchet mean [
39
] computation [
92
,
109
] and, more generally,
empirical risk minimization problems where the parameters are constrained to lie on a Riemannian
manifold [49].
JAX and its ecosystem.
JAX [
41
,
24
] is recently introduced machine learning framework which
support automatic differentiation capabilities [
14
] via
grad()
. Further some of the distinguishing
features of JAX are just-in-time (JIT) compilation using the accelerated linear algebra (XLA)
compiler [
46
] via
jit()
, automatic vectorization (batch-level parallelism) support with
vmap()
, and
strong support for parallel computation via
pmap()
. All the above transformations can be composed
arbitrarily because JAX follows the functional programming paradigm and implements these as pure
functions.
Given that JAX has many interesting features, its ecosystem has been constantly expanding in the
last couple of years. Examples include neural network modules (Flax [
54
], Haiku [
56
], Equinox [
69
],
Jraph [
44
], Equivariant-MLP [
38
]), reinforcement learning agents (Rlax [
12
]), Euclidean optimization
algorithms (Optax [
12
]), federated learning (Fedjax [
93
]), optimal transport toolboxes (Ott [
30
]),
sampling algorithms (Blackjax [
71
]), differential equation solvers (Diffrax [
68
]), rigid body simulators
(Brax [40]), and differentiable physics (Jax-md [97]), among others.
2
1.2 Rieoptax
We believe that the proposed framework for Riemannian optimization in JAX is a timely contribution
that brings several benefits of JAX and new features (such as privacy support) to the manifold
optimization community discussed below.
Automatic and efficient vectorization with vmap().
Functions that are written for inputs
of size 1can be converted to functions that take batch of inputs by wrapping it with
vmap()
.
For example, the function
def dist(point_a, point_b)
for computing distance between a
single
point_a
and a single
point_b
can be converted to function that computes distance
between a batch of
point_a
and/or a batch
point_b
by wrapping
dist
with
vmap()
without
modifying the
dist()
function. This is useful in many cases, e.g., Fréchet mean computation
minw∈M 1
nPn
i=1 fi(w) := 1
nPn
i=1 dist2(w, zi)
. Furthermore, vectorization with
vmap()
is
usually faster or on par with manual vectorization [24].
Per-example gradient clipping.
A key process in differentially private optimization is
per-example gradient clipping
1
nPn
i=1 clipτ
(
gradfi
(
w
)) , where
clipτ
ensures norm is atmost
τ
.
Here, the order of operations is important: the gradients are first clipped and then averaged.
Popular libraries including Autograd [
77
], Pytorch [
87
] and Tensorflow [
1
] are heavily optimized
to directly compute the mean gradient
1
nPn
i=1 gradfi
(
w
)and hence do not expose per-example
gradients i.e.,
gradfi
(
w
)
.
Hence, one has to resort to ad-hoc techniques [
45
,
94
,
73
] or come up
with algorithmic modifications [
25
] which inherently have speed versus performance trade-off.
JAX, however, offers native support for handling such scenarios and JAX-based differentially
private Euclidean optimization methods have been shown to be much faster than their non-
JAX counterparts [
104
]. We observe that JAX offer similar benefits for differentially private
Riemannian optimization as well.
Single Source Multiple Devices (SSMD) paradigm.
JAX follows the SSMD paradigm,
and hence, the code written for CPUs can be run on GPU/TPUs without any additional
modification.
Rieoptax is available at https://github.com/SaitejaUtpala/Rieoptax/.
2 Design and Implementation overview
The package currently implements several commonly used geometries, optimization algorithms and
differentially private mechanisms on manifolds. More geometries and advanced solvers will be added
in the future.
2.1 Core
rieoptax.core.ManifoldArray:
lightweight wrapper of the
jax
device array with
manifold
attribute and used to model array constrained to manifold. It is registered as
Pytree
to ensure
compatibility jax primitives like grad() and vmap().
rieoptax.core.rgrad: Riemannian gradient operator.
3
2.2 Geometries
Geometry module contains manifolds equipped with different Riemannian metrics. Each Geom-
etry contains Riemannian inner product
inp()
, induced norm
norm()
, Riemannian exponential
exp()
, logarithm maps
log()
, induced Riemannian distance
dist()
, parallel transport
pt()
, and
transformation from the Euclidean gradient to Riemannian gradient egrad_to_rgrad().
Manifolds include symmetric positive definite (SPD) matrices
SPD
(
m
) :=
{XRm×m
:
X
=
X>,X
0
}
, hyperbolic space, Grassmann manifold
G
(
m, r
) :=
{
[
X
] :
XRm×r,X>X
=
I}
where
[
X
] :=
{XO
:
OO
(
r
)
}
,
O
(
r
)denotes the orthogonal group and hypersphere
S
(
d
) :=
{xRd
:
x>x
= 1
}
. We use
TxM
to represent the tangent space at
x
and
hu, vix
to represent the Riemannian
inner product. For more detailed treatment on these geometries, we refer to [5, 21, 108].
rieoptax.geometry.spd.SPDAffineInvariant:
SPD matrices with the affine-invariant met-
ric [88]: SPD(m)with hU,ViX= tr(X1UX1V)for U,VTXSPD(m).
rieoptax.geometry.spd.SPDLogEuclidean:
SPD matrices with the Log-Euclidean metric
[
11
]: SPD(
m
)with
hU,ViX
=
tr
D
Ulogm
(
X
)D
Vlogm
(
X
)
where D
Ulogm
(
X
)is the direc-
tional derivative of matrix logarithm at Xalong U.
rieoptax.geometry.hyperbolic.PoincareBall:
the Poincare-ball model of Hyperbolic space
with Poincare metric [
108
], i.e.,
D
(
d
) :=
{xRd
:
x>x<
1
}
with
hu,vix
= 4
u>v/
(1
x>x
)
2
for u,vTxD(d).
rieoptax.geometry.hyperbolic.LorentzHyperboloid:
the Lorentz Hyperboloid model of
Hyperbolic space [
108
], i.e.,
H
(
d
) =
{xRd
:
hx,xiL
=
1
}
with
hu,vix
=
hu,viL
for
u,vTxH(d), where hu,viL:= u0v0+u1v1+· · · ud1vd1.
rieoptax.geometry.grassmann.GrassmannCanonicalMetric:
the Grassmann manifold with
the canonical metric [36], i.e., G(m, r)with hU,ViX= trUTVfor U,VTXG(m, r).
rieoptax.geometry.hypersphere.HypersphereCanonicalMetric:
the hypersphere mani-
fold which canonical metric [5, 21], i.e., S(d)with hu,vix=u>vfor u,vTxS(d).
2.3 Optimizers
Optimizers module contains Riemannian optimization algorithms. Design of optimizers follows
Optax [
12
], which implements every optimizer by chaining of few common transformations. Where
every optimizer
riepotax.optimizers.first_order.rsgd: Riemannian stochastic gradient descent [20].
riepotax.optimizers.first_order.rsvrg:
Riemannian stochastic variance reduced gradi-
ent descent [111].
riepotax.optimizers.first_order.rsrg:
Riemannian stochastic recursive gradient descent
[65].
riepotax.optimizers.first_order.rasa:
Riemannian adaptive stochastic gradient algo-
rithm [64].
riepotax.optimizers.zeroth_order.zo_rgd:
zeroth-order Riemannian gradient descent
[75].
4
摘要:

Rieoptax:RiemannianOptimizationinJAXSaitejaUtpala*AndiHan„PratikJawanpuria…BamdevMishra…AbstractWepresentRieoptax,anopensourcePythonlibraryforRiemannianoptimizationinJAX.Weshowthatmanydierentialgeometricprimitives,suchasRiemannianexponentialandlogarithmmaps,areusuallyfasterinRieoptaxthanexistingfra...

展开>> 收起<<
Rieoptax Riemannian Optimization in JAX.pdf

共16页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:学术论文 价格:10玖币 属性:16 页 大小:804.44KB 格式:PDF 时间:2025-04-15

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 16
客服
关注