Source code for xmodaler.utils.distributed

"""
From original at https://github.com/ChenRocks/UNITER/blob/master/utils/distributed.py
Original copyright of Microsoft code below, modifications by Jianjie Luo, Copyright 2021.	
"""

import math
import pickle

import torch
import torch.distributed as dist


def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
    """broadcast tensors in chunks of the specified size.

    Args:
        tensors: list of Tensors to broadcast
        root_rank: rank to broadcast
        buffer_size: broadcast chunk size in bytes
    """
    # buffer size in bytes, determine equiv. # of elements based on data type
    buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
    buffer = []

    def broadcast_buffer():
        # copy tensors into buffer_t
        offset = 0
        for t in buffer:
            numel = t.numel()
            buffer_t[offset:offset+numel].copy_(t.view(-1))
            offset += numel

        # broadcast
        dist.broadcast(buffer_t[:offset], root_rank)

        # copy all-reduced buffer back into tensors
        offset = 0
        for t in buffer:
            numel = t.numel()
            t.view(-1).copy_(buffer_t[offset:offset+numel])
            offset += numel

    filled = 0
    for t in tensors:
        sz = t.numel() * t.element_size()
        if sz > buffer_size:
            # tensor is bigger than buffer, broadcast directly
            dist.broadcast(t, root_rank)

        elif filled + sz > buffer_size:
            # buffer is full, broadcast and replace buffer with tensor
            broadcast_buffer()
            buffer = [t]
            filled = sz
        else:
            # add tensor to buffer
            buffer.append(t)
            filled += sz

    if len(buffer) > 0:
        broadcast_buffer()


def _encode(enc, max_size, use_max_size=False):
    enc_size = len(enc)
    enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
    if use_max_size:
        # this is used for broadcasting
        buffer_ = torch.cuda.ByteTensor(max_size+enc_byte)
    else:
        buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
    remainder = enc_size
    for i in range(enc_byte):
        base = 256 ** (enc_byte-i-1)
        buffer_[i] = remainder // base
        remainder %= base
    buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
    return buffer_, enc_byte


def _decode(buffer_, enc_byte):
    size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() for i in range(enc_byte))
    bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
    shift = size + enc_byte
    return bytes_list, shift


_BUFFER_SIZE = 4096


[docs]def all_gather_list(data): """Gathers arbitrary data from all nodes into a list.""" n_gpu = torch.cuda.device_count() enc = pickle.dumps(data) tensor_list = [torch.zeros(1, dtype=torch.int64).cuda() for _ in range(n_gpu)] enc_size = len(enc) dist.all_gather(tensor_list, tensor=torch.tensor([enc_size]).cuda()) max_size = torch.cat(tensor_list, dim=0).view(-1).max().item() in_buffer, enc_byte = _encode(enc, max_size) out_buffer = [in_buffer.new_zeros(in_buffer[:enc_byte+enc_size].shape) for _ in range(n_gpu)] dist.all_gather(out_buffer, tensor=in_buffer[:enc_byte+enc_size]) out_buffer = torch.cat(out_buffer, dim=0) results = [] for _ in range(n_gpu): bytes_list, shift = _decode(out_buffer, enc_byte) out_buffer = out_buffer[shift:] result = pickle.loads(bytes_list) results.append(result) return results
def any_broadcast(data, root_rank, n_gpu=None): """broadcast arbitrary data from root_rank to all nodes.""" if n_gpu is None: n_gpu = torch.cuda.device_count() enc = pickle.dumps(data) tensor_list = [torch.zeros(1, dtype=torch.int64).cuda() for _ in range(n_gpu)] dist.all_gather(tensor_list, tensor=torch.tensor([len(enc)]).cuda()) max_size = torch.cat(tensor_list, dim=0).view(-1).max().item() buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) dist.broadcast(buffer_, root_rank) bytes_list, _ = _decode(buffer_, enc_byte) result = pickle.loads(bytes_list) return result