timmライブラリ入門 - PyTorch Image Models (timm) ライブラリの紹介と使い方
2023/11/18
AI
目次
概要
timm (PyTorch Image Models) は Ross Wightman 氏によって開発された PyTorch ベースのコンピュータビジョンライブラリ。ImageNet クラスの画像分類モデルを中心に、物体検出、セグメンテーション、特徴抽出などのタスクをサポート。
主な機能
- 200+以上の事前学習済みモデル (ResNet, EfficientNet, ViT など)
- モデルアーキテクチャの柔軟なカスタマイズ
- 特徴抽出の簡易化
- 転移学習の効率化
- モデルアンサンブル機能
インストール
pip install timm
基本的な使い方
モデルのロード
import timm
# 事前学習済みモデルのロード
model = timm.create_model('resnet50', pretrained=True)
# カスタムクラス数に変更
model = timm.create_model('resnet50', pretrained=True, num_classes=10)
# 利用可能なモデル一覧表示
model_names = timm.list_models(pretrained=True)
データ前処理
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
画像分類の訓練例 (CIFAR-10)
# モデル構築
model = timm.create_model('resnet50', pretrained=True, num_classes=10)
# 損失関数と最適化手法
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 訓練ループ
for epoch in range(num_epochs):
for inputs, labels in trainloader:
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
高度な機能
特徴抽出
# 最終層直前の特徴抽出
model = timm.create_model('resnet50', pretrained=True)
features = model.forward_features(input)
# 多スケール特徴抽出
model = timm.create_model('resnet50', features_only=True, pretrained=True)
outputs = model(input) # 各スケールの特徴がリストで返る
モデルアンサンブル
class ModelEnsemble(torch.nn.Module):
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
def forward(self, x):
return torch.mean(torch.stack([m(x) for m in self.models]), dim=0)
# 使用例
model1 = timm.create_model('resnet18', pretrained=True)
model2 = timm.create_model('vgg16', pretrained=True)
ensemble = ModelEnsemble([model1, model2])
timm.create_model
timm.create_model は、指定したモデル名のモデルをロードし、そのモデルを返す関数。
### ソースコード
def create_model(
model_name: str, # モデル名(例: 'resnet50', 'vit_base_patch16_224')
pretrained: bool = False, # 事前学習済み重みを使用するか
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, # 外部のpretrained_cfgを指定
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, # pretrained_cfgの一部パラメータを上書き
checkpoint_path: Optional[Union[str, Path]] = None, # チェックポイントファイルのパス
cache_dir: Optional[Union[str, Path]] = None, # キャッシュディレクトリ
scriptable: Optional[bool] = None, # JITスクリプタブル設定
exportable: Optional[bool] = None, # ONNXエクスポート可能設定
no_jit: Optional[bool] = None, # JIT最適化を無効化
**kwargs, # その他のモデルパラメータ
):
"""Create a model.
モデル名に対応するエントリポイント関数を呼び出し、新しいモデルを生成します。
Tip:
**kwargs はエントリポイント関数 → ``timm.models.build_model_with_cfg()`` → モデルクラス __init__() へと渡されます。
None に設定された kwargs は渡す前に削除されます。
Args:
model_name: インスタンス化するモデル名(例: 'resnet50')
pretrained: True にすると ImageNet-1k の事前学習済み重みをロード
pretrained_cfg: 外部の pretrained_cfg を指定
pretrained_cfg_overlay: 基本 pretrained_cfg の一部を置き換え
checkpoint_path: モデル初期化後にロードするチェックポイントパス
cache_dir: Hugging Face Hub や Torch チェックポイントのキャッシュディレクトリを上書き
scriptable: モデルを JIT スクリプタブルに設定(一部モデル未対応)
exportable: モデルを ONNX エクスポート可能に設定(一部未実装)
no_jit: JIT スクリプト層を利用しない設定(活性化関数のみ)
キーワード引数:
drop_rate (float): 学習用の分類器ドロップアウト率
drop_path_rate (float): ストークastic depth ドロップ率
global_pool (str): 分類器のグローバルプーリングタイプ
"""
# すべてのモデルがサポートしないパラメータは None として扱い、デフォルト値を維持
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_id = parse_model_name(model_name) # モデルソース(HuggingFaceなど)を解析
if model_source:
assert not pretrained_cfg, 'Hugging Face Hub からモデルをロードする際は pretrained_cfg は指定不可'
if model_source == 'hf-hub':
# Hugging Face Hub からモデル設定をロード
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_id, cache_dir=cache_dir)
elif model_source == 'local-dir':
# ローカルディレクトリからモデル設定をロード
pretrained_cfg, model_name, model_args = load_model_config_from_path(model_id)
else:
assert False, f'不明な model_source {model_source}'
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v) # モデル引数を kwargs に反映
else:
model_name, pretrained_tag = split_model_name_tag(model_id)
if pretrained_tag and not pretrained_cfg:
# pretrained_cfg が未指定なら、モデル名のタグを使用
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError('不明なモデル (%s)' % model_name) # モデル名チェック
create_fn = model_entrypoint(model_name) # モデル生成関数を取得
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
cache_dir=cache_dir,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path) # チェックポイントをロード
return model
主なパラメータの説明
| パラメータ | 説明 |
|---|---|
model_name |
モデル名(例: 'resnet50', 'vit_base_patch16_224') |
pretrained |
事前学習済み重みを使用するか(True/False) |
num_classes |
分類クラス数の変更(例: num_classes=10) |
pretrained_cfg |
外部の事前学習設定を指定 |
checkpoint_path |
事後チェックポイントファイルをロード |
drop_rate |
ドロップアウト率(過学習防止用) |
drop_path_rate |
ストーク astic depth のドロップ率 |
global_pool |
グローバルプーリングタイプ(例: 'avg', 'max') |
使い方
from timm import create_model
事前学習なしの MobileNetV3-Large モデル作成
>>> model = create_model('mobilenetv3_large_100')
事前学習ありの MobileNetV3-Large モデル作成
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000
事前学習ありで分類層を 10 クラスに変更
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10
Dinov2 モデルの重みをカスタムディレクトリに保存
# データ保存先: `/data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/`
model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
モデル作成と重みロード
m = timm.create_model(
'ecaresnet101d',
features_only=True,
output_stride=8,
out_indices=(0,1,2,3, 4),
pretrained=True
)
features_only
- 意味: 特徴抽出モードを有効にするオプション。
- 説明: このフラグを
Trueに設定すると、モデルの最終的な分類器(Classifier)を除いた部分のみが出力されます。つまり、特徴マップ(feature maps)だけを取得したい場合に使用します。 - 注意:
output_strideやout_indicesを指定する際には、必ずfeatures_only=Trueとする必要があります。
# 例: 特徴抽出モードでResNet50をロード
model = timm.create_model('resnet50', features_only=True, pretrained=True)
output_stride
- 意味: 出力ストライド(出力特徴マップのスケーリング率)
- 説明: 入力画像に対して、最終的な特徴マップが何倍のダウンサンプリングされているかを制御します。特にセマンティックセグメンテーションや物体検出などのタスクで重要です。
- サポート: 一部のネットワークでは
output_stride=32のみサポートされています。小さめの値(例: 8 や 16)に設定することで、より高解像度な特徴マップを得られます。
# 例: output_stride=8 で特徴抽出モデルを作成
model = timm.create_model('resnet50', features_only=True, output_stride=8, pretrained=True)
out_indices
- 意味: 取得したい特徴層のインデックス番号
- 説明: モデル内のどの層から特徴を出力するかを指定します。複数のスケール(マルチスケール)の特徴が必要な場合(例:FPN、U-Net など)に便利です。
- 例:
(0, 1, 2, 3, 4)→ 5 つの異なるレイヤーの特徴をリスト形式で返す
# 例: 0〜4層目の特徴を抽出
model = timm.create_model(
'resnet50',
features_only=True,
out_indices=(0, 1, 2, 3, 4),
pretrained=True
)
timm.list_models
timm.list_models は、利用可能なモデル名の一覧を取得する関数です。
ワイルドカードによるフィルタリングや、特定のサブモジュール(例:vision_transformer)に絞り込むことが可能です。
ソースコード
def list_models(
filter: Union[str, List[str]] = '',
module: Union[str, List[str]] = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
) -> List[str]:
"""
利用可能なモデル名一覧を取得する関数。アルファベット順にソートして返す。
Args:
filter (str or List[str]): ワイルドカードによるフィルタリング(例: 'resnet*')
module (str or List[str]): 特定サブモジュール内のモデルのみ表示(例: 'vision_transformer')
pretrained (bool): True のとき、事前学習済みモデルのみ表示
exclude_filters (str or List[str]): フィルタ後に除外したいパターン
name_matches_cfg (bool): モデル名が設定ファイルと一致するもののみ表示
include_tags (Optional[bool]): モデル名に事前学習タグ(例:.in1k)を含めるか
Returns:
models (List[str]): ソート済みモデル名一覧
Example:
model_list('gluon_resnet*') -- 'gluon_resnet' から始まるすべてのモデルを取得
model_list('*resnext*', 'resnet') -- 'resnext' を含む 'resnet' モジュールのモデルを取得
"""
-
初期フィルタ適用:
filterに指定されたワイルドカード(例:'resnet*')でモデル名を絞り込み。
-
サブモジュール制限:
moduleが指定されていれば、そのサブモジュールに所属するモデルのみ選択。
-
非推奨モデル除去:
_deprecated_modelsに登録されている非推奨モデルは自動的に除外。
-
事前学習タグの追加(オプション):
include_tags=Trueの場合、モデル名に.in1kなどのタグも含めて表示。
-
除外フィルタ適用:
exclude_filtersで指定されたパターンに合致するモデルを最終的に除外。
-
事前学習ありのみ抽出:
pretrained=True時、_model_has_pretrainedに含まれるモデルのみ残す。
-
設定名とのマッチング:
name_matches_cfg=True時、_model_pretrained_cfgsと名前が一致するモデルのみ残す。
-
自然順ソートして返却:
- 数字も正しく並べ替わるよう、
_natural_key関数でソート。
- 数字も正しく並べ替わるよう、
パラメータの説明
| パラメータ | 型 | 説明 |
|---|---|---|
filter |
str or List[str] |
ワイルドカードでフィルタ(例:'resnet*') |
module |
str or List[str] |
特定のサブモジュールに限定(例:'vision_transformer') |
pretrained |
bool |
True のとき、事前学習済み重みがあるモデルのみ表示 |
exclude_filters |
str or List[str] |
フィルタ後に除外したいパターン(例:'*efficientnet*') |
name_matches_cfg |
bool |
モデル名が設定ファイルと一致するもののみ表示 |
include_tags |
Optional[bool] |
モデル名に事前学習タグ(例:.in1k)を含めるか |
使用例
基本的な使用法
# 'gluon_resnet' から始まるすべてのモデルを取得
model_list = timm.list_models('gluon_resnet*')
# 'resnext' を含む 'resnet' サブモジュールのモデルを取得
model_list = timm.list_models('*resnext*', 'resnet')
事前学習済みモデルのみ取得
# 事前学習済みのResNet系モデルのみ取得
model_list = timm.list_models('resnet*', pretrained=True)
特定のサブモジュールに絞る
# Vision Transformer 系モデルのみ取得
model_list = timm.list_models(module='vision_transformer')
フィルタ + 除外条件付き
# EfficientNet-B0系だが、lite版は除外
model_list = timm.list_models('tf_efficientnet_b0*', exclude_filters='*lite*')
load_pretrained
timm.models.load_pretrained は、事前学習済みモデルの重み(チェックポイント)を PyTorch モデルにロードするための関数です。
カスタム設定(入力チャネル数、クラス数など)に応じて柔軟に適応させることができます。
def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
""" Load pretrained checkpoint
Args:
model: 重みをロードしたいPyTorchモデル(例:ResNet、Vision Transformerなど)
pretrained_cfg: 事前学習設定情報
num_classes: 対象モデルの分類クラス数。デフォルトはImageNetの1000クラス。異なる場合は自動調整される
in_chans: 入力画像のチャネル数(RGBなら3、グレースケールなら1)。異なる場合、重みを調整してロード
filter_fn: state_dict をロード前に加工する関数(例:特定層だけ除外、変更など)
strict: ロード時にstate_dictとモデル構造が一致していないときにエラーを出すかどうか(True: 一致必須 / False: 柔軟にロード)
cache_dir: チェックポイントファイルを保存するディレクトリ
"""
パラメータの説明
| パラメータ | 型 | 説明 |
|---|---|---|
model |
nn.Module |
重みをロードする対象の PyTorch モデル |
pretrained_cfg |
Optional[Dict] |
事前学習済み重みやデータセットに関する設定情報 |
num_classes |
int |
対象モデルの分類クラス数(デフォルト:ImageNet の 1000) |
in_chans |
int |
入力画像のチャネル数(デフォルト:RGB 画像の 3) |
filter_fn |
Optional[Callable] |
state_dict をロード前に加工するフィルタ関数 |
strict |
bool |
モデルと state_dict が一致しないときにエラーを出すか |
cache_dir |
str or Path |
チェックポイントファイルを保存するディレクトリ |
使用例
Vision Transformer をロードし、カスタムクラス数で事前学習済み重みをロード
import timm
from timm.models import create_model
# ViT Base モデルを作成
model = create_model('vit_base_patch16_224', num_classes=10)
# 事前学習済み重みをロード
timm.models.load_pretrained(model, pretrained_cfg=model.default_cfg, num_classes=10)
グレースケール画像用に調整してロード
# グレースケール画像(1チャネル)用に調整
model = create_model('resnet50', in_chans=1, num_classes=10)
timm.models.load_pretrained(model, pretrained_cfg=model.default_cfg, in_chans=1)
timm.data
ImageDataset
- 標準的な画像分類用データセットを構築するためのクラス。
- ファイル構造に基づいて画像とラベルを読み込む(例: ImageNet 形式のフォルダ構造)。
- シンプルな使い方で、学習・検証用のデータローダーを作成可能。
使用例
from timm.data import ImageDataset
from torch.utils.data import DataLoader
dataset = ImageDataset(
'path/to/train', # データセットのパス
reader=None, # データを読み込む関数
split='train', # 'train', 'valid', 'test' などの分割方法
class_map=None, # クラス名のマッピング
load_bytes=False, # バイト列として読み込むかどうか
input_img_mode='RGB', # 入力画像のモード
transform=None, # データを変換する関数
target_transform=None, # ラベルを変換する関数
)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
IterableImageDataset
- 大規模データセットや ストリーミングデータ を扱うためのイテラブル型データセット。
- メモリ効率が高く、データ量が非常に多い場合に有効。
torch.utils.data.IterableDatasetを継承している。
主な特徴
- 全データをメモリに載せずに逐次読み込み
- 分散環境(DistributedDataParallel など)との親和性が高い
- サブセット(サブサンプリング)にも対応
使用例
from timm.data import IterableImageDataset
dataset = IterableImageDataset(
'path/to/train', # データセットのパス
reader=None, # データを読み込む関数
split='train', # 'train', 'valid', 'test' などの分割方法
class_map=None, # クラス名のマッピング
is_training=False, # 学習データか検証データかを指定
batch_size=1, # バッチサイズ
num_samples=None, # データセットのサンプル数
seed=42, # サンプリングのシード
repeats=0, # データセットの繰り返し回数
download=False, # データをダウンロードするかどうか
input_img_mode='RGB', # 入力画像のモード
input_key=None, # 入力画像のキー
target_key=None, # ターゲット画像のキー
transform=None, # 入力画像の変換
target_transform=None, # ターゲット画像の変換
max_steps=None, # データセットのステップ数
)
AugMixDataset
- AugMix というデータ拡張手法をサポートしたラッパーデータセット。
- AugMix は、複数の合成拡張経路を混ぜたデータ拡張を行い、モデルの堅牢性(robustness)を向上させます。
主な特徴
- 各バッチに対して、複数回の拡張(augmented copies)を生成し、それらを組み合わせて入力とする
__getitem__ではなく__iter__によってデータを返す(PyTorch のIterableDatasetと似た挙動)
使用例
from timm.data import ImageDataset, AugMixDataset
from torch.utils.data import DataLoader
# 通常の ImageDataset をベースにする
base_dataset = ImageDataset('path/to/train')
augmix_dataset = AugMixDataset(base_dataset, num_splits=3)
# データローダーに登録
loader = create_loader(augmix_dataset, batch_size=8, num_aug_splits=2)
source code
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)
num_splitsは AugMix の分割数を指定。通常 2 または 3 が使用される。
AugMix とは
AugMix は、ICML 2020 で提案された、複数のデータ拡張パスを混合して使用することで、ノイズ・汚れ・ぼかしなどに対するモデルの堅牢性を高める手法です。
特徴
- 拡張パスを複数(デフォルトで 2 つ)作成し、その結果をランダムウェイトで合成
- 合成画像 + 元画像 から損失関数で一貫性制約を課すことで、汎化性能を向上
AugMix を使った訓練コード
import timm
from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader
import torch.nn as nn
import torch.optim as optim
dataset = ImageDataset('../../imagenet1K/')
dataset = AugMixDataset(dataset, num_splits=2)
# AugMix を使うためのデータローダー
loader_train = create_loader(
dataset,
input_size=(3, 224, 224), # モデルの入力サイズ
batch_size=8, # バッチサイズ
is_training=True, # 学習モード
scale=[0.08, 1.], # 画像のスケール
ratio=[0.75, 1.33], # 画像のアスペクト比
num_aug_splits=2 # AugMixの分割数
)
# モデル・損失関数・オプティマイザ
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 訓練ループ
for inputs, targets in loader_train:
print(inputs.shape) # >> torch.Size([16, 3, 224, 224]), 16=batch_size*num_splits
output1 = model(inputs)
loss = criterion(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()