Link

미리 학습된 ConvNet으로 부터 전이 학습

이 튜토리얼에서, 너는 미리 학습된 네트워크로부터 전이학습을 통하여 강아지와 고양이를 분류하는 방법을 배우게 될 것이다.

미리 학습된 모델은 매우 큰 데이터 셋으로 부터 미리 학습되어 저장되어 있다. 사전 훈련 된 모델을 그대로 사용하거나 전이학습을 통하여 모델을 우리가 원하는 방법으로 커스터마이징 할 수 있다.

이미지 classification을 위한 전이학습은 직관적으로 다음과 같이 표현할 수 있다.

만약 모델이 매우 크고 일반적인 데이터셋을 통하여 훈련되었다면, 이 모델은 실제 visual 세상의 일반적인 모델로서 효과적으로 작동한다.

너는 큰 데이터셋에서 큰 모델을 학습하여 처음부터 시작할 필요 없이, 이러한 학습 된 모델을 사용할 수 있다.

이 노트북에서 너는 사전 훈련된 모델을 커스터마이징 하는 두 가지 방법을 시도할 것이다.

  1. 특징 추출: 미리 학습된 네트워크에서 배운 표현을 사용하여 새로운 학습데이터에서 의미있는 특징을 추출할 것이다. 미리 학습된 모델에다가 새로운 classifier를 추가하기만 하면 미리 학습된 모델의 특징맵을 재사용 할 수 이싿.

  2. 미세 조정(fine-tuning): 새로 추가 할 classifier 레이어와 기존 모델의 마지막 레이어를 함께 훈련시킨다. 이를 통해 기존의 모델에서 고차원 특징 표현을 미세 조정하여 우리가 원하는 작업에 보다 적합하게 만들 수 있다.

너는 일반적 머신 러닝 작업플로우를 따를 것이다.

  1. 데이터 검사 및 이해
  2. 케라스 ImageDataGenerator를 통하여 input pipeline 빌드
  3. 모델 구성
    • 미리 학습된 모델 Load(and 미리 학습된 가중치)
    • classfication layer 쌓기
  4. 모델 학습
  5. 모델 검증
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    

Data preprocessing

Data download

강아지와 고양이 데이터셋을 로드하기 위해 TensorFlow Datasets를 사용한다

tfds 패키지는 사전 정의 된 데이터를 로드하는 가장 쉬운 방법이다.

만약 너의 고유한 데이터셋을 가지고 있고 사용하고 싶다면 loading image data 를 찹조해라

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

tfds.load 메서드는 데이터를 다운로드 및 캐슁하고, tf.data.Dataset 객체를 리턴한다.

이 객체는 데이터를 조작하고 너의 모델로 공급하는 강력하고 효율적인 방법을 제공한다

“cats_vs_dogs”는 표준 분할을 정희하지 않았기 때문에 80%(학습), 10%(검증), 10%(테스트) 비율로 데이터를 나눈다.

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

tf.data.Dataset 객체의 결과는 (image, label) 쌍이고, 이미지는 3개 채널의 변수 shape이고 label은 스칼라 형태이다.

print(raw_train)
print(raw_validation)
print(raw_test)
> <DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
> <DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
> <DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>

훈련 세트에서 처음 두 개의 이미지와 레이블을 보여줍니다.

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):
  plt.figure()
  plt.imshow(image)
  plt.title(get_label_name(label))



데이터 구성 방식

데이터를 구성하기 위해 tf.image 모듈을 사용한다. 이미지를 고정된 인풋 사이즈로 resize하고, [-1, 1] 범위의 값으로 rescale한다

IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

map 함수를 이용하여 각각의 데이터 셋에 적용한다

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

셔플과 배치 값을 정한다

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

데이터의 배치를 검사한다

for image_batch, label_batch in train_batches.take(1):
   pass

image_batch.shape
> TensorShape([32, 160, 160, 3])

미리 학습된 convnets로 부터 base model 생성하기

너는 구글에서 개발 된 MobileNet V2 모델로부터 base model을 만들 수 있을 것이다. 140만개의 이미지와 1000개의 클래스로 이루어진 이미지넷 데이터 셋으로 부터 미리 훈련된 모델이다. ImageNet은 jackfruit와 주사기 같은 넓은 범위의 카테고리를 가진 연구용 학습 데이터 셋이다. 이 지식의 기반은 우리의 데이터셋으로 부터 고양이와 강아지를 구분하는데 도움을 줄 것이다.

먼저, 특징 추출을 위해 사용할 MobileNet V2의 레이어를 선택해야 한다. 대부분의 머신러닝 모델 다이어그램은 bottom에서 top으로 가기에 마지막 classification 레이어(top)은 유용하지 않다. 대신 너는 일반적인 관례에 따라 평탄화 작업(flatten operation) 이전의 마지막 레이어를 사용할 수 있다. 이 레이어는 “bottleneck layer”라고 불리운다. bottleneck 레이어는 final/top 레이어에 비해서 더 일반적인 특징들을 가지고 있고 분류할 수 있다.

일반적인 Convolution Network 구성도


MobileNet V2 구성도


먼저, ImageNet 데이터 셋으로 부터 학습 된 가중치를 가진 MobileNet V2 모델을 인스턴스화 한다. include_top=False 옵션값을 줌으로써, 너는 top에 classification 레이어를 포함하지 않는 특징 추출에 가장 최적화 된 네트워크를 불러올 것이다.

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

이 특징 추출기는 1601603 이미지를 551280의 특징들의 블록으로 변환시킨다. 이미지들의 배치가 어떻게 되는지 밑의 예를 통해 확인해봐라

feature_batch = base_model(image_batch)
print(feature_batch.shape)
> (32, 5, 5, 1280)

특징 추출

이 스텝에서, 너는 이전 스텝으로부터 만들어진 합성곱(convolutional) base를 고정할 것이다. 추가적으로, 너는 가장 상단에 classifier(분류기)를 추가할 것이고, 높은 수준의 classifier를 학습할 것이다.

컨볼루션 베이스 고정

모델을 학습하고 컴파일 하기 전에 컨볼루션 베이스를 고정하는 것은 중요하다. layer.trainable = False 옵션을 통하여 트레이닝 도중에 이 레이어의 가중치가 업데이트 되는 것을 방지한다. MobileNet V2는 많은 레이어를 가지고 있으므로 모든 모델의 학습가능 플레그를 False로 세팅하여 모든 레이어를 고정할 수 있다.

base_model.trainable = False
# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 161, 161, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984

분류 헤드를 추가하기(Add a classification head)

특징 블록으로 부터 예측기를 생성하기 위해서 5*5 공간 위치를 평균화 한다. 평균화를 위하여 tf.keras.layers.GlobalAveragePooling2D를 사용하고 이 레이어는 feature들을 이미지 한장당 1280개의 요소를 가지는 벡터로 변환한다.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
> (32, 1280)

이 feature들을 이미지 하나당 하나의 예측으로 변환하기 위하여 tf.keras.layers.Dense 를 적용한다. 여기서 활성화 함수를 사용할 필요는 없다 왜냐하면 이 예측은 logit 또는 원시 예측 값을 가지고 있을 것이기 때문이다. 양수는 class 1을 예측하고 음수는 class 0을 예측한다

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
> (32, 1)

이제 특징 추출기와 앞의 2 layer를 tf.keras.Sequential 모델을 사용하여 쌓는다.

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

모델 컴파일하기

너는 트레이닝 작업 전에 반드시 컴파일을 해야한다. 이 모델은 linear output(선형 출력)을 제공하기 때문에 from_logits=True 옵션으로 이진 cross-entropy 손실함수를 사용해라.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
________________________________________________________________

MobileNet의 250만개의 파라메터는 고정되었지만, 1200개의 학습가능한 파라메터가 Dense 레이어에 남아있다. 또한 이것은 가중치와 편향으로 두개의 tf.Variable 객체로 나눌 수 있다.

len(model.trainable_variables)
> 2

모델 학습하기

10 에폭을 학습하면 너는 96% 이하의 정확도를 얻을 수 있다.

initial_epochs = 10
validation_steps=20

loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
20/20 [==============================] - 1s 63ms/step - loss: 0.6461 - accuracy: 0.5469
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.65
initial accuracy: 0.55
history = model.fit(train_batches,
                    epochs=initial_epochs,
                    validation_data=validation_batches)
Epoch 1/10
582/582 [==============================] - 16s 27ms/step - loss: 0.3501 - accuracy: 0.8310 - val_loss: 0.1772 - val_accuracy: 0.8964
Epoch 2/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1947 - accuracy: 0.9200 - val_loss: 0.1358 - val_accuracy: 0.9243
Epoch 3/10
582/582 [==============================] - 13s 23ms/step - loss: 0.1623 - accuracy: 0.9320 - val_loss: 0.1295 - val_accuracy: 0.9304
Epoch 4/10
582/582 [==============================] - 13s 22ms/step - loss: 0.1457 - accuracy: 0.9393 - val_loss: 0.1162 - val_accuracy: 0.9368
Epoch 5/10
582/582 [==============================] - 13s 22ms/step - loss: 0.1366 - accuracy: 0.9428 - val_loss: 0.1105 - val_accuracy: 0.9424
Epoch 6/10
582/582 [==============================] - 13s 23ms/step - loss: 0.1343 - accuracy: 0.9429 - val_loss: 0.1065 - val_accuracy: 0.9445
Epoch 7/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1256 - accuracy: 0.9471 - val_loss: 0.1034 - val_accuracy: 0.9475
Epoch 8/10
582/582 [==============================] - 13s 23ms/step - loss: 0.1236 - accuracy: 0.9506 - val_loss: 0.1101 - val_accuracy: 0.9458
Epoch 9/10
582/582 [==============================] - 13s 22ms/step - loss: 0.1227 - accuracy: 0.9474 - val_loss: 0.0998 - val_accuracy: 0.9523
Epoch 10/10
582/582 [==============================] - 13s 23ms/step - loss: 0.1211 - accuracy: 0.9483 - val_loss: 0.1031 - val_accuracy: 0.9514

미세 조정 ( Fine tuning )

특징 추출 실험에서 너는 MobileNet V2 base model의 상단 레이어 몇개만 학습시켰다. 사전 훈련된 네트워크의 가중치는 학습되지 않았다.

성능을 더욱 향상시키는 한 가지 방법은 우리가 추가 한 분류기의 훈련과 함께 사전 훈련 된 모델의 최상위 레이어 가중치를 훈련 (또는 “미세 조정”)하는 것입니다. 훈련 과정을 통해 가중치는 일반적인(generic) 특징을 포함한 맵에서 훈련하려는 데이터 집합(cats and dogs)과 관련된 기능으로 강제 조정됩니다.

  • Note 사전 훈련 된 모델을 훈련 불가능으로 설정하여 최상위 분류기(classifier)를 훈련 한 뒤에만 ​시도해야합니다. 사전 훈련 된 모델 위에 무작위로 초기화 된 분류기를 추가하고 모든 레이어를 공동으로 훈련하려고하면 경사하강법 업데이트의 크기가 너무 클 수 있으며 (분류기-classifier 의 임의 가중치로 인해) 사전 훈련 된 모델은 배운 것을 잊어 버릴 수 있습니다.

또한 전체 MobileNet 모델이 아닌 소수의 최상위 계층을 미세 조정해야합니다. 대부분의 convolution 네트워크에서 계층이 높을수록 계층이 더 전문화됩니다. 처음 몇 층은 거의 모든 유형의 이미지로 일반화되는 매우 간단하고 일반적인 기능을 학습합니다. 더 높은 수준으로 올라가면 기능이 모델이 훈련 된 데이터 세트에 점점 더 구체적이됩니다. 미세 조정(fine tuning)의 목표는 이러한 특수 기능을 일반 학습을 덮어 쓰지 않고 새 데이터 세트와 함께 사용할 수 있도록 조정하는 것입니다.

모델의 상단 레이어 고정 해제

base_model 고정을 해제하고 bottom 레이어를 훈련 할 수 없도록 설정하기 만하면됩니다. 그런 다음 모델을 다시 컴파일하고 훈련을 다시 시작해야합니다.

base_model.trainable = True
# base model에 얼마나 많은 layer가 존재하는지 확인한다
print("Number of layers in the base model: ", len(base_model.layers))

# 미세조정을 시작할 레이어 위치를 정한다
fine_tune_at = 100

# 미세조정을 시작하기 전 레이어들은 다 고정한다
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
> Number of layers in the base model:  155

모델 컴파일하기

더 낮은 학습률(learning rate)를 적용하여 모델을 컴파일한다

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,863,873
Non-trainable params: 395,392
_________________________________________________________________
len(model.trainable_variables)
58

계속해서 모델 학습시키기

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_batches,
                         epochs=total_epochs,
                         initial_epoch =  history.epoch[-1],
                         validation_data=validation_batches)

요약

  • 특징 추출을 위해 사전 훈련 된 모델 사용 : 작은 데이터 세트로 작업 할 때 동일한 도메인에서 더 큰 데이터 세트에 대해 훈련 된 모델에서 학습 한 기능을 활용하는 것이 일반적입니다. 사전 훈련 된 모델을 인스턴스화하고 완전히 연결된 분류기를 맨 위에 추가하면됩니다. 사전 훈련 된 모델은 “동결”되고 분류기의 가중치 만 훈련 중에 업데이트됩니다. 이 경우 컨벌루션베이스는 각 이미지와 관련된 모든 기능을 추출했으며 추출 된 기능 세트가 제공된 이미지 클래스를 결정하는 분류기를 훈련했습니다.

  • 사전 훈련 된 모델 미세 조정 : 성능을 더욱 향상시키기 위해 사전 훈련 된 모델의 최상위 계층을 미세 조정을 통해 새로운 데이터 세트로 재사용 할 수 있습니다. 이 경우 모델이 데이터 세트와 관련된 고급 기능을 학습 할 수 있도록 가중치를 조정했습니다. 이 기술은 일반적으로 훈련 데이터 세트가 크고 사전 훈련 된 모델이 훈련 된 원래 데이터 세트와 매우 유사한 경우에 권장됩니다.