ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Tensorflow Generative Adversarial Network 예시 코드(GAN, MNIST Data Set)
    Data Science/Tensorflow 2022. 11. 1. 01:19
    반응형

    1. 기본 세팅

    1.1 라이브러리

    import os
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import tensorflow_datasets as tfds
    from tensorflow.keras import Model, layers

    1.2 GPU 세팅

    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            tf.config.experimental.set_memory_growth(gpus[0], True)
        except RuntimeError as e:
            print(e)
            
    # GPU 할당 상태 확인
    tf.test.is_gpu_available()
    #=> True

    2. Data Load

    본 포스팅에서 사용되는 MNIST 데이터셋은 손으로 쓰여진 0~9 숫자들의 이미지를 가지는 데이터셋이다.

    출처 위키백과 MNIST 데이터베이스

    tensorflow_dataset.load()를 통해서 MNIST 데이터셋을 받아올 수 있다.

    train set 60000개, test set 10000개의 이미지로 구성되어 있다.

    각 이미지는 28*28 픽셀로 이루어져있으며, 각 픽셀에는 해당하는 색정보가 숫자로 들어있다.

    (색정보는 0~255의 숫자로 표현된다.)

    dataset = tfds.load('mnist', split='train')

    3. Data Preprocessing

    batch_size = 1024
    train_data = dataset.map(lambda data: tf.cast(data['image'], tf.float32) / 255.).batch(batch_size)

    4. GAN 모델 생성 및 학습

    4.1 Generator(생성자)

    # GAN의 생성자 (Generator)
    class Generator(Model):
        def __init__(self, latent_dim):
            super().__init__()
            
            self.latent_dim = latent_dim
            self.generator = tf.keras.Sequential([
                layers.Dense(7 * 7 * 32, activation='relu'),
                layers.Reshape((7, 7, 32)), # (batch, 7, 7, 32)
                layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu'),
                layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu'),
                layers.Conv2DTranspose(1, 3, strides=1, padding='same', activation = 'sigmoid')
            ])
        
        def call(self, z):
            return self.generator(z)

    4.2 Discriminator(판별자)

    # GAN의 판별자 (Discriminator)
    class Discriminator(Model):
        def __init__(self, latent_dim):
            super().__init__()
            
            self.latent_dim = latent_dim
            self.discriminator = tf.keras.Sequential([
                layers.Conv2D(32, 3, strides=2, activation='relu', padding='same', input_shape=(28, 28, 1)), # (batch, 14, 14, 32)
                layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
                layers.Flatten(),
                layers.Dense(1)
                ## TODO: 2개의 layer 추가
                ## 1: flatten layer
                ## 2: dense layer (dim: 1)
                ## TODO ##
            ])
            
        def call(self, x):
            return self.discriminator(x)

    4.3 GAN 구축

    n_epochs = 200
    latent_dim = 10
    log_interval = 20 # 학습성능 출력 간격
    
    # TODO: 생성자와 판별자 선언
    discriminator = Discriminator(latent_dim)
    generator = Generator(latent_dim)
    
    # TODO: Loss는 Binary Cross Entropy (BCE) 로 정의
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    # TODO: Optimizer 정의 (adam with learning rate 1e-3)
    optimizer_g = tf.keras.optimizers.Adam(1e-3)
    optimizer_d = tf.keras.optimizers.Adam(1e-3)
    def train_step(inputs):
        # 한 batch에서 사용할 난수생성 shape : (batch_size, latent_dim)
        # 즉, fake data를 만들기위한 random한 z 생성
        random_z = tf.random.normal([batch_size, latent_dim]) 
    
        with tf.GradientTape() as tape_g, tf.GradientTape() as tape_d:
            # random_z로부터 생성자가 fake data 생성 => g_z
            g_z = generator(random_z)  
    
            # 판별자가 data의 진위여부 판별 (1이면 real이다!, 0이면 fake이다! 라고 판단한 것)
            real_output = discriminator(inputs) # real data를 입력받아 판별
            fake_output = discriminator(g_z) # fake data를 입력받아 판별
    
            # tf.ones_like()는 주어진 tensor와 동일한 차원정보를 가지고 1로 채워진 tensor 생성
            # 즉, 모든 data가 real이라고 판단된 경우를 나타내기 위한 코드이다.
            # tf.zeros_like()는 주어진 tensor와 동일한 차원정보를 가지고 0로 채워진 tensor 생성
            # 즉, 모든 data가 fake라고 판단된 경우를 나타내기 위한 코드이다.
            
            # 생성자 loss 계산 : fake_output은 fake data에 대한 판별자의 판별 결과
            # 생성된 fake data가 진짜같이 생겼을 경우 fake_output에 1(real로 판단)이 많이 나와야함
            # 즉, fake_output에 1이 많을수록 loss는 낮아짐
            loss_g = loss(tf.ones_like(fake_output), fake_output)
            
            # 판별자 loss 계산 : real_output은 real data에 대한 판별자의 판별 결과
            # real data는 1로 판단하고 fake data는 0으로 판단해야함
            # real data에 대해서는 1로, fake data에 대해서는 0으로 판단할수록 loss는 낮아짐
            loss_d = loss(tf.ones_like(real_output), real_output) + loss(tf.zeros_like(fake_output), fake_output)
    
            # 생성자,판별자에 대해서 역전파 수행
            grads_g = tape_g.gradient(loss_g, generator.trainable_variables) 
            grads_d = tape_d.gradient(loss_d, discriminator.trainable_variables) 
            optimizer_g.apply_gradients(zip(grads_g, generator.trainable_variables))
            optimizer_d.apply_gradients(zip(grads_d, discriminator.trainable_variables))
        
        return loss_g, loss_d

    4.4 GAN 학습

    # Training
    for epoch in range(1, n_epochs + 1):    
        total_loss_g, total_loss_d = 0, 0
        
        # 직접 배치별로 trin_step을 수행해서 loss를 받아옴
        for x in train_data:
            loss_g, loss_d = train_step(x)
            total_loss_g += loss_g
            total_loss_d += loss_d
        
        # 학습성능 출력간격별로 출력
        if epoch % log_interval == 0:
            print(f'Epoch {epoch:3d}: Generator loss {total_loss_g:.2f}, \
                  Discriminator loss {total_loss_d:.2f}')

    5. 데이터 생성(Generation) 및 확인

    # 랜덤한 한개의 random_z를 생성하여 이미지가 어떻게 생성되는지 확인
    random_z = tf.random.normal([1, latent_dim])
    
    # 생성자에 random_z입력하여서 fake data생성
    g_z = generator(random_z)
    
    # 생성된 fake data 시각화
    plt.imshow(g_z[0, :, :, 0], cmap='gray')
    plt.show()

    # Generator로부터 Image Sampling
    def plot_latent_images(n, digit_size=28):
        noises = tf.random.normal([n, n, latent_dim])
        image_width = digit_size * n
        image_height = image_width
        image = np.zeros((image_height, image_width))
    
        for i in range(n):
            for j in range(n):
                z = tf.random.normal([1, latent_dim])
                g_z = generator(z)
                image[i * digit_size : (i+1)*digit_size, j * digit_size : (j+1) * digit_size] = g_z[0, :, :, 0]
        
        plt.figure(figsize=(10,10))
        plt.imshow(image, cmap='Greys_r')
        plt.axis('Off')
        plt.show()

    반응형

    댓글

Designed by Tistory.