Instance-Level Task Parameters: A Robust Multi-task Weighting Framework

Recent works have shown that deep neural networks benefit from multi-task learning by learning a shared representation across several related tasks. However, performance of such systems depend on relative weighting between various losses involved during training. Prior works on loss weighting scheme...

Ausführliche Beschreibung

Gespeichert in:
Bibliographische Detailangaben
Veröffentlicht in:arXiv.org 2021-06
Hauptverfasser: Pavan Kumar Anasosalu Vasu, Saxena, Shreyas, Oncel Tuzel
Format: Artikel
Sprache:eng
Schlagworte:
Online-Zugang:Volltext
Tags: Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
Beschreibung
Zusammenfassung:Recent works have shown that deep neural networks benefit from multi-task learning by learning a shared representation across several related tasks. However, performance of such systems depend on relative weighting between various losses involved during training. Prior works on loss weighting schemes assume that instances are equally easy or hard for all tasks. In order to break this assumption, we let the training process dictate the optimal weighting of tasks for every instance in the dataset. More specifically, we equip every instance in the dataset with a set of learnable parameters (instance-level task parameters) where the cardinality is equal to the number of tasks learned by the model. These parameters model the weighting of each task for an instance. They are updated by gradient descent and do not require hand-crafted rules. We conduct extensive experiments on SURREAL and CityScapes datasets, for human shape and pose estimation, depth estimation and semantic segmentation tasks. In these tasks, our approach outperforms recent dynamic loss weighting approaches, e.g. reducing surface estimation errors by 8.97% on SURREAL. When applied to datasets where one or more tasks can have noisy annotations, the proposed method learns to prioritize learning from clean labels for a given task, e.g. reducing surface estimation errors by up to 60%. We also show that we can reliably detect corrupt labels for a given task as a by-product from learned instance-level task parameters.
ISSN:2331-8422