본문 바로가기

Programming/Python

[Python][Error] ImportError: cannot import name 'plot_confusion_matrix' from 'sklearn.metrics' (/home/gil/anaconda3/envs/msb_pt/lib/python3.10/site-packages/sklearn/metrics/__init__.py)

이번에는 또 sklearn.metrics에서 plot_confusion_matrix를 찾을 수 없다는 에러가 뜹니다. 

 

 

이런 에러는 100%에 99.99%는 버전 에러라고 생각하시면 됩니다.

Scikit-learn 홈페이지에서 확인해보시면 plot_confusion_matrix 함수는 1.0.x 버전에서 사용되던 함수임을 알 수 있습니다. 

 

 

제 scikit-learn 함수 버전을 뽑아보면 아래와 같습니다. 

conda list scikit

 

이를 해결하기 위한 방법은 크게 두가지 방법이 있습니다. 


1. 버전 다운그레이드

너무 간단하게 해결 가능한 방법 입니다.

해당 함수가 사용가능한 버전으로 scikit-learn 버전을 다운그레이드 해주면 됩니다. 

pip install scikit-learn==1.0.2

 

 

당연하게도 정상적으로 작동이 되는 것을 확인하실 수 있습니다. 


2. (권장) 최신버전에 맞추어 코드 수정

최신 버전에 맞추어 코드를 수정해 주면 됩니다. 

최신 버전에서는 plot_confusion_matrix 함수를 대신하여 ConfusionMatrixDisplay 함수를 사용합니다. 

사용 방법은 아래와 같습니다.

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

LR = classlist[0]   # LR = LogisticRegression() fit 결과
lrpred = LR.predict(test_x)


# plot_confusion_matrix(LR, test_x, test_y, cmap='Blues')

cm = confusion_matrix(test_y, lrpred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')  # 색상 맵 지정 가능
plt.title("Confusion Matrix")
plt.show()

 

plot_confusion_matrix(LR, test_x, test_y, cmap='Blues')를 대신하여 아래 코드를 작성해주면, 정상적으로 confusion matirix가 출력되는 것을 확인할 수 있습니다.