Skip to content
GitLab
    • Explore Projects Groups Snippets
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !603
An error occurred while fetching the assigned milestone of the selected merge_request.

Add test sequence parallel

  • Review changes

  • Download
  • Email patches
  • Plain diff
Closed Administrator requested to merge github/fork/bashnick/Add-test-sequence-parallel into main 2 years ago
  • Overview 1
  • Commits 51
  • Pipelines 0
  • Changes 4

Created by: bashnick

Patch Description

Adding tests for sequence_parallel flag to check rough equivalence between going through the sequence-parallel code-path (say, with MP 2) vs the current non sequence-parallel run.

Testing steps black . flake8 python3 -m pytest -v

Compare
  • main (base)

and
  • latest version
    aceb27a7
    51 commits, 2 years ago

4 files
+ 367
- 2

    Preferences

    File browser
    Compare changes
.cir‎cleci‎
confi‎g.yml‎ +0 -1
gpu_‎tests‎
test_model_para‎llel_mp1_mp2.py‎ +175 -0
test_sequenc‎e_parallel.py‎ +190 -0
metaseq/‎launcher‎
opt_job_co‎nstants.py‎ +2 -1
.circleci/config.yml
+ 0
- 1
  • View file @ aceb27a7


@@ -170,7 +170,6 @@ commands:
- store_test_results:
path: test-results
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
gpu_tests/test_model_parallel_mp1_mp2.py 0 → 100644
+ 175
- 0
  • View file @ aceb27a7

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import json
import multiprocessing
from functools import partial, partialmethod
import unittest
from unittest.mock import patch
import torch
from metaseq.dataclass.configs import DistributedTrainingConfig
from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main
from metaseq.cli.train import cli_main as train_cli_main
from metaseq.launcher.opt_job_constants import Size, M
@unittest.skipIf(not torch.cuda.is_available(), "test requires 4 GPUs, none found")
@unittest.skipIf(
DistributedTrainingConfig.distributed_world_size != 4,
"test requires 4 GPUs",
)
class TestModelParallel(unittest.TestCase):
"""
The tests verify that the model can be trained with
model_parallel = 1 and model_parallel = 2
The tests checks hat the number of trianing steps performed is correct
and that the required loss is achieved on the last iteration
"""
def test_model_parallel_mp1(self):
# parameters to train an mp1 model
argv_injection = (
"python3 metaseq/launcher/opt_baselines.py "
"--prefix train.8m --model-size 8m_mp1 --checkpoints-dir ./test-checkpoint "
"--tensorboard-logdir ./test-checkpoint --num-trials 1 --azure "
"--num-gpus 4 --num-nodes 1 --seed 1 "
"--local --disable-validation --max-epoch 5 --max-update 5 --benchmark "
)
max_update_first_run = 20
size_patch_dict = {"8m_mp1": Size(4, 128, 2, 64, int(0.03125 * M), 1.0e-3, 1)}
training_log_events = self._test_model_parallel(
max_update_first_run=max_update_first_run,
argv_injection=argv_injection,
size_patch_dict=size_patch_dict,
)
# check that training ran correctly
# check that the number of updates was correct
self.assertNotEqual(training_log_events, [])
self.assertIsNotNone(training_log_events[-1]["num_updates"])
self.assertEqual(
int(training_log_events[-1]["num_updates"]), max_update_first_run
)
# check the achieved loss is correct
loss_val = float(training_log_events[-1]["loss"])
self.assertAlmostEqual(loss_val, 14.736, 1) # 1 digit precision
def test_model_parallel_mp2(self):
# parameters to train an mp2 model
argv_injection = (
"python3 metaseq/launcher/opt_baselines.py "
"--prefix train.8m --model-size 8m --checkpoints-dir ./test-checkpoint "
"--tensorboard-logdir ./test-checkpoint --num-trials 1 --azure "
"--num-gpus 4 --num-nodes 1 --seed 1 "
"--local --disable-validation --max-epoch 5 --max-update 5 --benchmark "
)
max_update_first_run = 20
size_patch_dict = {"8m": Size(4, 128, 2, 64, int(0.03125 * M), 1.0e-3, 2)}
training_log_events = self._test_model_parallel(
max_update_first_run=max_update_first_run,
argv_injection=argv_injection,
size_patch_dict=size_patch_dict,
)
# check that training ran correctly
# check that the number of updates was correct
self.assertNotEqual(training_log_events, [])
self.assertIsNotNone(training_log_events[-1]["num_updates"])
self.assertEqual(
int(training_log_events[-1]["num_updates"]), max_update_first_run
)
# check the achieved loss is correct
loss_val = float(training_log_events[-1]["loss"])
self.assertAlmostEqual(loss_val, 14.744, 1) # 1 digit precision
def _test_model_parallel(
self, max_update_first_run, argv_injection, size_patch_dict
):
"""
Helper function to run the test
"""
# start the process for the model run
multiprocessing.set_start_method("spawn", force=True)
with torch.multiprocessing.Manager() as manager:
events = manager.list()
p = multiprocessing.Process(
target=run_training,
args=(max_update_first_run, events, argv_injection, size_patch_dict),
)
p.start()
p.join()
events_first_run = list(events)
# cleanup of the checkpoints files
cleanup_checkpoints = subprocess.Popen(
"rm -r ./test-checkpoint".split(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
_, _ = cleanup_checkpoints.communicate()
# parse the log events from the log_to_events()
training_log_events = [
json.loads(event["message"])
for event in events_first_run
if event["type"] == "log" and event["message"].startswith('{"epoch"')
]
return training_log_events
def run_training(max_update, events, argv_injection, size_patch_dict):
# clean any unused cach to reduce CUDA OOM
torch.cuda.empty_cache()
# main arguments to run the training script
# both patches are aneeded to run the job of the circleci GPUs
with patch("sys.argv", argv_injection.split()[1:]), patch(
"metaseq.launcher.slurm.local_run",
partial(local_run_mock, max_update=max_update, events=events),
), patch.dict(
"metaseq.launcher.opt_job_constants.MODEL_SIZES",
# reduce the batch size for CUDA memory optimization
size_patch_dict,
):
sweep_cli_main()
def local_run_mock(args, env, train_cmd, dry_run, max_update, events):
"""
The function introduces several patches for the argumets of the
model training. These patches are needed to pass gpu tests on
circleci GPUs (empirical knowledge)
"""
train_cmd[train_cmd.index("--max-update") + 1] = str(max_update)
train_cmd[train_cmd.index("--num-workers") + 1] = "1"
with patch("logging.Logger._log", partialmethod(log_to_events, events=events)):
with patch.dict("os.environ", env, clear=True):
with patch("sys.argv", train_cmd[1:]):
train_cli_main()
def log_to_events(self, info, message, args, events, **kwargs):
"""
The function is used to collect logging info from the subprocesses
and store it in the 'events' variable, which is then passed over
to the main process for asserting that the model ran correctly
"""
print(self, message)
if isinstance(message, str):
events.append(
{
"type": "log",
"message": message,
}
)
if __name__ == "__main__":
unittest.main()
gpu_tests/test_sequence_parallel.py 0 → 100644
+ 190
- 0
  • View file @ aceb27a7

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import json
import multiprocessing
from functools import partial, partialmethod
import unittest
from unittest.mock import patch
import torch
from metaseq.dataclass.configs import DistributedTrainingConfig
from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main
from metaseq.cli.train import cli_main as train_cli_main
from metaseq.launcher.opt_job_constants import Size, M
@unittest.skipIf(not torch.cuda.is_available(), "test requires 4 GPUs, none found")
@unittest.skipIf(
DistributedTrainingConfig.distributed_world_size != 4,
"test requires 4 GPUs",
)
class TestSequenceParallel(unittest.TestCase):
"""
The tests check rough equivalence between going through the
sequence-parallel code-path with MP 2 vs the current non
sequence-parallel run for the 8M model.
"""
def test_sequence_parallel(self):
# parameters to train an mp2 model with sequence_parallel flag
argv_injection = (
"python3 metaseq/launcher/opt_baselines.py "
"--prefix train.8m --model-size 8m --checkpoints-dir ./test-checkpoint "
"--tensorboard-logdir ./test-checkpoint --num-trials 1 --azure "
"--num-gpus 4 --num-nodes 1 --seed 1 "
"--local --disable-validation --max-epoch 5 --max-update 5 --benchmark "
)
max_update_first_run = 20
size_patch_dict = {"8m": Size(4, 128, 2, 64, int(0.03125 * M), 1.0e-3, 2)}
# train model with sequence_parallel flag
training_log_events_seq = self._test_model_parallel(
max_update_first_run=max_update_first_run,
argv_injection=argv_injection,
size_patch_dict=size_patch_dict,
is_sequence_parallel=True,
)
# train model without sequence_parallel flag
training_log_events = self._test_model_parallel(
max_update_first_run=max_update_first_run,
argv_injection=argv_injection,
size_patch_dict=size_patch_dict,
is_sequence_parallel=False,
)
# check that training ran correctly
# check that the number of updates was correct
self.assertNotEqual(training_log_events_seq, [])
self.assertNotEqual(training_log_events, [])
self.assertIsNotNone(training_log_events_seq[-1]["num_updates"])
self.assertIsNotNone(training_log_events[-1]["num_updates"])
self.assertEqual(
int(training_log_events[-1]["num_updates"]), max_update_first_run
)
self.assertEqual(
int(training_log_events_seq[-1]["num_updates"]), max_update_first_run
)
# check the achieved loss is similar between seq and non-seq
loss_val_seq = float(training_log_events_seq[-1]["loss"])
loss_val = float(training_log_events[-1]["loss"])
print("loss_val_seq: {} | loss_val: {}".format(loss_val_seq, loss_val))
self.assertAlmostEqual(
loss_val, loss_val_seq, 1
) # 1 digit precision; 14.702 - seq; 14.735 - non seq
def _test_model_parallel(
self,
max_update_first_run,
argv_injection,
size_patch_dict,
is_sequence_parallel,
):
"""
Helper function to run the test
"""
# start the process for the model run
multiprocessing.set_start_method("spawn", force=True)
with torch.multiprocessing.Manager() as manager:
events = manager.list()
p = multiprocessing.Process(
target=run_training,
args=(
max_update_first_run,
events,
argv_injection,
size_patch_dict,
is_sequence_parallel,
),
)
p.start()
p.join()
events_first_run = list(events)
# cleanup of the checkpoints files
cleanup_checkpoints = subprocess.Popen(
"rm -r ./test-checkpoint".split(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
_, _ = cleanup_checkpoints.communicate()
# parse the log events from the log_to_events()
training_log_events = [
json.loads(event["message"])
for event in events_first_run
if event["type"] == "log" and event["message"].startswith('{"epoch"')
]
return training_log_events
def run_training(
max_update, events, argv_injection, size_patch_dict, is_sequence_parallel
):
# clean any unused cach to reduce CUDA OOM
torch.cuda.empty_cache()
# main arguments to run the training script
# both patches are aneeded to run the job of the circleci GPUs
with patch("sys.argv", argv_injection.split()[1:]), patch(
"metaseq.launcher.slurm.local_run",
partial(
local_run_mock,
max_update=max_update,
events=events,
is_sequence_parallel=is_sequence_parallel,
),
), patch.dict(
"metaseq.launcher.opt_job_constants.MODEL_SIZES",
# reduce the batch size for CUDA memory optimization
size_patch_dict,
):
sweep_cli_main()
def local_run_mock(
args, env, train_cmd, dry_run, max_update, events, is_sequence_parallel
):
"""
The function introduces several patches for the argumets of the
model training. These patches are needed to pass gpu tests on
circleci GPUs and enable sequence_parallel parameter
"""
# update the parameters of the model training
train_cmd[train_cmd.index("--max-update") + 1] = str(max_update)
train_cmd[train_cmd.index("--num-workers") + 1] = "1"
train_cmd[train_cmd.index("--dropout") + 1] = "0.0"
train_cmd.remove("--checkpoint-activations")
train_cmd.remove("--distribute-checkpointed-activations")
# add sequence_parallel argument to the model arguments
if is_sequence_parallel:
train_cmd.append("--sequence-parallel")
with patch("logging.Logger._log", partialmethod(log_to_events, events=events)):
with patch.dict("os.environ", env, clear=True):
with patch("sys.argv", train_cmd[1:]):
train_cli_main()
def log_to_events(self, info, message, args, events, **kwargs):
"""
The function is used to collect logging info from the subprocesses
and store it in the 'events' variable, which is then passed over
to the main process for asserting that the model ran correctly
"""
print(self, message)
if isinstance(message, str):
events.append(
{
"type": "log",
"message": message,
}
)
if __name__ == "__main__":
unittest.main()
metaseq/launcher/opt_job_constants.py
+ 2
- 1
  • View file @ aceb27a7


@@ -30,7 +30,8 @@ TOTAL_TRAIN_TOKENS = 300e9
TOTAL_WARMUP_TOKENS = 375e6
M = 1024 * 1024 # 1 million
MODEL_SIZES = {
"8m": Size(4, 128, 2, 64, int(0.125 * M), 1.0e-3, 2), # tiny
"8m_mp1": Size(4, 128, 2, 64, int(0.125 * M), 1.0e-3, 1), # tiny with 1 mp
"8m": Size(4, 128, 2, 64, int(0.125 * M), 1.0e-3, 2), # tiny with 2 mp
"125m": Size(12, 768, 12, 64, int(0.5 * M), 6.0e-4, 2), # small
"350m": Size(24, 1024, 16, 64, int(0.5 * M), 3.0e-4, 2), # medium
"760m": Size(24, 1536, 16, 96, int(0.5 * M), 2.5e-4, 2), # large
0 Assignees
None
Assign to
0 Reviewers
None
Request review from
Labels
0
None
0
None
    Assign labels
  • Manage project labels

Milestone
No milestone
None
None
Time tracking
No estimate or time spent
Lock merge request
Unlocked
0
0 participants
Reference:
Source branch: github/fork/bashnick/Add-test-sequence-parallel

Menu

Explore Projects Groups Snippets