Quellcode durchsuchen

Further streamlit UI enchancements

Dr-Swopt vor 2 Wochen
Ursprung
Commit
68dd501a74
2 geänderte Dateien mit 151 neuen und 114 gelöschten Zeilen
  1. 127 114
      demo_app.py
  2. 24 0
      export_mobile.py

+ 127 - 114
demo_app.py

@@ -89,114 +89,126 @@ with tab1:
         if "last_detection" not in st.session_state:
             st.session_state.last_detection = None
 
-        # 1. Action Button (Centered and Prominent)
-        st.write("##")
-        _, col_btn, _ = st.columns([1, 2, 1])
-        if col_btn.button("🔍 Run Ripeness Detection", type="primary", width='stretch'):
+        # 1. Auto-Detection Trigger
+        if uploaded_file and st.session_state.last_detection is None:
             with st.spinner("Processing Detections Locally..."):
                 files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
                 res = requests.post(f"{API_BASE_URL}/analyze", files=files)
-                
                 if res.status_code == 200:
                     st.session_state.last_detection = res.json()
+                    st.rerun() # Refresh to show results immediately
                 else:
                     st.error(f"Detection Failed: {res.text}")
 
-            # 2. Results Layout
-            if st.session_state.last_detection:
-                st.divider()
-                
-                # SIDE-BY-SIDE ANALYTICAL VIEW
-                col_left, col_right = st.columns(2)
-                
-                # Fetch data once
-                data = st.session_state.last_detection
+        # 2. Results Layout
+        if st.session_state.last_detection:
+            st.divider()
+            
+            # SIDE-BY-SIDE ANALYTICAL VIEW
+            col_left, col_right = st.columns(2)
+            
+            # Fetch data once
+            data = st.session_state.last_detection
 
-                with col_left:
-                    st.image(uploaded_file, caption="Original Photo", width='stretch')
+            with col_left:
+                st.image(uploaded_file, caption="Original Photo", width='stretch')
+            
+            with col_right:
+                # Use the local model to plot the boxes directly
+                img = Image.open(uploaded_file)
+                results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
+                annotated_img = results[0].plot() # Draws boxes/labels
                 
-                with col_right:
-                    # Use the local model to plot the boxes directly
-                    img = Image.open(uploaded_file)
-                    results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
-                    annotated_img = results[0].plot() # Draws boxes/labels
-                    
-                    # Convert BGR (OpenCV format) to RGB for Streamlit
-                    annotated_img_rgb = annotated_img[:, :, ::-1] 
-                    st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
+                # Convert BGR (OpenCV format) to RGB for Streamlit
+                annotated_img_rgb = annotated_img[:, :, ::-1] 
+                st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
 
-                st.write("### 📈 Manager's Dashboard")
-                m_col1, m_col2, m_col3 = st.columns(3)
-                with m_col1:
-                    st.metric("Total Bunches", data.get('total_count', 0))
-                with m_col2:
-                    st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
-                with m_col3:
-                    abnormal = data['industrial_summary'].get('Abnormal', 0)
-                    st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
+            st.write("### 📈 Manager's Dashboard")
+            m_col1, m_col2, m_col3 = st.columns(3)
+            with m_col1:
+                st.metric("Total Bunches", data.get('total_count', 0))
+            with m_col2:
+                st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
+            with m_col3:
+                abnormal = data['industrial_summary'].get('Abnormal', 0)
+                st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
 
-                col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
-                
-                with col2:
-                    with st.container(border=True):
-                        st.write("### 🏷️ Detection Results")
-                        if not data['detections']:
-                            st.warning("No Fresh Fruit Bunches detected.")
+            col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
+            
+            with col2:
+                with st.container(border=True):
+                    st.write("### 🏷️ Detection Results")
+                    if not data['detections']:
+                        st.warning("No Fresh Fruit Bunches detected.")
+                    else:
+                        for det in data['detections']:
+                            st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
+                        
+                        st.write("### 📊 Harvest Quality Mix")
+                        # Convert industrial_summary dictionary to a DataFrame for charting
+                        summary_df = pd.DataFrame(
+                            list(data['industrial_summary'].items()), 
+                            columns=['Grade', 'Count']
+                        )
+                        # Filter out classes with 0 count for a cleaner chart
+                        summary_df = summary_df[summary_df['Count'] > 0]
+                        if not summary_df.empty:
+                            # Create a Pie Chart to show the proportion of each grade
+                            fig = px.pie(summary_df, values='Count', names='Grade', 
+                                         color='Grade',
+                                         color_discrete_map={
+                                             'Abnormal': '#ef4444', # Red
+                                             'Empty_Bunch': '#94a3b8', # Gray
+                                             'Ripe': '#22c55e', # Green
+                                             'Underripe': '#eab308', # Yellow
+                                             'Unripe': '#3b82f6', # Blue
+                                             'Overripe': '#a855f7' # Purple
+                                         },
+                                         hole=0.4)
+                            fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
+                            st.plotly_chart(fig, width='stretch')
+
+                        # 💡 Strategic R&D Insight: Harvest Efficiency
+                        st.write("---")
+                        st.write("#### 💡 Strategic R&D Insight")
+                        unripe_count = data['industrial_summary'].get('Unripe', 0)
+                        underripe_count = data['industrial_summary'].get('Underripe', 0)
+                        total_non_prime = unripe_count + underripe_count
+                        
+                        st.write(f"🌑 **Unripe (Mentah):** {unripe_count}")
+                        st.write(f"🌗 **Underripe (Kurang Masak):** {underripe_count}")
+                        
+                        if total_non_prime > 0:
+                            st.warning(f"🚨 **Potential Yield Loss:** {total_non_prime} bunches harvested too early. This will reduce OER (Oil Extraction Rate).")
                         else:
-                            for det in data['detections']:
-                                st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
-                            
-                            st.write("### 📊 Harvest Quality Mix")
-                            # Convert industrial_summary dictionary to a DataFrame for charting
-                            summary_df = pd.DataFrame(
-                                list(data['industrial_summary'].items()), 
-                                columns=['Grade', 'Count']
-                            )
-                            # Filter out classes with 0 count for a cleaner chart
-                            summary_df = summary_df[summary_df['Count'] > 0]
-                            
-                            if not summary_df.empty:
-                                # Create a Pie Chart to show the proportion of each grade
-                                fig = px.pie(summary_df, values='Count', names='Grade', 
-                                             color='Grade',
-                                             color_discrete_map={
-                                                 'Abnormal': '#ef4444', # Red
-                                                 'Empty_Bunch': '#94a3b8', # Gray
-                                                 'Ripe': '#22c55e', # Green
-                                                 'Underripe': '#eab308', # Yellow
-                                                 'Unripe': '#3b82f6', # Blue
-                                                 'Overripe': '#a855f7' # Purple
-                                             },
-                                             hole=0.4)
-                                fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
-                                st.plotly_chart(fig, width='stretch')
-                            
-                            # High-Priority Health Alert
-                            if data['industrial_summary'].get('Abnormal', 0) > 0:
-                                st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
-                            if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
-                                st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
-                            
-                            # 3. Cloud Actions (Only if detections found)
-                            st.write("---")
-                            st.write("#### ✨ Cloud Archive")
-                            if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
-                                with st.spinner("Archiving..."):
-                                    import json
-                                    primary_det = data['detections'][0]
-                                    payload = {"detection_data": json.dumps(primary_det)}
-                                    files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
-                                    
-                                    res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
-                                    
-                                    if res_cloud.status_code == 200:
-                                        res_json = res_cloud.json()
-                                        if res_json["status"] == "success":
-                                            st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
-                                        else:
-                                            st.error(f"Cloud Error: {res_json['message']}")
+                            st.success("✅ **Harvest Efficiency:** 100% Prime Ripeness detected.")
+                        
+                        # High-Priority Health Alert
+                        if data['industrial_summary'].get('Abnormal', 0) > 0:
+                            st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
+                        if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
+                            st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
+                        
+                        # 3. Cloud Actions (Only if detections found)
+                        st.write("---")
+                        st.write("#### ✨ Cloud Archive")
+                        if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
+                            with st.spinner("Archiving..."):
+                                import json
+                                primary_det = data['detections'][0]
+                                payload = {"detection_data": json.dumps(primary_det)}
+                                files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
+                                
+                                res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
+                                
+                                if res_cloud.status_code == 200:
+                                    res_json = res_cloud.json()
+                                    if res_json["status"] == "success":
+                                        st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
                                     else:
-                                        st.error("Failed to connect to cloud service")
+                                        st.error(f"Cloud Error: {res_json['message']}")
+                                else:
+                                    st.error("Failed to connect to cloud service")
 
 # --- Tab 2: Batch Processing ---
 with tab2:
@@ -258,29 +270,30 @@ with tab2:
     
     with col_batch2:
         st.write("##") # Alignment
+        if st.session_state.last_batch_results is None and uploaded_files:
+            if st.button("🔍 Process Batch", type="primary", width='stretch'):
+                with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
+                    files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
+                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
+                    
+                    if res.status_code == 200:
+                        data = res.json()
+                        if data["status"] == "success":
+                            st.session_state.last_batch_results = data
+                            st.session_state.batch_uploader_key += 1
+                            st.rerun()
+                        elif data["status"] == "partial_success":
+                            st.warning(data["message"])
+                            st.info(f"Successfully detected {data['detections_count']} bunches locally.")
+                        else:
+                            st.error(f"Batch Error: {data['message']}")
+                    else:
+                        st.error(f"Batch Processing Failed: {res.text}")
+
         if st.button("🗑️ Reset Uploader"):
             st.session_state.batch_uploader_key += 1
+            st.session_state.last_batch_results = None
             st.rerun()
-    
-    if uploaded_files:
-        if st.button(f"🚀 Process {len(uploaded_files)} Images"):
-            with st.spinner("Batch Processing in progress..."):
-                files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
-                res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
-                
-                if res.status_code == 200:
-                    data = res.json()
-                    if data["status"] == "success":
-                        st.session_state.last_batch_results = data
-                        st.session_state.batch_uploader_key += 1
-                        st.rerun()
-                    elif data["status"] == "partial_success":
-                        st.warning(data["message"])
-                        st.info(f"Successfully detected {data['detections_count']} bunches locally.")
-                    else:
-                        st.error(f"Batch Error: {data['message']}")
-                else:
-                    st.error(f"Batch Failed: {res.text}")
 
 # --- Tab 3: Similarity Search ---
 with tab3:

+ 24 - 0
export_mobile.py

@@ -0,0 +1,24 @@
+from ultralytics import YOLO
+import os
+
+# 1. Load your high-accuracy PC model
+model_path = 'best.pt'
+if not os.path.exists(model_path):
+    print(f"Error: {model_path} not found.")
+else:
+    model = YOLO(model_path) 
+
+    # 2. Export to TFLite with NMS and Quantization
+    # 'int8' optimization allows the model to leverage mobile NPUs
+    # 'nms' handles the overlapping box logic natively on-chip
+    # Note: Exporting to TFLite with int8 might require dataset for calibration or it might use dynamic range quantization if no data is provided.
+    # Ultralytics handle's calibration if 'data' is provided in export.
+    model.export(
+        format='tflite', 
+        int8=True, 
+        nms=True, 
+        imgsz=640,
+        data='unified_dataset/data.yaml' # For quantization calibration
+    )
+
+    print("Mobile assets generated in: best_saved_model/")