data는 피처의 데이터 세트를 가리킵니다.
target은 분류 시 레이블 값, 회귀일 때는 숫자 결괏 값 데이터 세트입니다.
target_names는 개별 레이블의 이름을 나타냅니다.
feature_names는 피처의 이름을 나타냅니다.
DESCR은 데이터 세트에 대한 설명과 각 피처의 설명을 나타냅니다.
#학습/테스트 데이터 세트 분리 - train_test_split()
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
dt_clf = DecisionTreeClassifier()
iris_data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.3, random_state=121)
#학습 데이터 기반으로 학습하고, 이 모델을 이용해 예측 정확도 측정
dt_clf.fit(X_train,y_train)
pred = dt_clf.predict(X_test)
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test,pred)))
-> 예측 정확도: 0.9556
모델이 과적합되면 다른 테스트용 데이터가 들어올 경우에 성능이 저하된다.
-> 이러한 문제점 개선위해 교차검증을 이용해 더 다양한 학습과 평가를 수행한다.
#K폴드 교차 검증
K 폴드 교차 검증은 가장 보편적으로 사용되는 교차 검증 기법이다.
먼저, K개의 데이터 폴드 세트를 만들어서 K번만큼 각 폴드 세트에 학습과 검증 평가를 반복적으로 수행한다.
예측 평가를 구했으면 이를 평균해서 K 폴드 평가 결과로 반영하면 된다.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold
import numpy as np
iris = load_iris()
features = iris.data
label = iris.target
dt_clf = DecisionTreeClassifier(random_state=156)
#5개의 폴드 세트로 분리하는 KFold객체와 폴드 세트별 정확도를 담을 리스트 객체 생성
kfold = KFold(n_splits = 5)
cv_accuracy = []
print('붓꽃 데이터 세트 크기:',features.shape[0])
#kfold객체의 split()를 호출하면 폴드 별 학습용, 검증용 테스트의 로우 인덱스를 array로 반환
n_iter = 0
for train_index, test_index in kfold.split(features):
#kfold.split()으로 반환된 인덱스를 이용해 학습용, 검증용 테스트 데이터 추출
X_train,X_test = features[train_index],features[test_index]
y_train,y_test = label[train_index],label[test_index]
#학습 및 예측
dt_clf.fit(X_train,y_train)
pred = dt_clf.predict(X_test)
n_iter = 1
#반복 시마다 정확도 측정
accuracy = np.round(accuracy_score(y_test,pred),4)
train_size = X_train.shape[0]
test_size = X_test.shape[0]
print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기: {2}, 검증 데이터 크기:{3}'
.format(n_iter,accuracy,train_size,test_size))
print('#{0}검증 세트 인덱스:{1}'.format(n_iter,test_index))
cv_accuracy.append(accuracy)
#개별 iteration별 정확도를 합하여 평균 정확도 계산
print('\n## 평균 검증 정확도:', np.mean(cv_accuracy))
->평균 검증 정확도: 0.9
#Stratified K 폴드
Stratified K 폴드는 k폴드로 분할된 레이블 데이터 세트가 전체 레이블 값의 분포도를 반영하지 못하는 문제를 해결해 준다.
#붓꽃데이터 세트 간단하게 데이터프레임으로 생성하고 레이블 값의 분포도 확인
iris = load_iris()
iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
iris_df['label']=iris.target
iris_df['label'].value_counts()
kfold = KFold(n_splits=3)
n_iter=0
for train_index, test_indext in kfold.split(iris_df):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('##교차검증: {0}'.format(n_iter))
print('학습 레이블 데이터 분포:\n',label_train.value_counts())
print('검증 레이블 데이터 분포:\n',label_test.value_counts())iris_df['label'].value_counts()
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=3)
n_iter=0
for train_index, test_index in skf.split(iris_df, iris_df['label']):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('##교차 검증: {0}'.format (n_iter))
print('학습 레이블 데이터 분포:\n', label_train.value_counts())
print('검증 레이블 데이터 분포:\n', label_test.value_counts())
학습 결과를 보면 학습 레이블과 검증 레이블 데이터 값의 분포도가 동일하게 할당됐음을 알 수 있다.
dt_clf = DecisionTreeClassifier(random_state=156)
skfold = StratifiedKFold(n_splits=3)
n_iter = 0
cv_accuracy=[]
for train_index, test_index in skfold.split(features, label):
X_train, X_test= features[train_index], features[test_index]
y_train, y_test = label[train_index], label[test_index]
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
n_iter += 1
accuracy = np.round(accuracy_score(y_test,pred),4)
train_size = X_train.shape[0]
test_size = X_test.shape[0]
print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기: {2}, 검증 데이터 크기:{3}'
.format(n_iter,accuracy,train_size,test_size))
print('\n#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_indext))
cv_accuracy.append(accuracy)
#교차검증별 정확도 및 평균 정확도 계산
print('\n## 교차 검증별 정확도:',np.round(cv_accuracy,4)
print('## 평균 검증 정확도:',np.mean(cv_accuracy))
'Machine Learning > 파이썬 머신러닝 완벽가이드' 카테고리의 다른 글
03. 데이터 전처리 (0) | 2021.11.10 |
---|---|
GridSearchCV (0) | 2021.11.08 |
사이킷런으로 시작하는 머신러닝 02.붓꽃품종 예측하기 (0) | 2021.09.06 |
9/6 데이터 핸들링-판다스 (0) | 2021.09.06 |