From 4f73d23b88fee2114aa3a08e22b62e7116cbe815 Mon Sep 17 00:00:00 2001
From: Ramakanth Pasunuru <ramakanth.1729@gmail.com>
Date: Thu, 15 Sep 2022 22:48:40 -0700
Subject: [PATCH] add sampling specific to pretrain data

---
 metaseq/tasks/streaming_language_modeling.py | 22 +++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/metaseq/tasks/streaming_language_modeling.py b/metaseq/tasks/streaming_language_modeling.py
index 73ba26f..dc39279 100644
--- a/metaseq/tasks/streaming_language_modeling.py
+++ b/metaseq/tasks/streaming_language_modeling.py
@@ -90,6 +90,10 @@ class StreamingLanguageModelingConfig(MetaseqDataclass):
         default=DEFAULT_MULTICORPUS_MAX,
         metadata={"help": "Maximum size for example proportional sampling"},
     )
+    pretrain_data_sampling_prob: Optional[float] = field(
+        default=0.0,
+        metadata={"help": "Set this to mix the x% of pretrain data in finetuning. Use only during finetuning"}
+    )
     data_subshard_count: int = field(
         default=1,
         metadata={
@@ -208,7 +212,23 @@ class StreamingLanguageModelingTask(LegacyTask):
             dtype=float,
         )
         logger.info(f"loaded total {dataset_lengths.sum()} blocks for all corpora")
-        sample_probs = self._get_sample_prob(dataset_lengths)
+        if self.args.pretrain_data_sampling_prob>0:
+            assert self.args.pretrain_data_sampling_prob < 1.0
+            def prefix_match_index(l):
+                prefix = "pretrain__"
+                for i, s in enumerate(l):
+                  if prefix in s:
+                      return i
+                return -1
+            pretrain_corpus_index = prefix_match_index(corpora)
+            assert pretrain_corpus_index != -1
+            tmp_dataset_lengths = np.copy(dataset_lengths)
+            tmp_dataset_lengths[pretrain_corpus_index] = 0
+            sample_probs = self._get_sample_prob(tmp_dataset_lengths)
+            sample_probs = sample_probs * (1 - self.args.pretrain_data_sampling_prob)
+            sample_probs[pretrain_corpus_index] = self.args.pretrain_data_sampling_prob
+        else:
+            sample_probs = self._get_sample_prob(dataset_lengths)
 
         logger.info(
             "Sample probability by corpus: %s",
-- 
GitLab