Pārlūkot izejas kodu

feat: Add a new FastAPI service for palm oil ripeness analysis, integrating ONNX/YOLO models, local history archiving, and cloud-based vectorization and storage.

Dr-Swopt 3 dienas atpakaļ
vecāks
revīzija
26730944ea
2 mainītis faili ar 93 papildinājumiem un 68 dzēšanām
  1. 20 4
      demo_app.py
  2. 73 64
      src/api/main.py

+ 20 - 4
demo_app.py

@@ -67,6 +67,20 @@ st.sidebar.slider(
     on_change=update_confidence
 )
 
+st.sidebar.markdown("---")
+st.sidebar.subheader("Inference Engine")
+engine_choice = st.sidebar.selectbox(
+    "Select Model Engine",
+    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)"],
+    index=0,
+    help="ONNX is optimized for latency. PyTorch provides native object handling."
+)
+model_type = "onnx" if "ONNX" in engine_choice else "pytorch"
+if model_type == "pytorch":
+    st.sidebar.warning("PyTorch Engine: Higher Memory Usage")
+else:
+    st.sidebar.info("ONNX Engine: ~39ms Latency")
+
 # Helper to reset results when files change
 def reset_single_results():
     st.session_state.last_detection = None
@@ -291,9 +305,10 @@ with tab1:
 
         # 1. Auto-Detection Trigger
         if uploaded_file and st.session_state.last_detection is None:
-            with st.spinner("Processing Detections Locally..."):
+            with st.spinner(f"Processing with {model_type.upper()} Engine..."):
                 files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
-                res = requests.post(f"{API_BASE_URL}/analyze", files=files)
+                payload = {"model_type": model_type}
+                res = requests.post(f"{API_BASE_URL}/analyze", files=files, data=payload)
                 if res.status_code == 200:
                     st.session_state.last_detection = res.json()
                     st.rerun() # Refresh to show results immediately
@@ -530,9 +545,10 @@ with tab2:
         st.write("##") # Alignment
         if st.session_state.last_batch_results is None and uploaded_files:
             if st.button("🔍 Process Batch", type="primary", width='stretch'):
-                with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
+                with st.spinner(f"Analyzing {len(uploaded_files)} images with {model_type.upper()}..."):
                     files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
-                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
+                    payload = {"model_type": model_type}
+                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files, data=payload)
                     
                     if res.status_code == 200:
                         data = res.json()

+ 73 - 64
src/api/main.py

@@ -4,6 +4,7 @@ import os
 import shutil
 from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
 import onnxruntime as ort
+from ultralytics import YOLO
 import numpy as np
 
 from dotenv import load_dotenv
@@ -46,64 +47,67 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.jso
 
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 
-# Initialize ONNX model
-onnx_path = 'best.onnx'
-ort_session = ort.InferenceSession(onnx_path)
-input_name = ort_session.get_inputs()[0].name
-class_names = {
-    0: 'Empty_Bunch',
-    1: 'Underripe',
-    2: 'Abnormal',
-    3: 'Ripe',
-    4: 'Unripe',
-    5: 'Overripe'
-}
-
-def preprocess(img: Image.Image):
-    """Preprocess image for YOLO ONNX input [1, 3, 640, 640]."""
-    img = img.convert("RGB")
-    # Store original size for scaling
-    orig_w, orig_h = img.size
-    img_resized = img.resize((640, 640))
-    img_array = np.array(img_resized) / 255.0
-    img_array = img_array.transpose(2, 0, 1) # HWC to CHW
-    img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
-    return img_array, orig_w, orig_h
-
-def run_inference(img: Image.Image, conf_threshold: float):
-    """Run ONNX inference and return list of detections."""
-    img_array, orig_w, orig_h = preprocess(img)
-    outputs = ort_session.run(None, {input_name: img_array})
-    # Output shape: [1, 300, 6] -> [x1, y1, x2, y2, conf, class_id]
-    detections_batch = outputs[0]
-    
-    scale_w = orig_w / 640.0
-    scale_h = orig_h / 640.0
-    
-    detections = []
-    valid_count = 0
-    for i in range(detections_batch.shape[1]):
-        det = detections_batch[0, i]
-        conf = float(det[4])
-        if conf >= conf_threshold:
-            valid_count += 1
-            x1, y1, x2, y2 = det[:4]
-            # Rescale
-            x1 *= scale_w
-            y1 *= scale_h
-            x2 *= scale_w
-            y2 *= scale_h
-            class_id = int(det[5])
-            class_name = class_names.get(class_id, "Unknown")
-            
+class ModelManager:
+    def __init__(self, onnx_path: str, pt_path: str):
+        self.onnx_session = ort.InferenceSession(onnx_path)
+        self.onnx_input_name = self.onnx_session.get_inputs()[0].name
+        self.pt_model = YOLO(pt_path)
+        self.class_names = self.pt_model.names
+
+    def preprocess_onnx(self, img: Image.Image):
+        img = img.convert("RGB")
+        orig_w, orig_h = img.size
+        img_resized = img.resize((640, 640))
+        img_array = np.array(img_resized) / 255.0
+        img_array = img_array.transpose(2, 0, 1)
+        img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
+        return img_array, orig_w, orig_h
+
+    def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
+        img_array, orig_w, orig_h = self.preprocess_onnx(img)
+        outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
+        detections_batch = outputs[0]
+        
+        scale_w = orig_w / 640.0
+        scale_h = orig_h / 640.0
+        
+        detections = []
+        valid_count = 0
+        for i in range(detections_batch.shape[1]):
+            det = detections_batch[0, i]
+            conf = float(det[4])
+            if conf >= conf_threshold:
+                valid_count += 1
+                x1, y1, x2, y2 = det[:4]
+                x1 *= scale_w; y1 *= scale_h; x2 *= scale_w; y2 *= scale_h
+                class_id = int(det[5])
+                class_name = self.class_names.get(class_id, "Unknown")
+                
+                detections.append({
+                    "bunch_id": valid_count,
+                    "class": class_name,
+                    "confidence": round(conf, 2),
+                    "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
+                    "box": [float(x1), float(y1), float(x2), float(y2)]
+                })
+        return detections
+
+    def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
+        results = self.pt_model(img, conf=conf_threshold, verbose=False)
+        detections = []
+        for i, box in enumerate(results[0].boxes):
+            class_id = int(box.cls)
+            class_name = self.class_names.get(class_id, "Unknown")
             detections.append({
-                "bunch_id": valid_count,
+                "bunch_id": i + 1,
                 "class": class_name,
-                "confidence": round(conf, 2),
+                "confidence": round(float(box.conf), 2),
                 "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
-                "box": [float(x1), float(y1), float(x2), float(y2)]
+                "box": box.xyxy.tolist()[0]
             })
-    return detections
+        return detections
+
+model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
 
 
 # Global state for the confidence threshold
@@ -146,17 +150,19 @@ async def set_confidence(threshold: float = Body(..., embed=True)):
 
 
 @app.post("/analyze")
-async def analyze_with_health_metrics(file: UploadFile = File(...)):
+async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
     """Industry-grade analysis with health metrics and summary."""
     image_bytes = await file.read()
     img = Image.open(io.BytesIO(image_bytes))
     
-    # Run ONNX inference (natively NMS-free)
-    detections = run_inference(img, current_conf)
-    
-    # Initialize summary for all known classes
-    summary = {name: 0 for name in class_names.values()}
+    # Select Inference Engine
+    if model_type == "pytorch":
+        detections = model_manager.run_pytorch_inference(img, current_conf)
+    else:
+        detections = model_manager.run_onnx_inference(img, current_conf)
     
+    # Initialize summary
+    summary = {name: 0 for name in model_manager.class_names.values()}
     for det in detections:
         summary[det['class']] += 1
     
@@ -220,8 +226,8 @@ async def vectorize_and_store(file: UploadFile = File(...), detection_data: str
             os.remove(temp_path)
 
 @app.post("/process_batch")
-async def process_batch(files: List[UploadFile] = File(...)):
-    """Handles multiple images: Detect -> Vectorize -> Store. Graceful handling of cloud errors."""
+async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
+    """Handles multiple images: Detect -> Vectorize -> Store."""
     batch_results = []
     temp_files = []
 
@@ -234,9 +240,12 @@ async def process_batch(files: List[UploadFile] = File(...)):
                 shutil.copyfileobj(file.file, f_out)
             temp_files.append(path)
 
-            # 2. ONNX Detect (natively NMS-free)
+            # 2. Detect
             img = Image.open(path)
-            detections = run_inference(img, current_conf)
+            if model_type == "pytorch":
+                detections = model_manager.run_pytorch_inference(img, current_conf)
+            else:
+                detections = model_manager.run_onnx_inference(img, current_conf)
             
             # 3. Process all detections in the image
             for det in detections: