Dr-Swopt před 3 dny
rodič
revize
8bacc37ce7
2 změnil soubory, kde provedl 88 přidání a 35 odebrání
  1. 62 23
      demo_app.py
  2. 26 12
      src/api/main.py

+ 62 - 23
demo_app.py

@@ -335,7 +335,7 @@ with tab1:
             st.divider()
 
             st.write("### 📈 Manager's Dashboard")
-            m_col1, m_col2, m_col3 = st.columns(3)
+            m_col1, m_col2, m_col3, m_col4 = st.columns(4)
             with m_col1:
                 st.metric("Total Bunches", data.get('total_count', 0))
             with m_col2:
@@ -343,10 +343,15 @@ with tab1:
             with m_col3:
                 abnormal = data['industrial_summary'].get('Abnormal', 0)
                 st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
+            with m_col4:
+                st.metric("Inference Speed", f"{data.get('inference_ms', 0):.1f} ms")
 
             col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
             
-            with col2:
+            with col1:
+                with st.expander("🛠️ Technical Evidence: Raw Output Tensor", expanded=False):
+                    st.write("First 5 detections from raw output tensor:")
+                    st.json(data.get('raw_array_sample', []))
                 with st.container(border=True):
                     st.write("### 🏷️ Detection Results")
                     if not data['detections']:
@@ -638,6 +643,9 @@ with tab3:
 # --- Tab 4: History Vault ---
 with tab4:
     st.subheader("📜 Local History Vault")
+    if "selected_history_id" not in st.session_state:
+        st.session_state.selected_history_id = None
+
     try:
         res = requests.get(f"{API_BASE_URL}/get_history")
         if res.status_code == 200:
@@ -645,32 +653,63 @@ with tab4:
             if not history_data:
                 st.info("No saved records found.")
             else:
-                # Selection table
-                df_history = pd.DataFrame(history_data)[['id', 'filename', 'timestamp']]
-                selected_id = st.selectbox("Select a record to review:", df_history['id'])
-                
-                if selected_id:
-                    record = next(item for item in history_data if item["id"] == selected_id)
-                    detections = json.loads(record['detections'])
+                if st.session_state.selected_history_id is None:
+                    # ListView Mode
+                    st.write("### 📋 Record List")
+                    df_history = pd.DataFrame(history_data)[['id', 'filename', 'timestamp', 'inference_ms']]
+                    st.dataframe(df_history, hide_index=True, use_container_width=True)
                     
-                    # Display Interactive Hover View
-                    if os.path.exists(record['archive_path']):
-                        with open(record['archive_path'], "rb") as f:
-                            hist_img = Image.open(f).convert("RGB")
-                            display_interactive_results(hist_img, detections, key=f"hist_{record['id']}")
+                    id_to_select = st.number_input("Enter Record ID to view details:", min_value=int(df_history['id'].min()), max_value=int(df_history['id'].max()), step=1)
+                    if st.button("Deep Dive Analysis", type="primary"):
+                        st.session_state.selected_history_id = id_to_select
+                        st.rerun()
+                else:
+                    # Detail View Mode
+                    record = next((item for item in history_data if item["id"] == st.session_state.selected_history_id), None)
+                    if not record:
+                        st.error("Record not found.")
+                        if st.button("Back to List"):
+                            st.session_state.selected_history_id = None
+                            st.rerun()
+                    else:
+                        if st.button("⬅️ Back to History List"):
+                            st.session_state.selected_history_id = None
+                            st.rerun()
                         
-                        st.write("### 📈 Archived Summary")
+                        st.divider()
+                        st.write(f"## 🔍 Deep Dive: Record #{record['id']} ({record['filename']})")
+                        detections = json.loads(record['detections'])
                         summary = json.loads(record['summary'])
-                        s_col1, s_col2, s_col3 = st.columns(3)
-                        with s_col1:
+                        
+                        # Metrics Row
+                        h_col1, h_col2, h_col3, h_col4 = st.columns(4)
+                        with h_col1:
                             st.metric("Total Bunches", sum(summary.values()))
-                        with s_col2:
+                        with h_col2:
                             st.metric("Healthy (Ripe)", summary.get('Ripe', 0))
-                        with s_col3:
-                            abnormal = summary.get('Abnormal', 0)
-                            st.metric("Abnormal Alerts", abnormal)
-                    else:
-                        st.error(f"Archive file not found: {record['archive_path']}")
+                        with h_col3:
+                            st.metric("Abnormal Alerts", summary.get('Abnormal', 0))
+                        with h_col4:
+                            st.metric("Inference Speed", f"{record.get('inference_ms', 0) or 0:.1f} ms")
+
+                        # Image View
+                        if os.path.exists(record['archive_path']):
+                            with open(record['archive_path'], "rb") as f:
+                                hist_img = Image.open(f).convert("RGB")
+                                display_interactive_results(hist_img, detections, key=f"hist_{record['id']}")
+                        else:
+                            st.error(f"Archive file not found: {record['archive_path']}")
+                        
+                        # Technical Evidence Expander
+                        with st.expander("🛠️ Technical Evidence: Raw Output Tensor"):
+                            raw_data = record.get('raw_array_sample')
+                            if raw_data:
+                                try:
+                                    st.json(json.loads(raw_data))
+                                except:
+                                    st.text(raw_data)
+                            else:
+                                st.info("No raw tensor data available for this record.")
         else:
             st.error(f"Failed to fetch history: {res.text}")
     except Exception as e:

+ 26 - 12
src/api/main.py

@@ -31,6 +31,8 @@ def init_local_db():
             archive_path TEXT,
             detections TEXT,
             summary TEXT,
+            inference_ms REAL,
+            raw_array_sample TEXT,
             timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
         )
     ''')
@@ -90,7 +92,7 @@ class ModelManager:
                     "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                     "box": [float(x1), float(y1), float(x2), float(y2)]
                 })
-        return detections
+        return detections, detections_batch[0, :5].tolist()
 
     def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
         results = self.pt_model(img, conf=conf_threshold, verbose=False)
@@ -105,7 +107,7 @@ class ModelManager:
                 "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                 "box": box.xyxy.tolist()[0]
             })
-        return detections
+        return detections, results[0].boxes.data[:5].tolist()
 
 model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')
 
@@ -155,11 +157,15 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     image_bytes = await file.read()
     img = Image.open(io.BytesIO(image_bytes))
     
+    import time
+    start_time = time.perf_counter()
     # Select Inference Engine
     if model_type == "pytorch":
-        detections = model_manager.run_pytorch_inference(img, current_conf)
+        detections, raw_sample = model_manager.run_pytorch_inference(img, current_conf)
     else:
-        detections = model_manager.run_onnx_inference(img, current_conf)
+        detections, raw_sample = model_manager.run_onnx_inference(img, current_conf)
+    end_time = time.perf_counter()
+    inference_ms = (end_time - start_time) * 1000
     
     # Initialize summary
     summary = {name: 0 for name in model_manager.class_names.values()}
@@ -178,8 +184,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
     # Save to SQLite
     conn = sqlite3.connect(DB_PATH)
     cursor = conn.cursor()
-    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary) VALUES (?, ?, ?, ?)",
-                   (file.filename, archive_path, json.dumps(detections), json.dumps(summary)))
+    cursor.execute("INSERT INTO history (filename, archive_path, detections, summary, inference_ms, raw_array_sample) VALUES (?, ?, ?, ?, ?, ?)",
+                   (file.filename, archive_path, json.dumps(detections), json.dumps(summary), inference_ms, json.dumps(raw_sample)))
     conn.commit()
     conn.close()
             
@@ -189,6 +195,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
         "total_count": len(detections),
         "industrial_summary": summary,
         "detections": detections,
+        "inference_ms": inference_ms,
+        "raw_array_sample": raw_sample,
         "archive_id": unique_id
     }
 
@@ -240,18 +248,22 @@ async def process_batch(files: List[UploadFile] = File(...), model_type: str = F
                 shutil.copyfileobj(file.file, f_out)
             temp_files.append(path)
 
+            import time
             # 2. Detect
             img = Image.open(path)
-            if model_type == "pytorch":
-                detections = model_manager.run_pytorch_inference(img, current_conf)
-            else:
-                detections = model_manager.run_onnx_inference(img, current_conf)
+            # FORCE PYTORCH for Batch
+            start_time = time.perf_counter()
+            detections, raw_sample = model_manager.run_pytorch_inference(img, current_conf)
+            end_time = time.perf_counter()
+            inference_ms = (end_time - start_time) * 1000
             
             # 3. Process all detections in the image
             for det in detections:
                 batch_results.append({
                     "path": path,
-                    "yolo": det
+                    "yolo": det,
+                    "inference_ms": inference_ms,
+                    "raw_array_sample": raw_sample
                 })
 
 
@@ -269,7 +281,9 @@ async def process_batch(files: List[UploadFile] = File(...), model_type: str = F
         for item in batch_results:
             detailed_detections.append({
                 "filename": os.path.basename(item['path']),
-                "detection": item['yolo']
+                "detection": item['yolo'],
+                "inference_ms": item['inference_ms'],
+                "raw_array_sample": item['raw_array_sample']
             })
 
         try: