참고
Click here to download the full example code
파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기
변형(Transform)¶
데이터가 항상 머신러닝 알고리즘 학습에 필요한 최종 처리가 된 형태로 제공되지는 않습니다. 변형(transform) 을 해서 데이터를 조작하고 학습에 적합하게 만듭니다.
모든 TorchVision 데이터셋들은 변형 로직을 갖는, 호출 가능한 객체(callable)를 받는 매개변수 두개
( 특징(feature)을 변경하기 위한 transform
과 정답(label)을 변경하기 위한 target_transform
)를 갖습니다
torchvision.transforms 모듈은
주로 사용하는 몇가지 변형(transform)을 제공합니다.
FashionMNIST 특징(feature)은 PIL Image 형식이며, 정답(label)은 정수(integer)입니다.
학습을 하려면 정규화(normalize)된 텐서 형태의 특징(feature)과 원-핫(one-hot)으로 부호화(encode)된 텐서 형태의
정답(label)이 필요합니다. 이러한 변형(transformation)을 하기 위해 ToTensor
와 Lambda
를 사용합니다.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 32768/26421880 [00:00<03:40, 119811.04it/s]
0%| | 65536/26421880 [00:00<03:35, 122347.74it/s]
0%| | 131072/26421880 [00:00<02:29, 176323.99it/s]
1%| | 196608/26421880 [00:01<02:08, 204384.61it/s]
1%|1 | 393216/26421880 [00:01<01:06, 392098.14it/s]
3%|3 | 819200/26421880 [00:01<00:31, 800826.32it/s]
6%|6 | 1638400/26421880 [00:01<00:16, 1545674.38it/s]
12%|#2 | 3276800/26421880 [00:02<00:07, 3018579.68it/s]
24%|##3 | 6225920/26421880 [00:02<00:03, 5450813.10it/s]
30%|##9 | 7864320/26421880 [00:02<00:03, 5443745.66it/s]
36%|###5 | 9502720/26421880 [00:03<00:03, 5552113.51it/s]
45%|####4 | 11862016/26421880 [00:03<00:02, 4971859.82it/s]
56%|#####5 | 14712832/26421880 [00:03<00:01, 6512849.36it/s]
66%|######5 | 17432576/26421880 [00:04<00:01, 6844156.94it/s]
69%|######8 | 18186240/26421880 [00:04<00:01, 6225364.76it/s]
81%|######## | 21299200/26421880 [00:04<00:00, 7739751.00it/s]
86%|########6 | 22773760/26421880 [00:04<00:00, 6470667.97it/s]
89%|########8 | 23461888/26421880 [00:05<00:00, 5325640.16it/s]
99%|#########9| 26181632/26421880 [00:05<00:00, 7261945.14it/s]
100%|##########| 26421880/26421880 [00:05<00:00, 4776534.41it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 108711.87it/s]
100%|##########| 29515/29515 [00:00<00:00, 108426.51it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%| | 32768/4422102 [00:00<00:38, 114680.30it/s]
1%|1 | 65536/4422102 [00:00<00:36, 120587.11it/s]
3%|2 | 131072/4422102 [00:00<00:25, 168878.95it/s]
4%|4 | 196608/4422102 [00:01<00:21, 197762.02it/s]
6%|5 | 262144/4422102 [00:01<00:19, 216120.88it/s]
8%|8 | 360448/4422102 [00:01<00:15, 264138.81it/s]
10%|# | 458752/4422102 [00:01<00:13, 285561.15it/s]
13%|#2 | 557056/4422102 [00:02<00:12, 306905.20it/s]
15%|#4 | 655360/4422102 [00:02<00:11, 322802.55it/s]
17%|#7 | 753664/4422102 [00:02<00:11, 331565.79it/s]
19%|#9 | 851968/4422102 [00:03<00:10, 344584.37it/s]
22%|##2 | 983040/4422102 [00:03<00:08, 390393.28it/s]
25%|##5 | 1114112/4422102 [00:03<00:07, 422070.75it/s]
28%|##8 | 1245184/4422102 [00:03<00:07, 445246.59it/s]
31%|###1 | 1376256/4422102 [00:04<00:06, 461643.84it/s]
34%|###4 | 1507328/4422102 [00:04<00:06, 474132.87it/s]
37%|###7 | 1638400/4422102 [00:04<00:05, 482066.18it/s]
41%|#### | 1802240/4422102 [00:04<00:05, 516611.61it/s]
44%|####4 | 1966080/4422102 [00:05<00:04, 529782.06it/s]
48%|####8 | 2129920/4422102 [00:05<00:04, 557592.86it/s]
53%|#####2 | 2326528/4422102 [00:05<00:03, 613771.16it/s]
56%|#####6 | 2490368/4422102 [00:05<00:02, 741092.06it/s]
59%|#####8 | 2588672/4422102 [00:06<00:02, 688571.81it/s]
62%|######1 | 2719744/4422102 [00:06<00:02, 639874.25it/s]
66%|######5 | 2916352/4422102 [00:06<00:02, 677196.36it/s]
70%|####### | 3112960/4422102 [00:06<00:01, 803376.46it/s]
73%|#######2 | 3211264/4422102 [00:06<00:01, 806051.11it/s]
76%|#######5 | 3342336/4422102 [00:06<00:01, 807283.05it/s]
78%|#######7 | 3440640/4422102 [00:07<00:01, 840936.62it/s]
82%|########1 | 3604480/4422102 [00:07<00:01, 752564.77it/s]
87%|########6 | 3833856/4422102 [00:07<00:00, 927180.63it/s]
90%|########9 | 3964928/4422102 [00:07<00:00, 1000823.52it/s]
93%|#########3| 4128768/4422102 [00:07<00:00, 827450.19it/s]
99%|#########9| 4390912/4422102 [00:08<00:00, 1038838.35it/s]
100%|##########| 4422102/4422102 [00:08<00:00, 544202.69it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 19365270.84it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()¶
ToTensor
는 PIL Image나 NumPy ndarray
를 FloatTensor
로 변환하고, 이미지의 픽셀의 크기(intensity) 값을 [0., 1.] 범위로
비례하여 조정(scale)합니다.
Lambda 변형(Transform)¶
Lambda 변형은 사용자 정의 람다(lambda) 함수를 적용합니다. 여기에서는 정수를 원-핫으로 부호화된 텐서로 바꾸는
함수를 정의합니다.
이 함수는 먼저 (데이터셋 정답의 개수인) 크기 10짜리 영 텐서(zero tensor)를 만들고,
scatter_ 를 호출하여
주어진 정답 y
에 해당하는 인덱스에 value=1
을 할당합니다.
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
더 읽어보기¶
Total running time of the script: ( 0 minutes 19.238 seconds)