機械学習

3分でわかる!PyTorch Lightningで機械学習をより簡単に実装!

一般的に機械学習モデルの実装や訓練は複雑になりますよね。

研究や業務に集中したいけど、

プログラムを書くのに多くの時間がかかってしまう

と感じる方も多いのではないでしょうか?

この記事では、効率的に機械学習モデルを実装する手助けをしてくれる強力なツールであるPyTorch Lightningの特徴や実装する方法を詳しく解説します。PyTorch Lightningを活用して、機械学習の実装を効率化しましょう

PyTorch Lightningとは

PyTorch Lightningは代表的な機械学習ライブラリであるPyTorchをベースとした、モデルの訓練や開発を効率的に行うためのライブラリです。

概要と特徴

機械学習のプログラムは一般的に複雑であり、短期間で業務や卒業研究で利用するのは困難です。一方、PyTorch Lightningでは機械学習モデルの構築をより簡単に実装でき、ビジネシスや研究に集中できるメリットがあります。

PyTorch Lightningは2019年5月、William Falcon氏によって初めて公開され、PyTorchの軽量ラッパーとして世界中のエンジニアから注目を集めました。

データサイエンスに集中するため、コーディングのコストを最小限に抑えることをコンセプトに開発され、指定されたコードの書き方があります。誰が書いても共通する点はPythonと同じく、メンテナンス性の向上が期待できますね。

モジュール設計

PyTorch Lightningはモジュール化された設計を採用しており、いくつかのコンポーネントに分割されます。例として、3つの主な部品を見てみましょう。

  • Trainer
  • LightningModule
  • LightningDataModule

名前から推測できるかもしれませんが、例えばモデルの訓練・検証・テストを制御するTrainerを用いると、訓練ループを全て記述する必要がなくなり、簡潔にモデルの学習を実装できます

PyTorchとは異なり、LightningModuleではモデルの定義のみならず、損失関数や最適化手法、さらにはトレーニングループまでもまとめて管理します。

また、LightningDataModuleを用いることで、データの読み込みや前処理、検証用とテスト用の分割などを効率的に行えるため、コード全体的がコンパクトになる点が特徴的です。

PyTorch LightningはベースとなるPyTorchと同じく、Define-by-run形式を採用しています。ここではDefine-and-runとの違いを抑えておきましょう。

Define-by-run

  • 実行中にモデルの計算グラフが構築される
  • 演算処理は実行時に計算グラフに追加される
  • 動的な制御や条件分岐が容易なため、柔軟なモデルの構築が可能
  • PyTorch, Chainerなどで採用

Define-and-run

  • 事前にモデルの計算グラフが定義される
  • 定義後は変更することができない
  • 処理の高速化が期待できる
  • TensorFlow 1.xやTheanoなどで採用

PyTorch Lightningの学習方法

PyTorch Lightningの解説やサンプルコードは他の代表的な機械学ライブラリを比較すると少ない傾向があるため、どのような方法で学習を進めるかが大切です

注意点

  • PyTorchの基本的な理解
  • 日本語の情報が少ない

PyTorchをベースとした設計であるため、PyTorchでの基本的な学習ループを理解する必要があります。さらに、PyTorch Lightningに関する日本語での情報は少ないことから、PyTorchをある程度学習することは理解を深めることに繋がるでしょう。

Lightningモジュールは他の機械学習ライブラリでは見かけない特有のtraining_stepconfigure_optimizersなどのコードを実装することがあります。幅広い情報を得るために、英語での文献も活用して学習することが大切です。

公式チュートリアル

  • 情報量が多い
  • 正確な情報がまとめられている
  • 英語で記載されている

PyTorch Lightningは公式ドキュメントが用意されており、正確なプログラムの記述法がまとめられています。ただし、最低限の情報で簡潔に記載するための情報ではないため、初学者にとっては理解が難しいかもしれません。

Webサイトの記事を検索し、気になる箇所について公式ドキュメントを活用すると効率的な学習ができるでしょう。英語での説明ですが、翻訳機を使って読み進めることがおすすめです。

PyTorch Lightningの公式チュートリアルはこちら↓

https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

GitHubやQiitaでの記事

プログラムの倉庫とも言えるGitHubやQiitaの記事ではPyTorch Lightningでのサンプルコードが見つかる可能性が大きいです。実際にライブラリを使い方を見ると、「この場合は〇〇だな」のように幅広いパターンを理解できるため、積極的に活用しましょう。

GitHubやQiitaでの機械学習プログラムは必ずしも正確とは限りません。バージョンごとに記述方法の違い等があるため、実行時にエラーが出た場合は検索したり、公式ドキュメントを参考にして解決する必要があります。

PyTorch Lightningでの実装例

実際にCNNモデルの実装を通して、PyTorch Lightningでの記述がどれほど変化するのかを体験しましょう。基本的にモデルの構成はPyTorchと変わらないため、定義自体はPyTorchライブラリを参考にできます。

ライブラリのインストール

Windowsの方はコマンドプロンプト、Macの方はターミナルを開き、次のコマンドを実行してライブラリをインストールしましょう。PyTorch Lightningライブラリは、Pythonでインポートする際にはplとして読み込まれることがあります。

pip install pytorch_lightning

MNISTを学習するモデルの構築

以下は、手描き数字で有名なMNISTデータセットを学習するCNNモデルをPyTorch Lightningライブラリを用いて実装するプログラム例と実行結果です。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl

# モデルの定義
class SimpleCNN(pl.LightningModule):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

# データローダーの設定
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        self.mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        self.mnist_val = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)

data_module = DataModule(batch_size=64)
model = SimpleCNN()

# CUDAが利用可能なら使う
if torch.cuda.is_available():
    model = model.cuda()

# モデルの訓練
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, datamodule=data_module)

「実行結果」

PyTorch Lightningでの実行結果

視覚的にモデルの訓練が進む過程が理解でき、層の構成やパラメータもまとめて表示されます。特にモデルの学習ループの記述が必要なく、全体的にスッキリとして印象を受けたのではないでしょうか?この簡潔さがPyTorch Lightningの魅力です

PyTorchとの違い

PyTorchをより簡潔に実装できるライブラリとのことですが、具体的にどれほどの差があるかを確認してみましょう。PyTorch Lightningでは省略できた訓練ループをPyTorchで実装した場合は以下のようになります。

# PyTorchでの訓練ループ
num_epochs = 5
for epoch in range(num_epochs):
    # モデルの訓練
    model.train()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_loss:.4f}")

PyTorch Lightningでは上記のループを書く必要がなく、とてもコンパクトなコードでモデルの訓練や検証を実行できます。機械学習をより簡潔に実装するためにも、積極的に活用していきましょう。

まとめ

PyTorch Lightningライブラリを利用すると複雑な機械学習のモデルを比較的シンプルに実装できます。ある程度、機械学習の訓練ループやPyTorchに関する知識を持つ方や、短期間で卒業研究や業務に機械学習を活用したい方は是非、PyTorch Lightningを用いて深層学習を実装してみましょう。