main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. from typing import List, Optional
  2. import uuid
  3. import os
  4. import shutil
  5. from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
  6. import onnxruntime as ort
  7. from ultralytics import YOLO
  8. import numpy as np
  9. from dotenv import load_dotenv
  10. import io
  11. from PIL import Image
  12. from src.infrastructure.vision_service import VertexVisionService
  13. from src.infrastructure.repository import MongoPalmOilRepository
  14. from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
  15. import sqlite3
  16. import json
  17. import pandas as pd
  18. from datetime import datetime
  19. DB_PATH = "palm_history.db"
  20. ARCHIVE_DIR = "history_archive"
  21. os.makedirs(ARCHIVE_DIR, exist_ok=True)
  22. def init_local_db():
  23. print(f"Initializing Local DB at {DB_PATH}...")
  24. conn = sqlite3.connect(DB_PATH)
  25. cursor = conn.cursor()
  26. cursor.execute('''
  27. CREATE TABLE IF NOT EXISTS history (
  28. id INTEGER PRIMARY KEY AUTOINCREMENT,
  29. filename TEXT,
  30. archive_path TEXT,
  31. detections TEXT,
  32. summary TEXT,
  33. engine TEXT,
  34. inference_ms REAL,
  35. processing_ms REAL,
  36. raw_tensor TEXT,
  37. timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
  38. )
  39. ''')
  40. # Migration: Add engine column if it doesn't exist (for existing DBs)
  41. try:
  42. cursor.execute("ALTER TABLE history ADD COLUMN engine TEXT")
  43. print("Migrated History table: Added 'engine' column.")
  44. except sqlite3.OperationalError:
  45. # Column already exists
  46. pass
  47. conn.commit()
  48. conn.close()
  49. print("Local DB Initialized.")
  50. init_local_db()
  51. # Load environment variables
  52. load_dotenv()
  53. # Set Google Cloud credentials globally
  54. os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
  55. app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
  56. class ModelManager:
  57. def __init__(self, onnx_path: str, pt_path: str, benchmark_path: str = 'sawit_tbs.pt'):
  58. self.onnx_session = ort.InferenceSession(onnx_path)
  59. self.onnx_input_name = self.onnx_session.get_inputs()[0].name
  60. self.pt_model = YOLO(pt_path)
  61. self.class_names = self.pt_model.names
  62. self.benchmark_model = YOLO(benchmark_path)
  63. self.benchmark_class_names = self.benchmark_model.names
  64. def preprocess_onnx(self, img: Image.Image):
  65. img = img.convert("RGB")
  66. orig_w, orig_h = img.size
  67. img_resized = img.resize((640, 640))
  68. img_array = np.array(img_resized) / 255.0
  69. img_array = img_array.transpose(2, 0, 1)
  70. img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
  71. return img_array, orig_w, orig_h
  72. def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
  73. img_array, orig_w, orig_h = self.preprocess_onnx(img)
  74. import time
  75. start_inf = time.perf_counter()
  76. outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
  77. end_inf = time.perf_counter()
  78. inference_ms = (end_inf - start_inf) * 1000
  79. # ONNX Output: [batch, num_boxes, 6] (Where 6: x1, y1, x2, y2, conf, cls)
  80. # Note: YOLOv8 endpoints often produce normalized coordinates (0.0 to 1.0)
  81. detections_batch = outputs[0]
  82. detections = []
  83. valid_count = 0
  84. for i in range(detections_batch.shape[1]):
  85. det = detections_batch[0, i]
  86. conf = float(det[4])
  87. if conf >= conf_threshold:
  88. valid_count += 1
  89. # 1. Coordinate Scaling: Convert normalized (0.0-1.0) to absolute pixels
  90. x1, y1, x2, y2 = det[:4]
  91. abs_x1 = x1 * orig_w
  92. abs_y1 = y1 * orig_h
  93. abs_x2 = x2 * orig_w
  94. abs_y2 = y2 * orig_h
  95. class_id = int(det[5])
  96. class_name = self.class_names.get(class_id, "Unknown")
  97. detections.append({
  98. "bunch_id": valid_count,
  99. "class": class_name,
  100. "confidence": round(conf, 2),
  101. "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
  102. "box": [float(abs_x1), float(abs_y1), float(abs_x2), float(abs_y2)]
  103. })
  104. # Capture a raw tensor sample (first 5 detections) for technical evidence
  105. raw_sample = detections_batch[0, :5].tolist()
  106. return detections, raw_sample, inference_ms
  107. def run_pytorch_inference(self, img: Image.Image, conf_threshold: float, engine_type: str = "pytorch"):
  108. import time
  109. start_inf = time.perf_counter()
  110. # Selection Logic for Third Engine (YOLOv8-Sawit)
  111. model = self.pt_model if engine_type == "pytorch" else self.benchmark_model
  112. names = self.class_names if engine_type == "pytorch" else self.benchmark_class_names
  113. results = model(img, conf=conf_threshold, verbose=False)
  114. end_inf = time.perf_counter()
  115. inference_ms = (end_inf - start_inf) * 1000
  116. detections = []
  117. for i, box in enumerate(results[0].boxes):
  118. class_id = int(box.cls)
  119. class_name = names.get(class_id, "Unknown")
  120. detections.append({
  121. "bunch_id": i + 1,
  122. "class": class_name,
  123. "confidence": round(float(box.conf), 2),
  124. "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
  125. "box": box.xyxy.tolist()[0]
  126. })
  127. # Extract snippet from results (simulating raw output)
  128. raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else []
  129. return detections, raw_snippet, inference_ms
  130. model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
  131. # Global state for the confidence threshold
  132. current_conf = 0.25
  133. # Initialize DDD Components
  134. vision_service = VertexVisionService(
  135. project_id=os.getenv("PROJECT_ID", "your-project-id"),
  136. location=os.getenv("LOCATION", "us-central1")
  137. )
  138. repo = MongoPalmOilRepository(
  139. uri=os.getenv("MONGO_URI"),
  140. db_name=os.getenv("DB_NAME", "palm_oil_db"),
  141. collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
  142. )
  143. db_connected = False
  144. try:
  145. print("Connecting to MongoDB Atlas...")
  146. repo.ensure_indexes()
  147. db_connected = True
  148. print("MongoDB Atlas Connected.")
  149. except Exception as e:
  150. print(f"Warning: Could not connect to MongoDB Atlas (Timeout). Cloud archival will be disabled. Details: {e}")
  151. analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
  152. analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
  153. search_use_case = SearchSimilarUseCase(vision_service, repo)
  154. @app.get("/get_confidence")
  155. # ... (rest of the code remains same until analyze)
  156. async def get_confidence():
  157. """Returns the current confidence threshold used by the model."""
  158. return {
  159. "status": "success",
  160. "current_confidence": current_conf,
  161. "model_version": "best.pt"
  162. }
  163. @app.post("/set_confidence")
  164. async def set_confidence(threshold: float = Body(..., embed=True)):
  165. """Updates the confidence threshold globally."""
  166. global current_conf
  167. if 0.0 <= threshold <= 1.0:
  168. current_conf = threshold
  169. return {"status": "success", "new_confidence": current_conf}
  170. else:
  171. return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
  172. @app.post("/analyze")
  173. async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
  174. """Industry-grade analysis with health metrics and summary."""
  175. image_bytes = await file.read()
  176. img = Image.open(io.BytesIO(image_bytes))
  177. import time
  178. start_total = time.perf_counter()
  179. # Select Inference Engine
  180. if model_type == "pytorch":
  181. detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
  182. elif model_type == "yolov8_sawit":
  183. detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "yolov8_sawit")
  184. else:
  185. detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
  186. end_total = time.perf_counter()
  187. total_ms = (end_total - start_total) * 1000
  188. processing_ms = total_ms - inference_ms
  189. # Initialize summary
  190. active_names = model_manager.class_names if model_type != "yolov8_sawit" else model_manager.benchmark_class_names
  191. summary = {name: 0 for name in active_names.values()}
  192. for det in detections:
  193. summary[det['class']] += 1
  194. # AUTO-ARCHIVE to Local History Vault
  195. unique_id = uuid.uuid4().hex
  196. archive_filename = f"{unique_id}_{file.filename}"
  197. archive_path = os.path.join(ARCHIVE_DIR, archive_filename)
  198. # Save image copy
  199. with open(archive_path, "wb") as buffer:
  200. buffer.write(image_bytes)
  201. # Save to SQLite
  202. conn = sqlite3.connect(DB_PATH)
  203. cursor = conn.cursor()
  204. cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, engine, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  205. (file.filename, archive_path, json.dumps(detections), json.dumps(summary), model_type, inference_ms, processing_ms, json.dumps(raw_sample)))
  206. conn.commit()
  207. conn.close()
  208. return {
  209. "status": "success",
  210. "current_threshold": current_conf,
  211. "total_count": len(detections),
  212. "industrial_summary": summary,
  213. "detections": detections,
  214. "inference_ms": inference_ms,
  215. "processing_ms": processing_ms,
  216. "raw_array_sample": raw_sample,
  217. "archive_id": unique_id
  218. }
  219. @app.post("/vectorize_and_store")
  220. async def vectorize_and_store(file: UploadFile = File(...), detection_data: str = Form(...)):
  221. """Cloud-dependent. Requires active billing."""
  222. if not db_connected:
  223. return {"status": "error", "message": "Cloud Archival is currently unavailable (Database Offline)."}
  224. import json
  225. try:
  226. primary_detection = json.loads(detection_data)
  227. except Exception:
  228. return {"status": "error", "message": "Invalid detection_data format"}
  229. unique_id = uuid.uuid4().hex
  230. temp_path = f"temp_vec_{unique_id}_{file.filename}"
  231. # Reset file pointer since it might have been read (though here it's a new request)
  232. # Actually, in a new request, we read it for the first time.
  233. with open(temp_path, "wb") as buffer:
  234. shutil.copyfileobj(file.file, buffer)
  235. try:
  236. record_id = analyze_use_case.execute(temp_path, primary_detection)
  237. return {
  238. "status": "success",
  239. "record_id": record_id,
  240. "message": "FFB vectorized and archived successfully"
  241. }
  242. except RuntimeError as e:
  243. return {"status": "error", "message": str(e)}
  244. except Exception as e:
  245. return {"status": "error", "message": f"An unexpected error occurred: {str(e)}"}
  246. finally:
  247. if os.path.exists(temp_path):
  248. os.remove(temp_path)
  249. @app.post("/process_batch")
  250. async def process_batch(
  251. files: List[UploadFile] = File(...),
  252. model_type: str = Form("onnx"),
  253. metadata: str = Form("{}") # JSON string from Frontend
  254. ):
  255. batch_id = f"BATCH_{uuid.uuid4().hex[:8].upper()}"
  256. output_dir = os.path.join("batch_outputs", batch_id)
  257. os.makedirs(os.path.join(output_dir, "raw"), exist_ok=True)
  258. start_time = datetime.now()
  259. meta_dict = json.loads(metadata)
  260. batch_records = []
  261. for file in files:
  262. unique_id = uuid.uuid4().hex[:6]
  263. filename = f"{unique_id}_{file.filename}"
  264. save_path = os.path.join(output_dir, "raw", filename)
  265. # 1. Save Raw Image to Bundle
  266. image_bytes = await file.read()
  267. with open(save_path, "wb") as f:
  268. f.write(image_bytes)
  269. # 2. Run Inference
  270. img = Image.open(io.BytesIO(image_bytes))
  271. # Selection logic based on existing API pattern
  272. if model_type == "pytorch":
  273. detections, raw_sample, inf_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
  274. elif model_type == "yolov8_sawit":
  275. detections, raw_sample, inf_ms = model_manager.run_pytorch_inference(img, current_conf, "yolov8_sawit")
  276. else:
  277. detections, raw_sample, inf_ms = model_manager.run_onnx_inference(img, current_conf)
  278. # 3. Normalize Coordinates for the Contract
  279. # Downstream processes shouldn't care about your input resolution
  280. w, h = img.size
  281. normalized_dets = []
  282. for d in detections:
  283. x1, y1, x2, y2 = d['box']
  284. normalized_dets.append({
  285. **d,
  286. "norm_box": [x1/w, y1/h, x2/w, y2/h]
  287. })
  288. batch_records.append({
  289. "image_id": unique_id,
  290. "filename": filename,
  291. "detections": normalized_dets,
  292. "inference_ms": inf_ms,
  293. "raw_tensor": raw_sample # Added for technical evidence/contract
  294. })
  295. end_time = datetime.now()
  296. duration = (end_time - start_time).total_seconds()
  297. # 4. Generate the Summary (For Manifest and immediate UI feedback)
  298. active_names = model_manager.class_names if model_type != "yolov8_sawit" else model_manager.benchmark_class_names
  299. total_summary = {name: 0 for name in active_names.values()}
  300. for record in batch_records:
  301. for det in record['detections']:
  302. total_summary[det['class']] += 1
  303. # 5. Generate the Manifest (The Contract)
  304. performance_metrics = {
  305. "start_time": start_time.isoformat(),
  306. "end_time": end_time.isoformat(),
  307. "duration_seconds": round(duration, 2)
  308. }
  309. manifest = {
  310. "job_id": batch_id,
  311. "timestamp": end_time.isoformat(),
  312. "source_context": meta_dict,
  313. "engine": {
  314. "name": "YOLO26",
  315. "type": model_type,
  316. "threshold": current_conf
  317. },
  318. "performance": performance_metrics, # Added performance metrics
  319. "industrial_summary": total_summary, # Added for subscribers
  320. "inventory": batch_records
  321. }
  322. with open(os.path.join(output_dir, "manifest.json"), "w") as f:
  323. json.dump(manifest, f, indent=4)
  324. # Note: Maintaining compatibility with the frontend's expectation of 'industrial_summary'
  325. # and 'processed_count' for immediate UI feedback.
  326. return {
  327. "status": "success",
  328. "batch_id": batch_id,
  329. "bundle_path": output_dir,
  330. "processed_count": len(files),
  331. "total_count": sum(total_summary.values()),
  332. "industrial_summary": total_summary,
  333. "performance": performance_metrics,
  334. "record_ids": [r['image_id'] for r in batch_records], # Backward compatibility
  335. "manifest_preview": manifest,
  336. "detailed_results": [{"filename": r['filename'], "detection": d} for r in batch_records for d in r['detections']] # Backward compatibility
  337. }
  338. @app.post("/search_hybrid")
  339. async def search_hybrid(
  340. file: Optional[UploadFile] = File(None),
  341. text_query: Optional[str] = Form(None),
  342. limit: int = Form(3)
  343. ):
  344. """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
  345. if not db_connected:
  346. return {"status": "error", "message": "Semantic Search is currently unavailable (Database Offline)."}
  347. temp_path = None
  348. try:
  349. try:
  350. if file:
  351. unique_id = uuid.uuid4().hex
  352. temp_path = f"temp_search_{unique_id}_{file.filename}"
  353. with open(temp_path, "wb") as buffer:
  354. shutil.copyfileobj(file.file, buffer)
  355. results = search_use_case.execute(image_path=temp_path, limit=limit)
  356. elif text_query:
  357. results = search_use_case.execute(text_query=text_query, limit=limit)
  358. else:
  359. return {"status": "error", "message": "No search input provided"}
  360. return {"status": "success", "results": results}
  361. except RuntimeError as e:
  362. return {"status": "error", "message": f"Search unavailable: {str(e)}"}
  363. finally:
  364. if temp_path and os.path.exists(temp_path):
  365. os.remove(temp_path)
  366. @app.get("/get_image/{record_id}")
  367. async def get_image(record_id: str):
  368. """Retrieve the Base64 image data for a specific record."""
  369. record = repo.get_by_id(record_id)
  370. if not record:
  371. return {"status": "error", "message": "Record not found"}
  372. return {
  373. "status": "success",
  374. "image_data": record.get("image_data")
  375. }
  376. @app.post("/save_to_history")
  377. async def save_to_history(file: UploadFile = File(...), detections: str = Form(...), summary: str = Form(...)):
  378. unique_id = uuid.uuid4().hex
  379. filename = f"{unique_id}_{file.filename}"
  380. archive_path = os.path.join(ARCHIVE_DIR, filename)
  381. with open(archive_path, "wb") as buffer:
  382. shutil.copyfileobj(file.file, buffer)
  383. conn = sqlite3.connect(DB_PATH)
  384. cursor = conn.cursor()
  385. cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
  386. (file.filename, archive_path, detections, summary, 0.0, 0.0, ""))
  387. conn.commit()
  388. conn.close()
  389. return {"status": "success", "message": "Saved to local vault"}
  390. @app.get("/get_history")
  391. async def get_history():
  392. conn = sqlite3.connect(DB_PATH)
  393. conn.row_factory = sqlite3.Row
  394. cursor = conn.cursor()
  395. cursor.execute("SELECT * FROM history ORDER BY timestamp DESC")
  396. rows = [dict(row) for row in cursor.fetchall()]
  397. conn.close()
  398. return {"status": "success", "history": rows}
  399. @app.get("/get_model_info")
  400. async def get_model_info(model_type: str = "onnx"):
  401. """Returns metadata and capabilities for the specified model engine."""
  402. if model_type in ["onnx", "pytorch"]:
  403. classes = list(model_manager.class_names.values())
  404. description = "Standard YOLO26 Industrial Model."
  405. elif model_type == "yolov8_sawit":
  406. classes = list(model_manager.benchmark_class_names.values())
  407. description = "YOLOv8-Sawit (Benchmark) - External Architecture."
  408. else:
  409. return {"status": "error", "message": "Unknown model type"}
  410. return {
  411. "status": "success",
  412. "model_type": model_type,
  413. "description": description,
  414. "detections_categories": classes
  415. }
  416. if __name__ == "__main__":
  417. import uvicorn
  418. uvicorn.run(app, host="0.0.0.0", port=8000)