diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 08ac0df3597c014a3552280bff295c9346364bac..7420a34ae3d575043d7bbf9a61700e85bb6a65b9 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -329,6 +329,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): reset_meters=reset_meters, ) + if reset_dataloader and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0: + logger.info( + f"Disregarding --reset-dataloader since we are continuing past a requeue" + ) + reset_dataloader = False if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"]