Skip to content
GitLab
    • Explore Projects Groups Snippets
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !598

add separate config option save locally

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Administrator requested to merge ngoyal_add_separate_config_option_save_locally into bigbig 2 years ago
  • Overview 7
  • Commits 3
  • Pipelines 0
  • Changes 4

Created by: ngoyal2707

Compare
  • bigbig (base)

and
  • latest version
    233d3525
    3 commits, 2 years ago

4 files
+ 47
- 12

    Preferences

    File browser
    Compare changes
met‎aseq‎
c‎li‎
trai‎n.py‎ +27 -4
data‎class‎
confi‎gs.py‎ +7 -0
checkpoin‎t_utils.py‎ +10 -5
train‎er.py‎ +3 -3
metaseq/cli/train.py
+ 27
- 4
  • View file @ 233d3525

  • Edit in single-file editor

  • Open in Web IDE


@@ -87,6 +87,16 @@ def main(cfg: DictConfig) -> None:
), "Must specify batch size either with --max-tokens or --batch-size"
metrics.reset()
if cfg.checkpoint.local_save_interval_updates > 0:
assert (
cfg.checkpoint.save_interval_updates > 0
), "local save must be used with --save-interval-updates > 0"
assert (
cfg.checkpoint.save_interval_updates
% cfg.checkpoint.local_save_interval_updates
== 0
), "--save-interval-updates must be a multiple of --local-save-interval-updates"
if cfg.common.log_file is not None:
handler = logging.FileHandler(filename=cfg.common.log_file)
logger.addHandler(handler)
@@ -334,7 +344,11 @@ def train(
)
continue
if distributed_utils.get_global_rank() == 0 and cfg.common.profile and i == 5:
if (
distributed_utils.get_global_rank() == 0
and cfg.common.profile
and i == 5
):
logger.info("STARTING PROFILER")
with profiler.profile(
profile_memory=True, with_stack=True, record_shapes=True
@@ -407,6 +421,17 @@ def validate_and_save(
f"num_updates: {num_updates} >= max_update: {max_update}"
)
save_locally = (
cfg.checkpoint.local_save_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.local_save_interval_updates == 0
)
save_to_NFS = (
cfg.checkpoint.save_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.save_interval_updates == 0
)
do_save = (
(
end_of_epoch
@@ -414,9 +439,7 @@ def validate_and_save(
and epoch_itr.epoch % cfg.checkpoint.save_interval_epochs == 0
)
or (
cfg.checkpoint.save_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.save_interval_updates == 0
(save_locally or save_to_NFS)
and num_updates >= cfg.dataset.validate_after_updates
and was_successful_step
)
metaseq/dataclass/configs.py
+ 7
- 0
  • View file @ 233d3525

  • Edit in single-file editor

  • Open in Web IDE


@@ -505,6 +505,13 @@ class CheckpointConfig(MetaseqDataclass):
save_interval_updates: int = field(
default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
)
local_save_interval_updates: int = field(
default=0,
metadata={
"help": "save a checkpoint (and validate) every N updates to local SSD. "
"Only applicable when copying to NFS asynchronously"
},
)
save_last_checkpoint: bool = field(
default=True,
metadata={"help": "store a last checkpoint at the end of the training run."},
metaseq/checkpoint_utils.py
+ 10
- 5
  • View file @ 233d3525

  • Edit in single-file editor

  • Open in Web IDE


@@ -63,11 +63,15 @@ def save_checkpoint(
and epoch % cfg.save_interval_epochs == 0
)
save_for_updates = (
not end_of_epoch
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
save_locally = (
cfg.local_save_interval_updates > 0
and updates % cfg.local_save_interval_updates == 0
)
save_to_NFS = (
cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0
)
save_for_updates = not end_of_epoch and (save_to_NFS or save_locally)
checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch
checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates
@@ -82,6 +86,7 @@ def save_checkpoint(
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0:
if PathManager.islink(checkpoints[0]):
PathManager.rm(checkpoints[0])
@@ -90,7 +95,7 @@ def save_checkpoint(
checkpoints[0],
extra_state,
training_finished=training_finished,
async_callback_fn=async_callback_fn,
async_callback_fn=async_callback_fn if save_to_NFS else None,
)
write_timer.stop()
metaseq/trainer.py
+ 3
- 3
  • View file @ 233d3525

  • Edit in single-file editor

  • Open in Web IDE


@@ -398,13 +398,13 @@ class Trainer(object):
def perform_save():
try:
logger.info(f"Beginning asynchronous torch.save to {filename}")
if async_callback_fn is not None:
async_callback_fn(filename)
async_callback_fn(filename)
    • Administrator
      Administrator @root · 2 years ago
      Author Owner

      Created by: suchenzang

      Should we gate this behind save_to_NFS logic somewhere?

      • Administrator
        Administrator @root · 2 years ago
        Author Owner

        Created by: ngoyal2707

        Similar to answer on Stephen's comment, I am currently passing async_callback as None if save_to_NFS == False

      • Please register or sign in to reply
    • Administrator
      Administrator @root · 2 years ago
      Author Owner

      Created by: stephenroller

      why do we no longer need the None protection here? I'm going to be annoyed if we get log spam that just says Async save failed: NoneType error every 25 updates

      • Administrator
        Administrator @root · 2 years ago
        Author Owner

        Created by: ngoyal2707

        because protection is now while submitting in the threadpool executor at line number 407. Earlier we'd have had log spams where even with non-async saving, we'd seen these logs, it should be better now.

        The reason earlier code was written this way is becasue earlier to Zach's async saving fix, we also were doing torch.save inside the async function. Now thats outside, we dont need to submit anything to thread pool executor if async callback is None

      • Please register or sign in to reply
Please register or sign in to reply
logger.info(f"Asynchronous torch.save to {filename} complete.")
except Exception as e:
logger.exception(f"Asynchronous save failed: {e}")
torch.save(state_dict, filename)
self.async_checkpoint.submit(perform_save)
if async_callback_fn is not None:
self.async_checkpoint.submit(perform_save)
logger.info(f"Finished saving checkpoint to {filename}")
def load_checkpoint(
0 Assignees
None
Assign to
0 Reviewers
None
Request review from
Labels
1
cla signed
1
cla signed
    Assign labels
  • Manage project labels

Milestone
No milestone
None
None
Time tracking
No estimate or time spent
Lock merge request
Unlocked
1
1 participant
Administrator
Reference: root/metaseq!598
Source branch: ngoyal_add_separate_config_option_save_locally

Menu

Explore Projects Groups Snippets