Jelajahi Sumber

ONNX can now label

Dr-Swopt 3 hari lalu
induk
melakukan
8192a3c84e
2 mengubah file dengan 99 tambahan dan 59 penghapusan
  1. 86 50
      demo_app.py
  2. 13 9
      src/api/main.py

+ 86 - 50
demo_app.py

@@ -163,22 +163,27 @@ def display_interactive_results(image, detections, key=None):
             opacity=0.3, # Semi-transparent until hover
             mode='lines',
             line=dict(color=color, width=3),
-            name=f"Bunch #{bunch_id}",
+            name=f"ID: #{bunch_id}", # Unified ID Tag
             text=f"<b>ID: #{bunch_id}</b><br>Grade: {det['class']}<br>Score: {det['confidence']:.2f}<br>Alert: {det['is_health_alert']}",
             hoverinfo="text"
         ))
 
     fig.update_layout(width=800, height=600, margin=dict(l=0, r=0, b=0, t=0), showlegend=False)
-    st.plotly_chart(fig, width='stretch', key=key)
+    st.plotly_chart(fig, use_container_width=True, key=key)
 
 def annotate_image(image, detections):
-    """Draws high-visibility boxes and background-shaded labels."""
+    """Draws high-visibility 'Plated Labels' and boxes on the image."""
     from PIL import ImageDraw, ImageFont
     draw = ImageDraw.Draw(image)
-    # Dynamic font size based on image resolution
+    
+    # 1. Dynamic Font Scaling (width // 40 as requested)
     font_size = max(20, image.width // 40)
     try:
-        font_path = "C:\\Windows\\Fonts\\arial.ttf"
+        # standard Windows font paths for agent environment
+        font_path = "C:\\Windows\\Fonts\\arialbd.ttf" # Bold for higher visibility
+        if not os.path.exists(font_path):
+            font_path = "C:\\Windows\\Fonts\\arial.ttf"
+            
         if os.path.exists(font_path):
             font = ImageFont.truetype(font_path, font_size)
         else:
@@ -192,20 +197,24 @@ def annotate_image(image, detections):
         conf = det['confidence']
         bunch_id = det.get('bunch_id', '?')
         color = overlay_colors.get(cls, '#ffffff')
+        
+        # 2. Draw Heavy-Duty Bounding Box
+        line_width = max(4, image.width // 150)
+        draw.rectangle(box, outline=color, width=line_width)
 
-        # 1. Draw Bold Bounding Box
-        draw.rectangle(box, outline=color, width=max(4, image.width // 200)) 
-
-        # 2. Draw Label Background (High Contrast)
+        # 3. Draw 'Plated Label' (Background Shaded)
         label = f"#{bunch_id} {cls} {conf:.2f}"
         try:
-            # textbbox provides precise coordinates for background rectangle
-            l, t, r, b = draw.textbbox((box[0], box[1] - font_size - 10), label, font=font)
-            draw.rectangle([l-5, t-5, r+5, b+5], fill=color)
-            draw.text((l, t), label, fill="white", font=font)
+            # Precise background calculation using textbbox
+            l, t, r, b = draw.textbbox((box[0], box[1]), label, font=font)
+            # Shift background up so it doesn't obscure the fruit
+            bg_rect = [l - 2, t - (b - t) - 10, r + 2, t - 6]
+            draw.rectangle(bg_rect, fill=color)
+            # Draw text inside the plate
+            draw.text((l, t - (b - t) - 8), label, fill="white", font=font)
         except:
-            # Fallback for basic text drawing
-            draw.text((box[0], box[1] - 25), label, fill=color)
+            # Simple fallback
+            draw.text((box[0], box[1] - font_size), label, fill=color)
     
     return image
 
@@ -689,6 +698,8 @@ with tab3:
 # --- Tab 4: History Vault ---
 with tab4:
     st.subheader("📜 Local History Vault")
+    st.caption("Industrial-grade audit log of all past AI harvest scans.")
+    
     if "selected_history_id" not in st.session_state:
         st.session_state.selected_history_id = None
 
@@ -697,74 +708,99 @@ with tab4:
         if res.status_code == 200:
             history_data = res.json().get("history", [])
             if not history_data:
-                st.info("No saved records found.")
+                st.info("No saved records found in the vault.")
             else:
                 if st.session_state.selected_history_id is None:
-                    # ListView Mode
-                    st.write("### 📋 Record List")
-                    df_history = pd.DataFrame(history_data)[['id', 'filename', 'timestamp', 'inference_ms']]
-                    st.dataframe(df_history, hide_index=True, width='stretch')
+                    # --- 1. ListView Mode (Management Dashboard) ---
+                    st.write("### 📋 Audit Log")
+                    
+                    # Prepare searchable dataframe
+                    df_history = pd.DataFrame(history_data)
+                    # Clean up for display
+                    display_df = df_history[['id', 'timestamp', 'filename', 'inference_ms']].copy()
+                    display_df.columns = ['ID', 'Date/Time', 'Filename', 'Inference (ms)']
+                    
+                    st.dataframe(
+                        display_df, 
+                        hide_index=True, 
+                        use_container_width=True,
+                        column_config={
+                            "ID": st.column_config.NumberColumn(width="small"),
+                            "Inference (ms)": st.column_config.NumberColumn(format="%.1f ms")
+                        }
+                    )
                     
-                    id_to_select = st.number_input("Enter Record ID to view details:", min_value=int(df_history['id'].min()), max_value=int(df_history['id'].max()), step=1)
-                    if st.button("Deep Dive Analysis", type="primary"):
-                        st.session_state.selected_history_id = id_to_select
-                        st.rerun()
+                    # Industrial Selection UI
+                    hist_col1, hist_col2 = st.columns([3, 1])
+                    with hist_col1:
+                        target_id = st.selectbox(
+                            "Select Record for Deep Dive Analysis",
+                            options=df_history['id'].tolist(),
+                            format_func=lambda x: f"Record #{x} - {df_history[df_history['id']==x]['filename'].values[0]}"
+                        )
+                    with hist_col2:
+                        st.write("##") # Alignment
+                        if st.button("🔬 Start Deep Dive", type="primary", use_container_width=True):
+                            st.session_state.selected_history_id = target_id
+                            st.rerun()
                 else:
-                    # Detail View Mode
+                    # --- 2. Detail View Mode (Technical Auditor) ---
                     record = next((item for item in history_data if item["id"] == st.session_state.selected_history_id), None)
                     if not record:
-                        st.error("Record not found.")
+                        st.error("Audit record not found.")
                         if st.button("Back to List"):
                             st.session_state.selected_history_id = None
                             st.rerun()
                     else:
-                        if st.button("⬅️ Back to History List"):
-                            st.session_state.selected_history_id = None
-                            st.rerun()
+                        st.button("⬅️ Back to Audit Log", on_click=lambda: st.session_state.update({"selected_history_id": None}))
                         
                         st.divider()
-                        st.write(f"## 🔍 Deep Dive: Record #{record['id']} ({record['filename']})")
+                        st.write(f"## 🔍 Deep Dive: Record #{record['id']}")
+                        st.caption(f"Original Filename: `{record['filename']}` | Processed: `{record['timestamp']}`")
+                        
                         detections = json.loads(record['detections'])
                         summary = json.loads(record['summary'])
                         
-                        # Metrics Row
+                        # Metrics Executive Summary
                         h_col1, h_col2, h_col3, h_col4 = st.columns(4)
                         with h_col1:
                             st.metric("Total Bunches", sum(summary.values()))
                         with h_col2:
                             st.metric("Healthy (Ripe)", summary.get('Ripe', 0))
                         with h_col3:
-                            st.metric("Inference Speed", f"{record.get('inference_ms', 0) or 0:.1f} ms", help="Raw model speed")
+                            st.metric("Engine Performance", f"{record.get('inference_ms', 0) or 0:.1f} ms")
                         with h_col4:
-                            st.metric("Post-Processing", f"{record.get('processing_ms', 0) or 0:.1f} ms", help="Labeling overhead")
+                            st.metric("Labeling Overhead", f"{record.get('processing_ms', 0) or 0:.1f} ms")
 
-                        # Image View
+                        # Re-Annotate Archived Image
                         if os.path.exists(record['archive_path']):
                             with open(record['archive_path'], "rb") as f:
                                 hist_img = Image.open(f).convert("RGB")
-                                display_interactive_results(hist_img, detections, key=f"hist_{record['id']}")
+                                
+                            # Side-by-Side: Interactive vs Static Plate
+                            v_tab1, v_tab2 = st.tabs(["Interactive Plotly View", "Static Annotated Evidence"])
+                            with v_tab1:
+                                display_interactive_results(hist_img, detections, key=f"hist_plotly_{record['id']}")
+                            with v_tab2:
+                                img_plate = annotate_image(hist_img.copy(), detections)
+                                st.image(img_plate, use_container_width=True, caption="Point-of-Harvest AI Interpretation")
                         else:
-                            st.error(f"Archive file not found: {record['archive_path']}")
+                            st.warning(f"Technical Error: Archive file missing at `{record['archive_path']}`")
                         
-                        # Technical Evidence Expander
-                        col_hist_tech1, col_hist_tech2 = st.columns([4, 1])
-                        with col_hist_tech1:
-                            st.write("#### 🛠️ Technical Evidence")
-                        with col_hist_tech2:
-                            if st.button("❓ Guide", key="guide_hist"):
-                                show_tech_guide()
-
-                        with st.expander("Raw Output Tensor (Archive)", expanded=False):
-                            st.caption("See the Interpretation Guide for a breakdown of these numbers.")
+                        # Technical Evidence Expander (Mathematical Audit)
+                        st.divider()
+                        st.write("### 🛠️ Technical Audit Trail")
+                        with st.expander("🔬 View Raw Mathematical Tensor", expanded=False):
+                            st.info("This is the exact numerical output from the AI engine prior to human-readable transformation.")
                             raw_data = record.get('raw_tensor')
                             if raw_data:
                                 try:
                                     st.json(json.loads(raw_data))
                                 except:
-                                    st.text(raw_data)
+                                    st.code(raw_data)
                             else:
-                                st.info("No raw tensor data available for this record.")
+                                st.warning("No raw tensor trace was archived for this legacy record.")
         else:
-            st.error(f"Failed to fetch history: {res.text}")
+            st.error(f"Vault Connection Failed: {res.text}")
     except Exception as e:
-        st.error(f"Error loading history: {str(e)}")
+        st.error(f"Audit System Error: {str(e)}")

+ 13 - 9
src/api/main.py

@@ -77,9 +77,9 @@ class ModelManager:
         end_inf = time.perf_counter()
         inference_ms = (end_inf - start_inf) * 1000
         
+        # ONNX Output: [batch, num_boxes, 6] (Where 6: x1, y1, x2, y2, conf, cls)
+        # Note: YOLOv8 endpoints often produce normalized coordinates (0.0 to 1.0)
         detections_batch = outputs[0]
-        scale_w = orig_w / 640.0
-        scale_h = orig_h / 640.0
         
         detections = []
         valid_count = 0
@@ -88,8 +88,13 @@ class ModelManager:
             conf = float(det[4])
             if conf >= conf_threshold:
                 valid_count += 1
+                # 1. Coordinate Scaling: Convert normalized (0.0-1.0) to absolute pixels
                 x1, y1, x2, y2 = det[:4]
-                x1 *= scale_w; y1 *= scale_h; x2 *= scale_w; y2 *= scale_h
+                abs_x1 = x1 * orig_w
+                abs_y1 = y1 * orig_h
+                abs_x2 = x2 * orig_w
+                abs_y2 = y2 * orig_h
+                
                 class_id = int(det[5])
                 class_name = self.class_names.get(class_id, "Unknown")
                 
@@ -98,11 +103,12 @@ class ModelManager:
                     "class": class_name,
                     "confidence": round(conf, 2),
                     "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
-                    "box": [float(x1), float(y1), float(x2), float(y2)]
+                    "box": [float(abs_x1), float(abs_y1), float(abs_x2), float(abs_y2)]
                 })
-        print(detections)
-        print(detections_batch)
-        return detections, detections_batch[0, :5].tolist(), inference_ms
+        
+        # Capture a raw tensor sample (first 5 detections) for technical evidence
+        raw_sample = detections_batch[0, :5].tolist()
+        return detections, raw_sample, inference_ms
 
     def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
         import time
@@ -125,8 +131,6 @@ class ModelManager:
         
         # Extract snippet from results (simulating raw output)
         raw_snippet = results[0].boxes.data[:5].tolist() if len(results[0].boxes) > 0 else []
-        print(detections)
-        print(raw_snippet)
         return detections, raw_snippet, inference_ms
 
 model_manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt')