| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- from typing import List, Optional
- import uuid
- import os
- import shutil
- from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
- import onnxruntime as ort
- from ultralytics import YOLO
- import numpy as np
- from dotenv import load_dotenv
- import io
- from PIL import Image
- from src.infrastructure.vision_service import VertexVisionService
- from src.infrastructure.repository import MongoPalmOilRepository
- from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
- import sqlite3
- import json
- DB_PATH = "palm_history.db"
- ARCHIVE_DIR = "history_archive"
- os.makedirs(ARCHIVE_DIR, exist_ok=True)
- def init_local_db():
- print(f"Initializing Local DB at {DB_PATH}...")
- conn = sqlite3.connect(DB_PATH)
- cursor = conn.cursor()
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS history (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- filename TEXT,
- archive_path TEXT,
- detections TEXT,
- summary TEXT,
- inference_ms REAL,
- processing_ms REAL,
- raw_tensor TEXT,
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
- )
- ''')
- conn.commit()
- conn.close()
- print("Local DB Initialized.")
- init_local_db()
- # Load environment variables
- load_dotenv()
- # Set Google Cloud credentials globally
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
- app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
- class ModelManager:
- def __init__(self, onnx_path: str, pt_path: str):
- self.onnx_session = ort.InferenceSession(onnx_path)
- self.onnx_input_name = self.onnx_session.get_inputs()[0].name
- self.pt_model = YOLO(pt_path)
- self.class_names = self.pt_model.names
- def preprocess_onnx(self, img: Image.Image):
- img = img.convert("RGB")
- orig_w, orig_h = img.size
- img_resized = img.resize((640, 640))
- img_array = np.array(img_resized) / 255.0
- img_array = img_array.transpose(2, 0, 1)
- img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
- return img_array, orig_w, orig_h
- def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
- img_array, orig_w, orig_h = self.preprocess_onnx(img)
-
- import time
- start_inf = time.perf_counter()
- outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
- end_inf = time.perf_counter()
- inference_ms = (end_inf - start_inf) * 1000
-
- # ONNX Output: [batch, num_boxes, 6] (Where 6: x1, y1, x2, y2, conf, cls)
- # Note: YOLOv8 endpoints often produce normalized coordinates (0.0 to 1.0)
- detections_batch = outputs[0]
-
- detections = []
- valid_count = 0
- for i in range(detections_batch.shape[1]):
- det = detections_batch[0, i]
- conf = float(det[4])
- if conf >= conf_threshold:
- valid_count += 1
- # 1. Coordinate Scaling: Convert normalized (0.0-1.0) to absolute pixels
- x1, y1, x2, y2 = det[:4]
- abs_x1 = x1 * orig_w
- abs_y1 = y1 * orig_h
- abs_x2 = x2 * orig_w
- abs_y2 = y2 * orig_h
-
- class_id = int(det[5])
- class_name = self.class_names.get(class_id, "Unknown")
-
- detections.append({
- "bunch_id": valid_count,
- "class": class_name,
- "confidence": round(conf, 2),
- "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
- "box": [float(abs_x1), float(abs_y1), float(abs_x2), float(abs_y2)]
- })
-
- # Capture a raw tensor sample (first 5 detections) for technical evidence
- raw_sample = detections_batch[0, :5].tolist()
- return detections, raw_sample, inference_ms
- def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
- import time
- start_inf = time.perf_counter()
- results = self.pt_model(img, conf=conf_threshold, verbose=False)
- end_inf = time.perf_counter()
- inference_ms = (end_inf - start_inf) * 1000
- detections = []
- for i, box in enumerate(results[0].boxes):
- class_id = int(box.cls)
- class_name = self.class_names.get(class_id, "Unknown")
- detections.append({
- "bunch_id": i + 1,
- "class": class_name,
- "confidence": round(float(box.conf), 2),
- "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
- "box": box.xyxy.tolist()[0]
- })
-
- # Extract snippet from results (simulating raw output)
- raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else []
- return detections, raw_snippet, inference_ms
- model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
- # Global state for the confidence threshold
- current_conf = 0.25
- # Initialize DDD Components
- vision_service = VertexVisionService(
- project_id=os.getenv("PROJECT_ID", "your-project-id"),
- location=os.getenv("LOCATION", "us-central1")
- )
- repo = MongoPalmOilRepository(
- uri=os.getenv("MONGO_URI"),
- db_name=os.getenv("DB_NAME", "palm_oil_db"),
- collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
- )
- repo.ensure_indexes()
- analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
- analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
- search_use_case = SearchSimilarUseCase(vision_service, repo)
- @app.get("/get_confidence")
- # ... (rest of the code remains same until analyze)
- async def get_confidence():
- """Returns the current confidence threshold used by the model."""
- return {
- "status": "success",
- "current_confidence": current_conf,
- "model_version": "best.pt"
- }
- @app.post("/set_confidence")
- async def set_confidence(threshold: float = Body(..., embed=True)):
- """Updates the confidence threshold globally."""
- global current_conf
- if 0.0 <= threshold <= 1.0:
- current_conf = threshold
- return {"status": "success", "new_confidence": current_conf}
- else:
- return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
- @app.post("/analyze")
- async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
- """Industry-grade analysis with health metrics and summary."""
- image_bytes = await file.read()
- img = Image.open(io.BytesIO(image_bytes))
-
- import time
- start_total = time.perf_counter()
- # Select Inference Engine
- if model_type == "pytorch":
- detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
- else:
- detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
-
- end_total = time.perf_counter()
- total_ms = (end_total - start_total) * 1000
- processing_ms = total_ms - inference_ms
-
- # Initialize summary
- summary = {name: 0 for name in model_manager.class_names.values()}
- for det in detections:
- summary[det['class']] += 1
-
- # AUTO-ARCHIVE to Local History Vault
- unique_id = uuid.uuid4().hex
- archive_filename = f"{unique_id}_{file.filename}"
- archive_path = os.path.join(ARCHIVE_DIR, archive_filename)
-
- # Save image copy
- with open(archive_path, "wb") as buffer:
- buffer.write(image_bytes)
-
- # Save to SQLite
- conn = sqlite3.connect(DB_PATH)
- cursor = conn.cursor()
- cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
- (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, processing_ms, json.dumps(raw_sample)))
- conn.commit()
- conn.close()
-
- return {
- "status": "success",
- "current_threshold": current_conf,
- "total_count": len(detections),
- "industrial_summary": summary,
- "detections": detections,
- "inference_ms": inference_ms,
- "processing_ms": processing_ms,
- "raw_array_sample": raw_sample,
- "archive_id": unique_id
- }
- @app.post("/vectorize_and_store")
- async def vectorize_and_store(file: UploadFile = File(...), detection_data: str = Form(...)):
- """Cloud-dependent. Requires active billing."""
- import json
- try:
- primary_detection = json.loads(detection_data)
- except Exception:
- return {"status": "error", "message": "Invalid detection_data format"}
- unique_id = uuid.uuid4().hex
- temp_path = f"temp_vec_{unique_id}_{file.filename}"
-
- # Reset file pointer since it might have been read (though here it's a new request)
- # Actually, in a new request, we read it for the first time.
- with open(temp_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- try:
- record_id = analyze_use_case.execute(temp_path, primary_detection)
- return {
- "status": "success",
- "record_id": record_id,
- "message": "FFB vectorized and archived successfully"
- }
- except RuntimeError as e:
- return {"status": "error", "message": str(e)}
- except Exception as e:
- return {"status": "error", "message": f"An unexpected error occurred: {str(e)}"}
- finally:
- if os.path.exists(temp_path):
- os.remove(temp_path)
- @app.post("/process_batch")
- async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
- """Handles multiple images: Detect -> Vectorize -> Store."""
- batch_results = []
- temp_files = []
- try:
- for file in files:
- # 1. Save Temp
- unique_id = uuid.uuid4().hex
- path = f"temp_batch_{unique_id}_{file.filename}"
- with open(path, "wb") as f_out:
- shutil.copyfileobj(file.file, f_out)
- temp_files.append(path)
- import time
- # 2. Detect
- img = Image.open(path)
- # FORCE PYTORCH for Batch
- start_total = time.perf_counter()
- detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
- end_total = time.perf_counter()
-
- total_ms = (end_total - start_total) * 1000
- processing_ms = total_ms - inference_ms
-
- # 3. Process all detections in the image
- for det in detections:
- batch_results.append({
- "path": path,
- "yolo": det,
- "inference_ms": inference_ms,
- "raw_array_sample": raw_sample
- })
- if not batch_results:
- return {"status": "no_detection", "message": "No bunches detected in batch"}
- # Calculate Total Industrial Summary for the Batch
- total_summary = {name: 0 for name in class_names.values()}
- for item in batch_results:
- total_summary[item['yolo']['class']] += 1
- # 4. Process Batch Use Case with error handling for cloud services
- detailed_detections = []
- for item in batch_results:
- detailed_detections.append({
- "filename": os.path.basename(item['path']),
- "detection": item['yolo'],
- "inference_ms": item['inference_ms'],
- "raw_array_sample": item['raw_array_sample']
- })
- try:
- record_ids = analyze_batch_use_case.execute(batch_results)
- total_records = len(record_ids)
- return {
- "status": "success",
- "processed_count": total_records,
- "total_count": sum(total_summary.values()),
- "record_ids": record_ids,
- "industrial_summary": total_summary,
- "detailed_results": detailed_detections,
- "message": f"Successfully processed {total_records} images and identified {sum(total_summary.values())} bunches"
- }
- except RuntimeError as e:
- return {
- "status": "partial_success",
- "message": f"Detections completed, but cloud archival failed: {str(e)}",
- "detections_count": len(batch_results),
- "detailed_results": detailed_detections
- }
- except Exception as e:
- return {"status": "error", "message": f"Batch processing failed: {str(e)}"}
- finally:
- # 5. Clean up all temp files
- for path in temp_files:
- if os.path.exists(path):
- os.remove(path)
- @app.post("/search_hybrid")
- async def search_hybrid(
- file: Optional[UploadFile] = File(None),
- text_query: Optional[str] = Form(None),
- limit: int = Form(3)
- ):
- """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
- temp_path = None
- try:
- try:
- if file:
- unique_id = uuid.uuid4().hex
- temp_path = f"temp_search_{unique_id}_{file.filename}"
- with open(temp_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
-
- results = search_use_case.execute(image_path=temp_path, limit=limit)
- elif text_query:
- results = search_use_case.execute(text_query=text_query, limit=limit)
- else:
- return {"status": "error", "message": "No search input provided"}
- return {"status": "success", "results": results}
- except RuntimeError as e:
- return {"status": "error", "message": f"Search unavailable: {str(e)}"}
-
- finally:
- if temp_path and os.path.exists(temp_path):
- os.remove(temp_path)
- @app.get("/get_image/{record_id}")
- async def get_image(record_id: str):
- """Retrieve the Base64 image data for a specific record."""
- record = repo.get_by_id(record_id)
- if not record:
- return {"status": "error", "message": "Record not found"}
- return {
- "status": "success",
- "image_data": record.get("image_data")
- }
- @app.post("/save_to_history")
- async def save_to_history(file: UploadFile = File(...), detections: str = Form(...), summary: str = Form(...)):
- unique_id = uuid.uuid4().hex
- filename = f"{unique_id}_{file.filename}"
- archive_path = os.path.join(ARCHIVE_DIR, filename)
-
- with open(archive_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
-
- conn = sqlite3.connect(DB_PATH)
- cursor = conn.cursor()
- cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
- (file.filename, archive_path, detections, summary, 0.0, 0.0, ""))
- conn.commit()
- conn.close()
- return {"status": "success", "message": "Saved to local vault"}
- @app.get("/get_history")
- async def get_history():
- conn = sqlite3.connect(DB_PATH)
- conn.row_factory = sqlite3.Row
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM history ORDER BY timestamp DESC")
- rows = [dict(row) for row in cursor.fetchall()]
- conn.close()
- return {"status": "success", "history": rows}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
|