metaseq
cli
train.py +27 -4
dataclass
configs.py +7 -0
checkpoint_utils.py +10 -5
trainer.py +3 -3
+ 27
- 4
@@ -87,6 +87,16 @@ def main(cfg: DictConfig) -> None:
@@ -334,7 +344,11 @@ def train(
@@ -407,6 +421,17 @@ def validate_and_save(
@@ -414,9 +439,7 @@ def validate_and_save(
+ 7
- 0
@@ -505,6 +505,13 @@ class CheckpointConfig(MetaseqDataclass):
+ 10
- 5
@@ -63,11 +63,15 @@ def save_checkpoint(
@@ -82,6 +86,7 @@ def save_checkpoint(
@@ -90,7 +95,7 @@ def save_checkpoint(
+ 3
- 3
@@ -398,13 +398,13 @@ class Trainer(object):
Created by: suchenzang
Should we gate this behind save_to_NFS logic somewhere?
Created by: ngoyal2707
Similar to answer on Stephen's comment, I am currently passing async_callback as None if
save_to_NFS == False
Created by: stephenroller
why do we no longer need the None protection here? I'm going to be annoyed if we get log spam that just says Async save failed: NoneType error every 25 updates
Created by: ngoyal2707
because protection is now while submitting in the threadpool executor at line number 407. Earlier we'd have had log spams where even with non-async saving, we'd seen these logs, it should be better now.
The reason earlier code was written this way is becasue earlier to Zach's async saving fix, we also were doing torch.save inside the async function. Now thats outside, we dont need to submit anything to thread pool executor if async callback is None