diff --git a/cpu_tests/test_streaming_iterators.py b/cpu_tests/test_streaming_iterators.py
index 08d23f105392d6c4d89ea73258090d6a98efa154..8fb8f530529383d60afb718b6babd594a73fc39d 100644
--- a/cpu_tests/test_streaming_iterators.py
+++ b/cpu_tests/test_streaming_iterators.py
@@ -12,6 +12,7 @@ from metaseq.data import (
     StreamingShuffleDataset,
     StreamingTokenBlockDataset,
     PartitionedStreamingDataset,
+    StreamingSrcTgtDataset,
 )
 from metaseq.data.document_to_sequence import DocumentToSequenceDataset, LockingArray
 
@@ -53,19 +54,24 @@ def get_simple_dataset():
 
 
 class FakeTensorData(torch.utils.data.Dataset):
-    def __init__(self):
+    def __init__(self, source_target):
         self.rng = random.Random(0)
         self.trng = torch.Generator()
         self.trng.manual_seed(0)
-        self.items = [
-            torch.randint(
-                256, size=(self.rng.randrange(512, 2048),), generator=self.trng
-            )
-            for _ in range(len(self))
-        ]
+        self.items = [self._gen_one(source_target) for _ in range(len(self))]
         self.queried = 0
         self.realized = [False for _ in self.items]
 
+    def _gen_one(self, source_target):
+        n = self.rng.randrange(512, 2048)
+        toks = torch.randint(256, size=(n,), generator=self.trng)
+        if not source_target:
+            return toks
+        src_tokens_len = self.rng.randrange(1, n)
+        tgt_tokens = torch.clone(toks)
+        tgt_tokens[:src_tokens_len] = 1
+        return (toks, tgt_tokens)
+
     def __len__(self):
         return 128
 
@@ -159,7 +165,7 @@ class TestStreamingIterators(unittest.TestCase):
         def create_dataset(
             break_mode="none", drop_last=True, sequence_size=2049, num_shards=1
         ):
-            dataset = FakeTensorData()
+            dataset = FakeTensorData(False)
             token_dataset = DocumentToSequenceDataset(
                 dataset,
                 # We generate blocks with one extra token, so that we have a target
@@ -282,10 +288,15 @@ class TestStreamingIterators(unittest.TestCase):
     def test_document_to_sequence(self):
         MAX_SEQ_LEN = 2048
 
-        def get_traditional_iterator(dataset, break_mode, drop_last):
+        def get_traditional_iterator(dataset, break_mode, drop_last, source_target):
             shuffle_dataset = StreamingShuffleDataset(dataset, seed=42)
             shuffle_dataset.set_epoch(0)
-            token_dataset = StreamingTokenBlockDataset(
+            Dataset = (
+                StreamingTokenBlockDataset
+                if not source_target
+                else StreamingSrcTgtDataset
+            )
+            token_dataset = Dataset(
                 shuffle_dataset,
                 # We generate blocks with one extra token, so that we have a target
                 # for the final input token. This results in slight data loss.
@@ -301,7 +312,9 @@ class TestStreamingIterators(unittest.TestCase):
             token_dataset.set_shuffle_buffer_size(4)
             return token_dataset
 
-        def get_document_to_sequence_iterator(dataset, break_mode, drop_last):
+        def get_document_to_sequence_iterator(
+            dataset, break_mode, drop_last, source_target
+        ):
             document_to_sequence_dataset = DocumentToSequenceDataset(
                 dataset,
                 # We generate blocks with one extra token, so that we have a target
@@ -314,15 +327,18 @@ class TestStreamingIterators(unittest.TestCase):
                 # 1284 is a randomly-generated offset to decouple the seed used here
                 # from the seed used above in StreamingShuffleDataset
                 seed=42,
+                source_target=source_target,
             )
             document_to_sequence_dataset.set_epoch(0)
             document_to_sequence_dataset.set_shuffle_buffer_size(4)
             return document_to_sequence_dataset
 
-        def compare(break_mode, drop_last):
-            a = get_traditional_iterator(FakeTensorData(), break_mode, drop_last)
+        def compare(break_mode, drop_last, source_target):
+            a = get_traditional_iterator(
+                FakeTensorData(source_target), break_mode, drop_last, source_target
+            )
             b = get_document_to_sequence_iterator(
-                FakeTensorData(), break_mode, drop_last
+                FakeTensorData(source_target), break_mode, drop_last, source_target
             )
             a_values = list(a)
             b_values = list(b)
@@ -330,15 +346,29 @@ class TestStreamingIterators(unittest.TestCase):
 
             for av, bv in zip(a_values, b_values):
                 self.assertTrue(torch.allclose(av["ids"], bv["ids"]))
-                self.assertTrue(torch.allclose(av["block"], bv["block"]))
-
-        compare("none", False)
-        compare("eos_pad_8", False)
-        compare("complete", False)
-
-        compare("none", True)
-        compare("eos_pad_8", True)
-        compare("complete", True)
+                if source_target:
+                    self.assertTrue(torch.allclose(av["src_block"], bv["src_block"]))
+                    self.assertTrue(torch.allclose(av["tgt_block"], bv["tgt_block"]))
+                else:
+                    self.assertTrue(torch.allclose(av["block"], bv["block"]))
+
+        # normal
+        compare("none", False, False)
+        compare("eos_pad_8", False, False)
+        compare("complete", False, False)
+
+        compare("none", True, False)
+        compare("eos_pad_8", True, False)
+        compare("complete", True, False)
+
+        # fine tuning
+        compare("none", False, True)
+        compare("eos_pad_8", False, True)
+        compare("complete", False, True)
+
+        compare("none", True, True)
+        compare("eos_pad_8", True, True)
+        compare("complete", True, True)
 
     def test_locking_array(self):
         l = LockingArray(20, 8)
diff --git a/metaseq/data/document_to_sequence.py b/metaseq/data/document_to_sequence.py
index 8b782d89b508c3d7fa36746802310cf6d3ef8ab8..cfcf181af81e6e55b922a16e380791c3f4c6a268 100644
--- a/metaseq/data/document_to_sequence.py
+++ b/metaseq/data/document_to_sequence.py
@@ -150,6 +150,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
             but only before iteration has begun.
         seed (int, optional): seed for shuffling
         permute_documents (bool, optional): randomly permute the order the documents are read (default: True)
+        source_target (bool, optional): the input dataset returns a tuple of tokens lists (source, target) (default: False)
         to_skip (int, optional): skip the first to_skip sequences before iteration begins (Default: 0)
     """
 
@@ -165,6 +166,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
         len_cache=None,
         to_skip=0,
         permute_documents=True,
+        source_target=False,
     ):
         super().__init__()
         self.dataset = dataset
@@ -200,9 +202,13 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
         self.shuffle_buffer_size = shuffle_buffer_size
         self.to_skip = to_skip
         self.permute_documents = permute_documents
+        self.source_target = source_target
 
         if break_mode == "none":
-            self.block_iterator = yield_token_blocks
+            if self.source_target:
+                self.block_iterator = yield_doc_blocks
+            else:
+                self.block_iterator = yield_token_blocks
         elif break_mode == "eos_pad_8":
             self.block_iterator = yield_single_sentences_pad_8
         elif break_mode == "complete":
@@ -211,7 +217,8 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
             self.block_iterator = yield_passthrough
         else:
             raise ValueError(
-                f'Invalid value for break_mode = {break_mode}. Available options are "none", "eos_pad_8" or "complete".'
+                f"Invalid value for break_mode = {break_mode}."
+                'Available options are "none", "eos_pad_8", "complete", or "passthrough".'
             )
 
         if not drop_last and padding_idx is None:
@@ -281,7 +288,10 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
                     # Cache miss: we don't know the number of tokens
                     # so we have to load and tokenize the document.
                     r = self.dataset[idx]
-                    ln = r.shape[0]
+                    if self.source_target:
+                        ln = r[0].shape[0]
+                    else:
+                        ln = r.shape[0]
                     self.len_cache.data[idx] = ln
                     yield (ln, [r])
                 else:
@@ -291,7 +301,7 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
                     # We create a single-element list here, so that we can replace the single element
                     # with the real Tensor value the first time _any_ SentenceFragment needs the
                     # real data from this document.
-                    yield (ln, [idx])
+                    yield (ln, [int(idx)])
 
         block_itr = self.block_iterator(documents(), self.block_size, self.drop_last)
 
@@ -358,19 +368,33 @@ class DocumentToSequenceDataset(torch.utils.data.IterableDataset):
 
                     # A padding tensor (<padding_value>, 0, length)
                     if doc[0] == "padding":
-                        subsequences.append(
-                            subsequences[-1].new_full((ln,), self.padding_idx)
-                        )
+                        example = subsequences[-1]
+                        if self.source_target:
+                            example = example[0]
+                        padding_tensor = example.new_full((ln,), self.padding_idx)
+                        if self.source_target:
+                            padding_tensor = (padding_tensor, padding_tensor)
+                        subsequences.append(padding_tensor)
                     else:
                         # This single-element list is shared among all SequenceFragments that use
                         # the same document. We update the list to ensure we only
                         # ever tokenize the document once.
-                        if not isinstance(doc[0], torch.Tensor):
+                        if isinstance(doc[0], int):
                             # an index into dataset that hasn't been loaded yet
                             # load it now (and for all other SequenceFragments where it hasn't been loaded yet)
                             doc[0] = self.dataset[doc[0]]
-                        subsequences.append(doc[0][start : start + ln])
-                elem["block"] = torch.cat(subsequences)
+                        if self.source_target:
+                            subsequences.append(
+                                tuple(elem[start : start + ln] for elem in doc[0])
+                            )
+                        else:
+                            subsequences.append(doc[0][start : start + ln])
+                if self.source_target:
+                    del elem["block"]
+                    elem["src_block"] = torch.cat(tuple(s for s, t in subsequences))
+                    elem["tgt_block"] = torch.cat(tuple(t for s, t in subsequences))
+                else:
+                    elem["block"] = torch.cat(subsequences)
                 elem["skip_time"] = skip_time
                 yield elem
         except StopIteration:
diff --git a/metaseq/tasks/streaming_finetune_language_modeling.py b/metaseq/tasks/streaming_finetune_language_modeling.py
index 9b2345e1ca18bc9acd983a8afd9f7f2d04ad2626..8f9957d8836fc9d6c1f899734f0a7eb1170bc9b7 100644
--- a/metaseq/tasks/streaming_finetune_language_modeling.py
+++ b/metaseq/tasks/streaming_finetune_language_modeling.py
@@ -15,14 +15,14 @@ import torch
 
 from metaseq.data import (
     JsonlDataset,
-    StreamingShuffleDataset,
-    StreamingSrcTgtDataset,
     data_utils,
 )
 from metaseq.tasks.streaming_language_modeling import (
     StreamingLanguageModelingTask,
     StreamingLanguageModelingConfig,
 )
+from metaseq.tasks.streaming_language_modeling import DocumentToSequenceDataset
+
 from metaseq.tasks import register_task
 
 logger = logging.getLogger(__name__)
@@ -96,10 +96,7 @@ class StreamingFinetuneLanguageModelingTask(StreamingLanguageModelingTask):
 
         dataset = torch.utils.data.ConcatDataset(datasets)
 
-        # shuffle order across epochs
-        dataset = StreamingShuffleDataset(dataset, seed=self.args.seed)
-
-        self.datasets[split] = StreamingSrcTgtDataset(
+        self.datasets[split] = DocumentToSequenceDataset(
             dataset,
             # We generate blocks with one extra token, so that we have a target
             # for the final input token. This results in slight data loss.
@@ -108,10 +105,8 @@ class StreamingFinetuneLanguageModelingTask(StreamingLanguageModelingTask):
             # we drop the remainder block during training
             drop_last=(split == "train"),
             padding_idx=self.source_dictionary.pad(),
-            # 1284 is a randomly-generated offset to decouple the seed used here
-            # from the seed used above in StreamingShuffleDataset
-            # TODO: Track this seed to avoid collisions. See issue #65
-            seed=1284 + self.args.seed,
+            seed=self.args.seed,
+            source_target=True,
         )
 
     def _collate_fn(self, items: List[Dict[str, Any]]):