機械学習モデルの説明可能性手法の比較:SHAPとLIMEの出力の一貫性評価

はじめに

機械学習モデルの説明可能性(Explainable AI, XAI)は、モデルの意思決定プロセスを理解し、ステークホルダーとの信頼関係を構築する上で重要な役割を果たしています。本記事では、代表的な2つの手法であるSHAP(SHapley Additive exPlanations)とLIME(Local Interpretable Model-agnostic Explanations)の説明の一貫性を比較します。

この記事を書いたひと

デジタルリアクタ合同会社 代表
機械学習・統計、数値計算などの領域を軸としたソフトウェアエンジニアリングを専門としています。スタートアップからグローバル企業まで、さまざまなスケールの企業にて、事業価値に直結する計算システムを設計・構築してきました。主に機械学習の応用分野において、出版・発表、特許取得等の実績があり、また、IT戦略やデータベース技術等の情報処理に関する専門領域の国家資格を複数保有しています。九州大学理学部卒業。日本ITストラテジスト協会正会員。

対象読者:

  • 機械学習エンジニア
  • データサイエンティスト
  • AIの導入を検討しているビジネスパーソン

記事のポイント:

  • SHAPとLIMEのローカルな説明手法としての特徴を理解する
  • 実験結果に基づき、各手法の長所と短所を把握する
  • 実務での使い分けの指針を得る

検証仮説

本記事では、SHAPとLIMEのローカルな説明に関して以下の仮説を検証します:

  1. ローカルな説明の一貫性
    • SHAPは理論的背景により、同じデータポイントに対して一貫した説明を提供する
    • LIMEはサンプリングベースの手法のため、説明が実行ごとに変動する可能性がある
  2. 次元数の影響
    • SHAPは高次元データでも安定したローカルな説明を提供する
    • LIMEは高次元データでは局所的な近似が困難になり、説明の安定性が低下する
    • LIMEはサンプリングベースのため、サンプル数を増やすことで、高次元でも説明の一貫性が改善する

説明可能性手法の理論的背景

SHAP値のローカルな説明

SHAPは、個々の予測に対して各特徴量の貢献度を計算します:

\phi_i(x) = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!(n-|S|-1)!}{n!}[f(S \cup \{i\}, x) - f(S, x)]

ここで:

  • \phi_i(x) は特定のデータポイント x における特徴量 i のSHAP値
  • f(S, x) は特徴量集合 S のみを使用した予測値

SHAPはゲーム理論に基づいたShapley値を利用しており、一貫した特徴量の寄与の評価が可能です。

LIMEのローカルな説明

LIMEは、予測対象の周辺で局所的に解釈可能なモデルを作成します:

\text{explanation}(x) = \argmin_{g \in G} L(f, g, \pi_x) + \Omega(g)

ここで:

  • \pi_x は x の周辺での重み(距離が近いほど大きな重み)
  • g は局所的な線形モデル

LIMEの近似モデルはデータのサンプリングに依存するため、異なる実行ごとに結果が変化します。

実験設計

データセット生成

次元数の影響を検証するため、以下の5次元、100次元のデータセットを作成しました。それぞれ、60%の独立な特徴量と、20%の冗長な特徴量、20%の繰り返し特徴量からなります。データサイズは、5,000サンプルです。

評価指標

ローカルな説明の一貫性スコア

  • 同じデータポイントに対する5回の説明結果の相関係数
  • 値が1に近いほど一貫性が高い
  • 低次元と高次元それぞれで評価
  • LIMEでは、サンプル数を1000, 5000, 20000と変化させて評価
  • テストデータから100サンプルを選択して評価

実験結果と考察

ローカルな説明の一貫性評価

実験の結果、以下の知見が得られました:

  1. 低次元(5次元)データでの結果
    • SHAP:一貫性スコア約1.0
    • LIME:サンプル数によらず一貫性スコアは約1.0
    • 両手法とも極めて安定した説明を提供
  2. 高次元(100次元)データでの結果
    • SHAP:一貫性スコア約1.0を維持
      • SHAPは実行ごとに結果が変わりえないので当然の結果ではあります
    • LIME:サンプル数による大きな違い
      • 1000サンプル:0.72
      • 5000サンプル:0.93
      • 20000サンプル:0.97


まとめ

本実験を通じて、SHAP値とLIMEのローカルな説明能力について、以下の重要な知見が得られました:

  1. 低次元データでの性能
    • SHAPとLIMEの両手法とも極めて高い一貫性を示す
    • LIMEは少ないサンプル数でも安定した説明が可能
  2. 高次元データでの特性
    • SHAPは一貫して高い安定性を維持
    • LIMEはサンプル数を適切に設定することで、SHAP値と同等の一貫性を達成可能
    • 特に、今回使用したLIMEの関数では、サンプル数のデフォルト値が5000であるため、次元数に応じて、事前に有効なサンプル数が設定できているか、検証が必要

これらの特性を理解し、ユースケースに応じて適切な手法とパラメータを選択することで、より信頼性の高いモデル説明が可能となります。

コード

from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
import shap
import lime
import lime.lime_tabular
import traceback

# 再現性のために乱数シードを設定
np.random.seed(42)

def generate_synthetic_data(n_samples=5000, n_features=5):
    """合成データセットを生成する関数"""
    try:
        n_informative = int(n_features * 0.6)
        n_redundant = int(n_features * 0.2)
        n_repeated = n_features - n_informative - n_redundant

        X, y = make_classification(
            n_samples=n_samples,
            n_features=n_features,
            n_informative=n_informative,
            n_redundant=n_redundant,
            n_repeated=n_repeated,
            random_state=42,
            class_sep=1.5,
            weights=[0.7, 0.3]
        )

        feature_names = [f'Feature{i+1}' for i in range(X.shape[1])]
        X = pd.DataFrame(X, columns=feature_names)
        return X, y
    except Exception as e:
        print(f"Error in generate_synthetic_data: {str(e)}")
        print(traceback.format_exc())
        raise

def evaluate_local_consistency(model, X, n_runs=5, lime_num_samples_list=[1000, 5000, 20000]):
    """ローカルな説明の一貫性を評価する関数"""
    n_features = len(X.columns)
    shap_values_runs = []
    lime_values_runs = defaultdict(list)

    for _ in range(n_runs):
        # SHAP値の計算(ローカル)
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X)[:, :, 1]
        shap_values_runs.append(shap_values)

        # LIME説明の計算(ローカル)
        explainer = lime.lime_tabular.LimeTabularExplainer(
            X.values,
            feature_names=X.columns,
            class_names=['Class0', 'Class1'],
            mode='classification'
        )

        for lime_num_samples in lime_num_samples_list:
            lime_values = np.zeros((len(X), n_features))
            for i in range(len(X)):
                exp = explainer.explain_instance(
                    X.iloc[i].values,
                    model.predict_proba,
                    num_features=n_features,
                    num_samples=lime_num_samples
                )
                importance_dict = dict(exp.local_exp[1])
                for feat_id in range(n_features):
                    lime_values[i, feat_id] = importance_dict.get(feat_id, 0)
            lime_values_runs[lime_num_samples].append(lime_values)

    # 一貫性スコアの計算
    shap_consistency = np.mean([
        np.corrcoef(
            shap_values_runs[0].flatten(), 
            shap_values_runs[i].flatten()
        )[0, 1]
        for i in range(1, n_runs)
    ])
    lime_consistency = {}
    for lime_num_samples in lime_num_samples_list:
        lime_consistency[lime_num_samples] = np.mean([
            np.corrcoef(
                lime_values_runs[lime_num_samples][0].flatten(), 
                lime_values_runs[lime_num_samples][i].flatten()
            )[0, 1] 
            for i in range(1, n_runs)
        ])

    return shap_consistency, lime_consistency

def main():
    # 低次元、高次元のデータについて、SHAP, LIME(num_sampling=1000, 5000, 20000)での比較
    for n_features in [5, 100]:
        X_low, y_low = generate_synthetic_data(n_samples=5000, n_features=n_features)

        X_train_low, X_test_low, y_train_low, y_test_low = train_test_split(X_low, y_low, test_size=0.2, random_state=42)
        model_low = RandomForestClassifier(n_estimators=100, random_state=42)
        model_low.fit(X_train_low, y_train_low)

        shap_consistency_low, lime_consistency_low = evaluate_local_consistency(model_low, X_test_low.head(100), lime_num_samples_list=[1000, 5000, 20000])

        # plot
        plt.figure(figsize=(10, 5))
        labels = ['SHAP', 'LIME(1000)', 'LIME(5000)', 'LIME(20000)']
        values = [shap_consistency_low, lime_consistency_low[1000], lime_consistency_low[5000], lime_consistency_low[20000]]
        plt.bar(labels, values)
        plt.title(f'Local Consistency Comparison (n_features={n_features})')
        plt.xlabel('Explanation Method')
        plt.ylabel('Consistency Score')
        plt.savefig(f'local_consistency_comparison_n_features_{n_features}.png')

if __name__ == "__main__":
    main() 

お気軽にご相談ください!

弊社のデジタル技術を活用し、貴社のDXをサポートします。

基本的な設計や実装支援はもちろん、機械学習・データ分析・3D計算などの高度な技術が求められる場面でも、最適なソリューションをご提案します。

初回相談は無料で受け付けております。困っていること、是非お聞かせください。