diff --git a/cpu_tests/test_streaming_iterators.py b/cpu_tests/test_streaming_iterators.py index 08d23f105392d6c4d89ea73258090d6a98efa154..8fb8f530529383d60afb718b6babd594a73fc39d 100644 --- a/cpu_tests/test_streaming_iterators.py +++ b/cpu_tests/test_streaming_iterators.py @@ -12,6 +12,7 @@ from metaseq.data import ( StreamingShuffleDataset, StreamingTokenBlockDataset, PartitionedStreamingDataset, + StreamingSrcTgtDataset, ) from metaseq.data.document_to_sequence import DocumentToSequenceDataset, LockingArray @@ -53,19 +54,24 @@ def get_simple_dataset(): class FakeTensorData(torch.utils.data.Dataset): - def __init__(self): + def __init__(self, source_target): self.rng = random.Random(0) self.trng = torch.Generator() self.trng.manual_seed(0) - self.items = [ - torch.randint( - 256, size=(self.rng.randrange(512, 2048),), generator=self.trng - ) - for _ in range(len(self)) - ] + self.items = [self._gen_one(source_target) for _ in range(len(self))] self.queried = 0 self.realized = [False for _ in self.items] + def _gen_one(self, source_target): + n = self.rng.randrange(512, 2048) + toks = torch.randint(256, size=(n,), generator=self.trng) + if not source_target: + return toks + src_tokens_len = self.rng.randrange(1, n) + tgt_tokens = torch.clone(toks) + tgt_tokens[:src_tokens_len] = 1 + return (toks, tgt_tokens) + def __len__(self): return 128 @@ -159,7 +165,7 @@ class TestStreamingIterators(unittest.TestCase): def create_dataset( break_mode="none", drop_last=True, sequence_size=2049, num_shards=1 ): - dataset = FakeTensorData() + dataset = FakeTensorData(False) token_dataset = DocumentToSequenceDataset( dataset, # We generate blocks with one extra token, so that we have a target @@ -282,10 +288,15 @@ class TestStreamingIterators(unittest.TestCase): def test_document_to_sequence(self): MAX_SEQ_LEN = 2048 - def get_traditional_iterator(dataset, break_mode, drop_last): + def get_traditional_iterator(dataset, break_mode, drop_last, source_target): shuffle_dataset = StreamingShuffleDataset(dataset, seed=42) shuffle_dataset.set_epoch(0) - token_dataset = StreamingTokenBlockDataset( + Dataset = ( + StreamingTokenBlockDataset + if not source_target + else StreamingSrcTgtDataset + ) + token_dataset = Dataset( shuffle_dataset, # We generate blocks with one extra token, so that we have a target # for the final input token. This results in slight data loss. @@ -301,7 +312,9 @@ class TestStreamingIterators(unittest.TestCase): token_dataset.set_shuffle_buffer_size(4) return token_dataset - def get_document_to_sequence_iterator(dataset, break_mode, drop_last): + def get_document_to_sequence_iterator( + dataset, break_mode, drop_last, source_target + ): document_to_sequence_dataset = DocumentToSequenceDataset( dataset, # We generate blocks with one extra token, so that we have a target @@ -314,15 +327,18 @@ class TestStreamingIterators(unittest.TestCase): # 1284 is a randomly-generated offset to decouple the seed used here # from the seed used above in StreamingShuffleDataset seed=42, + source_target=source_target, ) document_to_sequence_dataset.set_epoch(0) document_to_sequence_dataset.set_shuffle_buffer_size(4) return document_to_sequence_dataset - def compare(break_mode, drop_last): - a = get_traditional_iterator(FakeTensorData(), break_mode, drop_last) + def compare(break_mode, drop_last, source_target): + a = get_traditional_iterator( + FakeTensorData(source_target), break_mode, drop_last, source_target + ) b = get_document_to_sequence_iterator( - FakeTensorData(), break_mode, drop_last + FakeTensorData(source_target), break_mode, drop_last, source_target ) a_values = list(a) b_values = list(b) @@ -330,15 +346,29 @@ class TestStreamingIterators(unittest.TestCase): for av, bv in zip(a_values, b_values): self.assertTrue(torch.allclose(av["ids"], bv["ids"])) - self.assertTrue(torch.allclose(av["block"], bv["block"])) - - compare("none", False) - compare("eos_pad_8", False) - compare("complete", False) - - compare("none", True) - compare("eos_pad_8", True) - compare("complete", True) + if source_target: + self.assertTrue(torch.allclose(av["src_block"], bv["src_block"])) + self.assertTrue(torch.allclose(av["tgt_block"], bv["tgt_block"])) + else: + self.assertTrue(torch.allclose(av["block"], bv["block"])) + + # normal + compare("none", False, False) + compare("eos_pad_8", False, False) + compare("complete", False, False) + + compare("none", True, False) + compare("eos_pad_8", True, False) + compare("complete", True, False) + + # fine tuning + compare("none", False, True) + compare("eos_pad_8", False, True) + compare("complete", False, True) + + compare("none", True, True) + compare("eos_pad_8", True, True) + compare("complete", True, True) def test_locking_array(self): l = LockingArray(20, 8) diff --git a/metaseq/data/document_to_sequence.py b/metaseq/data/document_to_sequence.py index 8b782d89b508c3d7fa36746802310cf6d3ef8ab8..cfcf181af81e6e55b922a16e380791c3f4c6a268 100644 --- a/metaseq/data/document_to_sequence.py +++ b/metaseq/data/document_to_sequence.py @@ -150,6 +150,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): but only before iteration has begun. seed (int, optional): seed for shuffling permute_documents (bool, optional): randomly permute the order the documents are read (default: True) + source_target (bool, optional): the input dataset returns a tuple of tokens lists (source, target) (default: False) to_skip (int, optional): skip the first to_skip sequences before iteration begins (Default: 0) """ @@ -165,6 +166,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): len_cache=None, to_skip=0, permute_documents=True, + source_target=False, ): super().__init__() self.dataset = dataset @@ -200,9 +202,13 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): self.shuffle_buffer_size = shuffle_buffer_size self.to_skip = to_skip self.permute_documents = permute_documents + self.source_target = source_target if break_mode == "none": - self.block_iterator = yield_token_blocks + if self.source_target: + self.block_iterator = yield_doc_blocks + else: + self.block_iterator = yield_token_blocks elif break_mode == "eos_pad_8": self.block_iterator = yield_single_sentences_pad_8 elif break_mode == "complete": @@ -211,7 +217,8 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): self.block_iterator = yield_passthrough else: raise ValueError( - f'Invalid value for break_mode = {break_mode}. Available options are "none", "eos_pad_8" or "complete".' + f"Invalid value for break_mode = {break_mode}." + 'Available options are "none", "eos_pad_8", "complete", or "passthrough".' ) if not drop_last and padding_idx is None: @@ -281,7 +288,10 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): # Cache miss: we don't know the number of tokens # so we have to load and tokenize the document. r = self.dataset[idx] - ln = r.shape[0] + if self.source_target: + ln = r[0].shape[0] + else: + ln = r.shape[0] self.len_cache.data[idx] = ln yield (ln, [r]) else: @@ -291,7 +301,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): # We create a single-element list here, so that we can replace the single element # with the real Tensor value the first time _any_ SentenceFragment needs the # real data from this document. - yield (ln, [idx]) + yield (ln, [int(idx)]) block_itr = self.block_iterator(documents(), self.block_size, self.drop_last) @@ -358,19 +368,33 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset): # A padding tensor (<padding_value>, 0, length) if doc[0] == "padding": - subsequences.append( - subsequences[-1].new_full((ln,), self.padding_idx) - ) + example = subsequences[-1] + if self.source_target: + example = example[0] + padding_tensor = example.new_full((ln,), self.padding_idx) + if self.source_target: + padding_tensor = (padding_tensor, padding_tensor) + subsequences.append(padding_tensor) else: # This single-element list is shared among all SequenceFragments that use # the same document. We update the list to ensure we only # ever tokenize the document once. - if not isinstance(doc[0], torch.Tensor): + if isinstance(doc[0], int): # an index into dataset that hasn't been loaded yet # load it now (and for all other SequenceFragments where it hasn't been loaded yet) doc[0] = self.dataset[doc[0]] - subsequences.append(doc[0][start : start + ln]) - elem["block"] = torch.cat(subsequences) + if self.source_target: + subsequences.append( + tuple(elem[start : start + ln] for elem in doc[0]) + ) + else: + subsequences.append(doc[0][start : start + ln]) + if self.source_target: + del elem["block"] + elem["src_block"] = torch.cat(tuple(s for s, t in subsequences)) + elem["tgt_block"] = torch.cat(tuple(t for s, t in subsequences)) + else: + elem["block"] = torch.cat(subsequences) elem["skip_time"] = skip_time yield elem except StopIteration: diff --git a/metaseq/tasks/streaming_finetune_language_modeling.py b/metaseq/tasks/streaming_finetune_language_modeling.py index 9b2345e1ca18bc9acd983a8afd9f7f2d04ad2626..8f9957d8836fc9d6c1f899734f0a7eb1170bc9b7 100644 --- a/metaseq/tasks/streaming_finetune_language_modeling.py +++ b/metaseq/tasks/streaming_finetune_language_modeling.py @@ -15,14 +15,14 @@ import torch from metaseq.data import ( JsonlDataset, - StreamingShuffleDataset, - StreamingSrcTgtDataset, data_utils, ) from metaseq.tasks.streaming_language_modeling import ( StreamingLanguageModelingTask, StreamingLanguageModelingConfig, ) +from metaseq.tasks.streaming_language_modeling import DocumentToSequenceDataset + from metaseq.tasks import register_task logger = logging.getLogger(__name__) @@ -96,10 +96,7 @@ class StreamingFinetuneLanguageModelingTask(StreamingLanguageModelingTask): dataset = torch.utils.data.ConcatDataset(datasets) - # shuffle order across epochs - dataset = StreamingShuffleDataset(dataset, seed=self.args.seed) - - self.datasets[split] = StreamingSrcTgtDataset( + self.datasets[split] = DocumentToSequenceDataset( dataset, # We generate blocks with one extra token, so that we have a target # for the final input token. This results in slight data loss. @@ -108,10 +105,8 @@ class StreamingFinetuneLanguageModelingTask(StreamingLanguageModelingTask): # we drop the remainder block during training drop_last=(split == "train"), padding_idx=self.source_dictionary.pad(), - # 1284 is a randomly-generated offset to decouple the seed used here - # from the seed used above in StreamingShuffleDataset - # TODO: Track this seed to avoid collisions. See issue #65 - seed=1284 + self.args.seed, + seed=self.args.seed, + source_target=True, ) def _collate_fn(self, items: List[Dict[str, Any]]):