Source code for torch_geometric.llm.models.vision_transformer

from typing import Optional, Union

import torch
from torch import Tensor


[docs]class VisionTransformer(torch.nn.Module): r"""A wrapper around a Vision-Transformer from HuggingFace. Args: model_name (str): The HuggingFace model name, *e.g.*, :obj:`"ViT"`. """ def __init__( self, model_name: str, ) -> None: super().__init__() self.model_name = model_name from transformers import SwinConfig, SwinModel self.config = SwinConfig.from_pretrained(model_name) self.model = SwinModel(self.config)
[docs] @torch.no_grad() def forward( self, images: Tensor, output_device: Optional[Union[torch.device, str]] = None, ) -> Tensor: return self.model(images).last_hidden_state.to(output_device)
@property def device(self) -> torch.device: return next(iter(self.model.parameters())).device def __repr__(self) -> str: return f'{self.__class__.__name__}(model_name={self.model_name})'