KNN의 개념과 원리

개념

  • 지도 학습 중 하나로서 데이터 분류를 위한 알고리즘

  • 사례기반 학습법

  • 거리 측정 시 유클리디안 거리 계산법을 사용하나, 자료특성에 따라 다른 응용된 거리 계산법을 사용하기도 함

유클리디안 거리 (Euclidean distance)

두 점의 X와 Y의 값을 차를 제곱한 것의 합에 루트를 씌움

2차원 평면에 서로 다른 두 점 A(x1, y1)와 B(x2, y2)가 있을 때 이 둘의 거리 d는 유클리드 거리 계산법에 의해 다음과 같이 나온다.

knn1

최적의 k수

  • k수에 따라 예측도가 달라짐. 최적의 k를 찾는 것이 과제임

  • k수가 작을수록 훈련데이터의 정확도는 높아지나 validation/test data에는 부적합할 수 있음

  • k수가 클수록 일반화는 가능하나 정확도는 떨어질 수 있음

  • 결론적으로 최적의 k수는 훈련데이터와 테스트데이터의 예측오류차이가 없어야한다. 그리고 훈련데이터와 테스트데이터의 정확도가 모두 높아야한다.

파이썬에서 최적의 k수 찾는법

1
2
3
4
5
6
7
8
# 최적의 n_neighbors 찾기 예제 소스

for k in range(1, 11):
    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
    knn.fit(x_train, y_train)
    score = knn.score(x_valid, y_valid)
    print('k: %d, accuracy: %.2f' % (k, score*100))

k: 1, accuracy: 95.60

k: 2, accuracy: 95.00

k: 3, accuracy: 94.60

k: 4, accuracy: 93.80

k: 5, accuracy: 93.80

k: 6, accuracy: 93.40

k: 7, accuracy: 93.40

k: 8, accuracy: 93.60

k: 9, accuracy: 92.80

k: 10, accuracy: 92.80

실습과 시각화

분석사례

  • 유권자의 인구학적 변인 + 정치관련 태도가 대선 투표참여 여부 및 지지정당에 미치는 분류의 문제임
컬럼 내용
gender 성별
region 출생지역
edu 학력
income 소득
age 연령
score_gov 정부지지도
score_progress 진보성향
score_intention 정치관심도

종속변수 컬럼

컬럼 내용
vote 선거투표여부
parties 지지정당
1
2
data<-read.csv("data/vote.csv", header=T)

1
2
data<-data[ , 1:9]

1
head(data)
genderregioneduincomeagescore_govscore_progressscore_intentionvote
1 4 3 3 3 2 2 4.01
1 5 2 3 3 2 4 3.00
1 3 1 2 4 1 3 2.81
2 1 2 1 3 5 4 2.61
1 1 1 2 4 4 3 2.41
1 1 1 2 4 1 4 3.81
1
str(data)
1
2
3
4
5
6
7
8
9
10
'data.frame':	211 obs. of  9 variables:
 $ gender         : int  1 1 1 2 1 1 1 1 1 1 ...
 $ region         : int  4 5 3 1 1 1 1 5 2 1 ...
 $ edu            : int  3 2 1 2 1 1 1 2 1 1 ...
 $ income         : int  3 3 2 1 2 2 2 4 2 2 ...
 $ age            : int  3 3 4 3 4 4 4 4 4 3 ...
 $ score_gov      : int  2 2 1 5 4 1 4 3 2 4 ...
 $ score_progress : int  2 4 3 4 3 4 4 4 2 2 ...
 $ score_intention: num  4 3 2.8 2.6 2.4 3.8 2 3.6 2 3 ...
 $ vote           : int  1 0 1 1 1 1 1 1 0 1 ...
1
summary(data)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
     gender          region           edu            income
 Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000
 1st Qu.:1.000   1st Qu.:1.000   1st Qu.:1.000   1st Qu.:1.000
 Median :1.000   Median :1.000   Median :2.000   Median :2.000
 Mean   :1.341   Mean   :2.052   Mean   :1.867   Mean   :2.209
 3rd Qu.:2.000   3rd Qu.:3.000   3rd Qu.:2.000   3rd Qu.:3.000
 Max.   :2.000   Max.   :5.000   Max.   :3.000   Max.   :4.000
      age          score_gov     score_progress  score_intention
 Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000
 1st Qu.:2.000   1st Qu.:3.000   1st Qu.:2.000   1st Qu.:2.400
 Median :3.000   Median :3.000   Median :3.000   Median :3.000
 Mean   :2.654   Mean   :3.057   Mean   :3.095   Mean   :2.911
 3rd Qu.:3.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:3.400
 Max.   :4.000   Max.   :5.000   Max.   :5.000   Max.   :5.000
      vote
 Min.   :0.0000
 1st Qu.:0.0000
 Median :1.0000
 Mean   :0.7109
 3rd Qu.:1.0000
 Max.   :1.0000

train/test 데이터 나누기

1
install.packages("caret")
1
2
3
4
5
6
7
8
9
  There is a binary version available but the source version is later:
      binary source needs_compilation
caret 6.0-86 6.0-90              TRUE

  Binaries will be installed


Warning message:
"package 'caret' is in use and will not be installed"
1
install.packages("recipes", type = 'binary')
1
2
3
4
5
6
7
8
9
  There is a binary version available (and will be installed) but the
  source version is later:
        binary source
recipes 0.1.16 0.1.17

package 'recipes' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\MyCom\AppData\Local\Temp\RtmpEpSDhK\downloaded_packages
1
install.packages("dplyr")
1
2
3
4
5
6
7
8
9
  There is a binary version available but the source version is later:
      binary source needs_compilation
dplyr  1.0.6  1.0.7              TRUE

  Binaries will be installed
package 'dplyr' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\MyCom\AppData\Local\Temp\RtmpEpSDhK\downloaded_packages
1
2
library(caret)
library(dplyr)
1
2
3
4
5
6
7
8
9
10
11
12
Warning message:
"package 'caret' was built under R version 3.6.3"Warning message:
"package 'dplyr' was built under R version 3.6.3"
Attaching package: 'dplyr'

The following objects are masked from 'package:stats':

    filter, lag

The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
1
set.seed(42) # 랜덤 추출
1
2
training.samples <- createDataPartition(data$vote, p = 0.7, list = FALSE) # 훈련데이터 70% 설정해줌
training.samples
Resample1
1
2
3
5
7
10
11
12
13
14
18
19
20
21
22
23
24
27
28
30
31
32
33
34
35
36
37
38
39
41
...
175
177
178
179
180
181
182
185
186
188
189
190
191
192
194
195
196
198
199
200
202
203
204
205
206
207
208
209
210
211
1
2
train  <- data[training.samples, ]
test <- data[-training.samples, ]

모델적용

1
2
install.packages("class")

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  There is a binary version available but the source version is later:
      binary source needs_compilation
class 7.3-19 7.3-20              TRUE

  Binaries will be installed
package 'class' successfully unpacked and MD5 sums checked


Warning message:
"cannot remove prior installation of package 'class'"Warning message in file.copy(savedcopy, lib, recursive = TRUE):
"C:\Users\MyCom\anaconda3\Lib\R\library\00LOCK\class\libs\x64\class.dll를 C:\Users\MyCom\anaconda3\Lib\R\library\class\libs\x64\class.dll로 복사하는데 문제가 발생했습니다: Permission denied"Warning message:
"restored 'class'"


The downloaded binary packages are in
	C:\Users\MyCom\AppData\Local\Temp\RtmpEpSDhK\downloaded_packages
1
2
library(class)

1
2
y_train_pred=knn(train, train, train$vote, k=3)
y_train_pred
1
2
3
1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 1 1 0 1 0 1 1 1 0 0 0 1 1 0 1 0 1 0 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 1

Levels '0' '1'
1
2
y_test_pred=knn(train, test, train$vote, k=3)
y_test_pred
1
2
3
1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 0 1 1 0 0 0 1 0 1 1 1 1 0 1 1 1 0 1 1 1

Levels '0' '1'

모델평가

1) 훈련데이터

1
y_train_real = train$vote
1
y_train_real
1
1 0 1 1 1 1 0 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 0 1 1 1 0 1 0 1 0 1 1 0 1 1 1 1 0 1 1 1 1 1 1 0 0 0 1 1 1 1 1 0 1 1 1 1 1 1 1 0 1 1 1 0 0 1 1 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 0 1 1 1 1 1 0 1 0 1 1 1 0 0 0 1 1 0 1 0 1 0 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 1
1
y_train_pred
1
2
3
1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 1 1 0 1 0 1 1 1 0 0 0 1 1 0 1 0 1 0 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 1

Levels '0' '1'
1
2
train_table = table(y_train_real, y_train_pred)
train_table
1
2
3
4
            y_train_pred
y_train_real   0   1
           0  35   7
           1   0 106
  • 141개를 맞춤 /상당히 좋은 모델이다.
1
2
y_train_accuracy=(train_table[1,1]+train_table[2,2])/sum(train_table,4)
y_train_accuracy   # 정확도가 높음

0.927631578947368

2) 테스트데이터

1
y_test_real = test$vote
1
2
test_table = table(y_test_real, y_test_pred)
test_table
1
2
3
4
           y_test_pred
y_test_real  0  1
          0 10  9
          1  1 43
1
2
y_test_accuracy=(test_table[1,1]+test_table[2,2])/sum(test_table,4)
y_test_accuracy

0.791044776119403

Meta Info

Categories:

Published At:

Modified At:

Leave a comment