import streamlit as st import requests from ultralytics import YOLO import numpy as np from PIL import Image import io import base64 import pandas as pd import plotly.express as px # --- 1. Global Backend Check --- API_BASE_URL = "http://localhost:8000" def check_backend(): try: res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2) return res.status_code == 200 except: return False backend_active = check_backend() # Load YOLO model locally for Analytical View @st.cache_resource def load_yolo(): return YOLO('best.pt') yolo_model = load_yolo() if not backend_active: st.error("⚠️ Backend API is offline!") st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.") if st.button("🔄 Retry Connection"): st.rerun() st.stop() # Stops execution here, effectively disabling the app # --- 2. Main Page Config (Only rendered if backend is active) --- st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide") st.title("🌴 Palm Oil FFB Management System") st.markdown("### Production-Ready AI Analysis & Archival") # --- Sidebar --- st.sidebar.header("Backend Controls") def update_confidence(): new_conf = st.session_state.conf_slider try: requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf}) st.toast(f"Threshold updated to {new_conf}") except: st.sidebar.error("Failed to update threshold") # We already know backend is up here response = requests.get(f"{API_BASE_URL}/get_confidence") current_conf = response.json().get("current_confidence", 0.25) st.sidebar.success(f"Connected to API") # Synchronized Slider st.sidebar.slider( "Confidence Threshold", 0.1, 1.0, value=float(current_conf), key="conf_slider", on_change=update_confidence ) # --- Tabs --- tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"]) # --- Tab 1: Single Analysis --- with tab1: st.subheader("Analyze Single Bunch") uploaded_file = st.file_uploader("Upload a bunch image...", type=["jpg", "jpeg", "png"], key="single") if uploaded_file: # State initialization if "last_detection" not in st.session_state: st.session_state.last_detection = None # 1. Action Button (Centered and Prominent) st.write("##") _, col_btn, _ = st.columns([1, 2, 1]) if col_btn.button("🔍 Run Ripeness Detection", type="primary", width='stretch'): with st.spinner("Processing Detections Locally..."): files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)} res = requests.post(f"{API_BASE_URL}/analyze", files=files) if res.status_code == 200: st.session_state.last_detection = res.json() else: st.error(f"Detection Failed: {res.text}") # 2. Results Layout if st.session_state.last_detection: st.divider() # SIDE-BY-SIDE ANALYTICAL VIEW col_left, col_right = st.columns(2) with col_left: st.image(uploaded_file, caption="Original Photo", width='stretch') with col_right: # Use the local model to plot the boxes directly img = Image.open(uploaded_file) results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4) annotated_img = results[0].plot() # Draws boxes/labels # Convert BGR (OpenCV format) to RGB for Streamlit annotated_img_rgb = annotated_img[:, :, ::-1] st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch') col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below with col2: data = st.session_state.last_detection with st.container(border=True): st.write("### 🏷️ Detection Results") if not data['detections']: st.warning("No Fresh Fruit Bunches detected.") else: for det in data['detections']: st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence") st.write("### 📊 Harvest Quality Mix") # Convert industrial_summary dictionary to a DataFrame for charting summary_df = pd.DataFrame( list(data['industrial_summary'].items()), columns=['Grade', 'Count'] ) # Filter out classes with 0 count for a cleaner chart summary_df = summary_df[summary_df['Count'] > 0] if not summary_df.empty: # Create a Pie Chart to show the proportion of each grade fig = px.pie(summary_df, values='Count', names='Grade', color='Grade', color_discrete_map={ 'Abnormal': '#ef4444', # Red 'Empty_Bunch': '#94a3b8', # Gray 'Ripe': '#22c55e', # Green 'Underripe': '#eab308', # Yellow 'Unripe': '#3b82f6', # Blue 'Overripe': '#a855f7' # Purple }, hole=0.4) fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300) st.plotly_chart(fig, width='stretch') # High-Priority Health Alert if data['industrial_summary'].get('Abnormal', 0) > 0: st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!") if data['industrial_summary'].get('Empty_Bunch', 0) > 0: st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.") # 3. Cloud Actions (Only if detections found) st.write("---") st.write("#### ✨ Cloud Archive") if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'): with st.spinner("Archiving..."): import json primary_det = data['detections'][0] payload = {"detection_data": json.dumps(primary_det)} files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)} res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload) if res_cloud.status_code == 200: res_json = res_cloud.json() if res_json["status"] == "success": st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`") else: st.error(f"Cloud Error: {res_json['message']}") else: st.error("Failed to connect to cloud service") # --- Tab 2: Batch Processing --- with tab2: st.subheader("Bulk Analysis") # 1. Initialize Session State if "batch_uploader_key" not in st.session_state: st.session_state.batch_uploader_key = 0 if "last_batch_results" not in st.session_state: st.session_state.last_batch_results = None # 2. Display Persisted Results (if any) if st.session_state.last_batch_results: res_data = st.session_state.last_batch_results with st.container(border=True): st.success(f"✅ Successfully processed {res_data['processed_count']} images.") # Batch Summary Dashboard st.write("### 📈 Batch Quality Overview") batch_summary = res_data.get('industrial_summary', {}) if batch_summary: sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count']) sum_df = sum_df[sum_df['Count'] > 0] b_col1, b_col2 = st.columns([1, 1]) with b_col1: st.dataframe(sum_df, hide_index=True, width='stretch') with b_col2: if not sum_df.empty: fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade', color_discrete_map={ 'Abnormal': '#ef4444', 'Empty_Bunch': '#94a3b8', 'Ripe': '#22c55e' }) fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False) st.plotly_chart(fig_batch, width='stretch') if batch_summary.get('Abnormal', 0) > 0: st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!") st.write("Generated Record IDs:") st.code(res_data['record_ids']) if st.button("Clear Results & Start New Batch"): st.session_state.last_batch_results = None st.rerun() st.divider() # 3. Uploader UI col_batch1, col_batch2 = st.columns([4, 1]) with col_batch1: uploaded_files = st.file_uploader( "Upload multiple images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True, key=f"batch_{st.session_state.batch_uploader_key}" ) with col_batch2: st.write("##") # Alignment if st.button("🗑️ Reset Uploader"): st.session_state.batch_uploader_key += 1 st.rerun() if uploaded_files: if st.button(f"🚀 Process {len(uploaded_files)} Images"): with st.spinner("Batch Processing in progress..."): files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files] res = requests.post(f"{API_BASE_URL}/process_batch", files=files) if res.status_code == 200: data = res.json() if data["status"] == "success": st.session_state.last_batch_results = data st.session_state.batch_uploader_key += 1 st.rerun() elif data["status"] == "partial_success": st.warning(data["message"]) st.info(f"Successfully detected {data['detections_count']} bunches locally.") else: st.error(f"Batch Error: {data['message']}") else: st.error(f"Batch Failed: {res.text}") # --- Tab 3: Similarity Search --- with tab3: st.subheader("Hybrid Semantic Search") st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.") with st.form("hybrid_search_form"): col_input1, col_input2 = st.columns(2) with col_input1: search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search") with col_input2: text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'") top_k = st.slider("Results Limit (Top K)", 1, 20, 3) submit_search = st.form_submit_button("Run Semantic Search") if submit_search: if not search_file and not text_query: st.warning("Please provide either an image or a text query.") else: with st.spinner("Searching Vector Index..."): payload = {"limit": top_k} # If an image is uploaded, it takes precedence for visual search if search_file: files = {"file": (search_file.name, search_file.getvalue(), search_file.type)} # Pass top_k as part of the data res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload) # Otherwise, use text query elif text_query: payload["text_query"] = text_query # Send as form-data (data=) to match FastAPI's Form(None) res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload) if res.status_code == 200: results = res.json().get("results", []) if not results: st.warning("No similar records found.") else: st.success(f"Found {len(results)} matches.") for item in results: with st.container(border=True): c1, c2 = st.columns([1, 2]) # Fetch the image for this result rec_id = item["_id"] img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}") with c1: if img_res.status_code == 200: img_b64 = img_res.json().get("image_data") if img_b64: st.image(base64.b64decode(img_b64), width=250) else: st.write("No image data found.") else: st.write("Failed to load image.") with c2: st.write(f"**Class:** {item['ripeness_class']}") st.write(f"**Similarity Score:** {item['score']:.4f}") st.write(f"**Timestamp:** {item['timestamp']}") st.write(f"**ID:** `{rec_id}`") else: st.error(f"Search failed: {res.text}")