Pytorch Lightning の使い方
2023/10/18
AI
目次
PyTorch Lightning
PyTorch Lightning は、PyTorch をより効率的に使えるようにする軽量ライブラリです。以下に基本的な使い方を説明します。
インストール
まず PyTorch Lightning をインストールします:
pip install lightning
conda install lightning -c conda-forge
モデルの定義
LightningModule を継承してモデルを定義します。以下は簡単な例です:
import pytorch_lightning as pl
# LightningModule を継承したクラス定義
class MyLightningModule(L.LightningModule):
def __init__(self, model, criterion, lr):
super().__init__()
self.model = model # モデル
self.lr = lr # 学習率
self.criterion = criterion # 損失関数
# 学習ステップの定義
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("train_loss", loss, prog_bar=True) # 学習損失の記録
return loss
# 検証ステップの定義
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs, _ = self.model(inputs) # モデルによる予測
loss = self.criterion(outputs, labels)
# 正解率(accuracy)を計算
acc = torch.sum(torch.argmax(outputs, dim=1) == labels).item() / len(labels)
self.log("val_loss", loss, prog_bar=True) # 検証損失の記録
self.log("val_acc", acc, prog_bar=True) # 検証精度の記録
return loss
# オプティマイザと学習率スケジューラーの定義
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.model.parameters(),
self.lr,
betas=(0.9, 0.999),
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=-1,
gamma=0.1,
)
return [optimizer], [scheduler]
- ここで、
self.log("val_acc", acc, prog_bar=True)などのself.logは、学習中に記録したい値を指定します。これにより、TensorBoard や CSVLogger などで確認できます。prog_barを True にすると、進捗バーで確認できます。 configure_optimizers()でオプティマイザと学習率スケジューラーを定義します。ここでは、AdamW を使用していますが、任意のオプティマイザや学習率スケジューラーを指定。
各メソッドの役割
training_step: 1 epoch あたりの学習ステップで呼ばれます。ここで、損失を計算して返します。validation_step: 1 epoch あたりの検証ステップで呼ばれます。ここで、損失と正解率を計算して返します。test_step: 1 epoch あたりのテストステップで呼ばれます。ここで、損失と正解率を計算して返します。configure_optimizers: オプティマイザと学習率スケジューラーを定義します。これは、学習開始時に一度だけ呼ばれます。
データセットとデータローダーの準備
pytorch の dataloader を使う場合
# ## build dataset
train_dataset = ImageFolder(
root=data_root / train_data_path,
transform=train_transformer,
)
val_dataset = ImageFolder(
root=data_root / val_data_path,
transform=val_transformer,
)
# ## build dataloader
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
)
LightningDataModule を使う場合
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class MyDataModule(LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
self.transform = transforms.ToTensor()
def prepare_data(self):
# データのダウンロードなど、1回だけ実行される処理
MNIST(root='./data', train=True, download=True)
MNIST(root='./data', train=False, download=True)
def setup(self, stage=None):
# データセットの分割や前処理を定義
full_dataset = MNIST(root='./data', train=True, transform=self.transform)
self.train_dataset, self.val_dataset = random_split(full_dataset, [55000, 5000])
self.test_dataset = MNIST(root='./data', train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
各メソッドの役割
| メソッド名 | 説明 |
|---|---|
prepare_data() |
データのダウンロードなど、1 度だけ実行される処理(マルチ GPU 環境でも 1 回のみ) |
setup() |
学習/検証/テストデータセットの作成・分割を行う(分散環境では各ワーカーで呼ばれる) |
train_dataloader() |
学習用のデータローダーを返す |
val_dataloader() |
検証用のデータローダーを返す |
test_dataloader() |
テスト用のデータローダーを返す |
トレーナーの作成と学習の実行
trainer = pl.Trainer(
max_epochs=10,
accelerator="auto",
devices="auto",
logger=True,
log_every_n_steps=10,
default_root_dir="./logs"
)
主な引数一覧
| 引数名 | 説明 |
|---|---|
max_epochs |
学習する最大エポック数 |
accelerator |
使用するデバイス ("auto" / "cpu" / "gpu" / "tpu" など) |
devices |
使用するデバイス数 or 番号指定(例: 1, [0,1]) |
logger |
ロガーを使用するかどうか(True または TensorBoardLogger などのインスタンス) |
callbacks |
コールバック関数のリスト(例: EarlyStopping, ModelCheckpoint) |
log_every_n_steps |
何ステップごとにログ出力を行うか |
default_root_dir |
チェックポイント保存先ディレクトリ |
precision |
学習精度(32, 16(FP16), bf16 など) |
fast_dev_run |
実験用:True にすると 1 バッチだけ実行される |
主なメソッド
| メソッド名 | 説明 |
|---|---|
trainer.fit() |
学習+検証を実行 |
trainer.validate() |
検証のみを実行 |
trainer.test() |
テストのみを実行 |
trainer.predict() |
推論を実行(予測結果を得る) |
コードの例
import lightning as L
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# コールバックの追加例(早期終了、モデル保存など)
callbacks = [
EarlyStopping(monitor="val_acc", patience=10),
ModelCheckpoint(monitor="val_acc", save_top_k=5)
]
# Trainer のインスタンス化
trainer = pl.Trainer(
max_epochs=10,
accelerator="auto",
devices="auto",
logger=True,
callbacks=callbacks,
log_every_n_steps=10,
default_root_dir="./logs"
)
# 学習 + 検証
trainer.fit(model, train_dataloader, val_dataloader)
# 検証のみ
trainer.validate(model, val_dataloader)
# テストのみ
trainer.test(model, test_dataloader)
- ここで、EarlyStopping とは、モデルのパラメータを保存するためのコールバックです。EarlyStopping は、指定したエポック数以上にモデルのパラメータが更新されない場合、モデルのパラメータを保存します。つまり、EarlyStopping は、モデルのパラメータを保存するためのコールバックです。
- ここで、ModelCheckpoint とは、モデリングのパラメータを保存するためのコールバックです。ModelCheckpoint は、指定したエポック数以上にモデルのパラメータが更新される場合、モデルのパラメータを保存します。
monitor="val_acc"は、EarlyStopping と ModelCheckpoint が監視する指標を指定します。"val_acc"は、検証精度を指し、MyLightningModuleに定義したvalidation_step()にself.log("val_acc", acc, prog_bar=True)で記録した検証精度を監視します。trainer.fitは、モデルのtraining_stepを呼び、もしval_dataloaderが指定されていればvalidation_stepも呼びます。つまり、モデルのtraining_stepとvalidation_stepが定義べきです。trainer.validateは、モデルのvalidation_stepを呼びます。trainer.testは、モデルのtest_stepを呼びます。default_root_dirは、チェックポイントとログを保存するディレクトリです。
pretrained model を使う
torch.load を使う場合
チェックポイントをロードするには、torch.loadを使います。
import torch
model = EmotionNet(nc=len(train_dataset.classes))
model.load_state_dict(torch.load('checkpoint.pth', map_location='cpu'))
model = MyLightningModule(model, 1e-3, criterion)
LightningModule の関数を使う場合
- load_from_checkpoint
model = MyLightningModule.load_from_checkpoint( "./checkpoint.pth", map_location=torch.device("cpu"), hparams_file=None, model=EmotionNet(nc=len(train_dataset.classes)), lr=1e-3, criterion=criterion, ) - load_state_dict
model = MyLightningModule(EmotionNet(nc=len(train_dataset.classes)), 1e-3, criterion) model.load_state_dict(torch.load('checkpoint.pth', map_location='cpu'))