vision_service.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import vertexai
  3. from vertexai.vision_models import Image, MultiModalEmbeddingModel
  4. import base64
  5. import io
  6. from PIL import Image as PILImage
  7. from typing import List
  8. class VertexVisionService:
  9. def __init__(self, project_id: str, location: str):
  10. self.project_id = project_id
  11. self.location = location
  12. self._model = None
  13. def _get_model(self):
  14. """Lazy load the model and catch billing/connection errors."""
  15. if self._model is None:
  16. try:
  17. # Ensure credentials are set before init if using service account key
  18. # (This is now handled globally in main.py)
  19. vertexai.init(project=self.project_id, location=self.location)
  20. self._model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
  21. except Exception as e:
  22. # Log the specific error (e.g., Billing Disabled)
  23. print(f"Vertex AI Initialization Failed: {e}")
  24. raise RuntimeError("Cloud services (Vectorization/Search) are currently unavailable.")
  25. return self._model
  26. def get_image_embedding(self, image_path: str) -> List[float]:
  27. if not os.path.exists(image_path):
  28. raise FileNotFoundError(f"Image not found at {image_path}")
  29. model = self._get_model() # This will raise the RuntimeError if billing is down
  30. image = Image.load_from_file(image_path)
  31. # Standardizing to 1408 dimensions for consistency
  32. embeddings = model.get_embeddings(image=image, dimension=1408)
  33. return embeddings.image_embedding
  34. def get_text_embedding(self, text: str) -> List[float]:
  35. """Converts text query into a 1408-d vector."""
  36. model = self._get_model()
  37. embeddings = model.get_embeddings(
  38. contextual_text=text,
  39. dimension=1408
  40. )
  41. return embeddings.text_embedding
  42. def encode_image_to_base64(self, image_path: str) -> str:
  43. """Resizes image to 640x640 and encodes to Base64."""
  44. with PILImage.open(image_path) as img:
  45. img.thumbnail((640, 640))
  46. buffered = io.BytesIO()
  47. img.save(buffered, format="JPEG", quality=85)
  48. return base64.b64encode(buffered.getvalue()).decode('utf-8')