Python AutoML 라이브러리 중 하나인 Lazypredict를 이용해 여러 ML 모델들을 동시에 학습하고, 예측 성능을 비교해보자.


Lazypredict

Lazy Predict는 인도의 어느 시니어 데이터 사이언티스트인 Shankar Rao Pandala라는 개인이 개발한 오픈소스 머신러닝 자동화 관련 파이썬 오픈소스 프로젝트이다. 현재는 Classification과 Regression에 대한 기능만 제공되고 있다. Lazypredict를 이용하면 코드 한 줄로 여러 ML 모델을 불러와 학습시킬 수 있고, 추론 결과도 확인할 수 있다. 여러 모델들의 성능 지표도 비교할 수 있어 성능이 더 좋은 모델을 가려낼 수도 있다. 다만, 파라미터를 조정하는 기능은 따로 제공되지 않는다는 한계가 있다.


예제 데이터 로드

예제 데이터를 아래 링크에서 다운로드받아 Pandas DataFrame으로 로드한다. 예제 데이터는 심장병 해당 여부(target)가 포함된 데이터로 age, sex, cp 등의 특성을 통해 심장병 해당여부를 예측해보자.

Heart Disease Dataset

import tqdm
import pandas as pd

df = pd.read_csv("./data/heart.csv")
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1025 entries, 0 to 1024
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       1025 non-null   int64  
 1   sex       1025 non-null   int64  
 2   cp        1025 non-null   int64  
 3   trestbps  1025 non-null   int64  
 4   chol      1025 non-null   int64  
 5   fbs       1025 non-null   int64  
 6   restecg   1025 non-null   int64  
 7   thalach   1025 non-null   int64  
 8   exang     1025 non-null   int64  
 9   oldpeak   1025 non-null   float64
 10  slope     1025 non-null   int64  
 11  ca        1025 non-null   int64  
 12  thal      1025 non-null   int64  
 13  target    1025 non-null   int64  
dtypes: float64(1), int64(13)
memory usage: 112.2 KB
df.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 52 1 0 125 212 0 1 168 0 1.0 2 2 3 0
1 53 1 0 140 203 1 0 155 1 3.1 0 0 3 0
2 70 1 0 145 174 0 1 125 1 2.6 0 0 3 0
3 61 1 0 148 203 0 1 161 0 0.0 2 1 3 0
4 62 0 0 138 294 1 1 106 0 1.9 1 3 2 0
df.tail()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
1020 59 1 1 140 221 0 1 164 1 0.0 2 0 2 1
1021 60 1 0 125 258 0 0 141 1 2.8 1 1 3 0
1022 47 1 0 110 275 0 0 118 1 1.0 1 1 2 0
1023 50 0 0 110 254 0 0 159 0 0.0 2 0 2 1
1024 54 1 0 120 188 0 1 113 0 1.4 1 1 3 0


target 컬럼 분할

y_data = df.pop("target")
x_data = df

print("Shape of X: ", x_data.shape)
print("Shape of Y: ", y_data.shape)
Shape of X:  (1025, 13)
Shape of Y:  (1025,)
# 분할 전 데이터셋 중에서 target 값이 True인 비율
raw_true_ratio = y_data.sum() / len(y_data)
print("> Target True 비율: ", raw_true_ratio)
> Target True 비율:  0.5131707317073171


학습/검증 데이터 분할

참고로 train_test_splitstratify 옵션은 target 즉 라벨값으로 지정하면 된다. stratify를 target으로 지정하게 되면, target의 클래스 비율을 유지한 상태로 데이터셋을 분할하게 된다. 이 옵션을 지정해주지 않으면 분할 전/후 데이터셋의 클래스 비율이 불균형 상태를 이룰 수 있다. 이러한 불균형 상태에서 모델을 학습할 경우 모델 성능에 영향을 줄 수 있다.

무작위 분할

stratify에 None 값을 지정해 데이터셋을 분할한 경우를 살펴보자. 위에서 데이터 분할 전 target 값이 True인 비율은 약 51%인 것을 확인했다. 데이터 분할 후 학습 데이터는 약 52%, 검증 데이터는 약 47%로 비율에 약간의 차이가 있는 것을 알 수 있다.

from sklearn.model_selection import train_test_split

# 학습/검증 데이터 분할 (stratify 옵션 미적용)
x_train, x_test, y_train, y_test = train_test_split(
    x_data,
    y_data,
    test_size=0.2,
    random_state=2022,
    stratify=None
)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(820, 13) (820,)
(205, 13) (205,)
# 학습 데이터 중에서 target 값이 True인 비율
train_true_ratio = y_train.sum() / len(y_train)
# 검증 데이터 중에서 target 값이 True인 비율
test_true_ratio = y_test.sum() / len(y_test)

print("데이터셋 분할 시 stratify 옵션 미적용 결과")
print("> 학습 데이터 Target True 비율: ", train_true_ratio)
print("> 검증 데이터 Target True 비율: ", test_true_ratio)
데이터셋 분할 시 stratify 옵션 미적용 결과
> 학습 데이터 Target True 비율:  0.5231707317073171
> 검증 데이터 Target True 비율:  0.47317073170731705


계층적 분할

이번에는 stratifyy_data 값을 지정해 데이터셋을 분할한 경우를 살펴보자. 데이터 분할 후 학습 데이터는 약 51%, 검증 데이터는 약 51%로 분할 전 클래스 비율을 유지한 상태로 분할된 것을 알 수 있다.

from sklearn.model_selection import train_test_split

# 학습/검증 데이터 분할 (stratify 옵션 적용)
x_train, x_test, y_train, y_test = train_test_split(
    x_data,
    y_data,
    test_size=0.2,
    random_state=2022,
    stratify=y_data
)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(820, 13) (820,)
(205, 13) (205,)
# 학습 데이터 중에서 target 값이 True인 비율
train_true_ratio = y_train.sum() / len(y_train)
# 검증 데이터 중에서 target 값이 True인 비율
test_true_ratio = y_test.sum() / len(y_test)

print("데이터셋 분할 시 stratify 옵션 적용 결과")
print("> 학습 데이터 Target True 비율: ", train_true_ratio)
print("> 검증 데이터 Target True 비율: ", test_true_ratio)
데이터셋 분할 시 stratify 옵션 적용 결과
> 학습 데이터 Target True 비율:  0.5134146341463415
> 검증 데이터 Target True 비율:  0.5121951219512195


Lazypredict를 통한 자동 모델 학습

LazyClassifier를 이용해 Lazypredict의 지도학습 분류기 인스턴스 clf를 만든다. 이렇게 만든 clffit 메서드의 인자료 위에서 분할한 학습 데이터와 검증 데이터를 입력한다. 그 결과로 modelspredictions가 반환된다. models는 Scikit-Learn의 여러 분류 모델을 적용해 학습한 결과 데이터프레임이고, predictions은 각 모델 별 예측값을 모아둔 데이터프레임이다.

from lazypredict.Supervised import LazyClassifier

clf = LazyClassifier(verbose=0, predictions=True)

models, predictions = clf.fit(x_train, x_test, y_train, y_test)
100%|███████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 36.09it/s]
models
Accuracy Balanced Accuracy ROC AUC F1 Score Time Taken
Model
LGBMClassifier 1.00 1.00 1.00 1.00 0.04
LabelPropagation 1.00 1.00 1.00 1.00 0.03
XGBClassifier 1.00 1.00 1.00 1.00 0.08
DecisionTreeClassifier 1.00 1.00 1.00 1.00 0.01
RandomForestClassifier 1.00 1.00 1.00 1.00 0.13
ExtraTreeClassifier 1.00 1.00 1.00 1.00 0.01
ExtraTreesClassifier 1.00 1.00 1.00 1.00 0.10
BaggingClassifier 1.00 1.00 1.00 1.00 0.03
LabelSpreading 1.00 1.00 1.00 1.00 0.03
SVC 0.94 0.94 0.94 0.94 0.02
AdaBoostClassifier 0.92 0.92 0.92 0.92 0.07
NuSVC 0.90 0.90 0.90 0.90 0.02
QuadraticDiscriminantAnalysis 0.89 0.89 0.89 0.89 0.01
GaussianNB 0.88 0.88 0.88 0.88 0.01
SGDClassifier 0.87 0.87 0.87 0.87 0.01
CalibratedClassifierCV 0.87 0.87 0.87 0.87 0.08
LogisticRegression 0.86 0.86 0.86 0.86 0.01
LinearSVC 0.86 0.86 0.86 0.86 0.03
NearestCentroid 0.86 0.86 0.86 0.86 0.01
LinearDiscriminantAnalysis 0.86 0.86 0.86 0.86 0.01
RidgeClassifier 0.86 0.86 0.86 0.86 0.01
RidgeClassifierCV 0.86 0.86 0.86 0.86 0.01
BernoulliNB 0.85 0.85 0.85 0.85 0.01
KNeighborsClassifier 0.83 0.83 0.83 0.83 0.02
PassiveAggressiveClassifier 0.78 0.78 0.78 0.78 0.01
Perceptron 0.76 0.76 0.76 0.76 0.01
DummyClassifier 0.48 0.48 0.48 0.48 0.01
predictions.head()
AdaBoostClassifier BaggingClassifier BernoulliNB CalibratedClassifierCV DecisionTreeClassifier DummyClassifier ExtraTreeClassifier ExtraTreesClassifier GaussianNB KNeighborsClassifier ... PassiveAggressiveClassifier Perceptron QuadraticDiscriminantAnalysis RandomForestClassifier RidgeClassifier RidgeClassifierCV SGDClassifier SVC XGBClassifier LGBMClassifier
0 0 0 0 0 0 1 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 1 1 0 1 1 0 1 1 1 0 ... 1 1 1 1 1 1 1 1 1 1
2 1 1 1 1 1 0 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1
3 1 1 1 1 1 0 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1
4 1 0 1 1 0 1 0 0 1 1 ... 1 1 0 0 1 1 1 0 0 0

5 rows × 27 columns


모델 별 분류 리포트 출력

모델 별 분류 성능 지표를 출력할 수도 있다.

from sklearn.metrics import classification_report

for model_name in predictions.columns.tolist():
    print("="*60)
    print(f'{model_name}')
    print("="*60)
    print(classification_report(y_test, predictions[model_name]))
============================================================
AdaBoostClassifier
============================================================
              precision    recall  f1-score   support

           0       0.95      0.89      0.92       100
           1       0.90      0.95      0.93       105

    accuracy                           0.92       205
   macro avg       0.92      0.92      0.92       205
weighted avg       0.92      0.92      0.92       205

============================================================
BaggingClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
BernoulliNB
============================================================
              precision    recall  f1-score   support

           0       0.88      0.80      0.84       100
           1       0.82      0.90      0.86       105

    accuracy                           0.85       205
   macro avg       0.85      0.85      0.85       205
weighted avg       0.85      0.85      0.85       205

============================================================
CalibratedClassifierCV
============================================================
              precision    recall  f1-score   support

           0       0.95      0.78      0.86       100
           1       0.82      0.96      0.89       105

    accuracy                           0.87       205
   macro avg       0.89      0.87      0.87       205
weighted avg       0.88      0.87      0.87       205

============================================================
DecisionTreeClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
DummyClassifier
============================================================
              precision    recall  f1-score   support

           0       0.46      0.45      0.46       100
           1       0.49      0.50      0.50       105

    accuracy                           0.48       205
   macro avg       0.48      0.48      0.48       205
weighted avg       0.48      0.48      0.48       205

============================================================
ExtraTreeClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
ExtraTreesClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
GaussianNB
============================================================
              precision    recall  f1-score   support

           0       0.89      0.85      0.87       100
           1       0.86      0.90      0.88       105

    accuracy                           0.88       205
   macro avg       0.88      0.88      0.88       205
weighted avg       0.88      0.88      0.88       205

============================================================
KNeighborsClassifier
============================================================
              precision    recall  f1-score   support

           0       0.85      0.80      0.82       100
           1       0.82      0.87      0.84       105

    accuracy                           0.83       205
   macro avg       0.84      0.83      0.83       205
weighted avg       0.84      0.83      0.83       205

============================================================
LabelPropagation
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
LabelSpreading
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
LinearDiscriminantAnalysis
============================================================
              precision    recall  f1-score   support

           0       0.96      0.74      0.84       100
           1       0.80      0.97      0.88       105

    accuracy                           0.86       205
   macro avg       0.88      0.86      0.86       205
weighted avg       0.88      0.86      0.86       205

============================================================
LinearSVC
============================================================
              precision    recall  f1-score   support

           0       0.95      0.76      0.84       100
           1       0.81      0.96      0.88       105

    accuracy                           0.86       205
   macro avg       0.88      0.86      0.86       205
weighted avg       0.88      0.86      0.86       205

============================================================
LogisticRegression
============================================================
              precision    recall  f1-score   support

           0       0.93      0.78      0.85       100
           1       0.82      0.94      0.88       105

    accuracy                           0.86       205
   macro avg       0.87      0.86      0.86       205
weighted avg       0.87      0.86      0.86       205

============================================================
NearestCentroid
============================================================
              precision    recall  f1-score   support

           0       0.94      0.76      0.84       100
           1       0.81      0.95      0.87       105

    accuracy                           0.86       205
   macro avg       0.87      0.86      0.86       205
weighted avg       0.87      0.86      0.86       205

============================================================
NuSVC
============================================================
              precision    recall  f1-score   support

           0       0.94      0.85      0.89       100
           1       0.87      0.95      0.91       105

    accuracy                           0.90       205
   macro avg       0.91      0.90      0.90       205
weighted avg       0.91      0.90      0.90       205

============================================================
PassiveAggressiveClassifier
============================================================
              precision    recall  f1-score   support

           0       0.77      0.77      0.77       100
           1       0.78      0.78      0.78       105

    accuracy                           0.78       205
   macro avg       0.78      0.78      0.78       205
weighted avg       0.78      0.78      0.78       205

============================================================
Perceptron
============================================================
              precision    recall  f1-score   support

           0       0.76      0.74      0.75       100
           1       0.76      0.77      0.76       105

    accuracy                           0.76       205
   macro avg       0.76      0.76      0.76       205
weighted avg       0.76      0.76      0.76       205

============================================================
QuadraticDiscriminantAnalysis
============================================================
              precision    recall  f1-score   support

           0       0.91      0.87      0.89       100
           1       0.88      0.91      0.90       105

    accuracy                           0.89       205
   macro avg       0.89      0.89      0.89       205
weighted avg       0.89      0.89      0.89       205

============================================================
RandomForestClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
RidgeClassifier
============================================================
              precision    recall  f1-score   support

           0       0.96      0.74      0.84       100
           1       0.80      0.97      0.88       105

    accuracy                           0.86       205
   macro avg       0.88      0.86      0.86       205
weighted avg       0.88      0.86      0.86       205

============================================================
RidgeClassifierCV
============================================================
              precision    recall  f1-score   support

           0       0.96      0.74      0.84       100
           1       0.80      0.97      0.88       105

    accuracy                           0.86       205
   macro avg       0.88      0.86      0.86       205
weighted avg       0.88      0.86      0.86       205

============================================================
SGDClassifier
============================================================
              precision    recall  f1-score   support

           0       0.91      0.82      0.86       100
           1       0.84      0.92      0.88       105

    accuracy                           0.87       205
   macro avg       0.88      0.87      0.87       205
weighted avg       0.88      0.87      0.87       205

============================================================
SVC
============================================================
              precision    recall  f1-score   support

           0       0.97      0.90      0.93       100
           1       0.91      0.97      0.94       105

    accuracy                           0.94       205
   macro avg       0.94      0.94      0.94       205
weighted avg       0.94      0.94      0.94       205

============================================================
XGBClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205

============================================================
LGBMClassifier
============================================================
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       100
           1       1.00      1.00      1.00       105

    accuracy                           1.00       205
   macro avg       1.00      1.00      1.00       205
weighted avg       1.00      1.00      1.00       205
models
Accuracy Balanced Accuracy ROC AUC F1 Score Time Taken
Model
LGBMClassifier 1.00 1.00 1.00 1.00 0.04
LabelPropagation 1.00 1.00 1.00 1.00 0.03
XGBClassifier 1.00 1.00 1.00 1.00 0.08
DecisionTreeClassifier 1.00 1.00 1.00 1.00 0.01
RandomForestClassifier 1.00 1.00 1.00 1.00 0.13
ExtraTreeClassifier 1.00 1.00 1.00 1.00 0.01
ExtraTreesClassifier 1.00 1.00 1.00 1.00 0.10
BaggingClassifier 1.00 1.00 1.00 1.00 0.03
LabelSpreading 1.00 1.00 1.00 1.00 0.03
SVC 0.94 0.94 0.94 0.94 0.02
AdaBoostClassifier 0.92 0.92 0.92 0.92 0.07
NuSVC 0.90 0.90 0.90 0.90 0.02
QuadraticDiscriminantAnalysis 0.89 0.89 0.89 0.89 0.01
GaussianNB 0.88 0.88 0.88 0.88 0.01
SGDClassifier 0.87 0.87 0.87 0.87 0.01
CalibratedClassifierCV 0.87 0.87 0.87 0.87 0.08
LogisticRegression 0.86 0.86 0.86 0.86 0.01
LinearSVC 0.86 0.86 0.86 0.86 0.03
NearestCentroid 0.86 0.86 0.86 0.86 0.01
LinearDiscriminantAnalysis 0.86 0.86 0.86 0.86 0.01
RidgeClassifier 0.86 0.86 0.86 0.86 0.01
RidgeClassifierCV 0.86 0.86 0.86 0.86 0.01
BernoulliNB 0.85 0.85 0.85 0.85 0.01
KNeighborsClassifier 0.83 0.83 0.83 0.83 0.02
PassiveAggressiveClassifier 0.78 0.78 0.78 0.78 0.01
Perceptron 0.76 0.76 0.76 0.76 0.01
DummyClassifier 0.48 0.48 0.48 0.48 0.01


단일 모델 선택

여러 성능 지표 중 Balanced Accuracy 값이 가장 큰 모델 중 하나를 고르면, LGBMClassifier 모델이 선택된다.

models.loc[models['Balanced Accuracy'] == models['Balanced Accuracy'].max()].index[0]
'LGBMClassifier'

해당 모델을 직접 불러와 학습 후 예측한 결과는 역시 위의 결과와 동일한 것을 확인할 수 있다.

from lightgbm import LGBMClassifier
from sklearn.metrics import balanced_accuracy_score

lgbm = LGBMClassifier()
lgbm.fit(x_train, y_train)
y_pred = lgbm.predict(x_test)
balanced_accuracy_score(y_pred, y_pred)
1.0


참고

태그:

카테고리:

업데이트:

댓글남기기