Welcome to torch-dist-utils’s documentation!

Utilities for PyTorch distributed.

Module torch_dist_utils

Before using torch_dist_utils functions, you must either call torch_dist_utils.init_distributed() or initialize the default process group yourself. torch_dist_utils.init_distributed() can be called even if you did not start the script with torchrun: if you did not, it will assume it is the only process and create a process group with a single member.

Utilities for PyTorch distributed.

torch_dist_utils.get_local_group()[source]

Get the process group containing only the local processes.

Returns:

The process group containing only the local processes.

Return type:

ProcessGroup

torch_dist_utils.get_device()[source]

Get the device of the current process.

Returns:

The device of the current process.

Return type:

device

torch_dist_utils.init_distributed()[source]

Initialize distributed communication. If the process is not launched with torchrun, then assume it is the only process.

torch_dist_utils.cleanup_distributed()[source]

Clean up distributed communication.

torch_dist_utils.do_in_order(group=None)[source]

A context manager that ensures that all processes execute the block in order.

Parameters:

group (ProcessGroup | None) – The process group. If None, use the default group.

torch_dist_utils.do_in_local_order()[source]

A context manager that ensures that all local processes execute the block in order.

torch_dist_utils.on_rank_0(group=None)[source]

A decorator that ensures that only process 0 executes the function.

Parameters:

group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

A decorator.

Return type:

Callable

torch_dist_utils.on_local_rank_0()[source]

A decorator that ensures that only the local process 0 executes the function.

Returns:

A decorator.

Return type:

Callable

torch_dist_utils.print0(*args, sep=' ', end='\n', file=None, flush=False)

A version of print() that only prints on process 0.

torch_dist_utils.printl0(*args, sep=' ', end='\n', file=None, flush=False)

A version of print() that only prints on local process 0.

torch_dist_utils.all_gather_object(obj, group=None)[source]

Gather an object from each process and return a list of gathered objects in all processes.

Parameters:
  • obj (Any) – The object to gather.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

A list of gathered objects.

Return type:

List[Any]

torch_dist_utils.broadcast_object(obj=None, src=0, group=None)[source]

Broadcast an object from the source process and return the object in all processes.

Parameters:
  • obj (Any | None) – The object to broadcast. Ignored in processes other than src.

  • src (int) – The source process.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

The object broadcasted from the source process.

Return type:

Any

torch_dist_utils.gather_object(obj, dst=0, group=None)[source]

Gather an object from each process and return a list of gathered objects in the destination process.

Parameters:
  • obj (Any) – The object to gather.

  • dst (int) – The destination process.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

A list of gathered objects in the destination process, None in all other processes.

Return type:

List[Any] | None

torch_dist_utils.scatter_objects(objs=None, src=0, group=None)[source]

Scatter a list of objects from the source process and return each object in each process.

Parameters:
  • objs (List[Any] | None) – The list of objects to scatter. Ignored in processes other than src.

  • src (int) – The source process.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

The object scattered to the current process.

Return type:

Any

torch_dist_utils.all_gather_into_new(tensor, group=None)[source]

Gather a tensor from each process and return a list of gathered tensors in all processes. Tensors can have different shapes. Tensors must be all on CPU or all on GPU.

Parameters:
  • tensor (Tensor) – The tensor to gather.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Returns:

A list of gathered tensors.

Return type:

List[Tensor]

torch_dist_utils.broadcast_tensors(tensors, src=0, group=None)[source]

Broadcast an iterable of tensors from the given source process to all other processes.

To synchronize a model’s parameters in all processes to the versions in process 0:

broadcast_tensors(model.parameters())
Parameters:
  • tensors (Iterable[Tensor]) – The tensors to broadcast.

  • src (int) – The source process.

  • group (ProcessGroup | None) – The process group. If None, use the default group.

Indices and tables