Data Science/Tensorflow

Tensorflow Generative Adversarial Network 예시 코드(GAN, MNIST Data Set)

상어군 2022. 11. 1. 01:19
반응형

1. 기본 세팅

1.1 라이브러리

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras import Model, layers

1.2 GPU 세팅

os.environ["CUDA_VISIBLE_DEVICES"]="0"
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)
        
# GPU 할당 상태 확인
tf.test.is_gpu_available()
#=> True

2. Data Load

본 포스팅에서 사용되는 MNIST 데이터셋은 손으로 쓰여진 0~9 숫자들의 이미지를 가지는 데이터셋이다.

출처 위키백과 MNIST 데이터베이스

tensorflow_dataset.load()를 통해서 MNIST 데이터셋을 받아올 수 있다.

train set 60000개, test set 10000개의 이미지로 구성되어 있다.

각 이미지는 28*28 픽셀로 이루어져있으며, 각 픽셀에는 해당하는 색정보가 숫자로 들어있다.

(색정보는 0~255의 숫자로 표현된다.)

dataset = tfds.load('mnist', split='train')

3. Data Preprocessing

batch_size = 1024
train_data = dataset.map(lambda data: tf.cast(data['image'], tf.float32) / 255.).batch(batch_size)

4. GAN 모델 생성 및 학습

4.1 Generator(생성자)

# GAN의 생성자 (Generator)
class Generator(Model):
    def __init__(self, latent_dim):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.generator = tf.keras.Sequential([
            layers.Dense(7 * 7 * 32, activation='relu'),
            layers.Reshape((7, 7, 32)), # (batch, 7, 7, 32)
            layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu'),
            layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu'),
            layers.Conv2DTranspose(1, 3, strides=1, padding='same', activation = 'sigmoid')
        ])
    
    def call(self, z):
        return self.generator(z)

4.2 Discriminator(판별자)

# GAN의 판별자 (Discriminator)
class Discriminator(Model):
    def __init__(self, latent_dim):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.discriminator = tf.keras.Sequential([
            layers.Conv2D(32, 3, strides=2, activation='relu', padding='same', input_shape=(28, 28, 1)), # (batch, 14, 14, 32)
            layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
            layers.Flatten(),
            layers.Dense(1)
            ## TODO: 2개의 layer 추가
            ## 1: flatten layer
            ## 2: dense layer (dim: 1)
            ## TODO ##
        ])
        
    def call(self, x):
        return self.discriminator(x)

4.3 GAN 구축

n_epochs = 200
latent_dim = 10
log_interval = 20 # 학습성능 출력 간격

# TODO: 생성자와 판별자 선언
discriminator = Discriminator(latent_dim)
generator = Generator(latent_dim)

# TODO: Loss는 Binary Cross Entropy (BCE) 로 정의
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# TODO: Optimizer 정의 (adam with learning rate 1e-3)
optimizer_g = tf.keras.optimizers.Adam(1e-3)
optimizer_d = tf.keras.optimizers.Adam(1e-3)
def train_step(inputs):
    # 한 batch에서 사용할 난수생성 shape : (batch_size, latent_dim)
    # 즉, fake data를 만들기위한 random한 z 생성
    random_z = tf.random.normal([batch_size, latent_dim]) 

    with tf.GradientTape() as tape_g, tf.GradientTape() as tape_d:
        # random_z로부터 생성자가 fake data 생성 => g_z
        g_z = generator(random_z)  

        # 판별자가 data의 진위여부 판별 (1이면 real이다!, 0이면 fake이다! 라고 판단한 것)
        real_output = discriminator(inputs) # real data를 입력받아 판별
        fake_output = discriminator(g_z) # fake data를 입력받아 판별

        # tf.ones_like()는 주어진 tensor와 동일한 차원정보를 가지고 1로 채워진 tensor 생성
        # 즉, 모든 data가 real이라고 판단된 경우를 나타내기 위한 코드이다.
        # tf.zeros_like()는 주어진 tensor와 동일한 차원정보를 가지고 0로 채워진 tensor 생성
        # 즉, 모든 data가 fake라고 판단된 경우를 나타내기 위한 코드이다.
        
        # 생성자 loss 계산 : fake_output은 fake data에 대한 판별자의 판별 결과
        # 생성된 fake data가 진짜같이 생겼을 경우 fake_output에 1(real로 판단)이 많이 나와야함
        # 즉, fake_output에 1이 많을수록 loss는 낮아짐
        loss_g = loss(tf.ones_like(fake_output), fake_output)
        
        # 판별자 loss 계산 : real_output은 real data에 대한 판별자의 판별 결과
        # real data는 1로 판단하고 fake data는 0으로 판단해야함
        # real data에 대해서는 1로, fake data에 대해서는 0으로 판단할수록 loss는 낮아짐
        loss_d = loss(tf.ones_like(real_output), real_output) + loss(tf.zeros_like(fake_output), fake_output)

        # 생성자,판별자에 대해서 역전파 수행
        grads_g = tape_g.gradient(loss_g, generator.trainable_variables) 
        grads_d = tape_d.gradient(loss_d, discriminator.trainable_variables) 
        optimizer_g.apply_gradients(zip(grads_g, generator.trainable_variables))
        optimizer_d.apply_gradients(zip(grads_d, discriminator.trainable_variables))
    
    return loss_g, loss_d

4.4 GAN 학습

# Training
for epoch in range(1, n_epochs + 1):    
    total_loss_g, total_loss_d = 0, 0
    
    # 직접 배치별로 trin_step을 수행해서 loss를 받아옴
    for x in train_data:
        loss_g, loss_d = train_step(x)
        total_loss_g += loss_g
        total_loss_d += loss_d
    
    # 학습성능 출력간격별로 출력
    if epoch % log_interval == 0:
        print(f'Epoch {epoch:3d}: Generator loss {total_loss_g:.2f}, \
              Discriminator loss {total_loss_d:.2f}')

5. 데이터 생성(Generation) 및 확인

# 랜덤한 한개의 random_z를 생성하여 이미지가 어떻게 생성되는지 확인
random_z = tf.random.normal([1, latent_dim])

# 생성자에 random_z입력하여서 fake data생성
g_z = generator(random_z)

# 생성된 fake data 시각화
plt.imshow(g_z[0, :, :, 0], cmap='gray')
plt.show()

# Generator로부터 Image Sampling
def plot_latent_images(n, digit_size=28):
    noises = tf.random.normal([n, n, latent_dim])
    image_width = digit_size * n
    image_height = image_width
    image = np.zeros((image_height, image_width))

    for i in range(n):
        for j in range(n):
            z = tf.random.normal([1, latent_dim])
            g_z = generator(z)
            image[i * digit_size : (i+1)*digit_size, j * digit_size : (j+1) * digit_size] = g_z[0, :, :, 0]
    
    plt.figure(figsize=(10,10))
    plt.imshow(image, cmap='Greys_r')
    plt.axis('Off')
    plt.show()

반응형