repository.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from bson import ObjectId
  2. from pymongo import MongoClient
  3. from src.domain.models import PalmOilBunch
  4. class MongoPalmOilRepository:
  5. def __init__(self, uri: str, db_name: str, collection_name: str):
  6. self.client = MongoClient(uri)
  7. self.collection = self.client[db_name][collection_name]
  8. def ensure_indexes(self):
  9. """Create indexes for health alerts and vector search."""
  10. self.collection.create_index("is_abnormal")
  11. self.collection.create_index("ripeness_class")
  12. self.collection.create_index("timestamp")
  13. print("MongoDB Indexes Ensured.")
  14. def get_by_id(self, record_id: str):
  15. """Retrieve a specific record by its ID."""
  16. return self.collection.find_one({"_id": ObjectId(record_id)})
  17. def save(self, bunch: PalmOilBunch):
  18. # Convert dataclass to dict for MongoDB
  19. doc = bunch.__dict__.copy()
  20. # Remove id if it's None to let Mongo generate it
  21. if doc.get('id') is None:
  22. doc.pop('id')
  23. result = self.collection.insert_one(doc)
  24. return str(result.inserted_id)
  25. def save_many(self, bunches: list):
  26. """Bulk insert palm oil records."""
  27. docs = []
  28. for bunch in bunches:
  29. doc = bunch.__dict__.copy()
  30. if doc.get('id') is None:
  31. doc.pop('id')
  32. docs.append(doc)
  33. if docs:
  34. result = self.collection.insert_many(docs)
  35. return [str(i) for i in result.inserted_ids]
  36. return []
  37. def vector_search(self, query_vector: list, limit: int = 3):
  38. """Atlas Vector Search using the 1408-D index."""
  39. if len(query_vector) != 1408:
  40. raise ValueError(f"Query vector must be 1408-dimensional, got {len(query_vector)}")
  41. pipeline = [
  42. {
  43. "$vectorSearch": {
  44. "index": "vector_index",
  45. "path": "embedding",
  46. "queryVector": query_vector,
  47. "numCandidates": limit * 10,
  48. "limit": limit
  49. }
  50. },
  51. {
  52. "$project": {
  53. "embedding": 0,
  54. "image_data": 0,
  55. "score": {"$meta": "vectorSearchScore"}
  56. }
  57. }
  58. ]
  59. results = list(self.collection.aggregate(pipeline))
  60. for res in results:
  61. res["_id"] = str(res["_id"])
  62. return results