From 4f73d23b88fee2114aa3a08e22b62e7116cbe815 Mon Sep 17 00:00:00 2001 From: Ramakanth Pasunuru <ramakanth.1729@gmail.com> Date: Thu, 15 Sep 2022 22:48:40 -0700 Subject: [PATCH] add sampling specific to pretrain data --- metaseq/tasks/streaming_language_modeling.py | 22 +++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/metaseq/tasks/streaming_language_modeling.py b/metaseq/tasks/streaming_language_modeling.py index 73ba26f..dc39279 100644 --- a/metaseq/tasks/streaming_language_modeling.py +++ b/metaseq/tasks/streaming_language_modeling.py @@ -90,6 +90,10 @@ class StreamingLanguageModelingConfig(MetaseqDataclass): default=DEFAULT_MULTICORPUS_MAX, metadata={"help": "Maximum size for example proportional sampling"}, ) + pretrain_data_sampling_prob: Optional[float] = field( + default=0.0, + metadata={"help": "Set this to mix the x% of pretrain data in finetuning. Use only during finetuning"} + ) data_subshard_count: int = field( default=1, metadata={ @@ -208,7 +212,23 @@ class StreamingLanguageModelingTask(LegacyTask): dtype=float, ) logger.info(f"loaded total {dataset_lengths.sum()} blocks for all corpora") - sample_probs = self._get_sample_prob(dataset_lengths) + if self.args.pretrain_data_sampling_prob>0: + assert self.args.pretrain_data_sampling_prob < 1.0 + def prefix_match_index(l): + prefix = "pretrain__" + for i, s in enumerate(l): + if prefix in s: + return i + return -1 + pretrain_corpus_index = prefix_match_index(corpora) + assert pretrain_corpus_index != -1 + tmp_dataset_lengths = np.copy(dataset_lengths) + tmp_dataset_lengths[pretrain_corpus_index] = 0 + sample_probs = self._get_sample_prob(tmp_dataset_lengths) + sample_probs = sample_probs * (1 - self.args.pretrain_data_sampling_prob) + sample_probs[pretrain_corpus_index] = self.args.pretrain_data_sampling_prob + else: + sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by corpus: %s", -- GitLab