Skip to content

CIFAR100 multilabel + mean/std #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 39 additions & 43 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ class CIFAR10(data.Dataset):
['test_batch', '40351d587109b95175f43aff81a1287e'],
]

label_keys = ['labels']

This comment was marked as off-topic.

This comment was marked as off-topic.


mean = [125.30691805, 122.95039414, 113.86538318]
std = [62.99321928, 62.08870764, 66.70489964]

def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set

if download:
self.download()
Expand All @@ -62,44 +66,23 @@ def __init__(self, root, train=True,
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# now load the picked numpy arrays
if self.train:
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
# now load the pickled numpy arrays
data_list = self.train_list if train else self.test_list
self.data = []
self.labels = {k: [] for k in self.label_keys}
for fname, _ in data_list:
with open(os.path.join(self.root, self.base_folder, fname), 'rb') as fo:
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()

self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
f = self.test_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']

This comment was marked as off-topic.

This comment was marked as off-topic.

fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
self.data.append(entry['data'])
for key in self.label_keys:
self.labels[key] += entry[key]

self.data = np.concatenate(self.data)
self.data = self.data.reshape((-1, 3, 32, 32))
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC

def __getitem__(self, index):
"""
Expand All @@ -109,10 +92,9 @@ def __getitem__(self, index):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img, target = self.data[index], [self.labels[k][index] for k in self.label_keys]
if len(self.label_keys) == 1:
target = target[0]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
Expand All @@ -127,10 +109,7 @@ def __getitem__(self, index):
return img, target

def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
return len(self.data)

def _check_integrity(self):
root = self.root
Expand Down Expand Up @@ -164,6 +143,11 @@ class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

This is a subclass of the `CIFAR10` Dataset.

Additional Args:
include_coarse (bool, optional): If False (the default), the targets
are the fine-grained labels only. If True, targets are a pair
consisting of the fine-grained and the coarse labels, respectively.
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
Expand All @@ -176,3 +160,15 @@ class CIFAR100(CIFAR10):
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]

This comment was marked as off-topic.

mean = [129.30416561, 124.0699627, 112.43405006]
std = [68.1702429, 65.39180804, 70.41837019]

def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False, include_coarse=False):
self.label_keys = ['fine_labels']
if include_coarse:
self.label_keys.append('coarse_labels')
super(CIFAR100, self).__init__(root=root, train=train, transform=transform,
target_transform=target_transform, download=download)