Bläddra i källkod

include 3rd party model

Dr-Swopt 2 dagar sedan
förälder
incheckning
4bc1b9ab71
6 ändrade filer med 61 tillägg och 11 borttagningar
  1. BIN
      calibration_image_sample_data_20x128x128x3_float32.npy
  2. 13 5
      demo_app.py
  3. BIN
      palm_history.db
  4. BIN
      sawit_tbs.pt
  5. 16 6
      src/api/main.py
  6. 32 0
      test_benchmark.py

BIN
calibration_image_sample_data_20x128x128x3_float32.npy


+ 13 - 5
demo_app.py

@@ -171,7 +171,7 @@ st.sidebar.markdown("---")
 # Inference Engine
 engine_choice = st.sidebar.selectbox(
     "Select Model Engine:",
-    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)"],
+    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)", "Sawit-TBS (Benchmark)"],
     index=0,
     on_change=reset_all_analysis # Clear canvas on engine switch
 )
@@ -179,7 +179,8 @@ engine_choice = st.sidebar.selectbox(
 # Map selection to internal labels
 engine_map = {
     "YOLO26 (ONNX - High Speed)": "onnx",
-    "YOLO26 (PyTorch - Native)": "pytorch"
+    "YOLO26 (PyTorch - Native)": "pytorch",
+    "Sawit-TBS (Benchmark)": "benchmark"
 }
 
 st.sidebar.markdown("---")
@@ -211,7 +212,7 @@ def display_interactive_results(image, detections, key=None):
         x1, y1, x2, y2 = det['box']
         # Plotly y-axis is inverted relative to PIL, so we flip y
         y_top, y_bottom = img_height - y1, img_height - y2
-        color = overlay_colors.get(det['class'], "#ffeb3b")
+        color = overlay_colors.get(det['class'], "#9ca3af") # Fallback to neutral gray
 
         # The 'Hover' shape
         bunch_id = det.get('bunch_id', i+1)
@@ -256,7 +257,7 @@ def annotate_image(image, detections):
         cls = det['class']
         conf = det['confidence']
         bunch_id = det.get('bunch_id', '?')
-        color = overlay_colors.get(cls, '#ffffff')
+        color = overlay_colors.get(cls, '#9ca3af') # Fallback to neutral gray
         
         # 2. Draw Heavy-Duty Bounding Box
         line_width = max(4, image.width // 150)
@@ -436,7 +437,14 @@ with tab1:
             with m_col1:
                 st.metric("Total Bunches", data.get('total_count', 0))
             with m_col2:
-                st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
+                if model_type == "benchmark":
+                    # For benchmark model, show the top detected class instead of 'Healthy'
+                    top_class = "None"
+                    if data.get('industrial_summary'):
+                        top_class = max(data['industrial_summary'], key=data['industrial_summary'].get)
+                    st.metric("Top Detected Class", top_class)
+                else:
+                    st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
             with m_col3:
                 # Refined speed label based on engine
                 speed_label = "Raw Speed (Unlabeled)" if model_type == "onnx" else "Wrapped Speed (Auto-Labeled)"

BIN
palm_history.db


BIN
sawit_tbs.pt


+ 16 - 6
src/api/main.py

@@ -61,11 +61,13 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.jso
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 
 class ModelManager:
-    def __init__(self, onnx_path: str, pt_path: str):
+    def __init__(self, onnx_path: str, pt_path: str, benchmark_path: str = 'sawit_tbs.pt'):
         self.onnx_session = ort.InferenceSession(onnx_path)
         self.onnx_input_name = self.onnx_session.get_inputs()[0].name
         self.pt_model = YOLO(pt_path)
         self.class_names = self.pt_model.names
+        self.benchmark_model = YOLO(benchmark_path)
+        self.benchmark_class_names = self.benchmark_model.names
 
     def preprocess_onnx(self, img: Image.Image):
         img = img.convert("RGB")
@@ -118,17 +120,22 @@ class ModelManager:
         raw_sample = detections_batch[0, :5].tolist()
         return detections, raw_sample, inference_ms
 
-    def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
+    def run_pytorch_inference(self, img: Image.Image, conf_threshold: float, engine_type: str = "pytorch"):
         import time
         start_inf = time.perf_counter()
-        results = self.pt_model(img, conf=conf_threshold, verbose=False)
+        
+        # Selection Logic for Third Engine
+        model = self.pt_model if engine_type == "pytorch" else self.benchmark_model
+        names = self.class_names if engine_type == "pytorch" else self.benchmark_class_names
+        
+        results = model(img, conf=conf_threshold, verbose=False)
         end_inf = time.perf_counter()
         inference_ms = (end_inf - start_inf) * 1000
 
         detections = []
         for i, box in enumerate(results[0].boxes):
             class_id = int(box.cls)
-            class_name = self.class_names.get(class_id, "Unknown")
+            class_name = names.get(class_id, "Unknown")
             detections.append({
                 "bunch_id": i + 1,
                 "class": class_name,
@@ -203,7 +210,9 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     start_total = time.perf_counter()
     # Select Inference Engine
     if model_type == "pytorch":
-        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
+        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
+    elif model_type == "benchmark":
+        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "benchmark")
     else:
         detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
     
@@ -212,7 +221,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     processing_ms = total_ms - inference_ms
     
     # Initialize summary
-    summary = {name: 0 for name in model_manager.class_names.values()}
+    active_names = model_manager.class_names if model_type != "benchmark" else model_manager.benchmark_class_names
+    summary = {name: 0 for name in active_names.values()}
     for det in detections:
         summary[det['class']] += 1
     

+ 32 - 0
test_benchmark.py

@@ -0,0 +1,32 @@
+import os
+import sys
+from PIL import Image
+import io
+import torch
+
+# Add the project root to sys.path to import src
+sys.path.append(os.getcwd())
+
+from src.api.main import ModelManager
+
+def test_inference():
+    print("Testing ModelManager initialization...")
+    manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt', benchmark_path='sawit_tbs.pt')
+    print("ModelManager initialized successfully.")
+
+    # Create a dummy image for testing
+    img = Image.new('RGB', (640, 640), color = (73, 109, 137))
+    
+    print("\nTesting PyTorch inference (Native)...")
+    detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="pytorch")
+    print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
+    
+    print("\nTesting Benchmark inference (Sawit-TBS)...")
+    detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="benchmark")
+    print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
+    print(f"Benchmark Class Names: {manager.benchmark_class_names}")
+
+    print("\nVerification Complete.")
+
+if __name__ == "__main__":
+    test_inference()