|
|
@@ -4,6 +4,7 @@ import os
|
|
|
import shutil
|
|
|
from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
|
|
|
import onnxruntime as ort
|
|
|
+from ultralytics import YOLO
|
|
|
import numpy as np
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
@@ -46,64 +47,67 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.jso
|
|
|
|
|
|
app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
|
|
|
|
|
|
-# Initialize ONNX model
|
|
|
-onnx_path = 'best.onnx'
|
|
|
-ort_session = ort.InferenceSession(onnx_path)
|
|
|
-input_name = ort_session.get_inputs()[0].name
|
|
|
-class_names = {
|
|
|
- 0: 'Empty_Bunch',
|
|
|
- 1: 'Underripe',
|
|
|
- 2: 'Abnormal',
|
|
|
- 3: 'Ripe',
|
|
|
- 4: 'Unripe',
|
|
|
- 5: 'Overripe'
|
|
|
-}
|
|
|
-
|
|
|
-def preprocess(img: Image.Image):
|
|
|
- """Preprocess image for YOLO ONNX input [1, 3, 640, 640]."""
|
|
|
- img = img.convert("RGB")
|
|
|
- # Store original size for scaling
|
|
|
- orig_w, orig_h = img.size
|
|
|
- img_resized = img.resize((640, 640))
|
|
|
- img_array = np.array(img_resized) / 255.0
|
|
|
- img_array = img_array.transpose(2, 0, 1) # HWC to CHW
|
|
|
- img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
|
|
|
- return img_array, orig_w, orig_h
|
|
|
-
|
|
|
-def run_inference(img: Image.Image, conf_threshold: float):
|
|
|
- """Run ONNX inference and return list of detections."""
|
|
|
- img_array, orig_w, orig_h = preprocess(img)
|
|
|
- outputs = ort_session.run(None, {input_name: img_array})
|
|
|
- # Output shape: [1, 300, 6] -> [x1, y1, x2, y2, conf, class_id]
|
|
|
- detections_batch = outputs[0]
|
|
|
-
|
|
|
- scale_w = orig_w / 640.0
|
|
|
- scale_h = orig_h / 640.0
|
|
|
-
|
|
|
- detections = []
|
|
|
- valid_count = 0
|
|
|
- for i in range(detections_batch.shape[1]):
|
|
|
- det = detections_batch[0, i]
|
|
|
- conf = float(det[4])
|
|
|
- if conf >= conf_threshold:
|
|
|
- valid_count += 1
|
|
|
- x1, y1, x2, y2 = det[:4]
|
|
|
- # Rescale
|
|
|
- x1 *= scale_w
|
|
|
- y1 *= scale_h
|
|
|
- x2 *= scale_w
|
|
|
- y2 *= scale_h
|
|
|
- class_id = int(det[5])
|
|
|
- class_name = class_names.get(class_id, "Unknown")
|
|
|
-
|
|
|
+class ModelManager:
|
|
|
+ def __init__(self, onnx_path: str, pt_path: str):
|
|
|
+ self.onnx_session = ort.InferenceSession(onnx_path)
|
|
|
+ self.onnx_input_name = self.onnx_session.get_inputs()[0].name
|
|
|
+ self.pt_model = YOLO(pt_path)
|
|
|
+ self.class_names = self.pt_model.names
|
|
|
+
|
|
|
+ def preprocess_onnx(self, img: Image.Image):
|
|
|
+ img = img.convert("RGB")
|
|
|
+ orig_w, orig_h = img.size
|
|
|
+ img_resized = img.resize((640, 640))
|
|
|
+ img_array = np.array(img_resized) / 255.0
|
|
|
+ img_array = img_array.transpose(2, 0, 1)
|
|
|
+ img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
|
|
|
+ return img_array, orig_w, orig_h
|
|
|
+
|
|
|
+ def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
|
|
|
+ img_array, orig_w, orig_h = self.preprocess_onnx(img)
|
|
|
+ outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
|
|
|
+ detections_batch = outputs[0]
|
|
|
+
|
|
|
+ scale_w = orig_w / 640.0
|
|
|
+ scale_h = orig_h / 640.0
|
|
|
+
|
|
|
+ detections = []
|
|
|
+ valid_count = 0
|
|
|
+ for i in range(detections_batch.shape[1]):
|
|
|
+ det = detections_batch[0, i]
|
|
|
+ conf = float(det[4])
|
|
|
+ if conf >= conf_threshold:
|
|
|
+ valid_count += 1
|
|
|
+ x1, y1, x2, y2 = det[:4]
|
|
|
+ x1 *= scale_w; y1 *= scale_h; x2 *= scale_w; y2 *= scale_h
|
|
|
+ class_id = int(det[5])
|
|
|
+ class_name = self.class_names.get(class_id, "Unknown")
|
|
|
+
|
|
|
+ detections.append({
|
|
|
+ "bunch_id": valid_count,
|
|
|
+ "class": class_name,
|
|
|
+ "confidence": round(conf, 2),
|
|
|
+ "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
|
|
|
+ "box": [float(x1), float(y1), float(x2), float(y2)]
|
|
|
+ })
|
|
|
+ return detections
|
|
|
+
|
|
|
+ def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
|
|
|
+ results = self.pt_model(img, conf=conf_threshold, verbose=False)
|
|
|
+ detections = []
|
|
|
+ for i, box in enumerate(results[0].boxes):
|
|
|
+ class_id = int(box.cls)
|
|
|
+ class_name = self.class_names.get(class_id, "Unknown")
|
|
|
detections.append({
|
|
|
- "bunch_id": valid_count,
|
|
|
+ "bunch_id": i + 1,
|
|
|
"class": class_name,
|
|
|
- "confidence": round(conf, 2),
|
|
|
+ "confidence": round(float(box.conf), 2),
|
|
|
"is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
|
|
|
- "box": [float(x1), float(y1), float(x2), float(y2)]
|
|
|
+ "box": box.xyxy.tolist()[0]
|
|
|
})
|
|
|
- return detections
|
|
|
+ return detections
|
|
|
+
|
|
|
+model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
|
|
|
|
|
|
|
|
|
# Global state for the confidence threshold
|
|
|
@@ -146,17 +150,19 @@ async def set_confidence(threshold: float = Body(..., embed=True)):
|
|
|
|
|
|
|
|
|
@app.post("/analyze")
|
|
|
-async def analyze_with_health_metrics(file: UploadFile = File(...)):
|
|
|
+async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
|
|
|
"""Industry-grade analysis with health metrics and summary."""
|
|
|
image_bytes = await file.read()
|
|
|
img = Image.open(io.BytesIO(image_bytes))
|
|
|
|
|
|
- # Run ONNX inference (natively NMS-free)
|
|
|
- detections = run_inference(img, current_conf)
|
|
|
-
|
|
|
- # Initialize summary for all known classes
|
|
|
- summary = {name: 0 for name in class_names.values()}
|
|
|
+ # Select Inference Engine
|
|
|
+ if model_type == "pytorch":
|
|
|
+ detections = model_manager.run_pytorch_inference(img, current_conf)
|
|
|
+ else:
|
|
|
+ detections = model_manager.run_onnx_inference(img, current_conf)
|
|
|
|
|
|
+ # Initialize summary
|
|
|
+ summary = {name: 0 for name in model_manager.class_names.values()}
|
|
|
for det in detections:
|
|
|
summary[det['class']] += 1
|
|
|
|
|
|
@@ -220,8 +226,8 @@ async def vectorize_and_store(file: UploadFile = File(...), detection_data: str
|
|
|
os.remove(temp_path)
|
|
|
|
|
|
@app.post("/process_batch")
|
|
|
-async def process_batch(files: List[UploadFile] = File(...)):
|
|
|
- """Handles multiple images: Detect -> Vectorize -> Store. Graceful handling of cloud errors."""
|
|
|
+async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
|
|
|
+ """Handles multiple images: Detect -> Vectorize -> Store."""
|
|
|
batch_results = []
|
|
|
temp_files = []
|
|
|
|
|
|
@@ -234,9 +240,12 @@ async def process_batch(files: List[UploadFile] = File(...)):
|
|
|
shutil.copyfileobj(file.file, f_out)
|
|
|
temp_files.append(path)
|
|
|
|
|
|
- # 2. ONNX Detect (natively NMS-free)
|
|
|
+ # 2. Detect
|
|
|
img = Image.open(path)
|
|
|
- detections = run_inference(img, current_conf)
|
|
|
+ if model_type == "pytorch":
|
|
|
+ detections = model_manager.run_pytorch_inference(img, current_conf)
|
|
|
+ else:
|
|
|
+ detections = model_manager.run_onnx_inference(img, current_conf)
|
|
|
|
|
|
# 3. Process all detections in the image
|
|
|
for det in detections:
|