Graph Neural Networksを用いた異常検知入門

はじめに

近年、グラフニューラルネットワーク(GNN)は、グラフ構造を持つデータの分析において大きな注目を集めています。本記事では、GNNを用いた異常検知の基本的な実装方法について、架空の取引ネットワークデータを例に解説します。

この記事を書いたひと

デジタルリアクタ合同会社 代表
機械学習・統計、数値計算などの領域を軸としたソフトウェアエンジニアリングを専門としています。スタートアップからグローバル企業まで、さまざまなスケールの企業にて、事業価値に直結する計算システムを設計・構築してきました。主に機械学習の応用分野において、出版・発表、特許取得等の実績があり、また、IT戦略やデータベース技術等の情報処理に関する専門領域の国家資格を複数保有しています。九州大学理学部卒業。日本ITストラテジスト協会正会員。

対象読者:

  • グラフ構造データに対する機械学習に興味がある方
  • 金融取引における不正検知に応用できる技術を探している方
  • GNNの実装を通して、その動作原理を理解したい方

記事のポイント:

  • GNN、特にGraph Convolutional Network (GCN) の基礎と、異常検知への応用を解説
  • 架空の取引ネットワークデータを生成し、GCNによる異常検知モデルを構築・評価
  • エッジ情報(取引関係)の有無による性能比較(GCN vs MLP)を行い、GCNの有効性を示唆
  • モデルの実装にはPyTorch Geometricライブラリを使用

グラフニューラルネットワークの基礎

GNNとは

グラフニューラルネットワーク(Graph Neural Network, GNN)は、グラフ構造を持つデータを直接処理できるニューラルネットワークです。従来のニューラルネットワークでは扱うことのできなかった、ノード間の関係性やグラフ全体の構造を考慮した学習が可能です。

GNNは、各ノードの特徴量を、そのノードの近傍情報を使って更新します。これにより、グラフの構造を考慮した学習が可能になります。応用例としては、ソーシャルネットワーク分析、推薦システム、化学構造解析などがあります。

GCNの動作原理

GCN(Graph Convolutional Network)は、GNNの一種であり、基本的な動作原理は以下の3つのステップで構成されています。

  1. 特徴量の集約: 各ノードは、エッジで接続された隣接ノードの特徴量を収集します。これにより、局所的な構造を学習します。
  2. 特徴量の更新: 集約された情報を基に、自身の特徴量を更新します。非線形変換を適用することで、より豊かな表現を獲得します。
  3. 多層での処理: 上記の処理を複数層重ねることで、より広範な構造を学習します。本実装では2層のGCNを使用し、2ホップ先までの情報を考慮します。

これらの処理は以下の数式で表されます。

H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})

この数式は、GCNの核心部分を表しています。各記号の意味は以下のテーブルの通りです。

記号 説明
H^{(l)} l 層目のノードの特徴量を集めた行列。各行が各ノードの特徴ベクトルに対応します。
\tilde{A} 自己ループを加えた隣接行列。グラフの接続構造を表します。自己ループは、各ノードが自分自身の情報も考慮することを意味します。
\tilde{D} \tilde{A} の次数行列。各ノードに接続するエッジの数を表します(自己ループも含む)。
W^{(l)} 学習可能な重み行列。l 層目の学習パラメータであり、特徴量の変換を行います。
\sigma 非線形活性化関数。ReLUなどが用いられ、ニューラルネットワークに非線形性をもたらします。

この式は、各ノードの新しい特徴量を、隣接ノードの特徴量と自身の特徴量を組み合わせて計算することを示しています。\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}} の部分は、隣接ノードからの情報を集約し、正規化する役割を果たします。

取引ネットワークデータの生成

今回は仮想データとして、次の画像のような構造のグラフを生成します。

ノードの特徴量設計

取引ネットワークにおける各ノードは取引主体を表し、以下の4次元の特徴量を持ちます。

  1. 取引金額の平均
  2. 取引頻度
  3. 取引先数
  4. 取引の時間的分散

これらの特徴量は、各取引主体の取引行動を数値化したものです。これらを組み合わせることで、各取引主体の行動パターンを表現し、異常検知に利用します。

正常ノードと異常ノードの特徴

正常ノードの特徴

正常ノードの特徴量は、一般的な取引パターンを反映するように設計されています。取引金額は平均的で、取引頻度も一定範囲内、取引先も限定的で、取引の時間的なばらつきも少ないという特徴があります。

# 正常ノードの特徴量生成
normal_features = torch.zeros(n_normal, 4)
normal_features[:, 0] = torch.normal(mean=1000.0, std=200.0, size=(n_normal,))  # 取引金額
normal_features[:, 1] = torch.normal(mean=10.0, std=2.0, size=(n_normal,))      # 取引頻度
normal_features[:, 2] = torch.normal(mean=5.0, std=1.0, size=(n_normal,))       # 取引先数
normal_features[:, 3] = torch.normal(mean=2.0, std=0.5, size=(n_normal,))       # 時間的分散

上記のコードでは、各特徴量に対して平均値と標準偏差を指定し、正規分布に従う乱数を生成しています。各特徴量の具体的な設定は以下のテーブルの通りです。

特徴量 平均値 標準偏差 説明
取引金額 1000.0 200.0 一般的な取引規模、適度なばらつき
取引頻度 10.0 2.0 1日あたりの平均取引回数
取引先数 5.0 1.0 定常的な取引関係の数
時間的分散 2.0 0.5 取引タイミングの規則性

異常ノードの特徴

異常ノードは、3つの異なるパターンを実装しました。各パターンは、特徴的な不正取引のシナリオを表現しています。

# 異常ノードの特徴量生成(抜粋)
anomaly_features = torch.zeros(n_anomaly, 4)
n_pattern1 = n_anomaly // 3

# パターン1: 大口取引パターン
anomaly_features[:n_pattern1, 0] = torch.normal(mean=5000.0, std=1000.0, size=(n_pattern1,))

各異常パターンの特徴と、想定される不正シナリオは以下のテーブルの通りです。

パターン 特徴量の変更 想定シナリオ
大口取引 取引金額: 5000.0±1000.0 マネーロンダリング、不正な資金移動
多数取引 取引頻度: 30.0±5.0
取引先数: 15.0±3.0
分散型の不正送金、取引分割による規制回避
不規則取引 時間的分散: 8.0±2.0 自動化された不正取引、営業時間外取引

エッジの生成ロジック

エッジは、取引主体間の関係性を表します。正常な取引関係と異常な取引関係を区別するために、エッジの生成ロジックにも工夫を凝らしました。

正常ノードのエッジ生成

正常ノード間のエッジは、コミュニティ構造を形成するように生成されます。つまり、似たような取引パターンを持つノード同士が繋がりやすくなっています。

# 正常ノード間のエッジ生成(抜粋)
for i in range(n_normal):
    n_edges = int(torch.normal(float(base_edges_per_node), 1.0, size=(1,)).item())
    for _ in range(n_edges):
        j = int(torch.normal(mean=torch.tensor(float(i)), std=torch.tensor(float(n_normal/5))).item()) % n_normal

異常ノードのエッジ生成

異常ノードのエッジ生成は、各異常パターンに応じて異なるロジックを適用しています。

パターン エッジ数 接続先の選択方法
大口取引 固定2本 正常ノードからランダム
多数取引 固定10本 正常ノードからランダム
不規則取引 1-7本 全ノードからランダム

GCNによる異常検知モデル

モデルの実装

本記事では、PyTorch Geometricライブラリを用いてGCNを実装します。GCNの層を2層重ねることで、ノードの特徴量をより効果的に集約し、異常検知の精度を高めます。

class GCNAnomalyDetector(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

異常検知の仕組み

GCNは、ノードの特徴量とエッジの情報(接続関係)を組み合わせて、異常を検出します。

  1. 構造的特徴の学習: エッジパターンから取引関係の構造を学習し、正常な取引ネットワークの特徴を把握します。
  2. 複合的な判断: ノードの特徴量(取引金額、頻度など)とエッジパターン(接続数、接続先の特性)の両方の情報を組み合わせて総合的に判断します。

エッジ情報の重要性

エッジ情報の有効性を検証するため、同じデータセットに対してGCN(エッジ情報あり)とMLP(エッジ情報なし)の2つのモデルで比較実験を行いました。

実験設定

項目 説明
ノード数 300(訓練・テストそれぞれ)
異常ノード比率 10%

実験結果

評価指標 GCN(エッジ情報あり) MLP(エッジ情報なし)
テストデータAUC 0.8048 0.7004


考察

実験結果から、今回の仮想データにおいては、GCNがMLPよりも高い性能を示すことがわかりました。この結果は、エッジ情報が異常検知において重要な役割を果たしていることを示唆しています。

  1. 汎化性能の向上: GCNモデルはテストデータでより高いAUCスコア(0.8048)を達成しました。エッジ情報を活用することで、未知のデータに対する異常パターンの検出能力が向上したと考えられます。
  2. エッジパターンの有効性: 特徴量だけでは判別が難しい異常を、エッジパターンの違いから検出できます。
  3. コミュニティ構造の利用: コミュニティ構造を活用することで、より高度な異常検知が可能になります。

まとめ

本記事では、GCNを用いて取引ネットワークにおける異常検知を実装しました。実験の結果から、以下の知見が得られました。

  • エッジ情報を活用するGCNモデルは、特徴量のみを使用するMLPモデルと比較して、より高い汎化性能を示しました。
  • コミュニティ構造を考慮することで、より豊かな文脈情報を活用した異常検知が可能になっていると考えられます。

コード

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, roc_auc_score
import seaborn as sns
import sklearn.metrics
import argparse

# 引数パーサーの設定
parser = argparse.ArgumentParser(description='GCNを用いた異常検知')
parser.add_argument('--no-edge-info', action='store_true',
                   help='エッジ情報を使用しない(MLPモードで実行)')
args = parser.parse_args()

# モード名の設定(ファイル名用)
mode = "mlp" if args.no_edge_info else "gcn"

# 乱数シードを固定
torch.manual_seed(102)
np.random.seed(102)

def generate_transaction_network(n_nodes=100, base_edges_per_node=3, anomaly_ratio=0.1, use_edge_info=True):
    """取引ネットワークを生成する関数

    Parameters:
    -----------
    n_nodes : int
        ノードの総数
    base_edges_per_node : int
        1ノードあたりの基本エッジ数
    anomaly_ratio : float
        異常ノードの割合
    use_edge_info : bool
        エッジ情報を使用するかどうか
    """
    n_normal = int(n_nodes * (1 - anomaly_ratio))
    n_anomaly = n_nodes - n_normal

    # 正常ノードの特徴量生成(取引パターンを表現)
    # 特徴量の差を小さくするため、標準偏差を大きくする
    normal_features = torch.zeros(n_normal, 4)
    normal_features[:, 0] = torch.normal(mean=1000.0, std=500.0, size=(n_normal,))  # 取引金額
    normal_features[:, 1] = torch.normal(mean=10.0, std=5.0, size=(n_normal,))      # 取引頻度
    normal_features[:, 2] = torch.normal(mean=5.0, std=3.0, size=(n_normal,))       # 取引先数
    normal_features[:, 3] = torch.normal(mean=2.0, std=1.0, size=(n_normal,))       # 時間的分散

    # 異常ノードの特徴量生成(異常パターンを表現)
    # 特徴量を正常ノードに近づける
    anomaly_features = torch.zeros(n_anomaly, 4)
    # パターン1: やや大きな取引金額
    n_pattern1 = n_anomaly // 3
    anomaly_features[:n_pattern1, 0] = torch.normal(mean=2000.0, std=500.0, size=(n_pattern1,))
    anomaly_features[:n_pattern1, 1:] = normal_features[:n_pattern1, 1:]

    # パターン2: やや多い取引頻度と取引先
    n_pattern2 = n_anomaly // 3
    anomaly_features[n_pattern1:n_pattern1+n_pattern2, 0] = normal_features[:n_pattern2, 0]
    anomaly_features[n_pattern1:n_pattern1+n_pattern2, 1] = torch.normal(mean=15.0, std=5.0, size=(n_pattern2,))
    anomaly_features[n_pattern1:n_pattern1+n_pattern2, 2] = torch.normal(mean=8.0, std=3.0, size=(n_pattern2,))
    anomaly_features[n_pattern1:n_pattern1+n_pattern2, 3] = normal_features[:n_pattern2, 3]

    # パターン3: やや大きな時間的分散
    n_pattern3 = n_anomaly - n_pattern1 - n_pattern2
    anomaly_features[n_pattern1+n_pattern2:, 0:3] = normal_features[:n_pattern3, 0:3]
    anomaly_features[n_pattern1+n_pattern2:, 3] = torch.normal(mean=4.0, std=1.0, size=(n_pattern3,))

    # 特徴量を結合
    x = torch.cat([normal_features, anomaly_features], dim=0)

    # 特徴量の正規化
    mean = x.mean(dim=0, keepdim=True)
    std = x.std(dim=0, keepdim=True)
    x = (x - mean) / std

    # エッジの生成(取引関係を表現)
    if use_edge_info:
        edge_list = []

        # 正常ノード間のエッジ生成(コミュニティ構造を強化)
        communities = [[] for _ in range(5)]  # 5つのコミュニティを作成
        nodes_per_community = n_normal // 5

        # ノードをコミュニティに割り当て
        for i in range(n_normal):
            community_idx = i // nodes_per_community
            if community_idx < 5:  # 余りのノードは最後のコミュニティに
                communities[community_idx].append(i)
            else:
                communities[4].append(i)

        # コミュニティ内のエッジを生成(密な接続)
        for community in communities:
            for i in community:
                # コミュニティ内で密な接続
                n_edges = int(torch.normal(float(base_edges_per_node * 2), 1.0, size=(1,)).item())
                for _ in range(n_edges):
                    j = np.random.choice(community)
                    if i != j:
                        edge_list.append([i, j])
                        edge_list.append([j, i])

                # コミュニティ間の疎な接続
                if torch.rand(1).item() < 0.3:  # 30%の確率で他のコミュニティとも接続
                    other_community = communities[np.random.randint(0, 5)]
                    j = np.random.choice(other_community)
                    if i != j:
                        edge_list.append([i, j])
                        edge_list.append([j, i])

        # 異常ノード間のエッジ生成(より特徴的なパターン)
        for i in range(n_normal, n_nodes):
            pattern_type = (i - n_normal) // (n_anomaly // 3)
            if pattern_type == 0:  # 大きな取引金額のパターン
                # 複数のコミュニティと接続
                target_communities = np.random.choice(5, 2, replace=False)  # 2つのコミュニティを選択
                for comm_idx in target_communities:
                    for _ in range(2):  # 各コミュニティから2つのノードと接続
                        j = np.random.choice(communities[comm_idx])
                        edge_list.append([i, j])
                        edge_list.append([j, i])

            elif pattern_type == 1:  # 多数の取引先パターン
                # 1つのコミュニティに集中して接続
                target_community = np.random.randint(0, 5)
                n_edges = 8
                connected_nodes = set()
                for _ in range(n_edges):
                    j = np.random.choice(communities[target_community])
                    if j not in connected_nodes:
                        connected_nodes.add(j)
                        edge_list.append([i, j])
                        edge_list.append([j, i])

            else:  # 不規則な取引パターン
                # ランダムなコミュニティと接続
                n_edges = np.random.randint(1, 4)  # より少ない接続数
                for _ in range(n_edges):
                    target_community = np.random.randint(0, 5)
                    j = np.random.choice(communities[target_community])
                    edge_list.append([i, j])
                    edge_list.append([j, i])

        edge_index = torch.tensor(edge_list).t()
    else:
        # エッジ情報を使用しない場合は、自己ループのみ設定
        edge_index = torch.stack([torch.arange(n_nodes), torch.arange(n_nodes)], dim=0)

    # ラベルの生成
    y = torch.zeros(n_nodes)
    y[n_normal:] = 1

    return Data(x=x, edge_index=edge_index, y=y)

class GCNAnomalyDetector(torch.nn.Module):
    """GCNを用いた異常検知モデル"""
    def __init__(self, in_channels, hidden_channels, out_channels, use_edge_info=True):
        super().__init__()
        if use_edge_info:
            # GCNモード:エッジ情報を使用
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        else:
            # MLPモード:エッジ情報を使用しない
            self.conv1 = torch.nn.Linear(in_channels, hidden_channels)
            self.conv2 = torch.nn.Linear(hidden_channels, out_channels)
        self.use_edge_info = use_edge_info

    def forward(self, x, edge_index):
        if self.use_edge_info:
            # GCNモード
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.conv2(x, edge_index)
        else:
            # MLPモード
            x = self.conv1(x)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.conv2(x)
        return torch.sigmoid(x)

def evaluate_model(model, data):
    """モデルの評価を行う関数"""
    model.eval()
    with torch.no_grad():
        pred = model(data.x, data.edge_index)
        pred_probs = pred.squeeze()  # 予測確率
        pred_labels = (pred_probs > 0.5).float()  # 予測ラベル

        # 評価指標の計算
        accuracy = accuracy_score(data.y.numpy(), pred_labels.numpy())
        precision = precision_score(data.y.numpy(), pred_labels.numpy())
        recall = recall_score(data.y.numpy(), pred_labels.numpy())
        auc = roc_auc_score(data.y.numpy(), pred_probs.numpy())
        conf_matrix = confusion_matrix(data.y.numpy(), pred_labels.numpy())

        return accuracy, precision, recall, auc, conf_matrix, pred_labels

# トレーニングデータとテストデータの生成
train_data = generate_transaction_network(n_nodes=300, base_edges_per_node=3, use_edge_info=not args.no_edge_info)
test_data = generate_transaction_network(n_nodes=300, base_edges_per_node=3, use_edge_info=not args.no_edge_info)

# モデルの初期化と学習
model = GCNAnomalyDetector(in_channels=4, hidden_channels=16, out_channels=1, use_edge_info=not args.no_edge_info)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print(f"\nモデルモード: {'MLP(エッジ情報なし)' if args.no_edge_info else 'GCN(エッジ情報あり)'}")
print("モデルの学習を開始します...")

# 学習ループ
train_losses = []
test_losses = []
for epoch in range(200):
    # 訓練
    model.train()
    optimizer.zero_grad()
    out = model(train_data.x, train_data.edge_index)
    train_loss = F.binary_cross_entropy(out.squeeze(), train_data.y)
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())

    # テストデータでの損失計算
    model.eval()
    with torch.no_grad():
        test_out = model(test_data.x, test_data.edge_index)
        test_loss = F.binary_cross_entropy(test_out.squeeze(), test_data.y)
        test_losses.append(test_loss.item())

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}/200: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")

print("\n学習が完了しました。評価を開始します...")

# 学習曲線の可視化
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.title(f'Training and Test Loss ({mode.upper()})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.savefig(f'learning_curves_{mode}.png')
plt.close()

# モデルの評価
train_metrics = evaluate_model(model, train_data)
test_metrics = evaluate_model(model, test_data)

# 評価結果の表示と保存
def plot_confusion_matrix(conf_matrix, title, filename):
    plt.figure(figsize=(8, 6))
    if conf_matrix.size > 0:  # 混同行列が空でないことを確認
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title(f'{title} ({mode.upper()})')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(f'generated/gnn-anomaly-detection/{filename}_{mode}.png')
    plt.close()

# 訓練データの混同行列
plot_confusion_matrix(train_metrics[4], 'Confusion Matrix (Training Data)', 'train_confusion_matrix')

# テストデータの混同行列
plot_confusion_matrix(test_metrics[4], 'Confusion Matrix (Test Data)', 'test_confusion_matrix')

# ネットワークの可視化(テストデータ)
def visualize_network(data, pred_labels, title, filename):
    if not args.no_edge_info:  # エッジ情報がある場合のみ可視化
        G = nx.Graph()

        # すべてのノードを追加
        for i in range(len(data.y)):
            G.add_node(i)

        # エッジを追加
        edge_list = data.edge_index.t().tolist()
        G.add_edges_from(edge_list)

        plt.figure(figsize=(12, 8))
        pos = nx.spring_layout(G)

        # データをNumPy配列に変換
        true_labels = data.y.numpy()
        pred_labels = (pred_labels > 0.5).numpy()  # 予測確率を二値ラベルに変換

        # 正常ノードと異常ノードを別々に描画(予測と実際のラベルを比較)
        true_normal = np.where((true_labels == 0) & (pred_labels == 0))[0]  # 正常を正常と予測
        true_anomaly = np.where((true_labels == 1) & (pred_labels == 1))[0]  # 異常を異常と予測
        false_normal = np.where((true_labels == 1) & (pred_labels == 0))[0]  # 異常を正常と予測
        false_anomaly = np.where((true_labels == 0) & (pred_labels == 1))[0]  # 正常を異常と予測

        nx.draw_networkx_nodes(G, pos, nodelist=true_normal, 
                              node_color='lightblue', node_size=100, label='True Normal')
        nx.draw_networkx_nodes(G, pos, nodelist=true_anomaly,
                              node_color='red', node_size=100, label='True Anomaly')
        nx.draw_networkx_nodes(G, pos, nodelist=false_normal,
                              node_color='orange', node_size=100, label='Missed Anomaly')
        nx.draw_networkx_nodes(G, pos, nodelist=false_anomaly,
                              node_color='purple', node_size=100, label='False Anomaly')
        nx.draw_networkx_edges(G, pos, alpha=0.2)

        plt.title(f'{title} ({mode.upper()})')
        plt.legend()
        plt.savefig(f'{filename}_{mode}.png')
        plt.close()

# トレーニングデータとテストデータのネットワーク可視化
visualize_network(train_data, train_metrics[5], 'Training Data Network', 'train_network')
visualize_network(test_data, test_metrics[5], 'Test Data Network', 'test_network')

# 評価結果をファイルに保存
with open(f'evaluation_results_{mode}.txt', 'w') as f:
    f.write(f'Model: {mode.upper()}\n\n')
    f.write('Training Data Metrics:\n')
    f.write(f'Accuracy: {train_metrics[0]:.4f}\n')
    f.write(f'Precision: {train_metrics[1]:.4f}\n')
    f.write(f'Recall: {train_metrics[2]:.4f}\n')
    f.write(f'AUC: {train_metrics[3]:.4f}\n\n')

    f.write('Test Data Metrics:\n')
    f.write(f'Accuracy: {test_metrics[0]:.4f}\n')
    f.write(f'Precision: {test_metrics[1]:.4f}\n')
    f.write(f'Recall: {test_metrics[2]:.4f}\n')
    f.write(f'AUC: {test_metrics[3]:.4f}\n')

# ROC曲線の描画
def plot_roc_curve(model, data, title, filename):
    model.eval()
    with torch.no_grad():
        pred = model(data.x, data.edge_index)
        pred_probs = pred.squeeze().numpy()
        true_labels = data.y.numpy()

        fpr, tpr, _ = sklearn.metrics.roc_curve(true_labels, pred_probs)

        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc_score(true_labels, pred_probs):.4f})')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'{title} ({mode.upper()})')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.savefig(f'{filename}_{mode}.png')
        plt.close()

# 訓練データとテストデータのROC曲線を描画
plot_roc_curve(model, train_data, 'ROC Curve (Training Data)', 'train_roc_curve')
plot_roc_curve(model, test_data, 'ROC Curve (Test Data)', 'test_roc_curve')

print("\n評価結果:")
print(f"モデル: {mode.upper()}")
print(f"訓練データ - 精度: {train_metrics[0]:.4f}, 適合率: {train_metrics[1]:.4f}, 再現率: {train_metrics[2]:.4f}, AUC: {train_metrics[3]:.4f}")
print(f"テストデータ - 精度: {test_metrics[0]:.4f}, 適合率: {test_metrics[1]:.4f}, 再現率: {test_metrics[2]:.4f}, AUC: {test_metrics[3]:.4f}")

お気軽にご相談ください!

弊社のデジタル技術を活用し、貴社のDXをサポートします。

基本的な設計や実装支援はもちろん、機械学習・データ分析・3D計算などの高度な技術が求められる場面でも、最適なソリューションをご提案します。

初回相談は無料で受け付けております。困っていること、是非お聞かせください。