Dr-Swopt 3 dienas atpakaļ
vecāks
revīzija
3adbaf7d70
2 mainītis faili ar 72 papildinājumiem un 37 dzēšanām
  1. 31 21
      demo_app.py
  2. 41 16
      src/api/main.py

+ 31 - 21
demo_app.py

@@ -96,10 +96,14 @@ st.sidebar.markdown("---")
 st.sidebar.subheader("Inference Engine")
 engine_choice = st.sidebar.selectbox(
     "Select Model Engine",
-    ["YOLO26 (ONNX - High Speed)", "YOLO26 (PyTorch - Native)"],
+    ["YOLO26 (PyTorch - Native)", "YOLO26 (ONNX - High Speed)"],
     index=0,
     help="ONNX is optimized for latency. PyTorch provides native object handling."
 )
+st.sidebar.markdown("---")
+st.sidebar.subheader("🛠️ Technical Controls")
+show_trace = st.sidebar.toggle("🔬 Show Technical Trace", value=False, help="Enable to see raw mathematical tensor data alongside AI labels.")
+st.session_state.tech_trace = show_trace
 model_type = "onnx" if "ONNX" in engine_choice else "pytorch"
 if model_type == "pytorch":
     st.sidebar.warning("PyTorch Engine: Higher Memory Usage")
@@ -346,23 +350,9 @@ with tab1:
 
         # 2. Results Layout
         if st.session_state.last_detection:
-            st.divider()
-            
-            # PRIMARY ANNOTATED VIEW
-            st.write("### 🔍 AI Analytical View")
             data = st.session_state.last_detection
-            img = Image.open(uploaded_file).convert("RGB")
-            display_interactive_results(img, data['detections'], key="main_viewer")
-
-            # Visual Legend
-            st.write("#### 🎨 Ripeness Legend")
-            l_cols = st.columns(len(overlay_colors))
-            for i, (grade, color) in enumerate(overlay_colors.items()):
-                with l_cols[i]:
-                    st.markdown(f'<div style="background-color:{color}; padding:10px; border-radius:5px; text-align:center; color:white; font-weight:bold;">{grade}</div>', unsafe_allow_html=True)
-
             st.divider()
-
+            
             st.write("### 📈 Manager's Dashboard")
             m_col1, m_col2, m_col3, m_col4 = st.columns(4)
             with m_col1:
@@ -370,10 +360,30 @@ with tab1:
             with m_col2:
                 st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
             with m_col3:
-                abnormal = data['industrial_summary'].get('Abnormal', 0)
-                st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
+                # Refined speed label based on engine
+                speed_label = "Raw Speed (Unlabeled)" if model_type == "onnx" else "Wrapped Speed (Auto-Labeled)"
+                st.metric("Inference Speed", f"{data.get('inference_ms', 0):.1f} ms", help=speed_label)
             with m_col4:
-                st.metric("Inference Speed", f"{data.get('inference_ms', 0):.1f} ms")
+                st.metric("Post-Processing", f"{data.get('processing_ms', 0):.1f} ms", help="Labeling/Scaling overhead")
+
+            st.divider()
+
+            # Side-by-Side View (Technical Trace)
+            img = Image.open(uploaded_file).convert("RGB")
+            if st.session_state.get('tech_trace', False):
+                t_col1, t_col2 = st.columns(2)
+                with t_col1:
+                    st.subheader("🔢 Raw Output Tensor (The Math)")
+                    st.caption("First 5 rows of the 1x300x6 detection tensor.")
+                    st.json(data.get('raw_array_sample', []))
+                with t_col2:
+                    st.subheader("🎨 AI Interpretation")
+                    img_annotated = annotate_image(img.copy(), data['detections'])
+                    st.image(img_annotated, width='stretch')
+            else:
+                # Regular View
+                st.write("### 🔍 AI Analytical View")
+                display_interactive_results(img, data['detections'], key="main_viewer")
 
             col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
             
@@ -724,9 +734,9 @@ with tab4:
                         with h_col2:
                             st.metric("Healthy (Ripe)", summary.get('Ripe', 0))
                         with h_col3:
-                            st.metric("Abnormal Alerts", summary.get('Abnormal', 0))
+                            st.metric("Inference Speed", f"{record.get('inference_ms', 0) or 0:.1f} ms", help="Raw model speed")
                         with h_col4:
-                            st.metric("Inference Speed", f"{record.get('inference_ms', 0) or 0:.1f} ms")
+                            st.metric("Post-Processing", f"{record.get('processing_ms', 0) or 0:.1f} ms", help="Labeling overhead")
 
                         # Image View
                         if os.path.exists(record['archive_path']):

+ 41 - 16
src/api/main.py

@@ -22,6 +22,7 @@ ARCHIVE_DIR = "history_archive"
 os.makedirs(ARCHIVE_DIR, exist_ok=True)
 
 def init_local_db():
+    print(f"Initializing Local DB at {DB_PATH}...")
     conn = sqlite3.connect(DB_PATH)
     cursor = conn.cursor()
     cursor.execute('''
@@ -32,12 +33,14 @@ def init_local_db():
             detections TEXT,
             summary TEXT,
             inference_ms REAL,
+            processing_ms REAL,
             raw_tensor TEXT,
             timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
         )
     ''')
     conn.commit()
     conn.close()
+    print("Local DB Initialized.")
 
 init_local_db()
 
@@ -67,9 +70,14 @@ class ModelManager:
 
     def run_onnx_inference(self, img: Image.Image, conf_threshold: float):
         img_array, orig_w, orig_h = self.preprocess_onnx(img)
+        
+        import time
+        start_inf = time.perf_counter()
         outputs = self.onnx_session.run(None, {self.onnx_input_name: img_array})
-        detections_batch = outputs[0]
+        end_inf = time.perf_counter()
+        inference_ms = (end_inf - start_inf) * 1000
         
+        detections_batch = outputs[0]
         scale_w = orig_w / 640.0
         scale_h = orig_h / 640.0
         
@@ -92,10 +100,17 @@ class ModelManager:
                     "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                     "box": [float(x1), float(y1), float(x2), float(y2)]
                 })
-        return detections, detections_batch[0, :5].tolist()
+        print(detections)
+        print(detections_batch)
+        return detections, detections_batch[0, :5].tolist(), inference_ms
 
     def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
+        import time
+        start_inf = time.perf_counter()
         results = self.pt_model(img, conf=conf_threshold, verbose=False)
+        end_inf = time.perf_counter()
+        inference_ms = (end_inf - start_inf) * 1000
+
         detections = []
         for i, box in enumerate(results[0].boxes):
             class_id = int(box.cls)
@@ -107,7 +122,12 @@ class ModelManager:
                 "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                 "box": box.xyxy.tolist()[0]
             })
-        return detections, results[0].boxes.data[:5].tolist()
+        
+        # Extract snippet from results (simulating raw output)
+        raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else []
+        print(detections)
+        print(raw_snippet)
+        return detections, raw_snippet, inference_ms
 
 model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
 
@@ -158,14 +178,16 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     img = Image.open(io.BytesIO(image_bytes))
     
     import time
-    start_time = time.perf_counter()
+    start_total = time.perf_counter()
     # Select Inference Engine
     if model_type == "pytorch":
-        detections, raw_sample = model_manager.run_pytorch_inference(img, current_conf)
+        detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
     else:
-        detections, raw_sample = model_manager.run_onnx_inference(img, current_conf)
-    end_time = time.perf_counter()
-    inference_ms = (end_time - start_time) * 1000
+        detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
+    
+    end_total = time.perf_counter()
+    total_ms = (end_total - start_total) * 1000
+    processing_ms = total_ms - inference_ms
     
     # Initialize summary
     summary = {name: 0 for name in model_manager.class_names.values()}
@@ -184,8 +206,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     # Save to SQLite
     conn = sqlite3.connect(DB_PATH)
     cursor = conn.cursor()
-    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?)",
-                   (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, json.dumps(raw_sample)))
+    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
+                   (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, processing_ms, json.dumps(raw_sample)))
     conn.commit()
     conn.close()
             
@@ -196,6 +218,7 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
         "industrial_summary": summary,
         "detections": detections,
         "inference_ms": inference_ms,
+        "processing_ms": processing_ms,
         "raw_array_sample": raw_sample,
         "archive_id": unique_id
     }
@@ -252,10 +275,12 @@ async def process_batch(files: List[UploadFile] = File(...), model_type: str = F
             # 2. Detect
             img = Image.open(path)
             # FORCE PYTORCH for Batch
-            start_time = time.perf_counter()
-            detections, raw_sample = model_manager.run_pytorch_inference(img, current_conf)
-            end_time = time.perf_counter()
-            inference_ms = (end_time - start_time) * 1000
+            start_total = time.perf_counter()
+            detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
+            end_total = time.perf_counter()
+            
+            total_ms = (end_total - start_total) * 1000
+            processing_ms = total_ms - inference_ms
             
             # 3. Process all detections in the image
             for det in detections:
@@ -368,8 +393,8 @@ async def save_to_history(file: UploadFile = File(...), detections: str = Form(.
         
     conn = sqlite3.connect(DB_PATH)
     cursor = conn.cursor()
-    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary) VALUES (?, ?, ?, ?)",
-                   (file.filename, archive_path, detections, summary))
+    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, processing_ms, raw_tensor) VALUES (?, ?, ?, ?, ?, ?, ?)",
+                   (file.filename, archive_path, detections, summary, 0.0, 0.0, ""))
     conn.commit()
     conn.close()
     return {"status": "success", "message": "Saved to local vault"}