파이썬 딥러닝 ai 스쿨 기초/lecture04
lecture04 2교시 인공지능 텐서플로우 실습 필기체 인식기3 CNN
junny1997
2021. 3. 25. 12:32
import tensorflow as tf
import time
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
X = tf.placeholder(tf.float32, [None, 784], name="X")
X_img = tf.reshape(X, [-1, 28, 28, 1])
# [filter_size, filter_size, channel, num_filter]
# tf.reshape로 784->28x28 1채널 img로 만듬 -1은 자동 계산
Y = tf.placeholder(tf.float32, [None, 10], name="Y")
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
# convoluntion and pool 1
W1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
# 3x3 필터 1채널 32개로 해석 표준편차 0.01
L1 = tf.nn.conv2d(X_img, W1, strides=[1, 1, 1, 1], padding='SAME')
# X_img(X)에 W1 convolution 적용
# strides: [batch(모든이미지 보려면 간격1), 가로, 세로, channel(추출한32개 뛰어넘을 필요없음 1)] // VALID
# 가로 세로 jump 간격(1은 모두), padding은 사이즈 줄지 않게 채워줌
# (?, 28, 28, 32)
L1 = tf.nn.relu(L1)
L1 = tf.nn.max_pool(L1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# tf.nn.avg_pool # padding same: assume strides=1 (stride 1일때만 유지, 아니면 줄음)
# ksize 2x2씩 나눠서 봄, 2x2씩 봤으니 2x2씩 jump 28x28 -> 14x14
# (?, 14, 14, 32)
L1 = tf.nn.dropout(L1, keep_prob=keep_prob)
# convolution and pool 2
W2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
# conv&pool 1에서 32채널로 추출 -> 64개로
L2 = tf.nn.conv2d(L1, W2, strides=[1, 1, 1, 1], padding='SAME')
# (?, 14, 14, 64)
L2 = tf.nn.relu(L2)
L2 = tf.nn.max_pool(L2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# (?, 7, 7, 64)
L2 = tf.nn.dropout(L2, keep_prob=keep_prob)
# flat화 시켜서 mlp로 분류
L2_flat = tf.reshape(L2, [-1, 7 * 7 * 64])
# 7 * 7 * 64 = 3136
W3 = tf.get_variable("W3", shape=[3136, 300], initializer=tf.contrib.layers.xavier_initializer())
b3 = tf.Variable(tf.random_normal([300]))
L3 = tf.nn.relu(tf.matmul(L2_flat, W3) + b3)
L3 = tf.nn.dropout(L3, keep_prob=keep_prob)
W4 = tf.get_variable("W4", shape=[300, 10], initializer=tf.contrib.layers.xavier_initializer())
b3 = tf.Variable(tf.random_normal([10]))
hypothesis = tf.nn.xw_plus_b(L3, W4, b3, name="hypothesis")
correct_prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=hypothesis, labels=Y))
summary_op = tf.summary.scalar("loss", cost)
learning_rate = 0.001
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# initialize
sess = tf.Session()
sess.run(tf.global_variables_initializer())
training_epochs = 30
batch_size = 100
# tensorboard
timestamp = str(time.strftime('%m-%d-%H-%M-%S', time.localtime(time.time()))) # runs/1578546654/checkpoints/
out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
train_summary_dir = os.path.join(out_dir, "summaries", "train")
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
val_summary_dir = os.path.join(out_dir, "summaries", "valid")
val_summary_writer = tf.summary.FileWriter(val_summary_dir, sess.graph)
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
max = 0
early_stopped = 0
global_step = 0
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples / batch_size)
print(f"total_batch:{total_batch}")
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
feed_dict = {X: batch_xs, Y: batch_ys, keep_prob: 0.8}
c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
avg_cost += c / total_batch
# 20 스텝 마다 저장
if i % 20 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'Iter:', '%04d' % i, 'training cost =', '{:.9f}'.format(avg_cost))
val_accuracy, summaries = sess.run([accuracy, summary_op], feed_dict={X: mnist.validation.images, Y: mnist.validation.labels, keep_prob: 1.0})
val_summary_writer.add_summary(summaries, global_step)
print(f'{(epoch+1) * i} step', 'Validation Accuracy:', val_accuracy)
if val_accuracy > max:
max = val_accuracy
early_stopped = epoch + 1
saver.save(sess, checkpoint_prefix, global_step=early_stopped)