Source code for torch_geometric.llm.models.sentence_transformer
from enum import Enum
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm
class PoolingStrategy(Enum):
MEAN = 'mean'
LAST = 'last'
CLS = 'cls'
LAST_HIDDEN_STATE = 'last_hidden_state'
[docs]class SentenceTransformer(torch.nn.Module):
r"""A wrapper around a Sentence-Transformer from HuggingFace.
Args:
model_name (str): The HuggingFace model name, *e.g.*, :obj:`"BERT"`.
pooling_strategy (str, optional): The pooling strategy to use
for generating node embeddings. (default: :obj:`"mean"`)
"""
def __init__(
self,
model_name: str,
pooling_strategy: Union[PoolingStrategy, str] = 'mean',
) -> None:
super().__init__()
self.model_name = model_name
self.pooling_strategy = PoolingStrategy(pooling_strategy)
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Maximum sequence length from the model configuration (e.g. 8192 for
# models like ModernBERT)
self.max_seq_length = self.model.config.max_position_embeddings
"""
Some models define a max sequence length in their configuration. Others
only in the tokenizer. This is a hacky heuristic to find the max
sequence length that works for the model.
"""
probe_tokens = self.tokenizer("hacky heuristic", padding='max_length',
return_tensors='pt')
self.max_seq_length = min(self.max_seq_length,
probe_tokens.input_ids.shape[1])
[docs] def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
emb = out[0] # First element contains all token embeddings.
if self.pooling_strategy == PoolingStrategy.MEAN:
emb = mean_pooling(emb, attention_mask)
elif self.pooling_strategy == PoolingStrategy.LAST:
emb = last_pooling(emb, attention_mask)
elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
emb = out.last_hidden_state
else:
assert self.pooling_strategy == PoolingStrategy.CLS
emb = emb[:, 0, :]
emb = F.normalize(emb, p=2, dim=1)
return emb
def get_input_ids(
self,
text: List[str],
batch_size: Optional[int] = None,
output_device: Optional[Union[torch.device, str]] = None,
) -> Tensor:
is_empty = len(text) == 0
text = ['dummy'] if is_empty else text
batch_size = len(text) if batch_size is None else batch_size
input_ids: List[Tensor] = []
attention_masks: List[Tensor] = []
for start in range(0, len(text), batch_size):
token = self.tokenizer(
text[start:start + batch_size],
padding=True,
truncation=True,
return_tensors='pt',
max_length=self.max_seq_length,
)
input_ids.append(token.input_ids.to(self.device))
attention_masks.append(token.attention_mask.to(self.device))
def _out(x: List[Tensor]) -> Tensor:
out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
out = out[:0] if is_empty else out
return out.to(output_device)
return _out(input_ids), _out(attention_masks)
@property
def device(self) -> torch.device:
return next(iter(self.model.parameters())).device
@torch.no_grad()
def encode(
self,
text: List[str],
batch_size: Optional[int] = None,
output_device: Optional[Union[torch.device, str]] = None,
verbose=False,
) -> Tensor:
is_empty = len(text) == 0
text = ['dummy'] if is_empty else text
batch_size = len(text) if batch_size is None else batch_size
embs: List[Tensor] = []
loader = range(0, len(text), batch_size)
if verbose:
loader = tqdm(
loader, desc="Encoding " + str(len(text)) +
" strings w/ SentenceTransformer")
for start in loader:
token = self.tokenizer(
text[start:start + batch_size],
padding=True,
truncation=True,
return_tensors='pt',
max_length=self.max_seq_length,
)
try:
emb = self(
input_ids=token.input_ids.to(self.device),
attention_mask=token.attention_mask.to(self.device),
).to(output_device)
embs.append(emb)
except: # noqa
# fallback to using CPU for huge strings that cause OOMs
print("Sentence Transformer failed on cuda, trying w/ cpu...")
previous_device = self.device
self.model = self.model.to("cpu")
emb = self(
input_ids=token.input_ids.to(self.device),
attention_mask=token.attention_mask.to(self.device),
).to(output_device)
embs.append(emb)
self.model = self.model.to(previous_device)
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
out = out[:0] if is_empty else out
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model_name={self.model_name})'
def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
# Check whether language model uses left padding,
# which is always used for decoder LLMs
left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
if left_padding:
return emb[:, -1]
seq_indices = attention_mask.sum(dim=1) - 1
return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]