Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions
Pretrained language models have shown superior performance on many natural language processing tasks, yet they still struggle at multi-step formal reasoning tasks like grade school math problems. One key challenge of finetuning them to solve such math reasoning problems is that many existing dataset...
Gespeichert in:
Hauptverfasser: | , , , , , , |
---|---|
Format: | Artikel |
Sprache: | eng |
Schlagworte: | |
Online-Zugang: | Volltext bestellen |
Tags: |
Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
|
Zusammenfassung: | Pretrained language models have shown superior performance on many natural
language processing tasks, yet they still struggle at multi-step formal
reasoning tasks like grade school math problems. One key challenge of
finetuning them to solve such math reasoning problems is that many existing
datasets only contain one reference solution for each problem, despite the fact
that there are often alternative solutions resembling different reasoning paths
to the final answer. This way, the finetuned models are biased towards the
limited reference solutions, which limits their generalization to unseen
examples. To mitigate this issue, we propose to let the model perform sampling
during training and learn from both self-sampled fully-correct solutions, which
yield the correct answer upon execution, and partially-correct solutions, whose
intermediate state matches an intermediate state of a known correct solution.
We show that our use of self-sampled correct and partially-correct solutions
can benefit learning and help guide the sampling process, leading to more
efficient exploration of the solution space. Additionally, we explore various
training objectives to support learning from multiple solutions per example and
find they greatly affect the performance. Experiments on two math reasoning
datasets show the effectiveness of our method compared to learning from a
single reference solution with MLE, where we improve PASS@100 from 35.5% to
44.5% for GSM8K, and 27.6% to 36.2% PASS@80 for MathQA. Such improvements are
also consistent across different model sizes. Our code is available at
https://github.com/microsoft/TraceCodegen. |
---|---|
DOI: | 10.48550/arxiv.2205.14318 |