from bson import ObjectId from pymongo import MongoClient from src.domain.models import PalmOilBunch class MongoPalmOilRepository: def __init__(self, uri: str, db_name: str, collection_name: str): self.client = MongoClient(uri) self.collection = self.client[db_name][collection_name] def ensure_indexes(self): """Create indexes for health alerts and vector search.""" self.collection.create_index("is_abnormal") self.collection.create_index("ripeness_class") self.collection.create_index("timestamp") print("MongoDB Indexes Ensured.") def get_by_id(self, record_id: str): """Retrieve a specific record by its ID.""" return self.collection.find_one({"_id": ObjectId(record_id)}) def save(self, bunch: PalmOilBunch): # Convert dataclass to dict for MongoDB doc = bunch.__dict__.copy() # Remove id if it's None to let Mongo generate it if doc.get('id') is None: doc.pop('id') result = self.collection.insert_one(doc) return str(result.inserted_id) def save_many(self, bunches: list): """Bulk insert palm oil records.""" docs = [] for bunch in bunches: doc = bunch.__dict__.copy() if doc.get('id') is None: doc.pop('id') docs.append(doc) if docs: result = self.collection.insert_many(docs) return [str(i) for i in result.inserted_ids] return [] def vector_search(self, query_vector: list, limit: int = 3): """Atlas Vector Search using the 1408-D index.""" if len(query_vector) != 1408: raise ValueError(f"Query vector must be 1408-dimensional, got {len(query_vector)}") pipeline = [ { "$vectorSearch": { "index": "vector_index", "path": "embedding", "queryVector": query_vector, "numCandidates": limit * 10, "limit": limit } }, { "$project": { "embedding": 0, "image_data": 0, "score": {"$meta": "vectorSearchScore"} } } ] results = list(self.collection.aggregate(pipeline)) for res in results: res["_id"] = str(res["_id"]) return results