Dr-Swopt 2 днів тому
батько
коміт
0e965994ee
3 змінених файлів з 154 додано та 125 видалено
  1. 1 0
      .gitignore
  2. 69 40
      demo_app.py
  3. 84 85
      src/api/main.py

+ 1 - 0
.gitignore

@@ -37,6 +37,7 @@ Thumbs.db
 unified_dataset
 datasets
 runs
+batch_outputs
 best_saved_model
 history_archive
 palm_history.db

+ 69 - 40
demo_app.py

@@ -84,6 +84,50 @@ def show_tech_guide():
     | **Post-Processing** | None (NMS-Free) | Standard NMS |
     """)
 
+@st.dialog("📋 Batch Metadata Configuration")
+def configure_batch_metadata(uploaded_files):
+    st.write(f"Preparing to process **{len(uploaded_files)}** images.")
+    
+    col1, col2 = st.columns(2)
+    with col1:
+        estate = st.text_input("Estate / Venue", value="Estate Alpha")
+        block = st.text_input("Block ID", placeholder="e.g., B12")
+    with col2:
+        harvester = st.text_input("Harvester ID / Name")
+        priority = st.selectbox("Job Priority", ["Normal", "High", "Urgent"])
+
+    if st.button("🚀 Start Production Batch", type="primary", width='stretch'):
+        metadata = {
+            "estate": estate,
+            "block": block,
+            "harvester": harvester,
+            "priority": priority
+        }
+        
+        with st.spinner("Building Production Bundle..."):
+            files_payload = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
+            # Use engine_choice from session state to get the correct model_type
+            engine_map_rev = {
+                "YOLO26 (ONNX - High Speed)": "onnx",
+                "YOLO26 (PyTorch - Native)": "pytorch",
+                "YOLOv8-Sawit (Benchmark)": "yolov8_sawit"
+            }
+            selected_engine = st.session_state.get('engine_choice', "YOLO26 (ONNX - High Speed)")
+            data_payload = {
+                "model_type": engine_map_rev.get(selected_engine, "onnx"),
+                "metadata": json.dumps(metadata)
+            }
+            
+            try:
+                res = requests.post(f"{API_BASE_URL}/process_batch", files=files_payload, data=data_payload)
+                if res.status_code == 200:
+                    st.session_state.last_batch_results = res.json()
+                    st.rerun()
+                else:
+                    st.error(f"Batch Hand-off Failed: {res.text}")
+            except Exception as e:
+                st.error(f"Connection Error: {e}")
+
 # --- 1. Global Backend Check ---
 API_BASE_URL = "http://localhost:8000"
 
@@ -630,7 +674,31 @@ with tab2:
     if "last_batch_results" not in st.session_state:
         st.session_state.last_batch_results = None
 
-    # 2. Display Persisted Results (if any)
+    # 2. Uploader UI (Must be at top to avoid NameError during result persistence)
+    col_batch1, col_batch2 = st.columns([4, 1])
+    with col_batch1:
+        uploaded_files = st.file_uploader(
+            "Upload multiple images...", 
+            type=["jpg", "jpeg", "png"], 
+            accept_multiple_files=True, 
+            key=f"batch_{st.session_state.batch_uploader_key}",
+            on_change=reset_batch_results
+        )
+    
+    with col_batch2:
+        st.write("##") # Alignment
+        if st.session_state.last_batch_results is None and uploaded_files:
+            if st.button("🔍 Configure & Process Batch", type="primary", width='stretch'):
+                configure_batch_metadata(uploaded_files)
+
+        if st.button("🗑️ Reset Uploader"):
+            st.session_state.batch_uploader_key += 1
+            st.session_state.last_batch_results = None
+            st.rerun()
+
+    st.divider()
+
+    # 3. Display Persisted Results (if any)
     if st.session_state.last_batch_results:
         res_data = st.session_state.last_batch_results
         with st.container(border=True):
@@ -703,45 +771,6 @@ with tab2:
 
         st.divider()
 
-    # 3. Uploader UI
-    col_batch1, col_batch2 = st.columns([4, 1])
-    with col_batch1:
-        uploaded_files = st.file_uploader(
-            "Upload multiple images...", 
-            type=["jpg", "jpeg", "png"], 
-            accept_multiple_files=True, 
-            key=f"batch_{st.session_state.batch_uploader_key}",
-            on_change=reset_batch_results
-        )
-    
-    with col_batch2:
-        st.write("##") # Alignment
-        if st.session_state.last_batch_results is None and uploaded_files:
-            if st.button("🔍 Process Batch", type="primary", width='stretch'):
-                with st.spinner(f"Analyzing {len(uploaded_files)} images with {model_type.upper()}..."):
-                    files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
-                    payload = {"model_type": model_type}
-                    res = requests.post(f"{API_BASE_URL}/process_batch", files=files, data=payload)
-                    
-                    if res.status_code == 200:
-                        data = res.json()
-                        if data["status"] == "success":
-                            st.session_state.last_batch_results = data
-                            st.session_state.batch_uploader_key += 1
-                            st.rerun()
-                        elif data["status"] == "partial_success":
-                            st.warning(data["message"])
-                            st.info(f"Successfully detected {data['detections_count']} bunches locally.")
-                        else:
-                            st.error(f"Batch Error: {data['message']}")
-                    else:
-                        st.error(f"Batch Processing Failed: {res.text}")
-
-        if st.button("🗑️ Reset Uploader"):
-            st.session_state.batch_uploader_key += 1
-            st.session_state.last_batch_results = None
-            st.rerun()
-
 # --- Tab 3: Similarity Search ---
 with tab3:
     st.subheader("Hybrid Semantic Search")

+ 84 - 85
src/api/main.py

@@ -16,6 +16,8 @@ from src.infrastructure.repository import MongoPalmOilRepository
 from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
 import sqlite3
 import json
+import pandas as pd
+from datetime import datetime
 
 DB_PATH = "palm_history.db"
 ARCHIVE_DIR = "history_archive"
@@ -291,95 +293,92 @@ async def vectorize_and_store(file: UploadFile = File(...), detection_data: str
             os.remove(temp_path)
 
 @app.post("/process_batch")
-async def process_batch(files: List[UploadFile] = File(...), model_type: str = Form("onnx")):
-    """Handles multiple images: Detect -> Vectorize -> Store."""
-    if not db_connected:
-        # We could still do detection locally, but the prompt says 'Detect -> Vectorize -> Store'
-        # For simplicity in this demo, we'll block it if DB is offline.
-        return {"status": "error", "message": "Batch Processing (Cloud Archival) is currently unavailable (Database Offline)."}
-    batch_results = []
-    temp_files = []
-
-    try:
-        for file in files:
-            # 1. Save Temp
-            unique_id = uuid.uuid4().hex
-            path = f"temp_batch_{unique_id}_{file.filename}"
-            with open(path, "wb") as f_out:
-                shutil.copyfileobj(file.file, f_out)
-            temp_files.append(path)
-
-            import time
-            # 2. Detect
-            img = Image.open(path)
-            # FORCE PYTORCH for Batch
-            start_total = time.perf_counter()
-            detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
-            end_total = time.perf_counter()
-            
-            total_ms = (end_total - start_total) * 1000
-            processing_ms = total_ms - inference_ms
-            
-            # 3. Process all detections in the image
-            for det in detections:
-                batch_results.append({
-                    "path": path,
-                    "yolo": det,
-                    "engine": model_type, # Track engine
-                    "inference_ms": inference_ms,
-                    "raw_array_sample": raw_sample
-                })
-
-
-        if not batch_results:
-            return {"status": "no_detection", "message": "No bunches detected in batch"}
-
-        # Calculate Total Industrial Summary for the Batch
-        total_summary = {name: 0 for name in class_names.values()}
-        for item in batch_results:
-            total_summary[item['yolo']['class']] += 1
-
-
-        # 4. Process Batch Use Case with error handling for cloud services
-        detailed_detections = []
-        for item in batch_results:
-            detailed_detections.append({
-                "filename": os.path.basename(item['path']),
-                "detection": item['yolo'],
-                "inference_ms": item['inference_ms'],
-                "raw_array_sample": item['raw_array_sample']
+async def process_batch(
+    files: List[UploadFile] = File(...), 
+    model_type: str = Form("onnx"),
+    metadata: str = Form("{}") # JSON string from Frontend
+):
+    batch_id = f"BATCH_{uuid.uuid4().hex[:8].upper()}"
+    output_dir = os.path.join("batch_outputs", batch_id)
+    os.makedirs(os.path.join(output_dir, "raw"), exist_ok=True)
+    
+    meta_dict = json.loads(metadata)
+    batch_records = []
+    
+    for file in files:
+        unique_id = uuid.uuid4().hex[:6]
+        filename = f"{unique_id}_{file.filename}"
+        save_path = os.path.join(output_dir, "raw", filename)
+        
+        # 1. Save Raw Image to Bundle
+        image_bytes = await file.read()
+        with open(save_path, "wb") as f:
+            f.write(image_bytes)
+        
+        # 2. Run Inference
+        img = Image.open(io.BytesIO(image_bytes))
+        
+        # Selection logic based on existing API pattern
+        if model_type == "pytorch":
+            detections, raw_sample, inf_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
+        elif model_type == "yolov8_sawit":
+            detections, raw_sample, inf_ms = model_manager.run_pytorch_inference(img, current_conf, "yolov8_sawit")
+        else:
+            detections, raw_sample, inf_ms = model_manager.run_onnx_inference(img, current_conf)
+        
+        # 3. Normalize Coordinates for the Contract
+        # Downstream processes shouldn't care about your input resolution
+        w, h = img.size
+        normalized_dets = []
+        for d in detections:
+            x1, y1, x2, y2 = d['box']
+            normalized_dets.append({
+                **d,
+                "norm_box": [x1/w, y1/h, x2/w, y2/h] 
             })
 
-        try:
-            record_ids = analyze_batch_use_case.execute(batch_results)
-            total_records = len(record_ids)
-            return {
-                "status": "success", 
-                "processed_count": total_records, 
-                "total_count": sum(total_summary.values()),
-                "record_ids": record_ids,
-                "industrial_summary": total_summary,
-                "detailed_results": detailed_detections,
-                "message": f"Successfully processed {total_records} images and identified {sum(total_summary.values())} bunches"
-            }
-
-        except RuntimeError as e:
-            return {
-                "status": "partial_success",
-                "message": f"Detections completed, but cloud archival failed: {str(e)}",
-                "detections_count": len(batch_results),
-                "detailed_results": detailed_detections
-            }
-
+        batch_records.append({
+            "image_id": unique_id,
+            "filename": filename,
+            "detections": normalized_dets,
+            "inference_ms": inf_ms
+        })
+
+    # 4. Generate the Manifest (The Contract)
+    manifest = {
+        "job_id": batch_id,
+        "timestamp": datetime.now().isoformat(),
+        "source_context": meta_dict,
+        "engine": {
+            "name": "YOLO26",
+            "type": model_type,
+            "threshold": current_conf
+        },
+        "inventory": batch_records
+    }
+    
+    with open(os.path.join(output_dir, "manifest.json"), "w") as f:
+        json.dump(manifest, f, indent=4)
 
-    except Exception as e:
-        return {"status": "error", "message": f"Batch processing failed: {str(e)}"}
+    # Note: Maintaining compatibility with the frontend's expectation of 'industrial_summary'
+    # and 'processed_count' for immediate UI feedback.
+    active_names = model_manager.class_names if model_type != "yolov8_sawit" else model_manager.benchmark_class_names
+    total_summary = {name: 0 for name in active_names.values()}
+    for record in batch_records:
+        for det in record['detections']:
+            total_summary[det['class']] += 1
 
-    finally:
-        # 5. Clean up all temp files
-        for path in temp_files:
-            if os.path.exists(path):
-                os.remove(path)
+    return {
+        "status": "success",
+        "batch_id": batch_id,
+        "bundle_path": output_dir,
+        "processed_count": len(files),
+        "total_count": sum(total_summary.values()),
+        "industrial_summary": total_summary,
+        "record_ids": [r['image_id'] for r in batch_records], # Backward compatibility
+        "manifest_preview": manifest,
+        "detailed_results": [{"filename": r['filename'], "detection": d} for r in batch_records for d in r['detections']] # Backward compatibility
+    }
 @app.post("/search_hybrid")
 async def search_hybrid(
     file: Optional[UploadFile] = File(None),