Jelajahi Sumber

enhancements for different grades

Dr-Swopt 2 minggu lalu
induk
melakukan
f486848971
12 mengubah file dengan 194 tambahan dan 50 penghapusan
  1. 2 1
      .gitignore
  2. 1 1
      README.md
  3. TEMPAT SAMPAH
      best.onnx
  4. TEMPAT SAMPAH
      best.pt
  5. 119 38
      demo_app.py
  6. TEMPAT SAMPAH
      last.pt
  7. TEMPAT SAMPAH
      requirements.txt
  8. 27 8
      src/api/main.py
  9. 4 2
      src/application/analyze_bunch.py
  10. 1 0
      src/domain/models.py
  11. 7 0
      src/infrastructure/repository.py
  12. 33 0
      test_model.py

+ 2 - 1
.gitignore

@@ -40,4 +40,5 @@ Thumbs.db
 
 unified_dataset
 datasets
-runs
+runs
+best_saved_model

+ 1 - 1
README.md

@@ -45,7 +45,7 @@ pip install -r requirements.txt
 2. Extract into `/datasets`.
 3. **Train the model:**
 ```bash
-python train_script.py
+python train_p.py
 
 ```
 

TEMPAT SAMPAH
best.onnx


TEMPAT SAMPAH
best.pt


+ 119 - 38
demo_app.py

@@ -1,8 +1,12 @@
 import streamlit as st
 import requests
+from ultralytics import YOLO
+import numpy as np
 from PIL import Image
 import io
 import base64
+import pandas as pd
+import plotly.express as px
 
 # --- 1. Global Backend Check ---
 API_BASE_URL = "http://localhost:8000"
@@ -16,6 +20,13 @@ def check_backend():
 
 backend_active = check_backend()
 
+# Load YOLO model locally for Analytical View
+@st.cache_resource
+def load_yolo():
+    return YOLO('best.pt')
+
+yolo_model = load_yolo()
+
 if not backend_active:
     st.error("⚠️ Backend API is offline!")
     st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
@@ -69,7 +80,7 @@ with tab1:
         # 1. Action Button (Centered and Prominent)
         st.write("##")
         _, col_btn, _ = st.columns([1, 2, 1])
-        if col_btn.button("🔍 Run Ripeness Detection", type="primary", use_container_width=True):
+        if col_btn.button("🔍 Run Ripeness Detection", type="primary", width='stretch'):
             with st.spinner("Processing Detections Locally..."):
                 files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
                 res = requests.post(f"{API_BASE_URL}/analyze", files=files)
@@ -79,44 +90,89 @@ with tab1:
                 else:
                     st.error(f"Detection Failed: {res.text}")
 
-        # 2. Results Layout
-        if st.session_state.last_detection:
-            st.divider()
-            col1, col2 = st.columns([1.5, 1])
-            
-            with col1:
-                st.image(uploaded_file, caption="Analyzed Image", use_container_width=True)
-            
-            with col2:
-                data = st.session_state.last_detection
-                with st.container(border=True):
-                    st.write("### 🏷️ Detection Results")
-                    if not data['detections']:
-                        st.warning("No Fresh Fruit Bunches detected.")
-                    else:
-                        for det in data['detections']:
-                            st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
-                        
-                        # 3. Cloud Actions (Only if detections found)
-                        st.write("---")
-                        st.write("#### ✨ Cloud Archive")
-                        if st.button("🚀 Save to Atlas (Vectorize)", use_container_width=True):
-                            with st.spinner("Archiving..."):
-                                import json
-                                primary_det = data['detections'][0]
-                                payload = {"detection_data": json.dumps(primary_det)}
-                                files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
-                                
-                                res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
-                                
-                                if res_cloud.status_code == 200:
-                                    res_json = res_cloud.json()
-                                    if res_json["status"] == "success":
-                                        st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
+            # 2. Results Layout
+            if st.session_state.last_detection:
+                st.divider()
+                
+                # SIDE-BY-SIDE ANALYTICAL VIEW
+                col_left, col_right = st.columns(2)
+                
+                with col_left:
+                    st.image(uploaded_file, caption="Original Photo", width='stretch')
+                
+                with col_right:
+                    # Use the local model to plot the boxes directly
+                    img = Image.open(uploaded_file)
+                    results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
+                    annotated_img = results[0].plot() # Draws boxes/labels
+                    
+                    # Convert BGR (OpenCV format) to RGB for Streamlit
+                    annotated_img_rgb = annotated_img[:, :, ::-1] 
+                    st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
+
+                col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
+                
+                with col2:
+                    data = st.session_state.last_detection
+                    with st.container(border=True):
+                        st.write("### 🏷️ Detection Results")
+                        if not data['detections']:
+                            st.warning("No Fresh Fruit Bunches detected.")
+                        else:
+                            for det in data['detections']:
+                                st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
+                            
+                            st.write("### 📊 Harvest Quality Mix")
+                            # Convert industrial_summary dictionary to a DataFrame for charting
+                            summary_df = pd.DataFrame(
+                                list(data['industrial_summary'].items()), 
+                                columns=['Grade', 'Count']
+                            )
+                            # Filter out classes with 0 count for a cleaner chart
+                            summary_df = summary_df[summary_df['Count'] > 0]
+                            
+                            if not summary_df.empty:
+                                # Create a Pie Chart to show the proportion of each grade
+                                fig = px.pie(summary_df, values='Count', names='Grade', 
+                                             color='Grade',
+                                             color_discrete_map={
+                                                 'Abnormal': '#ef4444', # Red
+                                                 'Empty_Bunch': '#94a3b8', # Gray
+                                                 'Ripe': '#22c55e', # Green
+                                                 'Underripe': '#eab308', # Yellow
+                                                 'Unripe': '#3b82f6', # Blue
+                                                 'Overripe': '#a855f7' # Purple
+                                             },
+                                             hole=0.4)
+                                fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
+                                st.plotly_chart(fig, width='stretch')
+                            
+                            # High-Priority Health Alert
+                            if data['industrial_summary'].get('Abnormal', 0) > 0:
+                                st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
+                            if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
+                                st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
+                            
+                            # 3. Cloud Actions (Only if detections found)
+                            st.write("---")
+                            st.write("#### ✨ Cloud Archive")
+                            if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
+                                with st.spinner("Archiving..."):
+                                    import json
+                                    primary_det = data['detections'][0]
+                                    payload = {"detection_data": json.dumps(primary_det)}
+                                    files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
+                                    
+                                    res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
+                                    
+                                    if res_cloud.status_code == 200:
+                                        res_json = res_cloud.json()
+                                        if res_json["status"] == "success":
+                                            st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
+                                        else:
+                                            st.error(f"Cloud Error: {res_json['message']}")
                                     else:
-                                        st.error(f"Cloud Error: {res_json['message']}")
-                                else:
-                                    st.error("Failed to connect to cloud service")
+                                        st.error("Failed to connect to cloud service")
 
 # --- Tab 2: Batch Processing ---
 with tab2:
@@ -133,6 +189,31 @@ with tab2:
         res_data = st.session_state.last_batch_results
         with st.container(border=True):
             st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
+            
+            # Batch Summary Dashboard
+            st.write("### 📈 Batch Quality Overview")
+            batch_summary = res_data.get('industrial_summary', {})
+            if batch_summary:
+                sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
+                sum_df = sum_df[sum_df['Count'] > 0]
+                
+                b_col1, b_col2 = st.columns([1, 1])
+                with b_col1:
+                    st.dataframe(sum_df, hide_index=True, width='stretch')
+                with b_col2:
+                    if not sum_df.empty:
+                        fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
+                                          color_discrete_map={
+                                             'Abnormal': '#ef4444', 
+                                             'Empty_Bunch': '#94a3b8', 
+                                             'Ripe': '#22c55e'
+                                          })
+                        fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
+                        st.plotly_chart(fig_batch, width='stretch')
+
+            if batch_summary.get('Abnormal', 0) > 0:
+                st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
+
             st.write("Generated Record IDs:")
             st.code(res_data['record_ids'])
             if st.button("Clear Results & Start New Batch"):

TEMPAT SAMPAH
last.pt


TEMPAT SAMPAH
requirements.txt


+ 27 - 8
src/api/main.py

@@ -36,6 +36,7 @@ repo = MongoPalmOilRepository(
     db_name=os.getenv("DB_NAME", "palm_oil_db"),
     collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
 )
+repo.ensure_indexes()
 analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
 analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
 search_use_case = SearchSimilarUseCase(vision_service, repo)
@@ -62,24 +63,34 @@ async def set_confidence(threshold: float = Body(..., embed=True)):
 
 
 @app.post("/analyze")
-async def analyze_only(file: UploadFile = File(...)):
-    """Local YOLO detection only. Guaranteed to work without Billing."""
+async def analyze_with_health_metrics(file: UploadFile = File(...)):
+    """Industry-grade analysis with health metrics and summary."""
     image_bytes = await file.read()
     img = Image.open(io.BytesIO(image_bytes))
-    results = yolo_model(img, conf=current_conf)
+    
+    # Run yolov8 detection with agnostic NMS to merge overlapping detections
+    results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
     
     detections = []
+    # Initialize summary for all classes known by the model
+    summary = {name: 0 for name in yolo_model.names.values()}
+    
     for r in results:
         for box in r.boxes:
+            class_name = yolo_model.names[int(box.cls)]
+            summary[class_name] += 1
+            
             detections.append({
-                "class": yolo_model.names[int(box.cls)],
+                "class": class_name,
                 "confidence": round(float(box.conf), 2),
+                "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                 "box": box.xyxy.tolist()[0]
             })
             
     return {
-        "status": "success", 
+        "status": "success",
         "current_threshold": current_conf,
+        "industrial_summary": summary,
         "detections": detections
     }
 
@@ -130,18 +141,20 @@ async def process_batch(files: List[UploadFile] = File(...)):
                 shutil.copyfileobj(file.file, f_out)
             temp_files.append(path)
 
-            # 2. YOLO Detect
+            # 2. YOLO Detect with agnostic NMS
             img = Image.open(path)
-            yolo_res = yolo_model(img, conf=current_conf)
+            yolo_res = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
             
             # 3. Take the primary detection per image
             if yolo_res and yolo_res[0].boxes:
                 box = yolo_res[0].boxes[0]
+                class_name = yolo_model.names[int(box.cls)]
                 batch_results.append({
                     "path": path,
                     "yolo": {
-                        "class": yolo_model.names[int(box.cls)],
+                        "class": class_name,
                         "confidence": float(box.conf),
+                        "is_health_alert": class_name in ["Abnormal", "Empty_Bunch"],
                         "box": box.xyxy.tolist()[0]
                     }
                 })
@@ -149,6 +162,11 @@ async def process_batch(files: List[UploadFile] = File(...)):
         if not batch_results:
             return {"status": "no_detection", "message": "No bunches detected in batch"}
 
+        # Calculate Total Industrial Summary for the Batch
+        total_summary = {name: 0 for name in yolo_model.names.values()}
+        for item in batch_results:
+            total_summary[item['yolo']['class']] += 1
+
         # 4. Process Batch Use Case with error handling for cloud services
         try:
             record_ids = analyze_batch_use_case.execute(batch_results)
@@ -156,6 +174,7 @@ async def process_batch(files: List[UploadFile] = File(...)):
                 "status": "success", 
                 "processed_count": len(record_ids), 
                 "record_ids": record_ids,
+                "industrial_summary": total_summary,
                 "message": f"Successfully processed {len(record_ids)} bunches"
             }
         except RuntimeError as e:

+ 4 - 2
src/application/analyze_bunch.py

@@ -20,7 +20,8 @@ class AnalyzeBunchUseCase:
             confidence=yolo_result['confidence'],
             embedding=vector,
             box=yolo_result['box'],
-            image_data=encoded_string
+            image_data=encoded_string,
+            is_abnormal=yolo_result.get('is_health_alert', False)
         )
 
         # 4. Persist to MongoDB
@@ -50,7 +51,8 @@ class AnalyzeBatchUseCase:
                 confidence=item['yolo']['confidence'],
                 embedding=vector,
                 box=item['yolo']['box'],
-                image_data=img_b64
+                image_data=img_b64,
+                is_abnormal=item['yolo'].get('is_health_alert', False)
             )
             processed_bunches.append(bunch)
 

+ 1 - 0
src/domain/models.py

@@ -9,5 +9,6 @@ class PalmOilBunch:
     embedding: List[float]
     box: List[float]
     image_data: str
+    is_abnormal: bool = False
     timestamp: datetime = field(default_factory=datetime.now)
     id: Optional[str] = None

+ 7 - 0
src/infrastructure/repository.py

@@ -7,6 +7,13 @@ class MongoPalmOilRepository:
         self.client = MongoClient(uri)
         self.collection = self.client[db_name][collection_name]
 
+    def ensure_indexes(self):
+        """Create indexes for health alerts and vector search."""
+        self.collection.create_index("is_abnormal")
+        self.collection.create_index("ripeness_class")
+        self.collection.create_index("timestamp")
+        print("MongoDB Indexes Ensured.")
+
     def get_by_id(self, record_id: str):
         """Retrieve a specific record by its ID."""
         return self.collection.find_one({"_id": ObjectId(record_id)})

+ 33 - 0
test_model.py

@@ -0,0 +1,33 @@
+from ultralytics import YOLO
+import os
+
+def test():
+    # Load the custom trained model
+    model_path = "best.pt"
+    if not os.path.exists(model_path):
+        print(f"Error: {model_path} not found.")
+        return
+
+    model = YOLO(model_path)
+
+    # Path to data.yaml
+    data_yaml = "unified_dataset/data.yaml"
+    
+    # Run validation
+    print(f"Running validation on {data_yaml}...")
+    metrics = model.val(data=data_yaml, split='val')
+    
+    # Print results
+    print("\n--- Validation Results ---")
+    print(f"mAP50: {metrics.results_dict['metrics/mAP50(B)']:.4f}")
+    print(f"mAP50-95: {metrics.results_dict['metrics/mAP50-95(B)']:.4f}")
+    print(f"Fitness: {metrics.fitness:.4f}")
+    
+    # Check if mAP50 > 0.90
+    if metrics.results_dict['metrics/mAP50(B)'] > 0.90:
+        print("\nSUCCESS: mAP50 is greater than 0.90")
+    else:
+        print("\nWARNING: mAP50 is below 0.90")
+
+if __name__ == "__main__":
+    test()