Commit bb770000 authored by Peter Albert's avatar Peter Albert Committed by Peter Albert
Browse files

Automatically evaluate checkpoints after copying to NFS (#550)


* add async eval with dummy

* full testing setup, add configs

* fix naming

* only eval at frequency

* remove logging used for testing

* change to real frequencies

* remove logging

* added improvements

* remove eval last checkpoint

* flake8 lint

* change naming, add comment, always evaluate at end of training

* black lint

* rename to training_finished

Co-authored-by: default avatarPeter Albert <peteralbert@fb.com>
Showing with 99 additions and 43 deletions
+99 -43
......@@ -10,7 +10,6 @@ Train a new model on one or across multiple GPUs.
import argparse
import functools
import logging
import importlib
import math
import os
import subprocess
......@@ -439,7 +438,9 @@ def validate_and_save(
trainer,
epoch_itr,
training_finished=should_stop,
async_callback_fn=functools.partial(post_checkpoint_callback, cfg)
async_callback_fn=functools.partial(
post_checkpoint_callback, cfg, num_updates, should_stop
)
if cfg.checkpoint.cloud_upload_path
else None,
)
......@@ -459,7 +460,7 @@ def _checkpoint_add_directory(basename):
return m[1], f"checkpoint{m[3]}"
def post_checkpoint_callback(cfg, filename):
def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
if cfg.checkpoint.cloud_upload_path is not None:
if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path:
azcopy_logs = filename + "_azcopy_logs"
......@@ -526,6 +527,16 @@ def post_checkpoint_callback(cfg, filename):
),
)
os.remove(filename)
# Start running evals on uploaded checkpoint
nfs_evaluation(
cfg,
num_updates,
training_finished,
checkpoint_dir,
destination_checkpoints_dir,
)
else:
try:
# PathManager only supports writing to S3, but this function call
......@@ -541,28 +552,65 @@ def post_checkpoint_callback(cfg, filename):
logger.info(f"could not upload {filename}: {e}")
def _run_evaluations(
eval_module, cloud_upload_path, local_file, checkpoint_suffix, gloo_pg
def nfs_evaluation(
cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir
):
# Make sure all ranks have finished uploading checkpoints.
# If any rank doesn't hit the barrier within the timeout period, we throw an error and do
# not run evals. Error doesn't stop training run.
# dist.monitored_barrier(group=gloo_pg, timeout=timedelta(minutes=5))
# Run evals on rank 0
if distributed_utils.get_global_rank() != 0:
return
assert eval_module is not None, "--eval-module needs to be set."
module = importlib.import_module(eval_module)
if not hasattr(module, "eval_fn"):
raise RuntimeError(
f"{eval_module} must have a function called eval_fn to utilize for evaluations. "
"It expects the following signature:\n"
"def eval_fn(cloud_upload_path: str, checkpoint_name: str)"
if (
cfg.checkpoint.nfs_eval_script_path is not None
and distributed_utils.get_global_rank() == 0
and (
(
cfg.checkpoint.nfs_eval_frequency > 0
and num_updates % cfg.checkpoint.nfs_eval_frequency == 0
)
or training_finished
)
):
for retry in range(cfg.checkpoint.nfs_eval_num_attempts):
time.sleep(cfg.checkpoint.nfs_eval_attempt_wait_minutes * 60)
current_checkpoint_path = os.path.join(
destination_checkpoints_dir, checkpoint_dir
)
num_files = os.listdir(current_checkpoint_path)
# only count completed checkpoints
finished_checkpoint_parts = len(
[f for f in num_files if not f.startswith("_")]
)
if (
finished_checkpoint_parts
== cfg.distributed_training.distributed_world_size
):
logger.info(
f"All checkpoint parts for {checkpoint_dir} are in NFS, will now start to run evals"
)
script_dir = os.path.join(
os.environ.get("METASEQ_SAVE_DIR"),
cfg.checkpoint.nfs_eval_script_path,
)
res = subprocess.run(
[
"bash",
script_dir,
os.path.join(current_checkpoint_path, "checkpoint.pt"),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if res.returncode == 0:
logger.info(f"Sucessfully evaluated {checkpoint_dir}")
else:
logger.error(f"Error during evaluation: {res.returncode}")
logger.error(f"Eval script stdout = {res.stdout}")
logger.error(f"Eval script stderr = {res.stderr}")
return
logger.error(
(
f"Did not evaluate {checkpoint_dir}, as only {num_files}/"
f"{cfg.distributed_training.distributed_world_size} "
"checkpoint parts were copied to NFS within waiting time"
)
)
checkpoint_name = local_file.split("/")[-1].replace(checkpoint_suffix, "")
logger.info(f"Kicking off eval_fn from: {module}")
module.eval_fn(cloud_upload_path, checkpoint_name)
logger.info(f"Successfully ran evaluation")
def _run_azcopy(cmd, stdout, stderr):
......
......@@ -509,26 +509,6 @@ class CheckpointConfig(MetaseqDataclass):
default=True,
metadata={"help": "store a last checkpoint at the end of the training run."},
)
eval_module: Optional[str] = field(
default=None,
metadata={
"help": (
"Python module that is dinamically imported to run evaluations. It must have an eval_fn method."
"Required args for eval_fn:"
"1. First one contains the cloud upload path."
"2. Second one contains the filename of the checkpoints in the cloud"
)
},
)
evaluate_interval_updates: int = field(
default=0, metadata={"help": "run eval_fn from eval_module every N updates"}
)
evaluate_last_checkpoint: bool = field(
default=False,
metadata={
"help": "run the eval_fn from eval_module at the end of the training run"
},
)
keep_last_epochs: int = field(
default=-1, metadata={"help": "keep only the last N epoch checkpoints"}
)
......@@ -575,6 +555,34 @@ class CheckpointConfig(MetaseqDataclass):
"argparse_alias": "--cloud-dir",
},
)
nfs_eval_script_path: Optional[str] = field(
default=None,
metadata={
"help": "Path of eval script to run on checkpoints after they were uploaded"
},
)
nfs_eval_num_attempts: int = field(
default=10,
metadata={
"help": "Number of attempts of running evals on upload of checkpoint"
},
)
nfs_eval_attempt_wait_minutes: int = field(
default=5,
metadata={
"help": "Time to wait between attempts of running evals on upload of checkpoint"
},
)
nfs_eval_frequency: int = field(
default=5000,
metadata={
"help": (
"Run evaluation only on uploaded checkpoints"
"with multiples of this frequency"
),
},
)
# TODO(susanz): After https://github.com/fairinternal/fairseq-big-internal/issues/22 is tackled, modify this
# to use ComputeEnvs constant
cluster_env: str = field(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment