Welcome to torch-dist-utils’s documentation!
Utilities for PyTorch distributed.
Module torch_dist_utils
Before using torch_dist_utils
functions, you must either call torch_dist_utils.init_distributed()
or initialize the default process group yourself. torch_dist_utils.init_distributed()
can be called even if you did not start the script with torchrun
: if you did not, it will assume it is the only process and create a process group with a single member.
Utilities for PyTorch distributed.
- torch_dist_utils.get_local_group()[source]
Get the process group containing only the local processes.
- Returns:
The process group containing only the local processes.
- Return type:
ProcessGroup
- torch_dist_utils.get_device()[source]
Get the device of the current process.
- Returns:
The device of the current process.
- Return type:
device
- torch_dist_utils.init_distributed()[source]
Initialize distributed communication. If the process is not launched with
torchrun
, then assume it is the only process.
- torch_dist_utils.do_in_order(group=None)[source]
A context manager that ensures that all processes execute the block in order.
- Parameters:
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- torch_dist_utils.do_in_local_order()[source]
A context manager that ensures that all local processes execute the block in order.
- torch_dist_utils.on_rank_0(group=None)[source]
A decorator that ensures that only process 0 executes the function.
- Parameters:
group (ProcessGroup | None) – The process group. If
None
, use the default group.- Returns:
A decorator.
- Return type:
Callable
- torch_dist_utils.on_local_rank_0()[source]
A decorator that ensures that only the local process 0 executes the function.
- Returns:
A decorator.
- Return type:
Callable
- torch_dist_utils.print0(*args, sep=' ', end='\n', file=None, flush=False)
A version of
print()
that only prints on process 0.
- torch_dist_utils.printl0(*args, sep=' ', end='\n', file=None, flush=False)
A version of
print()
that only prints on local process 0.
- torch_dist_utils.all_gather_object(obj, group=None)[source]
Gather an object from each process and return a list of gathered objects in all processes.
- Parameters:
obj (Any) – The object to gather.
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- Returns:
A list of gathered objects.
- Return type:
List[Any]
- torch_dist_utils.broadcast_object(obj=None, src=0, group=None)[source]
Broadcast an object from the source process and return the object in all processes.
- Parameters:
obj (Any | None) – The object to broadcast. Ignored in processes other than
src
.src (int) – The source process.
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- Returns:
The object broadcasted from the source process.
- Return type:
Any
- torch_dist_utils.gather_object(obj, dst=0, group=None)[source]
Gather an object from each process and return a list of gathered objects in the destination process.
- Parameters:
obj (Any) – The object to gather.
dst (int) – The destination process.
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- Returns:
A list of gathered objects in the destination process,
None
in all other processes.- Return type:
List[Any] | None
- torch_dist_utils.scatter_objects(objs=None, src=0, group=None)[source]
Scatter a list of objects from the source process and return each object in each process.
- Parameters:
objs (List[Any] | None) – The list of objects to scatter. Ignored in processes other than
src
.src (int) – The source process.
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- Returns:
The object scattered to the current process.
- Return type:
Any
- torch_dist_utils.all_gather_into_new(tensor, group=None)[source]
Gather a tensor from each process and return a list of gathered tensors in all processes. Tensors can have different shapes. Tensors must be all on CPU or all on GPU.
- Parameters:
tensor (Tensor) – The tensor to gather.
group (ProcessGroup | None) – The process group. If
None
, use the default group.
- Returns:
A list of gathered tensors.
- Return type:
List[Tensor]
- torch_dist_utils.broadcast_tensors(tensors, src=0, group=None)[source]
Broadcast an iterable of tensors from the given source process to all other processes.
To synchronize a model’s parameters in all processes to the versions in process 0:
broadcast_tensors(model.parameters())
- Parameters:
tensors (Iterable[Tensor]) – The tensors to broadcast.
src (int) – The source process.
group (ProcessGroup | None) – The process group. If
None
, use the default group.