main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from typing import List, Optional
  2. import uuid
  3. import os
  4. import shutil
  5. from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
  6. import onnxruntime as ort
  7. from ultralytics import YOLO
  8. import numpy as np
  9. from dotenv import load_dotenv
  10. import io
  11. from PIL import Image
  12. from src.infrastructure.vision_service import VertexVisionService
  13. from src.infrastructure.repository import MongoPalmOilRepository
  14. from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
  15. import sqlite3
  16. import json
  17. DB_PATH = "palm_history.db"
  18. ARCHIVE_DIR = "history_archive"
  19. os.makedirs(ARCHIVE_DIR, exist_ok=True)
  20. def init_local_db():
  21. print(f"Initializing Local DB at {DB_PATH}...")
  22. conn = sqlite3.connect(DB_PATH)
  23. cursor = conn.cursor()
  24. cursor.execute('''
  25. CREATE TABLE IF NOT EXISTS history (
  26. id INTEGER PRIMARY KEY AUTOINCREMENT,
  27. filename TEXT,
  28. archive_path TEXT,
  29. detections TEXT,
  30. summary TEXT,
  31. inference_ms REAL,
  32. processing_ms REAL,
  33. raw_tensor TEXT,
  34. timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
  35. )
  36. ''')
  37. conn.commit()
  38. conn.close()
  39. print("Local DB Initialized.")
  40. init_local_db()
  41. # Load environment variables
  42. load_dotenv()
  43. # Set Google Cloud credentials globally
  44. os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
  45. app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
  46. class ModelManager:
  47. def __init__(self, onnx_path: str, pt_path: str):
  48. self.onnx_session = ort.InferenceSession(onnx_path)
  49. self.onnx_input_name = self.onnx_session.get_inputs()[0].name
  50. self.pt_model = YOLO(pt_path)
  51. self.class_names = self.pt_model.names
  52. def preprocess_onnx(self, img: Image.Image):
  53. img = img.convert("RGB")
  54. orig_w, orig_h = img.size
  55. img_resized = img.resize((640, 640))
  56. img_array = np.array(img_resized) / 255.0
  57. img_array = img_array.transpose(2, 0, 1)
  58. img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
  59. return img_array, orig_w, orig_h
  60. def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
  61. img_array, orig_w, orig_h = self.preprocess_onnx(img)
  62. import time
  63. start_inf = time.perf_counter()
  64. outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
  65. end_inf = time.perf_counter()
  66. inference_ms = (end_inf - start_inf) * 1000
  67. # ONNX Output: [batch, num_boxes, 6] (Where 6: x1, y1, x2, y2, conf, cls)
  68. # Note: YOLOv8 endpoints often produce normalized coordinates (0.0 to 1.0)
  69. detections_batch = outputs[0]
  70. detections = []
  71. valid_count = 0
  72. for i in range(detections_batch.shape[1]):
  73. det = detections_batch[0, i]
  74. conf = float(det[4])
  75. if conf >= conf_threshold:
  76. valid_count += 1
  77. # 1. Coordinate Scaling: Convert normalized (0.0-1.0) to absolute pixels
  78. x1, y1, x2, y2 = det[:4]
  79. abs_x1 = x1 * orig_w
  80. abs_y1 = y1 * orig_h
  81. abs_x2 = x2 * orig_w
  82. abs_y2 = y2 * orig_h
  83. class_id = int(det[5])
  84. class_name = self.class_names.get(class_id, "Unknown")
  85. detections.append({
  86. "bunch_id": valid_count,
  87. "class": class_name,
  88. "confidence": round(conf, 2),
  89. "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
  90. "box": [float(abs_x1), float(abs_y1), float(abs_x2), float(abs_y2)]
  91. })
  92. # Capture a raw tensor sample (first 5 detections) for technical evidence
  93. raw_sample = detections_batch[0, :5].tolist()
  94. return detections, raw_sample, inference_ms
  95. def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
  96. import time
  97. start_inf = time.perf_counter()
  98. results = self.pt_model(img, conf=conf_threshold, verbose=False)
  99. end_inf = time.perf_counter()
  100. inference_ms = (end_inf - start_inf) * 1000
  101. detections = []
  102. for i, box in enumerate(results[0].boxes):
  103. class_id = int(box.cls)
  104. class_name = self.class_names.get(class_id, "Unknown")
  105. detections.append({
  106. "bunch_id": i + 1,
  107. "class": class_name,
  108. "confidence": round(float(box.conf), 2),
  109. "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
  110. "box": box.xyxy.tolist()[0]
  111. })
  112. # Extract snippet from results (simulating raw output)
  113. raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else []
  114. return detections, raw_snippet, inference_ms
  115. model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
  116. # Global state for the confidence threshold
  117. current_conf = 0.25
  118. # Initialize DDD Components
  119. vision_service = VertexVisionService(
  120. project_id=os.getenv("PROJECT_ID", "your-project-id"),
  121. location=os.getenv("LOCATION", "us-central1")
  122. )
  123. repo = MongoPalmOilRepository(
  124. uri=os.getenv("MONGO_URI"),
  125. db_name=os.getenv("DB_NAME", "palm_oil_db"),
  126. collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
  127. )
  128. repo.ensure_indexes()
  129. analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
  130. analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
  131. search_use_case = SearchSimilarUseCase(vision_service, repo)
  132. @app.get("/get_confidence")
  133. # ... (rest of the code remains same until analyze)
  134. async def get_confidence():
  135. """Returns the current confidence threshold used by the model."""
  136. return {
  137. "status": "success",
  138. "current_confidence": current_conf,
  139. "model_version": "best.pt"
  140. }
  141. @app.post("/set_confidence")
  142. async def set_confidence(threshold: float = Body(..., embed=True)):
  143. """Updates the confidence threshold globally."""
  144. global current_conf
  145. if 0.0 <= threshold <= 1.0:
  146. current_conf = threshold
  147. return {"status": "success", "new_confidence": current_conf}
  148. else:
  149. return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
  150. @app.post("/analyze")
  151. async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
  152. """Industry-grade analysis with health metrics and summary."""
  153. image_bytes = await file.read()
  154. img = Image.open(io.BytesIO(image_bytes))
  155. import time
  156. start_total = time.perf_counter()
  157. # Select Inference Engine
  158. if model_type == "pytorch":
  159. detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
  160. else:
  161. detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
  162. end_total = time.perf_counter()
  163. total_ms = (end_total - start_total) * 1000
  164. processing_ms = total_ms - inference_ms
  165. # Initialize summary
  166. summary = {name: 0 for name in model_manager.class_names.values()}
  167. for det in detections:
  168. summary[det['class']] += 1
  169. # AUTO-ARCHIVE to Local History Vault
  170. unique_id = uuid.uuid4().hex
  171. archive_filename = f"{unique_id}_{file.filename}"
  172. archive_path = os.path.join(ARCHIVE_DIR, archive_filename)
  173. # Save image copy
  174. with open(archive_path, "wb") as buffer:
  175. buffer.write(image_bytes)
  176. # Save to SQLite
  177. conn = sqlite3.connect(DB_PATH)
  178. cursor = conn.cursor()
  179. cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
  180. (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, processing_ms, json.dumps(raw_sample)))
  181. conn.commit()
  182. conn.close()
  183. return {
  184. "status": "success",
  185. "current_threshold": current_conf,
  186. "total_count": len(detections),
  187. "industrial_summary": summary,
  188. "detections": detections,
  189. "inference_ms": inference_ms,
  190. "processing_ms": processing_ms,
  191. "raw_array_sample": raw_sample,
  192. "archive_id": unique_id
  193. }
  194. @app.post("/vectorize_and_store")
  195. async def vectorize_and_store(file: UploadFile = File(...), detection_data: str = Form(...)):
  196. """Cloud-dependent. Requires active billing."""
  197. import json
  198. try:
  199. primary_detection = json.loads(detection_data)
  200. except Exception:
  201. return {"status": "error", "message": "Invalid detection_data format"}
  202. unique_id = uuid.uuid4().hex
  203. temp_path = f"temp_vec_{unique_id}_{file.filename}"
  204. # Reset file pointer since it might have been read (though here it's a new request)
  205. # Actually, in a new request, we read it for the first time.
  206. with open(temp_path, "wb") as buffer:
  207. shutil.copyfileobj(file.file, buffer)
  208. try:
  209. record_id = analyze_use_case.execute(temp_path, primary_detection)
  210. return {
  211. "status": "success",
  212. "record_id": record_id,
  213. "message": "FFB vectorized and archived successfully"
  214. }
  215. except RuntimeError as e:
  216. return {"status": "error", "message": str(e)}
  217. except Exception as e:
  218. return {"status": "error", "message": f"An unexpected error occurred: {str(e)}"}
  219. finally:
  220. if os.path.exists(temp_path):
  221. os.remove(temp_path)
  222. @app.post("/process_batch")
  223. async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
  224. """Handles multiple images: Detect -> Vectorize -> Store."""
  225. batch_results = []
  226. temp_files = []
  227. try:
  228. for file in files:
  229. # 1. Save Temp
  230. unique_id = uuid.uuid4().hex
  231. path = f"temp_batch_{unique_id}_{file.filename}"
  232. with open(path, "wb") as f_out:
  233. shutil.copyfileobj(file.file, f_out)
  234. temp_files.append(path)
  235. import time
  236. # 2. Detect
  237. img = Image.open(path)
  238. # FORCE PYTORCH for Batch
  239. start_total = time.perf_counter()
  240. detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
  241. end_total = time.perf_counter()
  242. total_ms = (end_total - start_total) * 1000
  243. processing_ms = total_ms - inference_ms
  244. # 3. Process all detections in the image
  245. for det in detections:
  246. batch_results.append({
  247. "path": path,
  248. "yolo": det,
  249. "inference_ms": inference_ms,
  250. "raw_array_sample": raw_sample
  251. })
  252. if not batch_results:
  253. return {"status": "no_detection", "message": "No bunches detected in batch"}
  254. # Calculate Total Industrial Summary for the Batch
  255. total_summary = {name: 0 for name in class_names.values()}
  256. for item in batch_results:
  257. total_summary[item['yolo']['class']] += 1
  258. # 4. Process Batch Use Case with error handling for cloud services
  259. detailed_detections = []
  260. for item in batch_results:
  261. detailed_detections.append({
  262. "filename": os.path.basename(item['path']),
  263. "detection": item['yolo'],
  264. "inference_ms": item['inference_ms'],
  265. "raw_array_sample": item['raw_array_sample']
  266. })
  267. try:
  268. record_ids = analyze_batch_use_case.execute(batch_results)
  269. total_records = len(record_ids)
  270. return {
  271. "status": "success",
  272. "processed_count": total_records,
  273. "total_count": sum(total_summary.values()),
  274. "record_ids": record_ids,
  275. "industrial_summary": total_summary,
  276. "detailed_results": detailed_detections,
  277. "message": f"Successfully processed {total_records} images and identified {sum(total_summary.values())} bunches"
  278. }
  279. except RuntimeError as e:
  280. return {
  281. "status": "partial_success",
  282. "message": f"Detections completed, but cloud archival failed: {str(e)}",
  283. "detections_count": len(batch_results),
  284. "detailed_results": detailed_detections
  285. }
  286. except Exception as e:
  287. return {"status": "error", "message": f"Batch processing failed: {str(e)}"}
  288. finally:
  289. # 5. Clean up all temp files
  290. for path in temp_files:
  291. if os.path.exists(path):
  292. os.remove(path)
  293. @app.post("/search_hybrid")
  294. async def search_hybrid(
  295. file: Optional[UploadFile] = File(None),
  296. text_query: Optional[str] = Form(None),
  297. limit: int = Form(3)
  298. ):
  299. """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
  300. temp_path = None
  301. try:
  302. try:
  303. if file:
  304. unique_id = uuid.uuid4().hex
  305. temp_path = f"temp_search_{unique_id}_{file.filename}"
  306. with open(temp_path, "wb") as buffer:
  307. shutil.copyfileobj(file.file, buffer)
  308. results = search_use_case.execute(image_path=temp_path, limit=limit)
  309. elif text_query:
  310. results = search_use_case.execute(text_query=text_query, limit=limit)
  311. else:
  312. return {"status": "error", "message": "No search input provided"}
  313. return {"status": "success", "results": results}
  314. except RuntimeError as e:
  315. return {"status": "error", "message": f"Search unavailable: {str(e)}"}
  316. finally:
  317. if temp_path and os.path.exists(temp_path):
  318. os.remove(temp_path)
  319. @app.get("/get_image/{record_id}")
  320. async def get_image(record_id: str):
  321. """Retrieve the Base64 image data for a specific record."""
  322. record = repo.get_by_id(record_id)
  323. if not record:
  324. return {"status": "error", "message": "Record not found"}
  325. return {
  326. "status": "success",
  327. "image_data": record.get("image_data")
  328. }
  329. @app.post("/save_to_history")
  330. async def save_to_history(file: UploadFile = File(...), detections: str = Form(...), summary: str = Form(...)):
  331. unique_id = uuid.uuid4().hex
  332. filename = f"{unique_id}_{file.filename}"
  333. archive_path = os.path.join(ARCHIVE_DIR, filename)
  334. with open(archive_path, "wb") as buffer:
  335. shutil.copyfileobj(file.file, buffer)
  336. conn = sqlite3.connect(DB_PATH)
  337. cursor = conn.cursor()
  338. cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
  339. (file.filename, archive_path, detections, summary, 0.0, 0.0, ""))
  340. conn.commit()
  341. conn.close()
  342. return {"status": "success", "message": "Saved to local vault"}
  343. @app.get("/get_history")
  344. async def get_history():
  345. conn = sqlite3.connect(DB_PATH)
  346. conn.row_factory = sqlite3.Row
  347. cursor = conn.cursor()
  348. cursor.execute("SELECT * FROM history ORDER BY timestamp DESC")
  349. rows = [dict(row) for row in cursor.fetchall()]
  350. conn.close()
  351. return {"status": "success", "history": rows}
  352. if __name__ == "__main__":
  353. import uvicorn
  354. uvicorn.run(app, host="0.0.0.0", port=8000)