ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Pytorch Lightning 사용가이드 및 코드 예시(Scene Classification Dataset)
    Data Science/Pytorch 2022. 11. 17. 02:24
    반응형

    1. Pytorch Lightning이란?

    Tensorflow의 Keras와 같이 Pytorch를 위한 라이브러리이다.

    pytorch lightning의 전반적 동작 구조는 아래와 같다.

    2. 라이브러리 준비

    from glob import glob
    
    import torch
    import torch.utils.data as data
    from torchvision import transforms
    import PIL
    
    from torchvision.datasets import ImageFolder
    
    from torchvision.models import resnet18
    from torch.optim import Adam
    from torch.nn import functional, CrossEntropyLoss
    from torch import argmax
    from torchmetrics import Accuracy
    
    import torchinfo
    
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks.early_stopping import EarlyStopping
    
    import pandas as pd

     

    3. 데이터셋 준비

    본 포스팅에서 사용하는 데이터셋은 Kaggle의 Scene Classification에서 가져왔다.

    https://www.kaggle.com/datasets/nitishabharathi/scene-classification

     

     

    Scene Classification 데이터셋에서 test는 validation을 의미하며 pred는 test를 의미한다.

    헷갈리니 주의하도록 하자.

    3.1 이미지 리스트 파싱 및 라벨링하기

    수치형 데이터와는 다르게 이미지 데이터는 계층적 폴더에 저장 후 사용된다.

    glob 라이브러리는 해당 경로 조건에 맞는 모든 파일 경로를 파싱한다.

    ( * 은 모든 문자라는 의미이다)

    # from glob import glob
    
    # 사용할 Train, Valid, Test 이미지 경로 모두 파싱
    train_images = glob("./md_scenes/seg_train/seg_train/**/*g")
    valid_images = glob("./md_scenes/seg_test/seg_test/**/*g")
    test_images = glob("./md_scenes/seg_pred/seg_pred/*g")
    
    print(train_images[0])
    #=> ./md_scenes/seg_train/seg_train/forest/8554.jpg

    각 이미지들은 해당하는 라벨(class)의 이름을 가지는 폴더안에 저장되어 있다.

    따라서 폴더명을 통해서 라벨링을 진행할수 있다.

    Test 데이터셋은 라벨이 제공되지 않는다.

    # Train, Valid의 경우 이미지 경로로부터 라벨 파싱
    def get_label(image_path):
        return image_path.split("/")[-2]
    
    train_labels = [get_label(image_path) for image_path in train_images]
    valid_labels = [get_label(image_path) for image_path in valid_images]

    이미지 라벨의 벡터화를 하기 위한 딕셔너리 선언

    # 이미지의 라벨링을 위한 딕셔너리 선언
    label_to_idx = {"buildings": 0, "forest": 1, "glacier": 2, "mountain": 3, "sea": 4, "street": 5}
    idx_to_label = {0: "buildings", 1: "forest", 2: "glacier", 3: "mountain", 4: "sea", 5: "street"}

    3.2 이미지에 적용할 전처리 선언

    수치형 데이터에서 scaling, transform 등 전처리를 수행하듯이, 이미지 데이터에도 전처리를 수행해주어야한다.

    대표적으로 사이즈 조절, 데이터타입 변환 등이 있다.

    이번 포스팅에서는 img_transform이라는 변수에 전처리 방법들을 넣고 사용한다.

    # 이미지 전처리에 사용될 공통된 기법을 전역변수로 선언
    img_transform = transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])

    3.3 데이터셋 Class 선언 및 데이터셋 생성

    모델에 이미지 데이터를 넣기 위해서 이미지와 라벨을 불러올 수 있는 데이터 class가 필요하다.

    (Test의 경우는 이미지만 불러온다.)

    Train/Valid를 위한 ImageDatasetWithLabel와 Test를 위한 ImageDatasetWithoutLabel은 입력받는 dataset_str의 구조가 다르니 주의하자.

    # 라벨링이 되어있는 Train/Validation Set을 위한 DataSet Class
    # dataset_str 형태 : [(이미지경로,라벨),(이미지경로,라벨), ... ,(이미지경로,라벨)]
    class ImageDatasetWithLabel(data.Dataset):
        def __init__(self, dataset_str):
            self.dataset_str = dataset_str # class 객체 생성시에 입력되는 값
            self.transform = img_transform # 위 cell에서 선언한 이미지 전처리 방법 사용
    
        def __len__(self):
            return len(self.dataset_str)
    
        def __getitem__(self, index):
            image_path, label = self.dataset_str[index] #특정 인덱스의 이미지주소/라벨 선택
            image = PIL.Image.open(image_path).convert("RGB") # PIL 객체로 변환(RGB형태)
            image = self.transform(image) # 앞서 선언한 전처리 기법들 적용
            return image, label_to_idx[label] # 이미지 객체(PIL:RGB)와 라벨(숫자) 반환
    
    # 라벨링이 없는 Test Set을 위한 DataSet Class
    # dataset_str 형태 : [이미지경로, 이미지경로, ... ,이미지경로]
    class ImageDatasetWithoutLabel(data.Dataset):
        def __init__(self, dataset_str):
            self.dataset_str = dataset_str # class 객체 생성시에 입력되는 값
            self.transform = img_transform # 위 cell에서 선언한 이미지 전처리 방법 사용
    
        def __len__(self):
            return len(self.dataset_str)
    
        def __getitem__(self, index):
            image_path = self.dataset_str[index] #특정 인덱스의 이미지주소 선택(라벨없음!!)
            image = PIL.Image.open(image_path).convert("RGB") # PIL 객체로 변환(RGB형태)
            image = self.transform(image) # 앞서 선언한 전처리 기법들 적용
            return image # 이미지 객체(PIL:RGB) 반환(라벨 없음!!)

    선언한 data class에 train/valid/test 이미지(&라벨)을 넣어 객체를 생성해준다.

    train_dataset = ImageDatasetWithLabel(list(zip(train_images,train_labels)))
    valid_dataset = ImageDatasetWithLabel(list(zip(valid_images,valid_labels)))
    test_dataset = ImageDatasetWithoutLabel(test_images)

    zip의 간단한 활용 예시는 아래 그림과 같다.

    3.3 (2) Imagefolder를 통한 class ImageDatasetWithLabel 대체

    pytorch에서는 라벨이 있는 데이터에 대해서 Imagefolder라는 라이브러리를 제공하고 있다.

    위에서 우리가 구현한 ImageDatasetWithLabel 과 동일한 기능을 하며,

    데이터가 존재하는 root 폴더의 경로만 지정해주면 되서 더욱 편하다.

    #from torchvision.datasets import ImageFolder
    
    train_path = "./md_scenes/seg_train/seg_train"
    valid_path = "./md_scenes/seg_test/seg_test"
    
    # 계층적 폴더 구조를 가지고 있다면 자동적으로 이미지 객체 생성 및 라벨링을 해주는 라이브러리
    train_dataset = ImageFolder(train_path, transform=img_transform)
    valid_dataset = ImageFolder(valid_path, transform=img_transform)

    Imagefolder 라이브러리를 사용할 경우 ImageDatasetWithLabel는 선언/생성하지 않아도 된다.

    3.4 데이터로더(DataLoader) 생성

    데이터 class가 상품을 낱개 포장한거라면,

    데이터로더는 낱개 포장된 상품을 박스(batch)에 모아담는 역할을 한다.

    모든 데이터를 한번에 모델에 넣고 학습할 수 있으면 좋겠지만,

    하드웨어적, 시간적 여건상으로 인해서 batch로 분할하여 학습을 진행한다.

    데이터로더는 모델이 학습(및 평가)할 때 batch 사이즈에 맞게 데이터를 전달해주는 역할을 하며,

    데이터 shuffle이나 사용할 thread 개수 설정등의 편의를 제공해준다.

    # import torch.utils.data as data
    
    train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=16)
    valid_dataloader = data.DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=16)
    test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=16)

     

    4. 모델 생성 및 학습

    4.1 모델 선언

    pytorch lightning은 기존 pytorch에서 일일히 구현해야했던 역전파 등의 구간들을 자동화해주었다.

    사용자들은 이미 정의된 함수들을 Overriding하여 pytorch 구현을 더욱 쉽게 할수 있다.

    여기서 우리가 재정의 할수 있는 함수들은 아래와 같으며, 이번 포스팅에서 사용하지 않는 함수는 회색처리하였다.

    • __init__
    • forward
    • configure_optimizers
    • train_dataloader
    • training_step
    • val_dataloader
    • validation_step
    • test_dataloader
    • test_step
    • predict_step

    __init__은 class 내에서 사용될 필드값을 정의하는 구간이다.

    대표적으로 사용할 model의 Layer들을 정의하며 그 외에도 손실함수(loss func)나 학습률(learning rate) 등을 정의힌다.

    본 포스팅에서는 torchvision.model에서 제공하는 resnet18을 사용하여 Layer를 구성하는 과정을 생략했다.

     

    forward는 앞서 선언된 model이 학습될때 동작하는 방식을 정의해준다.

    이미 레이어 구성이 되어있는 resnet18을 model에 저장하였기에 model(x)만을 반환해준다.

     

    configure_optimizers는 최적화 함수를 선언하는 함수이다.

     

    training_step, validation_step는 각각 train 및 validation의 batch 단위에서 어떠한 task를 수행할지 지정해준다.

    batch 데이터를 모델에 넣고 결과값에 대해서 loss를 구하는 task등이 수행된다.

    train은 역전파를 수행해야하기에 loss를 return하고, validation은 역전파를 수행하지 않기때문에 return이 없다.

     

    self.log는 학습에 따른 로깅 기능을 의미한다.

    on_step은 step마다 log를 찍는것을 의미하며 on_epoch는 epoch마다 log를 찍는 것을 의미한다.

    둘다 True로 설정하면 둘다 찍히게 된다.

    아래 그림은 on_epoch=True, on_step=False 일 때의 예시이다.

    # from torchvision.models import resnet18
    # from torch.optim import Adam
    # from torch.nn import functional, CrossEntropyLoss
    # from torch import argmax
    # from torchmetrics import Accuracy
    
    class ResNetClassifier(pl.LightningModule):
        # class 객체 생성시에 초기화 및 모델선언 파트
        def __init__(self, class_count, lr=0.0001):
            super().__init__()
            self.model = resnet18(num_classes=class_count)
            self.learning_rate = lr
            self.num_classes = class_count
            self.loss_fn = CrossEntropyLoss()
            self.accuracy = Accuracy()
        
        # 모델의 forward(순전파) 정의 (tensorflow에선 call)
        def forward(self, X):
            return self.model(X)
        
        def configure_optimizers(self):
            return Adam(self.parameters(),lr=self.learning_rate)
        
        def training_step(self, batch, batch_idx):
            img, real_label = batch # batch에서 이미지와 라벨 분리
            pred_label = self(img) # 모델에 이미지를 넣어서 예측하기
            # real_label이 숫자이기 때문에 벡터로 변환 0 => [1,0,0,0,0,0]
            real_label = functional.one_hot(real_label, num_classes=self.num_classes).float()
            
            # loss 및 accuracy 측정
            loss = self.loss_fn(pred_label, real_label) # loss 계산
            #acc = self.accuracy(argmax(pred_label), argmax(real_label))
            acc = (argmax(real_label,1) == argmax(pred_label,1)).type(torch.FloatTensor).mean()
            
            # 학습시에 로깅이 되도록 설정
            self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
            self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
            return loss
        
        def validation_step(self, batch, batch_idx): # Train과 동일
            img, real_label = batch
    
            pred_label = self(img)
            real_label = functional.one_hot(real_label, num_classes=self.num_classes).float()
            
            loss = self.loss_fn(pred_label, real_label) # loss 계산
            #acc = self.accuracy(argmax(pred_label), argmax(real_label))
            acc = (argmax(real_label,1) == argmax(pred_label,1)).type(torch.FloatTensor).mean()
            
            self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
            self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    4.2 모델 생성

    모델 생성은 앞서 만들어준 class 객체를 생성하면 된다.

    이때 __init__ 함수에서 라벨의 개수에 대해서 입력하기로 했기때문에 6을 매개변수로 넣어준다.

    model = ResNetClassifier(6) # 6은 라벨의 개수

    모델의 구조를 보고싶다면 아래의 코드로 확인 가능하다.

    # import torchinfo
    
    torchinfo.summary(model)

    4.3 모델 학습

    모델 학습의 경우 pytorch lightning의 Trainer를 통해 쉽게 수행할 수 있다.

    pytorch에서 todevice로 직접 gpu를 지정해주던것을 Trainer에서 한번만 지정해주면 된다.

    또한 early stopping의 경우도 아래와 같이 쉽게 지정 가능하다.

    trainer.fit으로 모델을 학습 및 검증할 수 있다.

    (4.1에서 생략한 train_dataloader 등을 선언한다면 fit에서 데이터로더를 넣어주지 않아도 된다.)

    # import pytorch_lightning as pl
    # from pytorch_lightning.callbacks.early_stopping import EarlyStopping
    
    trainer = pl.Trainer(accelerator='gpu', devices=-1, # gpu 선언
                         max_epochs=3, # 최대 반복 학습 횟수 
                         # early stopping 기법 적용(validation의 loss 기준)
                         callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
    trainer.fit(model,train_dataloader,valid_dataloader)

    max_epoch의 설정에 따라서 3번만에 종료됐음을 알리는 메시지가 떠있다.

    train의 accuracy는 약 84% validation은 약 78%로 확인된다.

    (간단한 포스팅을 위하여 max_epoch을 3으로 설정했다.)

     

    5. 테스트

    pytorch lightning에서 예측의 경우 trainer.predict로 약속되어 있다.

    (위에서 생략한 predict_step 함수를 재정의하여 기능을 변경할 수도 있다.)

    trainer.predict의 매개변수로는 model과 사용할 test셋의 데이터로더가 필요하다.

    pred_lst = trainer.predict(model, test_dataloader)

    return된 pred_lst를 살펴보면 아래와 같다.

    라벨이 6개이기때문에 길이 6의 tensor들이 반환이 되었다.

    pred_lst[:3]

    예측 결과에 사용된 이미지이름과 라벨을 씌워서 사용자가 알기 쉽게 변경해보자

    #import pandas as pd
    
    pred_filenames = [x.split('/')[-1] for x in pred_images]
    preds = [label_num_dic[int(torch.argmax(x))] for x in pred_lst]
    
    pred_df = pd.DataFrame({"image": pred_filenames, "label": preds})
    pred_df.head()

    해당 결과를 csv로 저장한다.

    pred_df.to_csv("submission_pl.csv", index=False)
    반응형

    댓글

Designed by Tistory.