AIが作る心電図解析の未来
2025/02/12 AI Python 活動報告
はじめに
私たちの体にとってなくてはならない臓器、心臓。それは血液を循環させ続けるべく、一定のリズムでドックンドックンと動き続けています。
さて、心臓には心房と心室という部分がある事はご存じの事かと思いますが、ではその内心拍のリズムを生み出しているのはどちらでしょうか?――正解は心房です。心房が収縮リズムの起点となっており、そこから発せられた電気信号が心房自身、そして心室を一定のリズムで動かしているのです。
心房細動とは
そんな心房が正しいリズムを刻めなくなる病気があります。それが「心房細動」と言う病気です。
心房細動とはその名の通り心房が細かく動いている病態を指します。通常の心臓は一分間に60回ほど拍動しているのですが、心房細動が起こると心房は一分間に500回ほどピクピクと動くようになります。(この時、心室はバラバラのリズムで一分間に120回程の回数脈打ちます)
心房細動が起こると心臓の機能が落ちて心不全になったり、心房で滞留した血液が固まり、それが脳に飛んで脳梗塞を起こしたりと、命に係わる疾患を引き起こす可能性があります。心房細動を早期発見・早期治療する事は非常に重要な目標なのです。
さて、心房細動の発見には心電図を用います。心房細動が起きている時の心電図では、心房から発せられる信号(P波)が無数に現れたり、心室から発せられる信号(QRS波)がバラバラのリズムになっていたりするので、少し勉強すればすぐに診断できるようになるでしょう。
しなしながら、このような異常な心電図を呈するのは「今まさに心房細動が起きている患者」です。「さっきまで心房細動が起きていたものの、今は正常(洞調律=sinus rhythm)に戻っている」というような状態で心電図を取った場合、なかなか診断するのは困難です。
そこで「洞調律の心電図でも、どうにかして心房細動を見つけられないか」という目標が生まれました。そしてそれをAIを使って実現した研究があります。
心房細動発見AI
論文紹介
Zachi I Attiaらは「健常者から得られた洞調律の心電図」と「心房細動が見つかった人から得られた洞調律の心電図」を入力データとして与え、それらを区別するAIの開発を行いました。[1]
彼らは180922人の患者から得られた649931個の心電図からなるデータセットを使って学習を行い、その結果非常に高い精度で心房細動を検知できるAIモデルの開発に成功しました。
このようなAIモデルを活用する事で、今までは見過ごされていた患者についても、心房細動のリスクがあると判断して精査を行えるようになるかもしれません。
自分も作ってみたい……!
このような論文を知って、自分も心電図を解析するAIを開発してみたいと思った私は、心電図解析AIを作ってみることにしました。
データセット
無料で利用できる12誘導心電図のデータセットを探していたところ、PTB-XL[2][3]というデータセットを見つけました。このデータセットは18885人の患者、21837個の12誘導心電図のデータからなるデータセットであり、無料でアクセス出来るデータセットの中では最大規模と言える巨大なデータセットです。
加えて、すべての心電図が一定の秒数だけ切り抜かれた物となっており、データの前処理をしなくて済む点も扱いやすいと感じました。
目標
このデータセットでは心電図ごとに「正常」「伝導障害」「心筋梗塞」「心肥大」「ST/Tの変化」の5種類のラベルが付けられており、これを推測するAIを作ってみることにしました。
基礎知識
どのようなAI(ニューラルネットワーク)を組んだのかお話する前に、まずは基礎知識をおさらいしましょう。なお、分かりやすさの為に正確性には欠ける説明になっていますので、ご注意ください。
さて、ニューラルネットワークは「沢山の関数を合成した巨大関数」と言えます。例えば関数f
, g
, h
に以下のような役割があるとします。
関数 | 役割 |
---|---|
f | データの一部を見る |
g | 結果を集計する |
h | データ全体を詳しく分析する |
これら三つの関数を合成した関数h(g(f(input)))
はどのような役割を持つでしょうか? そうです、「まずデータの一部を見て、次にそれを集計して、その結果を詳しく分析する」という関数にな訳です。
このように関数を組み合わせてAIを作るのですが、この関数一つ一つを「層=Layer」と呼びます。
実際にネットで「ニューラルネットワーク layer」「ディープラーニング layer」のように検索してみましょう。沢山の記事が出てくるはずです。
詳しい内容はここでは割愛しますが、今回のモデルで重要になってくるLayerだけ解説しようと思います。それが以下の三つのLayerです。
Layer名(keras) | Layer名(pytorch) | 役割 |
---|---|---|
Conv1D | Conv1d | データの一部を見る |
MaxPool1D | MaxPool1d | 結果を集計する |
Dense | Linear | データ全体を詳しく分析する |
Dropout | Dropout | データをランダムに隠す |
まさに先ほど説明したf
, g
, h
の正式名称が出てきましたね。なお、使うライブラリによって若干名称が異なっていることに注意して下さい。
そして新しくDropout
というよく分からないものが登場しました。これは一部のデータをランダムで隠す事によってAIに「応用的な理解」をさせることが目的の層です。
完成したモデル
実装にはPytorchというライブラリを使いました。必要な物をインポートしておきます。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
続いて「Conv1d=データの一部を見る」と「MaxPool1d=結果を集計する」を組み込んだ関数ConvPool
を用意しました。
ところで、ここにはBatchNorm1d
というLayerも組み込まれているのですが、これは学習時に偶然生まれた「変な偏り」と均す作用がある層です。詳しくは割愛します。
また leaky_relu
という関数は「活性化関数」と呼ばれます。ざっくり言うと「xが正か負かで挙動が変わる」関数です。ひとまず「ConvやLinear(Dense)に作用させると良い」と思っておけばいいかと思います。
class ConvPool(nn.Module):
def __init__(self, in_channel, out_channel):
super(ConvPool, self).__init__()
self.batch_norm = nn.BatchNorm1d(in_channel)
self.conv1 = nn.Conv1d(in_channel, in_channel, 7, padding='same')
self.conv2 = nn.Conv1d(in_channel, out_channel, 7, padding='same')
self.pooling = nn.MaxPool1d(kernel_size=2, stride=2)
def forward(self, x):
x = self.batch_norm(x)
x = F.leaky_relu(self.conv1(x))
x = F.leaky_relu(self.conv2(x))
x = self.pooling(x)
return x
続いて「Linear=データ全体を詳しく分析する」と「Dropout=データをランダムに隠す」を使った関数を作りました。
linear1とlinear2の結果に「leaky_relu
」を作用させているのが見て取れます。
しかしlinear3の直後は違った事を行っています。ひとまずreturn F.sigmoid(x)
の方を見て頂きたいのですが、ここで使われているsigmoid
は「情報を確率に変換する」関数です。最終出力は「この疾患にかかっている確率」ですので、sigmoid
を作用させているのです。
なお、学習中は諸事情によりsigmoid
の計算を後回しにしています。ですので学習中か否かで分岐を行っています。
class DenseLayer(nn.Module):
def __init__(self):
super(DenseLayer, self).__init__()
self.batch_norm = nn.BatchNorm1d(16)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(16 * 125, 2048)
self.dropout1 = nn.Dropout(0.25)
self.linear2 = nn.Linear(2048, 2048)
self.dropout2 = nn.Dropout(0.25)
self.linear3 = nn.Linear(2048, 5)
def forward(self, x):
x = self.batch_norm(x)
x = self.flatten(x)
x = F.leaky_relu(self.linear1(x))
x = self.dropout1(x)
x = F.leaky_relu(self.linear2(x))
x = self.dropout2(x)
x = self.linear3(x)
if self.training:
return x
else:
return F.sigmoid(x)
最後に上記二つの関数を組み合わせて、AIモデルを作成しました。
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.batch_norm = nn.BatchNorm1d(12)
self.conv_pools = nn.ModuleList([
ConvPool(12, 32),
ConvPool(32, 32),
ConvPool(32, 16)
])
self.dense = DenseLayer()
def forward(self, x):
x = self.batch_norm(x)
for conv_pool in self.conv_pools:
x = conv_pool(x)
x = self.dense(x)
return x
結果
上記モデルを学習させ、出来上がったモデルのROC曲線を描きました。非常にうまく学習が進んでいることが見て取れます。
最後に
実際に心電図を解析するAIを作ってみて、AIに対してより深い理解を出来たと思います。みなさんもぜひ作ってみてください!
参考文献
[1]
Z. I. Attia et al., “An artificial intelligence-enabled ECG algorithm for the identification of patients with atrial fibrillation during sinus rhythm: a retrospective analysis of outcome prediction,” The Lancet, vol. 394, no. 10201, pp. 861–867, Sep. 2019, doi: 10.1016/S0140-6736(19)31721-0.
[2]
Wagner, P., Strodthoff, N., Bousseljot, R., Samek, W. & Schaeffter, T. PTB-XL, a large publicly available electrocardiography dataset. PhysioNet. https://doi.org/10.13026/6sec-a640 (2020).
[3]
Wagner, P., Strodthoff, N., Bousseljot, RD. et al. PTB-XL, a large publicly available electrocardiography dataset. Sci Data 7, 154 (2020). https://doi.org/10.1038/s41597-020-0495-6