torch.distributedの通信
2023/10/8
AI
目次
- 目次
- torch.distributed
- 集合通信操作
- まとめ
- Tip
- 参考
torch.distributed
torch.distributed は PyTorch の分散学習を実装するためのコアライブラリです。
プロセス間通信(IPC)を効率化するための集合通信(AllReduce, AllGather, Broadcast)とポイントツーポイント通信(Send/Recv)を提供します。
初期化方法
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # GPU なら NCCL 推奨
init_method='env://', # 環境変数で初期化
world_size=4, # 総プロセス数
rank=0 # 現在のプロセスID
)
集合通信操作
ブロードキャスト操作
dist.broadcast
-
機能: 特定プロセス([src](file://d:\code\MYBLOG\themes\volantis\scripts\tags\media.js#L9-L9))のテンソルを全プロセスに配信。
-
用途: 初期重みやハイパーパラメータの共有。
-
例:
dist.broadcast(tensor, src=0) # ランク0から全プロセスに配信
dist.broadcast_object_list
- 機能: Python オブジェクトリストを全プロセスに配信。
- 用途: 非テンソルデータ(文字列、辞書など)の共有。
- 例:
dist.broadcast_object_list(obj_list, src=0) # ランク0のオブジェクトを全プロセスに送信
集約操作 (Reduce)
dist.all_reduce
- 機能: 全プロセスのテンソルを集約(加算、平均など)し、結果を全プロセスに配布。
- 用途: 勾配同期や損失関数の平均計算。
- 例:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) # 全プロセスのテンソルを加算
dist.reduce
- 機能: 全プロセスのテンソルを集約し、結果を特定プロセス(
dst)に送信。 - 用途: マスターノードでの結果収集。
- 例:
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM) # 全プロセスのデータをランク0に集約
非同期集約
- 機能: 非同期処理で通信と計算を並列化。
- 例:
work = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=True) work.wait() # 通信完了まで待機
データ収集 (Gather)
dist.all_gather
- 機能: 全プロセスのテンソルを収集し、各プロセスに全データを複製。
- 用途: 分散データの集約(例: 推論結果の統合)。
- 例:
dist.all_gather(gather_list, tensor) # 全プロセスのデータを各プロセスに複製
dist.gather
- 機能: 全プロセスのテンソルを特定プロセス(
dst)に集約。 - 用途: マスターノードでの結果集約。
- 例:
dist.gather(tensor, gather_list, dst=0) # 全プロセスのデータをランク0に集約
オブジェクト収集
- 関数:
dist.all_gather_object(gather_list, obj)dist.gather_object(obj, gather_list, dst=0)
データ配布 (Scatter)
dist.scatter
- 機能: 源プロセス([src](file://d:\code\MYBLOG\themes\volantis\scripts\tags\media.js#L9-L9))のテンソルリストを全プロセスに分散。
- 用途: データ分割処理。
- 例:
dist.scatter(tensor, scatter_list, src=0) # ランク0のリストを全プロセスに分散
dist.scatter_object_list
- 機能: Python オブジェクトリストを全プロセスに分散。
- 例:
dist.scatter_object_list(obj, scatter_list, src=0)
複合操作
dist.reduce_scatter/dist.reduce_scatter_tensor
- 機能: 各プロセスのデータを集約+分散。
- 用途: モデル並列化時の効率化。
- 例:
dist.reduce_scatter(output_tensor, [input_tensor], op=dist.ReduceOp.SUM)
dist.all_to_all/dist.all_to_all_single
- 機能: 全プロセス間でテンソルの部分交換。
- 用途: パイプライン並列処理。
- 例:
dist.all_to_all_single(output_tensor, input_tensor) # 全プロセス間でデータ交換
同期操作
dist.barrier()
- 機能: 全プロセスの同期(全プロセスが到達するまで待機)。
dist.monitored_barrier(timeout=10)
- 機能: タイムアウト付き同期(デバッグ時有効)。
まとめ
| 機能 | 対象 | 用途 |
|---|---|---|
| ブロードキャスト | broadcast, broadcast_object_list |
データ配信(全プロセス) |
| 集約 | all_reduce, reduce |
勾配同期、結果集約 |
| 収集 | all_gather, gather |
データ統合(全プロセス/特定プロセス) |
| 配布 | scatter, scatter_object_list |
データ分散(源プロセス → 全プロセス) |
| 複合操作 | reduce_scatter, all_to_all |
高度な並列化(モデル/パイプライン並列) |
| 同期 | barrier, monitored_barrier |
プロセス同期 |
Tip
-
torch.distributed.all_reduceはブロッキング(阻害的)操作ですdist.all_reduce(tensor, op=dist.ReduceOp.SUM)この呼び出しはブロッキング処理として実行されます。
つまり、通信が完了するまで関数は戻りません(待機状態)。