Просмотр исходного кода

feat: Add a new FastAPI service for palm oil ripeness analysis, integrating ONNX/YOLO models, local history archiving, and cloud-based vectorization and storage.

Dr-Swopt 3 дней назад
Родитель
Сommit
26730944ea
2 измененных файлов с 93 добавлено и 68 удалено
  1. 20 4
      demo_app.py
  2. 73 64
      src/api/main.py

+ 20 - 4
demo_app.py

@@ -67,6 +67,20 @@ st.sidebar.slider(
     on_change=update_confidence
     on_change=update_confidence
 )
 )
 
 
+st.sidebar.markdown("---")
+st.sidebar.subheader("Inference Engine")
+engine_choice = st.sidebar.selectbox(
+    "Select Model Engine",
+    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)"],
+    index=0,
+    help="ONNX is optimized for latency. PyTorch provides native object handling."
+)
+model_type = "onnx" if "ONNX" in engine_choice else "pytorch"
+if model_type == "pytorch":
+    st.sidebar.warning("PyTorch Engine: Higher Memory Usage")
+else:
+    st.sidebar.info("ONNX Engine: ~39ms Latency")
+
 # Helper to reset results when files change
 # Helper to reset results when files change
 def reset_single_results():
 def reset_single_results():
     st.session_state.last_detection = None
     st.session_state.last_detection = None
@@ -291,9 +305,10 @@ with tab1:
 
 
         # 1. Auto-Detection Trigger
         # 1. Auto-Detection Trigger
         if uploaded_file and st.session_state.last_detection is None:
         if uploaded_file and st.session_state.last_detection is None:
-            with st.spinner("Processing Detections Locally..."):
+            with st.spinner(f"Processing with {model_type.upper()} Engine..."):
                 files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
                 files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
-                res = requests.post(f"{API_BASE_URL}/analyze", files=files)
+                payload = {"model_type": model_type}
+                res = requests.post(f"{API_BASE_URL}/analyze", files=files, data=payload)
                 if res.status_code == 200:
                 if res.status_code == 200:
                     st.session_state.last_detection = res.json()
                     st.session_state.last_detection = res.json()
                     st.rerun() # Refresh to show results immediately
                     st.rerun() # Refresh to show results immediately
@@ -530,9 +545,10 @@ with tab2:
         st.write("##") # Alignment
         st.write("##") # Alignment
         if st.session_state.last_batch_results is None and uploaded_files:
         if st.session_state.last_batch_results is None and uploaded_files:
             if st.button("🔍 Process Batch", type="primary", width='stretch'):
             if st.button("🔍 Process Batch", type="primary", width='stretch'):
-                with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
+                with st.spinner(f"Analyzing {len(uploaded_files)} images with {model_type.upper()}..."):
                     files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
                     files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
-                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
+                    payload = {"model_type": model_type}
+                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files, data=payload)
                     
                     
                     if res.status_code == 200:
                     if res.status_code == 200:
                         data = res.json()
                         data = res.json()

+ 73 - 64
src/api/main.py

@@ -4,6 +4,7 @@ import os
 import shutil
 import shutil
 from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
 from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
 import onnxruntime as ort
 import onnxruntime as ort
+from ultralytics import YOLO
 import numpy as np
 import numpy as np
 
 
 from dotenv import load_dotenv
 from dotenv import load_dotenv
@@ -46,64 +47,67 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.jso
 
 
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 
 
-# Initialize ONNX model
-onnx_path = 'best.onnx'
-ort_session = ort.InferenceSession(onnx_path)
-input_name = ort_session.get_inputs()[0].name
-class_names = {
-    0: 'Empty_Bunch',
-    1: 'Underripe',
-    2: 'Abnormal',
-    3: 'Ripe',
-    4: 'Unripe',
-    5: 'Overripe'
-}
-
-def preprocess(img: Image.Image):
-    """Preprocess image for YOLO ONNX input [1, 3, 640, 640]."""
-    img = img.convert("RGB")
-    # Store original size for scaling
-    orig_w, orig_h = img.size
-    img_resized = img.resize((640, 640))
-    img_array = np.array(img_resized) / 255.0
-    img_array = img_array.transpose(2, 0, 1) # HWC to CHW
-    img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
-    return img_array, orig_w, orig_h
-
-def run_inference(img: Image.Image, conf_threshold: float):
-    """Run ONNX inference and return list of detections."""
-    img_array, orig_w, orig_h = preprocess(img)
-    outputs = ort_session.run(None, {input_name: img_array})
-    # Output shape: [1, 300, 6] -> [x1, y1, x2, y2, conf, class_id]
-    detections_batch = outputs[0]
-    
-    scale_w = orig_w / 640.0
-    scale_h = orig_h / 640.0
-    
-    detections = []
-    valid_count = 0
-    for i in range(detections_batch.shape[1]):
-        det = detections_batch[0, i]
-        conf = float(det[4])
-        if conf >= conf_threshold:
-            valid_count += 1
-            x1, y1, x2, y2 = det[:4]
-            # Rescale
-            x1 *= scale_w
-            y1 *= scale_h
-            x2 *= scale_w
-            y2 *= scale_h
-            class_id = int(det[5])
-            class_name = class_names.get(class_id, "Unknown")
-            
+class ModelManager:
+    def __init__(self, onnx_path: str, pt_path: str):
+        self.onnx_session = ort.InferenceSession(onnx_path)
+        self.onnx_input_name = self.onnx_session.get_inputs()[0].name
+        self.pt_model = YOLO(pt_path)
+        self.class_names = self.pt_model.names
+
+    def preprocess_onnx(self, img: Image.Image):
+        img = img.convert("RGB")
+        orig_w, orig_h = img.size
+        img_resized = img.resize((640, 640))
+        img_array = np.array(img_resized) / 255.0
+        img_array = img_array.transpose(2, 0, 1)
+        img_array = img_array.reshape(1, 3, 640, 640).astype(np.float32)
+        return img_array, orig_w, orig_h
+
+    def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
+        img_array, orig_w, orig_h = self.preprocess_onnx(img)
+        outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
+        detections_batch = outputs[0]
+        
+        scale_w = orig_w / 640.0
+        scale_h = orig_h / 640.0
+        
+        detections = []
+        valid_count = 0
+        for i in range(detections_batch.shape[1]):
+            det = detections_batch[0, i]
+            conf = float(det[4])
+            if conf >= conf_threshold:
+                valid_count += 1
+                x1, y1, x2, y2 = det[:4]
+                x1 *= scale_w; y1 *= scale_h; x2 *= scale_w; y2 *= scale_h
+                class_id = int(det[5])
+                class_name = self.class_names.get(class_id, "Unknown")
+                
+                detections.append({
+                    "bunch_id": valid_count,
+                    "class": class_name,
+                    "confidence": round(conf, 2),
+                    "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
+                    "box": [float(x1), float(y1), float(x2), float(y2)]
+                })
+        return detections
+
+    def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
+        results = self.pt_model(img, conf=conf_threshold, verbose=False)
+        detections = []
+        for i, box in enumerate(results[0].boxes):
+            class_id = int(box.cls)
+            class_name = self.class_names.get(class_id, "Unknown")
             detections.append({
             detections.append({
-                "bunch_id": valid_count,
+                "bunch_id": i + 1,
                 "class": class_name,
                 "class": class_name,
-                "confidence": round(conf, 2),
+                "confidence": round(float(box.conf), 2),
                 "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                 "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
-                "box": [float(x1), float(y1), float(x2), float(y2)]
+                "box": box.xyxy.tolist()[0]
             })
             })
-    return detections
+        return detections
+
+model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
 
 
 
 
 # Global state for the confidence threshold
 # Global state for the confidence threshold
@@ -146,17 +150,19 @@ async def set_confidence(threshold: float = Body(..., embed=True)):
 
 
 
 
 @app.post("/analyze")
 @app.post("/analyze")
-async def analyze_with_health_metrics(file: UploadFile = File(...)):
+async def analyze_with_health_metrics(file: UploadFile = File(...), model_type: str = Form("onnx")):
     """Industry-grade analysis with health metrics and summary."""
     """Industry-grade analysis with health metrics and summary."""
     image_bytes = await file.read()
     image_bytes = await file.read()
     img = Image.open(io.BytesIO(image_bytes))
     img = Image.open(io.BytesIO(image_bytes))
     
     
-    # Run ONNX inference (natively NMS-free)
-    detections = run_inference(img, current_conf)
-    
-    # Initialize summary for all known classes
-    summary = {name: 0 for name in class_names.values()}
+    # Select Inference Engine
+    if model_type == "pytorch":
+        detections = model_manager.run_pytorch_inference(img, current_conf)
+    else:
+        detections = model_manager.run_onnx_inference(img, current_conf)
     
     
+    # Initialize summary
+    summary = {name: 0 for name in model_manager.class_names.values()}
     for det in detections:
     for det in detections:
         summary[det['class']] += 1
         summary[det['class']] += 1
     
     
@@ -220,8 +226,8 @@ async def vectorize_and_store(file: UploadFile = File(...), detection_data: str
             os.remove(temp_path)
             os.remove(temp_path)
 
 
 @app.post("/process_batch")
 @app.post("/process_batch")
-async def process_batch(files: List[UploadFile] = File(...)):
-    """Handles multiple images: Detect -> Vectorize -> Store. Graceful handling of cloud errors."""
+async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
+    """Handles multiple images: Detect -> Vectorize -> Store."""
     batch_results = []
     batch_results = []
     temp_files = []
     temp_files = []
 
 
@@ -234,9 +240,12 @@ async def process_batch(files: List[UploadFile] = File(...)):
                 shutil.copyfileobj(file.file, f_out)
                 shutil.copyfileobj(file.file, f_out)
             temp_files.append(path)
             temp_files.append(path)
 
 
-            # 2. ONNX Detect (natively NMS-free)
+            # 2. Detect
             img = Image.open(path)
             img = Image.open(path)
-            detections = run_inference(img, current_conf)
+            if model_type == "pytorch":
+                detections = model_manager.run_pytorch_inference(img, current_conf)
+            else:
+                detections = model_manager.run_onnx_inference(img, current_conf)
             
             
             # 3. Process all detections in the image
             # 3. Process all detections in the image
             for det in detections:
             for det in detections: