diff --git a/data.py b/data.py
index 56ecf95baefa204c46c1764a1acfb6073526c533..81abadaed2fa20358d31a7fabc8ceade1c44f2f0 100644
--- a/data.py
+++ b/data.py
@@ -100,9 +100,9 @@ class DataSet():
         # Now one-hot it.
         label_hot = to_categorical(label_encoded, len(self.classes))
 
-        assert len(label_hot) == len(self.classes)
+        assert len(label_hot[0]) == len(self.classes)
 
-        return label_hot
+        return label_hot[0]
 
     def split_train_test(self):
         """Split the data into train and test groups."""