소스 검색

feat: Add a FastAPI backend for model inference and management, update the demo app to use a YOLOv8-Sawit benchmark model and display its capabilities via the new API.

Dr-Swopt 2 일 전
부모
커밋
793eb2495e
3개의 변경된 파일44개의 추가작업 그리고 10개의 파일을 삭제
  1. 19 4
      demo_app.py
  2. 23 4
      src/api/main.py
  3. 2 2
      test_benchmark.py

+ 19 - 4
demo_app.py

@@ -186,7 +186,7 @@ st.sidebar.markdown("---")
 # Inference Engine
 engine_choice = st.sidebar.selectbox(
     "Select Model Engine:",
-    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)", "Sawit-TBS (Benchmark)"],
+    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)", "YOLOv8-Sawit (Benchmark)"],
     index=0,
     on_change=reset_all_analysis # Clear canvas on engine switch
 )
@@ -195,7 +195,7 @@ engine_choice = st.sidebar.selectbox(
 engine_map = {
     "YOLO26 (ONNX - High Speed)": "onnx",
     "YOLO26 (PyTorch - Native)": "pytorch",
-    "Sawit-TBS (Benchmark)": "benchmark"
+    "YOLOv8-Sawit (Benchmark)": "yolov8_sawit"
 }
 
 st.sidebar.markdown("---")
@@ -205,6 +205,21 @@ model_type = engine_map[engine_choice]
 if st.sidebar.button("❓ How to read results?", icon="📘", width='stretch'):
     show_tech_guide()
 
+st.sidebar.markdown("---")
+st.sidebar.subheader("🏗️ Model Capabilities")
+try:
+    info_res = requests.get(f"{API_BASE_URL}/get_model_info", params={"model_type": model_type})
+    if info_res.status_code == 200:
+        m_info = info_res.json()
+        st.sidebar.caption(m_info['description'])
+        st.sidebar.write("**Detected Categories:**")
+        # Display as a cloud of tags or bullets
+        cols = st.sidebar.columns(2)
+        for i, cat in enumerate(m_info['detections_categories']):
+            cols[i % 2].markdown(f"- `{cat}`")
+except:
+    st.sidebar.error("Failed to load model metadata.")
+
 # Function definitions moved to top
 
 def display_interactive_results(image, detections, key=None):
@@ -229,7 +244,7 @@ def display_interactive_results(image, detections, key=None):
         y_top, y_bottom = img_height - y1, img_height - y2
         
         color = get_color(det['class'])
-        is_bench = (st.session_state.get('engine_choice') == "Sawit-TBS (Benchmark)")
+        is_bench = (st.session_state.get('engine_choice') == "YOLOv8-Sawit (Benchmark)")
 
         # The 'Hover' shape
         bunch_id = det.get('bunch_id', i+1)
@@ -275,7 +290,7 @@ def annotate_image(image, detections):
         conf = det['confidence']
         bunch_id = det.get('bunch_id', '?')
         color = get_color(cls)
-        is_bench = (st.session_state.get('engine_choice') == "Sawit-TBS (Benchmark)")
+        is_bench = (st.session_state.get('engine_choice') == "YOLOv8-Sawit (Benchmark)")
         
         # 2. Draw Heavy-Duty Bounding Box
         line_width = max(6 if is_bench else 4, image.width // (80 if is_bench else 150))

+ 23 - 4
src/api/main.py

@@ -124,7 +124,7 @@ class ModelManager:
         import time
         start_inf = time.perf_counter()
         
-        # Selection Logic for Third Engine
+        # Selection Logic for Third Engine (YOLOv8-Sawit)
         model = self.pt_model if engine_type == "pytorch" else self.benchmark_model
         names = self.class_names if engine_type == "pytorch" else self.benchmark_class_names
         
@@ -211,8 +211,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     # Select Inference Engine
     if model_type == "pytorch":
         detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
-    elif model_type == "benchmark":
-        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "benchmark")
+    elif model_type == "yolov8_sawit":
+        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "yolov8_sawit")
     else:
         detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
     
@@ -221,7 +221,7 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     processing_ms = total_ms - inference_ms
     
     # Initialize summary
-    active_names = model_manager.class_names if model_type != "benchmark" else model_manager.benchmark_class_names
+    active_names = model_manager.class_names if model_type != "yolov8_sawit" else model_manager.benchmark_class_names
     summary = {name: 0 for name in active_names.values()}
     for det in detections:
         summary[det['class']] += 1
@@ -450,6 +450,25 @@ async def get_history():
     conn.close()
     return {"status": "success", "history": rows}
 
+@app.get("/get_model_info")
+async def get_model_info(model_type: str = "onnx"):
+    """Returns metadata and capabilities for the specified model engine."""
+    if model_type in ["onnx", "pytorch"]:
+        classes = list(model_manager.class_names.values())
+        description = "Standard YOLO26 Industrial Model."
+    elif model_type == "yolov8_sawit":
+        classes = list(model_manager.benchmark_class_names.values())
+        description = "YOLOv8-Sawit (Benchmark) - External Architecture."
+    else:
+        return {"status": "error", "message": "Unknown model type"}
+        
+    return {
+        "status": "success",
+        "model_type": model_type,
+        "description": description,
+        "detections_categories": classes
+    }
+
 if __name__ == "__main__":
     import uvicorn
     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 2 - 2
test_benchmark.py

@@ -21,8 +21,8 @@ def test_inference():
     detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="pytorch")
     print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
     
-    print("\nTesting Benchmark inference (Sawit-TBS)...")
-    detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="benchmark")
+    print("\nTesting Benchmark inference (YOLOv8-Sawit)...")
+    detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="yolov8_sawit")
     print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
     print(f"Benchmark Class Names: {manager.benchmark_class_names}")