from typing import List, Optional import uuid import os import shutil from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks import onnxruntime as ort from ultralytics import YOLO import numpy as np from dotenv import load_dotenv import io from PIL import Image from src.infrastructure.vision_service import VertexVisionService from src.infrastructure.repository import MongoPalmOilRepository from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase import sqlite3 import json DB_PATH = "palm_history.db" ARCHIVE_DIR = "history_archive" os.makedirs(ARCHIVE_DIR, exist_ok=True) def init_local_db(): print(f"Initializing Local DB at {DB_PATH}...") conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS history ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT, archive_path TEXT, detections TEXT, summary TEXT, inference_ms REAL, processing_ms REAL, raw_tensor TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) ''') conn.commit() conn.close() print("Local DB Initialized.") init_local_db() # Load environment variables load_dotenv() # Set Google Cloud credentials globally os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json" app = FastAPI(title="Palm Oil Ripeness Service (DDD)") class ModelManager: def __init__(self, onnx_path: str, pt_path: str): self.onnx_session = ort.InferenceSession(onnx_path) self.onnx_input_name = self.onnx_session.get_inputs()[0].name self.pt_model = YOLO(pt_path) self.class_names = self.pt_model.names def preprocess_onnx(self, img: Image.Image): img = img.convert("RGB") orig_w, orig_h = img.size img_resized = img.resize((640, 640)) img_array = np.array(img_resized) / 255.0 img_array = img_array.transpose(2, 0, 1) img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32) return img_array, orig_w, orig_h def run_onnx_inference(self, img: Image.Image, conf_threshold: float): img_array, orig_w, orig_h = self.preprocess_onnx(img) import time start_inf = time.perf_counter() outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array}) end_inf = time.perf_counter() inference_ms = (end_inf - start_inf) * 1000 # ONNX Output: [batch, num_boxes, 6] (Where 6: x1, y1, x2, y2, conf, cls) # Note: YOLOv8 endpoints often produce normalized coordinates (0.0 to 1.0) detections_batch = outputs[0] detections = [] valid_count = 0 for i in range(detections_batch.shape[1]): det = detections_batch[0, i] conf = float(det[4]) if conf >= conf_threshold: valid_count += 1 # 1. Coordinate Scaling: Convert normalized (0.0-1.0) to absolute pixels x1, y1, x2, y2 = det[:4] abs_x1 = x1 * orig_w abs_y1 = y1 * orig_h abs_x2 = x2 * orig_w abs_y2 = y2 * orig_h class_id = int(det[5]) class_name = self.class_names.get(class_id, "Unknown") detections.append({ "bunch_id": valid_count, "class": class_name, "confidence": round(conf, 2), "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"], "box": [float(abs_x1), float(abs_y1), float(abs_x2), float(abs_y2)] }) # Capture a raw tensor sample (first 5 detections) for technical evidence raw_sample = detections_batch[0, :5].tolist() return detections, raw_sample, inference_ms def run_pytorch_inference(self, img: Image.Image, conf_threshold: float): import time start_inf = time.perf_counter() results = self.pt_model(img, conf=conf_threshold, verbose=False) end_inf = time.perf_counter() inference_ms = (end_inf - start_inf) * 1000 detections = [] for i, box in enumerate(results[0].boxes): class_id = int(box.cls) class_name = self.class_names.get(class_id, "Unknown") detections.append({ "bunch_id": i + 1, "class": class_name, "confidence": round(float(box.conf), 2), "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"], "box": box.xyxy.tolist()[0] }) # Extract snippet from results (simulating raw output) raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else [] return detections, raw_snippet, inference_ms model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt') # Global state for the confidence threshold current_conf = 0.25 # Initialize DDD Components vision_service = VertexVisionService( project_id=os.getenv("PROJECT_ID", "your-project-id"), location=os.getenv("LOCATION", "us-central1") ) repo = MongoPalmOilRepository( uri=os.getenv("MONGO_URI"), db_name=os.getenv("DB_NAME", "palm_oil_db"), collection_name=os.getenv("COLLECTION_NAME", "ffb_records") ) repo.ensure_indexes() analyze_use_case = AnalyzeBunchUseCase(vision_service, repo) analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo) search_use_case = SearchSimilarUseCase(vision_service, repo) @app.get("/get_confidence") # ... (rest of the code remains same until analyze) async def get_confidence(): """Returns the current confidence threshold used by the model.""" return { "status": "success", "current_confidence": current_conf, "model_version": "best.pt" } @app.post("/set_confidence") async def set_confidence(threshold: float = Body(..., embed=True)): """Updates the confidence threshold globally.""" global current_conf if 0.0 <= threshold <= 1.0: current_conf = threshold return {"status": "success", "new_confidence": current_conf} else: return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"} @app.post("/analyze") async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")): """Industry-grade analysis with health metrics and summary.""" image_bytes = await file.read() img = Image.open(io.BytesIO(image_bytes)) import time start_total = time.perf_counter() # Select Inference Engine if model_type == "pytorch": detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf) else: detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf) end_total = time.perf_counter() total_ms = (end_total - start_total) * 1000 processing_ms = total_ms - inference_ms # Initialize summary summary = {name: 0 for name in model_manager.class_names.values()} for det in detections: summary[det['class']] += 1 # AUTO-ARCHIVE to Local History Vault unique_id = uuid.uuid4().hex archive_filename = f"{unique_id}_{file.filename}" archive_path = os.path.join(ARCHIVE_DIR, archive_filename) # Save image copy with open(archive_path, "wb") as buffer: buffer.write(image_bytes) # Save to SQLite conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)", (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, processing_ms, json.dumps(raw_sample))) conn.commit() conn.close() return { "status": "success", "current_threshold": current_conf, "total_count": len(detections), "industrial_summary": summary, "detections": detections, "inference_ms": inference_ms, "processing_ms": processing_ms, "raw_array_sample": raw_sample, "archive_id": unique_id } @app.post("/vectorize_and_store") async def vectorize_and_store(file: UploadFile = File(...), detection_data: str = Form(...)): """Cloud-dependent. Requires active billing.""" import json try: primary_detection = json.loads(detection_data) except Exception: return {"status": "error", "message": "Invalid detection_data format"} unique_id = uuid.uuid4().hex temp_path = f"temp_vec_{unique_id}_{file.filename}" # Reset file pointer since it might have been read (though here it's a new request) # Actually, in a new request, we read it for the first time. with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: record_id = analyze_use_case.execute(temp_path, primary_detection) return { "status": "success", "record_id": record_id, "message": "FFB vectorized and archived successfully" } except RuntimeError as e: return {"status": "error", "message": str(e)} except Exception as e: return {"status": "error", "message": f"An unexpected error occurred: {str(e)}"} finally: if os.path.exists(temp_path): os.remove(temp_path) @app.post("/process_batch") async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")): """Handles multiple images: Detect -> Vectorize -> Store.""" batch_results = [] temp_files = [] try: for file in files: # 1. Save Temp unique_id = uuid.uuid4().hex path = f"temp_batch_{unique_id}_{file.filename}" with open(path, "wb") as f_out: shutil.copyfileobj(file.file, f_out) temp_files.append(path) import time # 2. Detect img = Image.open(path) # FORCE PYTORCH for Batch start_total = time.perf_counter() detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf) end_total = time.perf_counter() total_ms = (end_total - start_total) * 1000 processing_ms = total_ms - inference_ms # 3. Process all detections in the image for det in detections: batch_results.append({ "path": path, "yolo": det, "inference_ms": inference_ms, "raw_array_sample": raw_sample }) if not batch_results: return {"status": "no_detection", "message": "No bunches detected in batch"} # Calculate Total Industrial Summary for the Batch total_summary = {name: 0 for name in class_names.values()} for item in batch_results: total_summary[item['yolo']['class']] += 1 # 4. Process Batch Use Case with error handling for cloud services detailed_detections = [] for item in batch_results: detailed_detections.append({ "filename": os.path.basename(item['path']), "detection": item['yolo'], "inference_ms": item['inference_ms'], "raw_array_sample": item['raw_array_sample'] }) try: record_ids = analyze_batch_use_case.execute(batch_results) total_records = len(record_ids) return { "status": "success", "processed_count": total_records, "total_count": sum(total_summary.values()), "record_ids": record_ids, "industrial_summary": total_summary, "detailed_results": detailed_detections, "message": f"Successfully processed {total_records} images and identified {sum(total_summary.values())} bunches" } except RuntimeError as e: return { "status": "partial_success", "message": f"Detections completed, but cloud archival failed: {str(e)}", "detections_count": len(batch_results), "detailed_results": detailed_detections } except Exception as e: return {"status": "error", "message": f"Batch processing failed: {str(e)}"} finally: # 5. Clean up all temp files for path in temp_files: if os.path.exists(path): os.remove(path) @app.post("/search_hybrid") async def search_hybrid( file: Optional[UploadFile] = File(None), text_query: Optional[str] = Form(None), limit: int = Form(3) ): """Hybrid Search: Supports Visual Similarity and Natural Language Search.""" temp_path = None try: try: if file: unique_id = uuid.uuid4().hex temp_path = f"temp_search_{unique_id}_{file.filename}" with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) results = search_use_case.execute(image_path=temp_path, limit=limit) elif text_query: results = search_use_case.execute(text_query=text_query, limit=limit) else: return {"status": "error", "message": "No search input provided"} return {"status": "success", "results": results} except RuntimeError as e: return {"status": "error", "message": f"Search unavailable: {str(e)}"} finally: if temp_path and os.path.exists(temp_path): os.remove(temp_path) @app.get("/get_image/{record_id}") async def get_image(record_id: str): """Retrieve the Base64 image data for a specific record.""" record = repo.get_by_id(record_id) if not record: return {"status": "error", "message": "Record not found"} return { "status": "success", "image_data": record.get("image_data") } @app.post("/save_to_history") async def save_to_history(file: UploadFile = File(...), detections: str = Form(...), summary: str = Form(...)): unique_id = uuid.uuid4().hex filename = f"{unique_id}_{file.filename}" archive_path = os.path.join(ARCHIVE_DIR, filename) with open(archive_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)", (file.filename, archive_path, detections, summary, 0.0, 0.0, "")) conn.commit() conn.close() return {"status": "success", "message": "Saved to local vault"} @app.get("/get_history") async def get_history(): conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM history ORDER BY timestamp DESC") rows = [dict(row) for row in cursor.fetchall()] conn.close() return {"status": "success", "history": rows} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)