인공지능을 좋아하는 곧미남

pytorch DataLoader 본문

code_study/pytorch

pytorch DataLoader

곧미남 2022. 1. 10. 16:41

pytorch의 패키지를 사용하여 Image Data를 Load하는 코드 설명과 과정에 대해서 알아보겠습니다.

 

Data Augmentation과 os.listdir를 이용한 Window File folder에서 파일을 가져오는 방법도 있으니 참고하시면됩니다.

 

제가 pytorch에서 사용한 Dataloader관련 패키지는 "torch.utils.data.Dataset"과 "torch.utils.data.DataLoader"입니다.

 

오늘의 내용은 아래와 같이 간략히 정리됩니다.

 

- INDEX -

 

1. import os를 이용한 image file명 불러와 list에 저장하기.

 

2. torch.utils.data.Dataset 클래스를 사용하여 저장된 image file명 list에서 file경로를 불러와 각 이미지를 input data 형식으로 변환

   -> 변환내용: image load(opencv), transform(numpy array), augmentation(albumentation)

 

3. torch.utils.data.DataLoader 클래스를 사용하여 전체 데이터를 batch size, channel, height, width 형태의 tensor를 제네러이터 형식으로 생성.


1. import os를 이용한 image file명 불러와 list에 저장

train_input_path = './datasets/raw_data/train/input'
train_mask_path = './datasets/raw_data/train/mask'
validation_input_path = './datasets/raw_data/valid/input'
validation_mask_path = './datasets/raw_data/valid/mask'
test_input_path = './datasets/raw_data/test/input'
test_mask_path = './datasets/raw_data/test/mask'

sort_function = lambda f: int(''.join(filter(str.isdigit, f)))

input_list_train = os.listdir(train_input_path)
input_list_train = [file for file in input_list_train if file.endswith("jpg")]
input_list_train.sort(key=sort_function)
mask_list_train = os.listdir(train_mask_path)
mask_list_train = [file for file in mask_list_train if file.endswith('jpg')]
mask_list_train.sort(key=sort_function)

input_list_valid = os.listdir(validation_input_path)
input_list_valid = [file for file in input_list_valid if file.endswith("jpg")]
input_list_valid.sort(key=sort_function)

mask_list_valid = os.listdir(validation_mask_path)
mask_list_valid = [file for file in mask_list_valid if file.endswith('jpg')]
mask_list_valid.sort(key=sort_function)

input_list_test = os.listdir(test_input_path)
input_list_test = [file for file in input_list_test if file.endswith("jpg")]
input_list_test.sort(key=sort_function)
mask_list_test = os.listdir(test_mask_path)
mask_list_test = [file for file in mask_list_test if file.endswith('jpg')]
mask_list_test.sort(key=sort_function)

print(f"inputs train data num : {len(input_list_train)}")
print(f"masks train data num : {len(mask_list_train)}")
print(f"inputs valid data num : {len(input_list_valid)}")
print(f"masks valid data num : {len(mask_list_valid)}")
print(f"inputs test data num : {len(input_list_test)}")
print(f"masks test data num : {len(mask_list_test)}")

저는 Sort function을 특이하게 lambda를 사용하여 구현해보았습니다.

 

전체적인 Precess: image file path 접근 -> list에 file name 저장 -> sort


2. torch.utils.data.Dataset 클래스를 사용하여 INPUT 형식 변환

- torch.utils.data.Dataset을 통해서 파일 경로 정보를 전달받아 __init__, __len__, __getitem__ 메서드를 이용해 전체 데이터 집합을 생성합니다.

class CustomImageDataset(Dataset):
	def __init__(self, img_list, mask_list, img_path, mask_path, is_train_data = False, resize=None, transform=None, augmentation=None):
        self.img_list = img_list
        self.mask_list = mask_list
        self.img_path = img_path
        self.mask_path = mask_path
        self.transform = transform
        self.resize = resize
        self.is_train_data = is_train_data
        self.augmentation = augmentation

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.img_path, self.img_list[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.mask_path, self.mask_list[idx]), cv2.IMREAD_GRAYSCALE)
        mask = np.expand_dims(mask, -1)

        # if self.is_train_data == True:
        #     image, mask = self.augmentation(image, mask)

        if self.transform:
            image, mask = self.transform(image, mask, self.is_train_data)

        image, mask = self.resize(image, mask)

        return image, mask

- __init__: Class의 생성자가 생성되는 순간 동시에, class로 전달되는 parameters가 정의된다.

- __len__: 전체 이미지 file 수를 입력하면 그 수 만큼 반복한다.

- __getitem__: 반복할때, 순차적인 index를 불러와 해당 index를 가지고 image data를 load하고 저장 및 반환한다. 


3. torch.utils.data.DataLoader 클래스를 사용하여 데이터 제네레이터 생성

- torch.utils.data.DataLoader을 통해서 batch size만큼 묶은 최종 batch size, channel, height, width 형태의 data를 제네러이터 형식으로 저장하여 생성합니다.

    training_data = CustomImageDataset(input_list_train, mask_list_train, train_input_path, train_mask_path,
                                       is_train_data=True, resize=resize_data, transform=transform, augmentation=augmentation)
    validation_data = CustomImageDataset(input_list_valid, mask_list_valid, validation_input_path, validation_mask_path,
                                         is_train_data=False, resize=resize_data, transform=transform, augmentation=augmentation)
    test_data = CustomImageDataset(input_list_test, mask_list_test, test_input_path, test_mask_path,
                                   is_train_data=False, resize=resize_data, transform=transform, augmentation=augmentation)

    train_dataloader = DataLoader(training_data, batch_size=3, shuffle=True)
    valid_dataloader = DataLoader(validation_data, batch_size=3, shuffle=False)
    test_dataloader = DataLoader(test_data, batch_size=3, shuffle=False)

- DataLoader(Dataset_type_instance, batch_size, shuffle, ...) 여러 parameters가 존재하고 pytorch의 공식홈페이지에 보시면 더욱더 자세한 설명이 있습니다.

 

- train_dataloader = DataLoader(Dataset_type_instance, batch_size, shuffle, ...) 형식으로 구현하면, train_dataloader가 제네레이터가 됩니다.

 

제가 Semantic Segmentation 구현 모델에서 사용한 main.py, data_loader.py 코드가 저의 Github에 있으니 한번 보시고 참고하시면 좋을 것 같습니다. (My git url: https://github.com/sangheonEN/segmentation_pytorch_fcn_deeplabv3)

 

틀린 설명이나 부적절한 내용이 포함되어 있으면 거침없이 댓글해주시면 수정 보완하겠습니다.

 

감사합니다.

Comments