자기개발/데이터분석

[머신러닝] k-NN(k-Nearest Neighbors) 알고리즘 학습 - 1)손글씨 분류

코대장 2021. 9. 4. 22:57
반응형

[머신러닝] k-NN(k-Nearest Neighbors) 알고리즘 학습 - 1) 손글씨 분류

서론

k-NN 알고리즘(지도학습)은 가장 간단한 분류(Classification) 머신러닝 알고리즘입니다. 새로운 데이터(input) 예측 할때, train 데이터셋에서 가장 가까운 데이터 포인트(최근접 이웃)을 찾는 방식이죠.

본론

k-NN 알고리즘은 간단한 알고리즘이지만 이미지(영상)에서 글자인식, 영화,음악 등 상품추천에 대한 선호도 예측, 유전자 데이터의 패턴 인식등 여러 분야에 활동될 수 있습니다. 이번 글에서는 k-NN 알고리즘으로 손글씨 이미지 분류하는 예제를 다뤄보려고 합니다.

1. 데이터셋 설명

손글씨 이미지 분류를 위해 MNIST 데이셋을 사용하려고 합니다. 데이터 범위는 0 ~ 9까지이며, 총 70,000장 이미지로 구성되어 있습니다. 60,000장은 학습용(train), 10,000장은 검증용(test) 입니다. 각 이미지 크기는 28*28 pixel 사이즈네요.

MNIST 데이터셋

이미지를 print 로 출력해보면 아래와 같이 '0' 값인 걸 알 수 있고요.

0 에 해당하는 이미지 1장 데이터

2. 코드 구현

tensorflow 케라스에서 제공하는 mnist 데이터셋을 활용하자.

from tensorflow.keras.datasets import mnist
# 학습용 60,000개 / 검증용 10,000개
# 손글씨 한장 이미지는 28X28 = 784개 pixel
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape) #(60000, 28, 28)
print(y_train.shape) #(60000,)

# 0번째 이미지 출력
for x in x_train[0]:
    for i in x:
        print('{:3}'.format(i), end='')
    print()

# 28X28 배열을 1차원 784개로 재배열
# 5000개로 테스트 수행
x_train = x_train.reshape(-1, 28*28)[:5000]
y_train = y_train[:5000]
x_test = x_test.reshape(-1, 28*28)[:5000]
y_test = y_test[:5000]
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
# 총 10개 이미지 2X5 그리드 형태로 그려보기
w, h = 2, 5
fig, axes = plt.subplots(w, h)
fig.set_size_inches(12, 6)
for i in range(w*h):
    axes[i//h, i%h].imshow(x_train[i].reshape(-1, 28))
    axes[i//h, i%h].set_title(y_train[i], fontsize=20)
    axes[i//h, i%h].axis('off')
plt.tight_layout()
plt.show()
from sklearn.neighbors import KNeighborsClassifier

# 모델 생성
knn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)

# 학습
knn.fit(x_train, y_train)

# 예측
prediction = knn.predict(x_test)
# 모델평가 - 방법1
print((prediction == y_test).mean())
# 모델평가 - 방법2
print(knn.score(x_test, y_test))
# 최적의 이웃수 찾기
for k in range(1, 11):
    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
    knn.fit(x_train, y_train)
    score = knn.score(x_test, y_test)
    print('k: %d, accuracy: %.2f' % (k, score*100))

결론

가장 간단한 알고리즘이지만 데이터셋 전처리부터 배우고 익혀야 할 부분이 많네요. 예를들어 numpy reshape 그리고 모델을 평가하는데 쓰이는 그리드 서치(위 내용에서는 아직 안담은..) 등등

그냥 데이터셋을 fit(학습) 하고 predict(예측)하는 학습법보다는 어떤 데이터가 주어지고, 어떤 예측을 해야하는지에 따라 여러 알고리즘을 학습해 나가는데 중점을 두려고 해요. 그래야 공부도 재밌고요.

 

참고자료