Commit 736a27a6 authored by Susan Zhang's avatar Susan Zhang
Browse files

Revert "Revert "Added dynamic configs (profiling) (#473)" (#509)"

This reverts commit 044291f4.
No related merge requests found
Showing with 122 additions and 4 deletions
+122 -4
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import pathlib
import os
from metaseq.dataclass.configs import DynamicConfig
class TestDynamicConfig(unittest.TestCase):
def test_malformed_config_load(self):
metaseq_dir = pathlib.Path(__file__).parent.parent.resolve()
malformed_config = os.path.join(metaseq_dir, "metaseq/config/malformed.json")
DynamicConfig(json_file_path=malformed_config)
def test_empty_config_load(self):
metaseq_dir = pathlib.Path(__file__).parent.parent.resolve()
empty_config = os.path.join(metaseq_dir, "metaseq/config/empty.json")
DynamicConfig(json_file_path=empty_config)
def test_nonexistent_config_load(self):
nonexistent_config = "404.json"
DynamicConfig(json_file_path=nonexistent_config)
if __name__ == "__main__":
unittest.main()
......@@ -33,6 +33,7 @@ from metaseq import (
)
from metaseq.data import iterators, data_utils
from metaseq.data.plasma_utils import PlasmaStore
from metaseq.dataclass.configs import DynamicConfig
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from metaseq.file_io import PathManager
......@@ -40,8 +41,12 @@ from metaseq.logging import meters, metrics, progress_bar
from metaseq.model_parallel.megatron_trainer import MegatronTrainer
from metaseq.trainer import Trainer
if "SLURM_PROCID" in os.environ:
format_string = f"slurm_procid {os.environ['SLURM_PROCID']} : %(asctime)s | %(levelname)s | %(name)s | %(message)s"
else:
format_string = f"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
format=format_string,
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
......@@ -288,9 +293,18 @@ def train(
return valid_losses, should_stop
dcfg = DynamicConfig(
json_file_path=cfg.common.dynamic_config_path,
timeout=cfg.common.dynamic_config_timeout,
)
for i, samples in enumerate(progress):
if distributed_utils.get_global_rank() == 0 and cfg.common.profile and i == 5:
logger.info("STARTING PROFILER")
force_profile = dcfg["force_profile"]
do_profile = (distributed_utils.get_global_rank() == 0) and (
(cfg.common.profile and i == 5) or force_profile
)
if do_profile:
logger.info(f"STARTING PROFILER: step {i}")
with profiler.profile(
profile_memory=True, with_stack=True, record_shapes=True
) as prof:
......@@ -306,8 +320,9 @@ def train(
file=sourceFile,
)
prof.export_chrome_trace(
os.path.join(cfg.checkpoint.save_dir, "profiler_trace.json")
os.path.join(cfg.checkpoint.save_dir, f"profiler_trace_step_{i}.json")
)
logger.info(f"FINISHING PROFILER: step {i}")
else:
valid_losses, should_stop = train(i, samples)
if should_stop:
......
{"This is not event a vali
\ No newline at end of file
......@@ -4,6 +4,9 @@
# LICENSE file in the root directory of this source tree.
import sys
import time
import json
import logging
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, List, Optional
......@@ -194,6 +197,15 @@ class CommonConfig(MetaseqDataclass):
log_nvidia_smi: bool = field(
default=False, metadata={"help": "log output from nvidia-smi during training"}
)
dynamic_config_path: Optional[str] = field(
default=None,
metadata={
"help": "a path to place a file and load dynamic configuratinons from (it is being checked periodically)"
},
)
dynamic_config_timeout: float = field(
default=30.0, metadata={"help": "dynamic configuration state timeout"}
)
@dataclass
......@@ -745,3 +757,63 @@ class MetaseqConfig(MetaseqDataclass):
lr_scheduler: Any = MISSING
bpe: Any = MISSING
tokenizer: Any = None
class DynamicConfig:
"""
can be used instead of redis to store updatable key-value
test:
data = {"a": 1}
with open('data.json', 'w') as fp:
json.dump(data, fp)
dcfg = DynoCfg(json_file_path = "data.json", timeout = 1)
print(dcfg["a"]) # > 1
data = {"a": 4}
with open('data.json', 'w') as fp:
json.dump(data, fp)
print(dcfg["a"]) # > 1
time.sleep(1)
print(dcfg["a"]) # > 4
"""
default_state = {"force_profile": False}
valid_keys = ["force_profile"]
def __init__(self, json_file_path=None, timeout=30): # timeout in sec
self.data = DynamicConfig.default_state
self.timeout = timeout
self.timer_start = 0
self.json_file_path = json_file_path
self.refresh()
def validate(self):
for k, v in self.data.items():
if k not in DynamicConfig.valid_keys:
logging.warning(f"Encounterd unknown dynamic config option: {k} {v}")
def refresh(self):
if self.json_file_path is not None:
# if data have not been updated in a while, reload from the file
# can speed up if were to calculate hash sum of the file to monitor updates
timer_value = time.time() - self.timer_start
if timer_value > self.timeout:
try:
with open(self.json_file_path, "r") as json_file:
self.data = json.load(json_file)
self.validate()
self.timer_start = time.time()
except json.JSONDecodeError as jsonerror:
logging.warning(
f"""Error refreshing dynamic config: reading file {self.json_file_path}
resulted in JSONDecodeError {jsonerror}"""
)
except IOError as ioerror:
logging.warning(
f"""Error refreshing dynamic config: reading file {self.json_file_path}
resulted in IOError {ioerror}"""
)
def __getitem__(self, key):
self.refresh()
return self.data[key]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment