LOADING

読み込みが遅い場合はキャッシュを有効にしてください。ブラウザはデフォルトで有効になっています

torch.distributedの通信

目次


torch.distributed

torch.distributed は PyTorch の分散学習を実装するためのコアライブラリです。
プロセス間通信(IPC)を効率化するための集合通信(AllReduce, AllGather, Broadcast)とポイントツーポイント通信(Send/Recv)を提供します。

初期化方法

import torch.distributed as dist

dist.init_process_group(
    backend='nccl',          # GPU なら NCCL 推奨
    init_method='env://',    # 環境変数で初期化
    world_size=4,            # 総プロセス数
    rank=0                   # 現在のプロセスID
)

集合通信操作

ブロードキャスト操作

dist.broadcast

  • 機能: 特定プロセス([src](file://d:\code\MYBLOG\themes\volantis\scripts\tags\media.js#L9-L9))のテンソルを全プロセスに配信

  • 用途: 初期重みやハイパーパラメータの共有。

  • :

    dist.broadcast(tensor, src=0)  # ランク0から全プロセスに配信
    

dist.broadcast_object_list

  • 機能: Python オブジェクトリストを全プロセスに配信。
  • 用途: 非テンソルデータ(文字列、辞書など)の共有。
  • :
    dist.broadcast_object_list(obj_list, src=0)  # ランク0のオブジェクトを全プロセスに送信
    

集約操作 (Reduce)

dist.all_reduce

  • 機能: 全プロセスのテンソルを集約(加算、平均など)し、結果を全プロセスに配布。
  • 用途: 勾配同期や損失関数の平均計算。
  • :
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)  # 全プロセスのテンソルを加算
    

dist.reduce

  • 機能: 全プロセスのテンソルを集約し、結果を特定プロセス(dst)に送信。
  • 用途: マスターノードでの結果収集。
  • :
    dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)  # 全プロセスのデータをランク0に集約
    

非同期集約

  • 機能: 非同期処理で通信と計算を並列化。
  • :
    work = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=True)
    work.wait()  # 通信完了まで待機
    

データ収集 (Gather)

dist.all_gather

  • 機能: 全プロセスのテンソルを収集し、各プロセスに全データを複製。
  • 用途: 分散データの集約(例: 推論結果の統合)。
  • :
    dist.all_gather(gather_list, tensor)  # 全プロセスのデータを各プロセスに複製
    

dist.gather

  • 機能: 全プロセスのテンソルを特定プロセスdst)に集約。
  • 用途: マスターノードでの結果集約。
  • :
    dist.gather(tensor, gather_list, dst=0)  # 全プロセスのデータをランク0に集約
    

オブジェクト収集

  • 関数:
    • dist.all_gather_object(gather_list, obj)
    • dist.gather_object(obj, gather_list, dst=0)

データ配布 (Scatter)

dist.scatter

  • 機能: 源プロセス([src](file://d:\code\MYBLOG\themes\volantis\scripts\tags\media.js#L9-L9))のテンソルリストを全プロセスに分散
  • 用途: データ分割処理。
  • :
    dist.scatter(tensor, scatter_list, src=0)  # ランク0のリストを全プロセスに分散
    

dist.scatter_object_list

  • 機能: Python オブジェクトリストを全プロセスに分散。
  • :
    dist.scatter_object_list(obj, scatter_list, src=0)
    

複合操作

dist.reduce_scatter/dist.reduce_scatter_tensor

  • 機能: 各プロセスのデータを集約+分散
  • 用途: モデル並列化時の効率化。
  • :
    dist.reduce_scatter(output_tensor, [input_tensor], op=dist.ReduceOp.SUM)
    

dist.all_to_all/dist.all_to_all_single

  • 機能: 全プロセス間でテンソルの部分交換
  • 用途: パイプライン並列処理。
  • :
    dist.all_to_all_single(output_tensor, input_tensor)  # 全プロセス間でデータ交換
    

同期操作

dist.barrier()

  • 機能: 全プロセスの同期(全プロセスが到達するまで待機)。

dist.monitored_barrier(timeout=10)

  • 機能: タイムアウト付き同期(デバッグ時有効)。

まとめ

機能 対象 用途
ブロードキャスト broadcast, broadcast_object_list データ配信(全プロセス)
集約 all_reduce, reduce 勾配同期、結果集約
収集 all_gather, gather データ統合(全プロセス/特定プロセス)
配布 scatter, scatter_object_list データ分散(源プロセス → 全プロセス)
複合操作 reduce_scatter, all_to_all 高度な並列化(モデル/パイプライン並列)
同期 barrier, monitored_barrier プロセス同期

Tip

  1. torch.distributed.all_reduceブロッキング(阻害的)操作です

    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    

    この呼び出しはブロッキング処理として実行されます。

    つまり、通信が完了するまで関数は戻りません(待機状態)

参考

avatar
lijunjie2232
個人技術ブログ
My Github
目次0