| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- from bson import ObjectId
- from pymongo import MongoClient
- from src.domain.models import PalmOilBunch
- class MongoPalmOilRepository:
- def __init__(self, uri: str, db_name: str, collection_name: str):
- self.client = MongoClient(uri)
- self.collection = self.client[db_name][collection_name]
- def ensure_indexes(self):
- """Create indexes for health alerts and vector search."""
- self.collection.create_index("is_abnormal")
- self.collection.create_index("ripeness_class")
- self.collection.create_index("timestamp")
- print("MongoDB Indexes Ensured.")
- def get_by_id(self, record_id: str):
- """Retrieve a specific record by its ID."""
- return self.collection.find_one({"_id": ObjectId(record_id)})
- def save(self, bunch: PalmOilBunch):
- # Convert dataclass to dict for MongoDB
- doc = bunch.__dict__.copy()
- # Remove id if it's None to let Mongo generate it
- if doc.get('id') is None:
- doc.pop('id')
-
- result = self.collection.insert_one(doc)
- return str(result.inserted_id)
- def save_many(self, bunches: list):
- """Bulk insert palm oil records."""
- docs = []
- for bunch in bunches:
- doc = bunch.__dict__.copy()
- if doc.get('id') is None:
- doc.pop('id')
- docs.append(doc)
-
- if docs:
- result = self.collection.insert_many(docs)
- return [str(i) for i in result.inserted_ids]
- return []
- def vector_search(self, query_vector: list, limit: int = 3):
- """Atlas Vector Search using the 1408-D index."""
- if len(query_vector) != 1408:
- raise ValueError(f"Query vector must be 1408-dimensional, got {len(query_vector)}")
- pipeline = [
- {
- "$vectorSearch": {
- "index": "vector_index",
- "path": "embedding",
- "queryVector": query_vector,
- "numCandidates": limit * 10,
- "limit": limit
- }
- },
- {
- "$project": {
- "embedding": 0,
- "image_data": 0,
- "score": {"$meta": "vectorSearchScore"}
- }
- }
- ]
- results = list(self.collection.aggregate(pipeline))
- for res in results:
- res["_id"] = str(res["_id"])
- return results
|