Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !537

Support overriding sequence parallelism in the API

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Administrator requested to merge punitkoura/sequence-parallelism-api into bigbig Nov 30, 2022
  • Overview 10
  • Commits 4
  • Pipelines 0
  • Changes 4

Created by: punitkoura

Patch Description

  1. Support overriding sequence parallelism in the API
  2. Support loading FSDP sharded models through the API

Testing steps

  1. Take a sequence parallel checkpoint, and create a constants module file (custom_constants_module.py)
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

MAX_SEQ_LEN = 2048
BATCH_SIZE = 2048  # silly high bc we dynamically batch by MAX_BATCH_TOKENS
MAX_BATCH_TOKENS = 3072
DEFAULT_PORT = 6010
MODEL_PARALLEL = <REPLACE_MODEL_PARALLEL_SIZE_HERE>
TOTAL_WORLD_SIZE = <REPLACE_WORLD_SIZE_HERE>
MAX_BEAM = 16

CHECKPOINT_FOLDER = <INSERT_MODEL_CHECKPOINT_FOLDER_HERE>

# tokenizer files
HF_TOKENIZER = <INSERT_TOKENIZER_FILE_HERE>
MODEL_FILE = os.path.join(CHECKPOINT_FOLDER, "reshard.pt")


LAUNCH_ARGS = [
    f"--model-parallel-size {MODEL_PARALLEL}",
    f"--distributed-world-size {TOTAL_WORLD_SIZE}",
    "--ddp-backend fully_sharded",
    "--task language_modeling",
    "--bpe hf_byte_bpe",
    f"--hf-tokenizer {HF_TOKENIZER}",
    f"--path {MODEL_FILE}",
    "--beam 1 --nbest 1",
    "--distributed-port 13000",
    "--checkpoint-shard-count 1",
    "--use-sharded-state",
    f"--batch-size {BATCH_SIZE}",
    f"--buffer-size {BATCH_SIZE * MAX_SEQ_LEN}",
    f"--max-tokens {BATCH_SIZE * MAX_SEQ_LEN}",
    "/tmp",  # required "data" argument.
]

# Optional arg overrides which influence model loading during inference
INFERENCE_ARG_OVERRIDES = {"sequence_parallel": False}
  1. Export the constants module for it to be visible to interactive_hosted.py
export PYTHONPATH=$PYTHONPATH:/path/to/custom_constants_module METASEQ_SERVICE_CONSTANTS_MODULE=custom_constants_module
  1. Run interactive_hosted.py
srun --exclusive -N <NUMBER_OF_NODES> --gpus-per-node <NUMBER_OF_GPUS_PER_NODE> --tasks 1 -c 96 --partition <PUT_PARTITION_NAME_HERE> --time "1-00:00:00" --qos high  --pty python metaseq/cli/interactive_hosted.py

NUMBER_OF_NODES*NUMBER_OF_GPUS_PER_NODE should be equal to MODEL_PARALLEL_SIZE

  1. Prompt the model
curl -k http://<ip>:<port>/completions -H "Authorization: Bearer Punit" -H "Content-Type: application/json" \
-d '{
"prompt": [16853, 16947, 19678, 16709, 16647, 19493, 17495, 16688, 16397],
"temperature": 0.0,
"max_tokens": 0, "min_tokens": 0,
"top_p": 1.0, "n": 1, "best_of": 1,
"echo": true, "logprobs": 1, "seed": 1
}'
Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: punitkoura/sequence-parallelism-api