diff --git a/data.py b/data.py index 56ecf95baefa204c46c1764a1acfb6073526c533..81abadaed2fa20358d31a7fabc8ceade1c44f2f0 100644 --- a/data.py +++ b/data.py @@ -100,9 +100,9 @@ class DataSet(): # Now one-hot it. label_hot = to_categorical(label_encoded, len(self.classes)) - assert len(label_hot) == len(self.classes) + assert len(label_hot[0]) == len(self.classes) - return label_hot + return label_hot[0] def split_train_test(self): """Split the data into train and test groups."""