はじめに
機械学習モデルの説明可能性(Explainable AI, XAI)は、モデルの意思決定プロセスを理解し、ステークホルダーとの信頼関係を構築する上で重要な役割を果たしています。本記事では、代表的な2つの手法であるSHAP(SHapley Additive exPlanations)とLIME(Local Interpretable Model-agnostic Explanations)の説明の一貫性を比較します。
対象読者:
- 機械学習エンジニア
- データサイエンティスト
- AIの導入を検討しているビジネスパーソン
記事のポイント:
- SHAPとLIMEのローカルな説明手法としての特徴を理解する
- 実験結果に基づき、各手法の長所と短所を把握する
- 実務での使い分けの指針を得る
検証仮説
本記事では、SHAPとLIMEのローカルな説明に関して以下の仮説を検証します:
- ローカルな説明の一貫性
- SHAPは理論的背景により、同じデータポイントに対して一貫した説明を提供する
- LIMEはサンプリングベースの手法のため、説明が実行ごとに変動する可能性がある
- 次元数の影響
- SHAPは高次元データでも安定したローカルな説明を提供する
- LIMEは高次元データでは局所的な近似が困難になり、説明の安定性が低下する
- LIMEはサンプリングベースのため、サンプル数を増やすことで、高次元でも説明の一貫性が改善する
説明可能性手法の理論的背景
SHAP値のローカルな説明
SHAPは、個々の予測に対して各特徴量の貢献度を計算します:
\phi_i(x) = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!(n-|S|-1)!}{n!}[f(S \cup \{i\}, x) - f(S, x)]
ここで:
\phi_i(x)
は特定のデータポイント x における特徴量 i のSHAP値f(S, x)
は特徴量集合 S のみを使用した予測値
SHAPはゲーム理論に基づいたShapley値を利用しており、一貫した特徴量の寄与の評価が可能です。
LIMEのローカルな説明
LIMEは、予測対象の周辺で局所的に解釈可能なモデルを作成します:
\text{explanation}(x) = \argmin_{g \in G} L(f, g, \pi_x) + \Omega(g)
ここで:
\pi_x
は x の周辺での重み(距離が近いほど大きな重み)g
は局所的な線形モデル
LIMEの近似モデルはデータのサンプリングに依存するため、異なる実行ごとに結果が変化します。
実験設計
データセット生成
次元数の影響を検証するため、以下の5次元、100次元のデータセットを作成しました。それぞれ、60%の独立な特徴量と、20%の冗長な特徴量、20%の繰り返し特徴量からなります。データサイズは、5,000サンプルです。
評価指標
ローカルな説明の一貫性スコア
- 同じデータポイントに対する5回の説明結果の相関係数
- 値が1に近いほど一貫性が高い
- 低次元と高次元それぞれで評価
- LIMEでは、サンプル数を1000, 5000, 20000と変化させて評価
- テストデータから100サンプルを選択して評価
実験結果と考察
ローカルな説明の一貫性評価
実験の結果、以下の知見が得られました:
- 低次元(5次元)データでの結果
- SHAP:一貫性スコア約1.0
- LIME:サンプル数によらず一貫性スコアは約1.0
- 両手法とも極めて安定した説明を提供
- 高次元(100次元)データでの結果
- SHAP:一貫性スコア約1.0を維持
- SHAPは実行ごとに結果が変わりえないので当然の結果ではあります
- LIME:サンプル数による大きな違い
- 1000サンプル:0.72
- 5000サンプル:0.93
- 20000サンプル:0.97
- SHAP:一貫性スコア約1.0を維持
まとめ
本実験を通じて、SHAP値とLIMEのローカルな説明能力について、以下の重要な知見が得られました:
- 低次元データでの性能
- SHAPとLIMEの両手法とも極めて高い一貫性を示す
- LIMEは少ないサンプル数でも安定した説明が可能
- 高次元データでの特性
- SHAPは一貫して高い安定性を維持
- LIMEはサンプル数を適切に設定することで、SHAP値と同等の一貫性を達成可能
- 特に、今回使用したLIMEの関数では、サンプル数のデフォルト値が5000であるため、次元数に応じて、事前に有効なサンプル数が設定できているか、検証が必要
これらの特性を理解し、ユースケースに応じて適切な手法とパラメータを選択することで、より信頼性の高いモデル説明が可能となります。
コード
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
import shap
import lime
import lime.lime_tabular
import traceback
# 再現性のために乱数シードを設定
np.random.seed(42)
def generate_synthetic_data(n_samples=5000, n_features=5):
"""合成データセットを生成する関数"""
try:
n_informative = int(n_features * 0.6)
n_redundant = int(n_features * 0.2)
n_repeated = n_features - n_informative - n_redundant
X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=n_informative,
n_redundant=n_redundant,
n_repeated=n_repeated,
random_state=42,
class_sep=1.5,
weights=[0.7, 0.3]
)
feature_names = [f'Feature{i+1}' for i in range(X.shape[1])]
X = pd.DataFrame(X, columns=feature_names)
return X, y
except Exception as e:
print(f"Error in generate_synthetic_data: {str(e)}")
print(traceback.format_exc())
raise
def evaluate_local_consistency(model, X, n_runs=5, lime_num_samples_list=[1000, 5000, 20000]):
"""ローカルな説明の一貫性を評価する関数"""
n_features = len(X.columns)
shap_values_runs = []
lime_values_runs = defaultdict(list)
for _ in range(n_runs):
# SHAP値の計算(ローカル)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)[:, :, 1]
shap_values_runs.append(shap_values)
# LIME説明の計算(ローカル)
explainer = lime.lime_tabular.LimeTabularExplainer(
X.values,
feature_names=X.columns,
class_names=['Class0', 'Class1'],
mode='classification'
)
for lime_num_samples in lime_num_samples_list:
lime_values = np.zeros((len(X), n_features))
for i in range(len(X)):
exp = explainer.explain_instance(
X.iloc[i].values,
model.predict_proba,
num_features=n_features,
num_samples=lime_num_samples
)
importance_dict = dict(exp.local_exp[1])
for feat_id in range(n_features):
lime_values[i, feat_id] = importance_dict.get(feat_id, 0)
lime_values_runs[lime_num_samples].append(lime_values)
# 一貫性スコアの計算
shap_consistency = np.mean([
np.corrcoef(
shap_values_runs[0].flatten(),
shap_values_runs[i].flatten()
)[0, 1]
for i in range(1, n_runs)
])
lime_consistency = {}
for lime_num_samples in lime_num_samples_list:
lime_consistency[lime_num_samples] = np.mean([
np.corrcoef(
lime_values_runs[lime_num_samples][0].flatten(),
lime_values_runs[lime_num_samples][i].flatten()
)[0, 1]
for i in range(1, n_runs)
])
return shap_consistency, lime_consistency
def main():
# 低次元、高次元のデータについて、SHAP, LIME(num_sampling=1000, 5000, 20000)での比較
for n_features in [5, 100]:
X_low, y_low = generate_synthetic_data(n_samples=5000, n_features=n_features)
X_train_low, X_test_low, y_train_low, y_test_low = train_test_split(X_low, y_low, test_size=0.2, random_state=42)
model_low = RandomForestClassifier(n_estimators=100, random_state=42)
model_low.fit(X_train_low, y_train_low)
shap_consistency_low, lime_consistency_low = evaluate_local_consistency(model_low, X_test_low.head(100), lime_num_samples_list=[1000, 5000, 20000])
# plot
plt.figure(figsize=(10, 5))
labels = ['SHAP', 'LIME(1000)', 'LIME(5000)', 'LIME(20000)']
values = [shap_consistency_low, lime_consistency_low[1000], lime_consistency_low[5000], lime_consistency_low[20000]]
plt.bar(labels, values)
plt.title(f'Local Consistency Comparison (n_features={n_features})')
plt.xlabel('Explanation Method')
plt.ylabel('Consistency Score')
plt.savefig(f'local_consistency_comparison_n_features_{n_features}.png')
if __name__ == "__main__":
main()
お気軽にご相談ください!
弊社のデジタル技術を活用し、貴社のDXをサポートします。
基本的な設計や実装支援はもちろん、機械学習・データ分析・3D計算などの高度な技術が求められる場面でも、最適なソリューションをご提案します。
初回相談は無料で受け付けております。困っていること、是非お聞かせください。