| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import os
- import vertexai
- from vertexai.vision_models import Image, MultiModalEmbeddingModel
- import base64
- import io
- from PIL import Image as PILImage
- from typing import List
- class VertexVisionService:
- def __init__(self, project_id: str, location: str):
- self.project_id = project_id
- self.location = location
- self._model = None
- def _get_model(self):
- """Lazy load the model and catch billing/connection errors."""
- if self._model is None:
- try:
- # Ensure credentials are set before init if using service account key
- # (This is now handled globally in main.py)
- vertexai.init(project=self.project_id, location=self.location)
- self._model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
- except Exception as e:
- # Log the specific error (e.g., Billing Disabled)
- print(f"Vertex AI Initialization Failed: {e}")
- raise RuntimeError("Cloud services (Vectorization/Search) are currently unavailable.")
- return self._model
- def get_image_embedding(self, image_path: str) -> List[float]:
- if not os.path.exists(image_path):
- raise FileNotFoundError(f"Image not found at {image_path}")
-
- model = self._get_model() # This will raise the RuntimeError if billing is down
- image = Image.load_from_file(image_path)
- # Standardizing to 1408 dimensions for consistency
- embeddings = model.get_embeddings(image=image, dimension=1408)
- return embeddings.image_embedding
- def get_text_embedding(self, text: str) -> List[float]:
- """Converts text query into a 1408-d vector."""
- model = self._get_model()
- embeddings = model.get_embeddings(
- contextual_text=text,
- dimension=1408
- )
- return embeddings.text_embedding
- def encode_image_to_base64(self, image_path: str) -> str:
- """Resizes image to 640x640 and encodes to Base64."""
- with PILImage.open(image_path) as img:
- img.thumbnail((640, 640))
- buffered = io.BytesIO()
- img.save(buffered, format="JPEG", quality=85)
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
|