BASE MODEL HEAD MODEL
- ํธ๋์คํผ ๋ฌ๋์, ํ์ต์ด ์ ๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์์ ์ฐ๋ฆฌ์ ๋ฌธ์ ์ ๋ง๊ฒ ํ์ฉํ๋ ๊ฒ์ด๋ฏ๋ก
ํ์ต์ด ์ ๋ ๋ชจ๋ธ์ base model๋ง ๊ฐ์ ธ์จ๋ค ์ฆ, head๋ชจ๋ธ์ ๋นผ๊ณ ๊ฐ์ ธ์์ ์ฌ์ฉ์๊ฐ ์ง์ head๋ชจ๋ธ์ ์์ธกํ๋ ๊ฒ์ด๋ค.
ํธ๋์คํผ ๋ฌ๋์ ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ๋ค
์ค์นํ๊ธฐ
!pip install tensorflow-gpu==2.0.0.alpha0
!pip install tqdm
Dogs vs Cats dataset ๋ค์ด๋ก๋๋ฐ๊ธฐ
!wget --no-check-certificate \
https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
-O ./cats_and_dogs_filtered.zip
์ํฌํธํ๊ธฐ
import os
import zipfile
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# ํ์ด์ฌ์ ์งํ์ํ๋ฅผ ํ์ํด ์ฃผ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
from tqdm import tqdm_notebook
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%matplotlib inline
์์ถํ์ผ ํ๊ธฐ
zipfile.ZipFile('/content/cats_and_dogs_filtered.zip').extractall()
์์ถํ์ผ์ ํ๋ฉด, ํ์ผ์ด ์ด๋ ๊ฒ ๋ฌ๋ค.
์์ถ์ ํผ ํด๋(๊ฐ์์ง, ๊ณ ์์ด ์ฌ์ง์ด ๋ค์ด์๋ ํด๋)๋ฅผ ๋ณ์๋ก ์ ์ฅํ๊ธฐ
train_dir = '/content/cats_and_dogs_filtered/train'
val_dir = '/content/cats_and_dogs_filtered/validation'
# ํ์ผ๋ก ์กด์ฌํ๋ ๋ฐ์ดํฐ๋ฅผ, ๋ฉ๋ชจ๋ฆฌ์ ์ค๋นํ๋ ๊ฒ์ด๋ค
์ด๋ฏธ ์ ๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์ ํค๋ ๋ผ์ด๋ด๊ธฐ
# ์ฐ๋ฆฌ๊ฐ ๋ง๋ค๋ ค๋ ๋ชจ๋ธ์ ์ธํ ์ด๋ฏธ์ง๋ 128ํ 128์ด ์ปฌ๋ฌ์ด๋ฏธ์ง๋ก ํ๋ค
IMG_SHAPE = (128, 128, 3)
base_model = tf.keras.applications.MobileNetV2( input_shape = IMG_SHAPE
, include_top = False # ํค๋ ๋ผ๊ธฐ
, weights = 'imagenet')
#๋ฒ ์ด์ค๋ชจ๋ธ์ Freezing์ํจ๋ค
base_model.trainable = False
ํค๋ ๋ชจ๋ธ ๋ง๋ค๊ธฐ
from keras.layers import Flatten, Dense
head_model = base_model.output
head_model = Flatten()(head_model)
head_model = Dense(128, 'relu')(head_model)
head_model = Dense(1, 'sigmoid')(head_model)
๋ฒ ์ด์ค๋ชจ๋ธ๊ณผ ํค๋๋ชจ๋ธ ํฉ์น๊ธฐ
from keras.models import Model
model = Model(inputs = base_model.input ,outputs = head_model)
๋ชจ๋ธ ์ปดํ์ผ ํ๊ธฐ
from keras.optimizers import RMSprop
model.compile(RMSprop(0.0001), loss='binary_crossentropy', metrics=['accuracy'] )
๋ฐ์ดํฐ ์ฆ๊ฐ์ํค๊ธฐ(Generators ์ฌ์ฉ)
์ด๋ฏธ์ง๋ฅผ ์ฌ๋ฌ๊ฐ๋ก ๋ณํ์ํค๊ณ ์กฐ์์์ผ์ ์ด๋ฏธ์ง๋ฅผ ๋ ํ๋ถํ๊ฒ ์์ฑํด์ค๋ค (๋ฐ์ดํฐ ์ฆ๊ฐ)
# shrear : ๋ํ๋ค. ์ฌ๋์ด ์์ ์๋ค๊ณ ์๋ ์ฌ์ง์ ๋ํ ์ด๋ฏธ์ง๋ก ๋ณํ์์ผ์ ์ด๋ฏธ์ง ์ฆ๊ฐ ์ํจ๋ค
# zoom : ์ด๋ฏธ์ง๋ฅผ ์ค ์ํจ๋ค
# flip : ์ด๋ฏธ์ง ์ข์ฐ๋ฐ์ ์ํจ๋ค (horizontal_filp)์ ์์๋๋ก ๋ฐ์ ์ํจ๋ค vertical_flip์ ์์์ผ๋ก ๋ฐ์ ์ํจ๋ค
# shift : ์ผ์ชฝ ์ค๋ฅธ์ชฝ ์ ์๋๋ก ์ด๋์์ผ์ ์น์ฐ์ ธ์ง ์ด๋ฏธ์ง๋ฅผ ๋ง๋ ๋ค
# rotation : ์ด๋ฏธ์ง๋ฅผ ํ์ ์ํจ๋ค
# fill_mode : ํฝ์
์ด ๋น์์ ธ์์ผ๋ฉด ์ฑ์ฐ๋ ๊ฒ์ด๋ค
<์ด๋ฏธ์ง ์ฆ๊ฐ์ ์์>
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
# width_shift_range๋ ์ด๋ฏธ์ง๋ฅผ ์ข์ฐ๋ก , height_shift_range๋ ์ด๋ฏธ์ง๋ฅผ ์ํ๋ก ์์ง์ด๋ ๊ฒ์ ์๋ฏธํ๋ค
train_datagen = ImageDataGenerator(rescale=1/255,
zoom_range=0.2,
width_shift_range=0.2,
height_shift_range=0.2)
# rescale ์ฝ๋๋ ํ์ผ์ ๋ํ์ด๋ก ๋ฐ๊ฟ์ ๋ฉ๋ชจ๋ฆฌ๋ก ๋ง๋๋ ์ฝ๋์ด๋ค. ์ด๋ฏธ์ง์ ํฝ์ ๊ฐ ๋ฒ์๋ 0์์ 255์ด๋ค. ๋ฐ๋ผ์ 255๋ก ๋๋์ด์ ํผ์ฒ์ค์ผ์ผ๋งํ๋ฉด ๋๋ค. ex ) ImageDataGenerator(rescale= 1/255)
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=(128,128),
class_mode='binary')
# target size๋ ์์์ ๋ง๋ ๋ชจ๋ธ์ input_shape์ ์ผ์นํด์ผ ํ๋ค.
# ๋ถ๋ฅ๋ฌธ์ ์์ ํด๋์ค๊ฐ ๋๊ฐ๋ฉด binary์ด๋ค.
# train_generator์ X_train, y_train๋ฅผ ๋ ๋ค ๊ฐ์ง๊ณ ์๋ค.(train_dir์ด๋ฏ๋ก)
val(๋ชจ์๊ณ ์ฌ๊ฐ์ ๊ฐ๋ ) ๋ํ ํผ์ณ์ค์ผ์ผ๋ง ํด์ค๋ค
val_datagen = ImageDataGenerator(rescale=1/255)
val_generator = val_datagen.flow_from_directory(val_dir,
target_size=(128,128),
class_mode='binary')
๋ชจ๋ธ ํ์ต์ํค๊ธฐ
epoch_history = model.fit(train_generator,
epochs=5,
validation_data=val_generator)
import numpy as np
from google.colab import files
from keras.preprocessing import image
uploaded = files.upload()
for fn in uploaded.keys():
# predicting images
path = '/content/' + fn
img = image.load_img(path, target_size=(128,128))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
# print(x)
images = np.vstack([x])
images = images / 255
# print(images)
classes = model.predict(images, batch_size=10)
print(classes[0])
if classes[0]>0.5:
print(fn + " is a dog")
else:
print(fn + " is a cat")
์ฌ์งํ์ผ์ ์ ํํ์ฌ ๋ฃ์ผ๋ฉด ๊ฐ์์ง์ธ์ง ๊ณ ์์ด์ธ์ง ์๋ณํ๋ ๋ชจ๋ธ์ด ์์ฑ๋๋ค.
ps.
Google์
Teachable Machine์ ์ฐ๋ฉด ๊ฐํธํ๋ค
https://teachablemachine.withgoogle.com/