Betty: An Automatic Differentiation Library for Multilevel Optimization
Gradient-based multilevel optimization (MLO) has gained attention as a framework for studying numerous problems, ranging from hyperparameter optimization and meta-learning to neural architecture search and reinforcement learning. However, gradients in MLO, which are obtained by composing best-respon...
Gespeichert in:
Hauptverfasser: | , , , |
---|---|
Format: | Artikel |
Sprache: | eng |
Schlagworte: | |
Online-Zugang: | Volltext bestellen |
Tags: |
Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
|
Zusammenfassung: | Gradient-based multilevel optimization (MLO) has gained attention as a
framework for studying numerous problems, ranging from hyperparameter
optimization and meta-learning to neural architecture search and reinforcement
learning. However, gradients in MLO, which are obtained by composing
best-response Jacobians via the chain rule, are notoriously difficult to
implement and memory/compute intensive. We take an initial step towards closing
this gap by introducing Betty, a software library for large-scale MLO. At its
core, we devise a novel dataflow graph for MLO, which allows us to (1) develop
efficient automatic differentiation for MLO that reduces the computational
complexity from O(d^3) to O(d^2), (2) incorporate systems support such as
mixed-precision and data-parallel training for scalability, and (3) facilitate
implementation of MLO programs of arbitrary complexity while allowing a modular
interface for diverse algorithmic and systems design choices. We empirically
demonstrate that Betty can be used to implement an array of MLO programs, while
also observing up to 11% increase in test accuracy, 14% decrease in GPU memory
usage, and 20% decrease in training wall time over existing implementations on
multiple benchmarks. We also showcase that Betty enables scaling MLO to models
with hundreds of millions of parameters. We open-source the code at
https://github.com/leopard-ai/betty. |
---|---|
DOI: | 10.48550/arxiv.2207.02849 |