1.1 Background on Riemannian optimization, privacy, and JAX
Riemannian optimization. Riemannian optimization [5, 21] considers the following problem
min
w∈M f(w),(1)
where
f
:
M → R
, and
M
denotes a Riemannian manifold. Instead of considering
(1)
as a
constrained problem, Riemannian optimization [
5
,
21
] views it as an unconstrained problem on the
manifold space. Riemannian (stochastic) gradient descent [
112
,
20
] generalizes the Euclidean gradient
descent with intrinsic updates on manifold, i.e.,
wt+1
=
Expwt
(
−ηtgradf
(
wt
)), where
gradf
(
wt
)is
the Riemannian (stochastic) gradient,
Expw
(
·
)is the Riemannian exponential map at
w
and
ηt
is
the step size. Recent years have witnessed significant advancements for Riemannian optimization
where more advanced solvers are generalized from the Euclidean space to Riemannian manifolds.
These include variance reduction methods [
111
,
96
,
65
,
114
,
48
,
47
], adaptive gradient methods
[
15
,
64
], accelerated gradient methods [
50
,
76
,
7
,
113
,
8
], quasi-Newton methods [
58
,
89
], zeroth-order
methods [
75
] and second order methods, such as trust region methods [
4
] and cubic regularized
Newton’s methods [6].
Differential privacy on Riemannian manifolds.
Differential privacy (DP) provides a rigorous
treatment for data privacy by precisely quantifying the deviation in the model’s output distribution
under modification of a small number of data points [
34
,
33
,
32
,
35
]. Provable guarantees of DP
coupled with properties like immunity to arbitrary post-processing and graceful composability have
made it a de-facto standard of privacy with steadfast adoption in the real applications [
37
,
10
,
31
,
83
,
3
].
Further, it has been shown empirically that DP models resist various kinds of leakage attacks that
can cause privacy violations [91, 26, 95, 115, 13].
Recently, there is a surge of interest on differential privacy over Riemannian manifolds, which
has been explored in the context of Fréchet mean [
39
] computation [
92
,
109
] and, more generally,
empirical risk minimization problems where the parameters are constrained to lie on a Riemannian
manifold [49].
JAX and its ecosystem.
JAX [
41
,
24
] is recently introduced machine learning framework which
support automatic differentiation capabilities [
14
] via
grad()
. Further some of the distinguishing
features of JAX are just-in-time (JIT) compilation using the accelerated linear algebra (XLA)
compiler [
46
] via
jit()
, automatic vectorization (batch-level parallelism) support with
vmap()
, and
strong support for parallel computation via
pmap()
. All the above transformations can be composed
arbitrarily because JAX follows the functional programming paradigm and implements these as pure
functions.
Given that JAX has many interesting features, its ecosystem has been constantly expanding in the
last couple of years. Examples include neural network modules (Flax [
54
], Haiku [
56
], Equinox [
69
],
Jraph [
44
], Equivariant-MLP [
38
]), reinforcement learning agents (Rlax [
12
]), Euclidean optimization
algorithms (Optax [
12
]), federated learning (Fedjax [
93
]), optimal transport toolboxes (Ott [
30
]),
sampling algorithms (Blackjax [
71
]), differential equation solvers (Diffrax [
68
]), rigid body simulators
(Brax [40]), and differentiable physics (Jax-md [97]), among others.
2