Source code for torch_geometric.hash_tensor

import functools
import warnings
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import numpy as np
import torch
import torch.utils._pytree as pytree
import xxhash
from torch import Tensor

import torch_geometric.typing
from torch_geometric.typing import CPUHashMap, CUDAHashMap

aten = torch.ops.aten

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}


def implements(torch_function: Callable) -> Callable:
    r"""Registers a :pytorch:`PyTorch` function override."""
    @functools.wraps(torch_function)
    def decorator(my_function: Callable) -> Callable:
        HANDLED_FUNCTIONS[torch_function] = my_function
        return my_function

    return decorator


def as_key_tensor(
    key: Any,
    *,
    device: Optional[torch.device] = None,
) -> Tensor:
    try:
        key = torch.as_tensor(key, device=device)
    except Exception:
        device = device or torch.get_default_device()
        key = torch.tensor(
            [xxhash.xxh64(x).intdigest() & 0x7FFFFFFFFFFFFFFF for x in key],
            dtype=torch.int64, device=device)

    if key.element_size() == 1:
        key = key.view(torch.uint8)
    elif key.element_size() == 2:
        key = key.view(torch.int16)
    elif key.element_size() == 4:
        key = key.view(torch.int32)
    elif key.element_size() == 8:
        key = key.view(torch.int64)
    else:
        raise ValueError(f"Received invalid dtype '{key.dtype}' with "
                         f"{key.element_size()} bytes")

    return key


def get_hash_map(key: Tensor) -> Union[CPUHashMap, CUDAHashMap]:
    if torch_geometric.typing.WITH_CUDA_HASH_MAP and key.is_cuda:
        return CUDAHashMap(key, 0.5)

    if key.is_cuda:
        warnings.warn(
            "Fallback to CPU-based mapping algorithm which may "
            "cause slowdowns and device synchronization. Please "
            "install 'pyg-lib' for an accelerated 'HashTensor' "
            "implementation.", stacklevel=2)

    if torch_geometric.typing.WITH_CPU_HASH_MAP:
        return CPUHashMap(key.cpu(), -1)

    import pandas as pd

    return pd.CategoricalDtype(
        categories=key.cpu().numpy(),
        ordered=True,
    )


[docs]class HashTensor(Tensor): r"""A :pytorch:`null` :class:`torch.Tensor` that can be referenced by arbitrary keys rather than indices in the first dimension. :class:`HashTensor` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, and extends it by CPU- and GPU-accelerated mapping routines. This allow for fast and efficient access to non-contiguous indices/keys while the underlying data is stored in a compact format. This representation is ideal for scenarios where one needs a fast mapping routine without relying on CPU-based external packages, and can be used, *e.g.*, to perform mapping of global indices to local indices during subgraph creation, or in data-processing pipelines to map non-contiguous input data into a contiguous space, such as * mapping of hashed node IDs to range :obj:`[0, num_nodes - 1]` * mapping of raw input data, *e.g.*, categorical data to range :obj:`[0, num_categories - 1]` Specifically, :class:`HashTensor` supports *any* keys of *any* type, *e.g.*, strings, timestamps, etc. .. code-block:: python from torch_geometric import HashTensor key = torch.tensor([1000, 100, 10000]) value = torch.randn(3, 4) tensor = HashTensor(key, value) assert tensor.size() == (3, 4) # Filtering: query = torch.tensor([10000, 1000]) out = tensor[query] assert out.equal(value[[2, 0]]) # Accessing non-existing keys: out = tensor[[10000, 0]] out.isnan() >>> tensor([[False, False, False, False], ... [True, True, True, True]) # If `value` is not given, indexing returns the position of `query` in # `key`, and `-1` otherwise: key = ['Animation', 'Comedy', 'Fantasy'] tensor = HashTensor(key) out = tensor[['Comedy', 'Romance']] >>> tensor([1, -1]) Args: key: The keys in the first dimension. value: The values to hold. dtype: The desired data type of the values of the returned tensor. device: The device of the returned tensor. """ _map: Union[Tensor, CPUHashMap, CUDAHashMap] _value: Optional[Tensor] _min_key: Tensor _max_key: Tensor @staticmethod def __new__( cls: Type, key: Any, value: Optional[Any] = None, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> 'HashTensor': if value is not None: value = torch.as_tensor(value, dtype=dtype, device=device) device = value.device key = as_key_tensor(key, device=device) if key.dim() != 1: raise ValueError(f"'key' data in '{cls.__name__}' needs to be " f"one-dimensional (got {key.dim()} dimensions)") if not key.is_contiguous(): raise ValueError(f"'key' data in '{cls.__name__}' needs to be " f"contiguous") if value is not None: if key.device != value.device: raise ValueError(f"'key' and 'value' data in '{cls.__name__}' " f"are expected to be on the same device (got " f"'{key.device}' and '{value.device}')") if key.numel() != value.size(0): raise ValueError(f"'key' and 'value' data in '{cls.__name__}' " f"are expected to have the same size in the " f"first dimension (got {key.size(0)} and " f"{value.size(0)})") min_key = key.min() if key.numel() > 0 else key.new_zeros(()) max_key = key.max() if key.numel() > 0 else key.new_zeros(()) _range = max_key - min_key # TODO Expose fixed threshold as argument. if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000 or _range <= 2 * key.numel()): _map = torch.full( size=(_range + 3, ), fill_value=-1, dtype=torch.int64, device=key.device, ) _map[key.long() - (min_key.long() - 1)] = torch.arange( key.numel(), dtype=_map.dtype, device=_map.device, ) else: _map = get_hash_map(key) return cls._from_data( _map, value, min_key, max_key, num_keys=key.numel(), dtype=dtype, ) # Private Methods ######################################################### @classmethod def _from_data( cls, _map: Union[Tensor, CPUHashMap, CUDAHashMap], value: Optional[Tensor], min_key: Tensor, max_key: Tensor, *, num_keys: int, dtype: Optional[torch.dtype], ) -> 'HashTensor': if value is not None: dtype = value.dtype size = value.size() stride = value.stride() layout = value.layout requires_grad = value.requires_grad else: dtype = dtype or torch.int64 size = torch.Size([num_keys]) stride = (1, ) layout = torch.strided requires_grad = False out = Tensor._make_wrapper_subclass( cls, size=size, strides=stride, dtype=dtype, device=min_key.device, layout=layout, requires_grad=requires_grad, ) assert isinstance(out, HashTensor) out._map = _map out._value = value out._min_key = min_key out._max_key = max_key return out @property def _key(self) -> Tensor: if isinstance(self._map, Tensor): mask = self._map >= 0 key = mask.nonzero().view(-1) - 1 key = key[self._map[mask]] elif (torch_geometric.typing.WITH_CUDA_HASH_MAP or torch_geometric.typing.WITH_CPU_HASH_MAP): key = self._map.keys().to(self.device) else: key = torch.from_numpy(self._map.categories.to_numpy()) return key.to(self.device) def _shallow_copy(self) -> 'HashTensor': return self._from_data( self._map, self._value, self._min_key, self._max_key, num_keys=self.size(0), dtype=self.dtype, ) def _get(self, query: Tensor) -> Tensor: if isinstance(self._map, Tensor): index = query.long() - (self._min_key.long() - 1) index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)] elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda: index = self._map.get(query) elif torch_geometric.typing.WITH_CPU_HASH_MAP: index = self._map.get(query.cpu()) else: import pandas as pd ser = pd.Series(query.cpu().numpy(), dtype=self._map) index = torch.from_numpy(ser.cat.codes.to_numpy().copy()).long() index = index.to(self.device) if self._value is None: return index.to(self.dtype) out = self._value[index] mask = index != -1 mask = mask.view([-1] + [1] * (out.dim() - 1)) fill_value = float('NaN') if out.is_floating_point() else -1 if torch_geometric.typing.WITH_PT20: other: Union[int, float, Tensor] = fill_value else: other = torch.full_like(out, fill_value) return out.where(mask, other) # Methods #################################################################
[docs] def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`HashTensor` representation back to a :class:`torch.Tensor` representation. """ if self._value is not None: return self._value return torch.arange(self.size(0), dtype=self.dtype, device=self.device)
# PyTorch/Python builtins ################################################# # Prevent auto-wrapping outputs back into the proper subclass type: __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore @classmethod def __torch_dispatch__( # type: ignore cls: Type, func: Callable[..., Any], types: Iterable[Type[Any]], args: Iterable[Tuple[Any, ...]] = (), kwargs: Optional[Dict[Any, Any]] = None, ) -> Any: # Hold a number of `HANDLED_FUNCTIONS` that implement specific # functions for valid `HashTensor` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we treat them as vanilla tensors. args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args) if kwargs is not None: kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), kwargs) return func(*args, **(kwargs or {})) def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: attrs = ['_map', '_min_key', '_max_key'] if self._value is not None: attrs.append('_value') ctx = (self.size(0), self.dtype) return attrs, ctx @staticmethod def __tensor_unflatten__( inner_tensors: Dict[str, Any], ctx: Tuple[Any, ...], outer_size: Tuple[int, ...], outer_stride: Tuple[int, ...], ) -> 'HashTensor': return HashTensor._from_data( inner_tensors['_map'], inner_tensors.get('_value', None), inner_tensors['_min_key'], inner_tensors['_min_key'], num_keys=ctx[0], dtype=ctx[1], ) def __repr__(self) -> str: # type: ignore indent = len(f'{self.__class__.__name__}(') tensor_str = torch._tensor_str._tensor_str(self.as_tensor(), indent) return torch._tensor_str._str_intern(self, tensor_contents=tensor_str) def tolist(self) -> List[Any]: """""" # noqa: D419 return self.as_tensor().tolist() def numpy(self, *, force: bool = False) -> np.ndarray: """""" # noqa: D419 return self.as_tensor().numpy(force=force) def index_select( # type: ignore self, dim: int, index: Any, ) -> Union['HashTensor', Tensor]: """""" # noqa: D419 return torch.index_select(self, dim, index) def select( # type: ignore self, dim: int, index: Any, ) -> Union['HashTensor', Tensor]: """""" # noqa: D419 return torch.select(self, dim, index) def share_memory_(self) -> 'HashTensor': """""" # noqa: D419 if isinstance(self._map, Tensor): self._map.share_memory_() if self._value is not None: self._value.share_memory_() self._min_key.share_memory_() self._max_key.share_memory_() return self def is_shared(self) -> bool: """""" # noqa: D419 return self._min_key.is_shared() def detach_(self) -> 'HashTensor': """""" # noqa: D419 if self._value is not None: self._value.detach_() return super().detach_() # type: ignore def __getitem__(self, indices: Any) -> Union['HashTensor', Tensor]: if not isinstance(indices, tuple): indices = (indices, ) assert len(indices) > 0 # We convert any index tensor in the first dimension into a tensor. # This means that downstream handling (i.e. in `aten.index.Tensor`) # needs to take this pre-conversion into account. However, detecting # whether the first dimension is indexed can be tricky at times: # * We need to take into account `Ellipsis` # * We need to take any unsqueezing into account if indices[0] is Ellipsis and len(indices) > 1: nonempty_indices = [i for i in indices[1:] if i is not None] if len(nonempty_indices) == self.dim(): indices = indices[1:] if isinstance(indices[0], (int, bool)): index: Union[int, Tensor] = int(as_key_tensor([indices[0]])) indices = (index, ) + indices[1:] elif isinstance(indices[0], (Tensor, list, np.ndarray)): index = as_key_tensor(indices[0], device=self.device) indices = (index, ) + indices[1:] indices = indices[0] if len(indices) == 1 else indices return super().__getitem__(indices)
@implements(aten.alias.default) def _alias(tensor: HashTensor) -> HashTensor: return tensor._shallow_copy() @implements(aten.clone.default) def _clone( tensor: HashTensor, *, memory_format: torch.memory_format = torch.preserve_format, ) -> HashTensor: value = tensor._value if value is not None: value = aten.clone.default(value, memory_format=memory_format) return tensor._from_data( tensor._map, # NOTE No reason to do clone since it is read-only. value, tensor._min_key, # NOTE No reason to do clone since it is read-only. tensor._max_key, # NOTE No reason to do clone since it is read-only. num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.detach.default) def _detach(tensor: HashTensor) -> HashTensor: value = tensor._value if value is not None: value = aten.detach.default(value) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten._to_copy.default) def _to_copy( tensor: HashTensor, *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, pin_memory: bool = False, non_blocking: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> HashTensor: value = tensor._value if value is not None: value = aten._to_copy.default( value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, non_blocking=non_blocking, memory_format=memory_format, ) min_key = aten._to_copy.default(tensor._min_key, device=device) max_key = aten._to_copy.default(tensor._max_key, device=device) _map = tensor._map if isinstance(_map, Tensor): _map = aten._to_copy.default(_map, device=device) # Only convert `_map` in case `CUDAHashMap` exists - otherwise we use # CPU-based mapping anyway and there is no need for a copy. elif (torch_geometric.typing.WITH_CUDA_HASH_MAP and tensor.is_cuda and tensor.device != min_key.device): key = _map.keys() key = aten._to_copy.default(key, device=device) _map = get_hash_map(key) return tensor._from_data( _map, value, min_key, max_key, num_keys=tensor.size(0), dtype=dtype or tensor.dtype, ) @implements(aten._pin_memory.default) def _pin_memory(tensor: HashTensor) -> HashTensor: _map = tensor._map if isinstance(_map, Tensor): _map = aten._pin_memory.default(_map) value = tensor._value if value is not None: value = aten._pin_memory.default(value) return tensor._from_data( _map, value, aten._pin_memory.default(tensor._min_key), aten._pin_memory.default(tensor._max_key), num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.unsqueeze.default) def _unsqueeze(tensor: HashTensor, dim: int) -> HashTensor: if dim == 0 or dim == -(tensor.dim() + 1): raise IndexError(f"Cannot unsqueeze '{tensor.__class__.__name__}' in " f"the first dimension. Please call `as_tensor()` " f"beforehand") return tensor._from_data( tensor._map, aten.unsqueeze.default(tensor.as_tensor(), dim), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.squeeze.default) def _squeeze_default(tensor: HashTensor) -> HashTensor: if tensor._value is None: return tensor._shallow_copy() value = tensor.as_tensor() for d in range(tensor.dim() - 1, 0, -1): value = value.squeeze(d) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.squeeze.dim) @implements(getattr(aten.squeeze, 'dims', aten.squeeze.dim)) def _squeeze_dim( tensor: HashTensor, dim: Union[int, List[int]], ) -> HashTensor: if isinstance(dim, int): dim = [dim] for d in dim: if d < -tensor.dim() or d >= tensor.dim(): raise IndexError(f"Dimension out of range (expected to be in " f"range of [{-tensor.dim()}, {tensor.dim()-1}], " f"but got {d})") if tensor._value is None: return tensor._shallow_copy() value = tensor.as_tensor() for d in dim[::-1]: if d != 0 and d != -tensor.dim(): value = value.squeeze(d) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.slice.Tensor) def _slice( tensor: HashTensor, dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ) -> HashTensor: if dim == 0 or dim == -tensor.dim(): copy = start is None or start == 0 or start <= -tensor.size(0) copy &= end is None or end > tensor.size(0) copy &= step == 1 if copy: return tensor._shallow_copy() key = aten.slice.Tensor(tensor._key, 0, start, end, step) value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step) return tensor.__class__(key, value) return tensor._from_data( tensor._map, aten.slice.Tensor(tensor.as_tensor(), dim, start, end, step), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) # Since PyTorch does only allow PyTorch tensors as indices in `index_select`, # we need to create a wrapper function and monkey patch `index_select` :( _old_index_select = torch.index_select def _new_index_select( input: Tensor, dim: Union[int, str], index: Tensor, out: Optional[Tensor] = None, ) -> Tensor: if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()): raise IndexError(f"Dimension out of range (expected to be in range of " f"[{-input.dim()}, {input.dim()-1}], but got {dim})") # We convert any index tensor in the first dimension into a tensor. This # means that downstream handling (i.e. in `aten.index_select.default`) # needs to take this pre-conversion into account. if (not torch.jit.is_scripting() and isinstance(input, HashTensor) and isinstance(dim, int) and (dim == 0 or dim == -input.dim())): index = as_key_tensor(index, device=input.device) if isinstance(dim, int): # Type narrowing... if out is None: return _old_index_select(input, dim, index) else: return _old_index_select(input, dim, index, out=out) else: if out is None: return _old_index_select(input, dim, index) else: return _old_index_select(input, dim, index, out=out) torch.index_select = _new_index_select # type: ignore @implements(aten.index_select.default) def _index_select( tensor: HashTensor, dim: int, index: Tensor, ) -> Union[HashTensor, Tensor]: if dim == 0 or dim == -tensor.dim(): return tensor._get(index) return tensor._from_data( tensor._map, aten.index_select.default(tensor.as_tensor(), dim, index), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) # Since PyTorch does only allow PyTorch tensors as indices in `select`, we need # to create a wrapper function and monkey patch `select` :( _old_select = torch.select def _new_select( input: Tensor, dim: Union[int, str], index: int, ) -> Tensor: if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()): raise IndexError(f"Dimension out of range (expected to be in range of " f"[{-input.dim()}, {input.dim()-1}], but got {dim})") # We convert any index in the first dimension into an integer. This means # that downstream handling (i.e. in `aten.select.int`) needs to take this # pre-conversion into account. if (not torch.jit.is_scripting() and isinstance(input, HashTensor) and isinstance(dim, int) and (dim == 0 or dim == -input.dim())): index = int(as_key_tensor([index])) if isinstance(dim, int): # Type narrowing... return _old_select(input, dim, index) else: return _old_select(input, dim, index) torch.select = _new_select # type: ignore @implements(aten.select.int) def _select( tensor: HashTensor, dim: int, index: int, ) -> Union[HashTensor, Tensor]: if dim == 0 or dim == -tensor.dim(): key = torch.tensor( [index], dtype=tensor._min_key.dtype, device=tensor.device, ) return tensor._get(key).squeeze(0) return tensor._from_data( tensor._map, aten.select.int(tensor.as_tensor(), dim, index), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.index.Tensor) def _index( tensor: HashTensor, indices: List[Optional[Tensor]], ) -> Union[HashTensor, Tensor]: assert len(indices) > 0 if indices[0] is not None: out = tensor._get(indices[0]) if len(indices) > 1: out = aten.index.Tensor(out, [None] + indices[1:]) return out return tensor._from_data( tensor._map, aten.index.Tensor(tensor.as_tensor(), indices), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, )