from typing import List
from torch import Tensor
[docs]def lexsort(
keys: List[Tensor],
dim: int = -1,
descending: bool = False,
) -> Tensor:
r"""Performs an indirect stable sort using a sequence of keys.
Given multiple sorting keys, returns an array of integer indices that
describe their sort order.
The last key in the sequence is used for the primary sort order, the
second-to-last key for the secondary sort order, and so on.
Args:
keys ([torch.Tensor]): The :math:`k` different columns to be sorted.
The last key is the primary sort key.
dim (int, optional): The dimension to sort along. (default: :obj:`-1`)
descending (bool, optional): Controls the sorting order (ascending or
descending). (default: :obj:`False`)
"""
assert len(keys) >= 1
out = keys[0].argsort(dim=dim, descending=descending, stable=True)
for k in keys[1:]:
index = k.gather(dim, out)
index = index.argsort(dim=dim, descending=descending, stable=True)
out = out.gather(dim, index)
return out