sim0609 2023. 2. 6. 14:17

KNN(k-최근접 이웃) 알고리즘

KNN 알고리즘은 지도 학습 알고리즘 중 하나로 매우 직관적이고 간단하다. 어떤 데이터가 주어지면 그 주변의 데이터를 살펴본 뒤 더 많은 데이터가 포함되어 있는 범주로 분류하는 방식이며, n_neighbors 값(학습할 때 참고할 데이터 수)에 따라 분류가 달라질 수 있다. 또한, 거리를 측정할 땐 유클리드 거리(Euclidean distance)를 사용한다.

 

 

이러한 KNN 알고리즘은 훈련 데이터셋을 그냥 저장하는 것이 모델을 만드는 과정의 전부이기에 데이터만 있으면 쉽게 사용할 수 있다. 하지만, 데이터의 양이 너무 많을 경우 메모리가 많이 필요하고 직선 거리를 계산하는데도 오래 걸린다.

 

그러면 이제 KNN 알고리즘을 통해 도미와 빙어를 분류해보자

 

데이터 준비하는 과정

import matplotlib.pyplot as plt

# 도미 길이
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]

# 도미 무게
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

# 도미 길이와 무게 산점도로 살펴보기
plt.scatter(bream_length, bream_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

# 빙어 길이
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]

# 빙어 무게
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

# 빙어 길이와 무게 산점도로 살펴보기
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

# 도미와 빙어의 길이, 무게를 산점도로 살펴보기
plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

머신러닝 프로그램

from sklearn.neighbors import KNeighborsClassifier

# 두 개의 리스트를 하나로 합치기
length = bream_length + smelt_length
weight = bream_weight + smelt_weight

# zip 함수로 각 생선의 length와 weight 짝지어주기
fish_data = [[l, w] for [l, w] in zip(length, weight)]

# fish_data labeling하기(데이터의 정답을 알려주는 부분)
# 도미: 1, 빙어: 0
fish_target = [1] * 35 + [0] * 14
print(fish_target)

# k-최근접 이웃 알고리즘
kn = KNeighborsClassifier()

# 생선 분류를 위해 학습시키는 과정
kn.fit(fish_data, fish_target)

# 모델 성능 평가(= 모델 정확도)
kn.score(fish_data, fish_target)

# 모델 예측
kn.predict([[30, 600]])

# 새로운 데이터가 등장했을 때 가장 가까운 데이터 n_neighbors개를 참고해 학습 가능(하이퍼 파라미터)
kn_ = KNeighborsClassifier(n_neighbors= 5)

kn_.fit(fish_data, fish_target)
kn_.score(fish_data, fish_target)