Source code for torch_geometric.utils._lexsort

from typing import List

from torch import Tensor


[docs]def lexsort( keys: List[Tensor], dim: int = -1, descending: bool = False, ) -> Tensor: r"""Performs an indirect stable sort using a sequence of keys. Given multiple sorting keys, returns an array of integer indices that describe their sort order. The last key in the sequence is used for the primary sort order, the second-to-last key for the secondary sort order, and so on. Args: keys ([torch.Tensor]): The :math:`k` different columns to be sorted. The last key is the primary sort key. dim (int, optional): The dimension to sort along. (default: :obj:`-1`) descending (bool, optional): Controls the sorting order (ascending or descending). (default: :obj:`False`) """ assert len(keys) >= 1 out = keys[0].argsort(dim=dim, descending=descending, stable=True) for k in keys[1:]: index = k.gather(dim, out) index = index.argsort(dim=dim, descending=descending, stable=True) out = out.gather(dim, index) return out