diff --git a/metaseq/cli/README.md b/metaseq/cli/README.md index 1c7edb3089196b776fb318ef9dbe387c465cbd8a..beb67b8340657793d4ed9439fa2708f6977cf69c 100644 --- a/metaseq/cli/README.md +++ b/metaseq/cli/README.md @@ -47,7 +47,7 @@ LAUNCH_ARGS = [ f"--merges-filename {BPE_MERGES}", f"--vocab-filename {BPE_VOCAB}", f"--path {MODEL_FILE}", - "--beam 1 --nbest 1", + "--beam 1", "--distributed-port 13000", "--checkpoint-shard-count 1", f"--batch-size {BATCH_SIZE}", diff --git a/metaseq/cli/interactive_hosted.py b/metaseq/cli/interactive_hosted.py index 7bae5effde4850017897cdec2f156ad04dac227e..d9be0ff80cecf5d1117461709f2d001ab5fff246 100644 --- a/metaseq/cli/interactive_hosted.py +++ b/metaseq/cli/interactive_hosted.py @@ -305,15 +305,17 @@ def completions(engine=None): generation_args["top_p"] = round(float(generation_args["top_p"]), 1) else: generation_args["top_p"] = 1.0 - # beam search top n - if "n" in generation_args: - if int(generation_args["n"]) > MAX_BEAM: - logger.warning( - f'beam size/sampling size of {int(generation_args["n"])} too large, using {MAX_BEAM} to avoid OOM' - ) - generation_args["n"] = min(MAX_BEAM, max(1, int(generation_args["n"]))) - else: + if "n" not in generation_args: generation_args["n"] = 1 + if "best_of" not in generation_args: + generation_args["best_of"] = generation_args["n"] + # beam search + if int(generation_args["best_of"]) > MAX_BEAM: + logger.warning( + f'beam size/sampling size of {int(generation_args["best_of"])} too large, using {MAX_BEAM} to avoid OOM' + ) + generation_args["best_of"] = MAX_BEAM + generation_args["n"] = min(MAX_BEAM, int(generation_args["n"])) ret_queue = queue.Queue() for i, prompt in enumerate(prompts): diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index b03c8130547ebca9eb081aef203dd6643b1c4fdb..73098a936e94f6bcb6d07d1770ec0c8f00c1d5d6 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -574,10 +574,6 @@ class GenerationConfig(MetaseqDataclass): default=5, metadata={"help": "beam size"}, ) - nbest: int = field( - default=1, - metadata={"help": "number of hypotheses to output"}, - ) max_len_a: float = field( default=0, metadata={ diff --git a/metaseq/hub_utils.py b/metaseq/hub_utils.py index ee27babcbdd4e032bdb46841c0defa603b48608f..6797f9792b4be138d0feb02812684ffbb6095dad 100644 --- a/metaseq/hub_utils.py +++ b/metaseq/hub_utils.py @@ -197,8 +197,8 @@ class GeneratorInterface: temperature: softmax temperature top_p: nucleus probability log_probs: return this cutoff of the probability distribution - n: beam size - best_of: number of beams to return. must be <= n + best_of: beam size + n: number of beams to return. must be <= best_of echo: if true, returned text/tokens/scores includes the prompt. This is useful for getting PPL evaluations. stop: a list of terminating tokens @@ -289,7 +289,7 @@ class GeneratorInterface: # actually turn everything into strings for i in range(all_tokens.size(0)): beams = [] - for j in range(best_of): + for j in range(n): # first beam is always the highest scoring tokens = all_tokens[i, j].tolist() scores = all_scores[i, j].tolist()