demo_app.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. import plotly.graph_objects as go
  11. import json
  12. import os
  13. from datetime import datetime
  14. from fpdf import FPDF
  15. # --- 1. Global Backend Check ---
  16. API_BASE_URL = "http://localhost:8000"
  17. def check_backend():
  18. try:
  19. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  20. return res.status_code == 200
  21. except:
  22. return False
  23. backend_active = check_backend()
  24. # LOCAL MODEL LOADING REMOVED (YOLO26 Clean Sweep)
  25. # UI now relies entirely on Backend API for NMS-Free inference.
  26. if not backend_active:
  27. st.error("⚠️ Backend API is offline!")
  28. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  29. if st.button("🔄 Retry Connection"):
  30. st.rerun()
  31. st.stop() # Stops execution here, effectively disabling the app
  32. # --- 2. Main Page Config (Only rendered if backend is active) ---
  33. st.set_page_config(page_title="Palm Oil Ripeness AI (YOLO26)", layout="wide")
  34. st.title("🌴 Palm Oil FFB Management System")
  35. st.markdown("### Production-Ready AI Analysis & Archival")
  36. # --- Sidebar ---
  37. st.sidebar.header("Backend Controls")
  38. def update_confidence():
  39. new_conf = st.session_state.conf_slider
  40. try:
  41. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  42. st.toast(f"Threshold updated to {new_conf}")
  43. except:
  44. st.sidebar.error("Failed to update threshold")
  45. # We already know backend is up here
  46. response = requests.get(f"{API_BASE_URL}/get_confidence")
  47. current_conf = response.json().get("current_confidence", 0.25)
  48. st.sidebar.success(f"Connected to API")
  49. st.sidebar.info("Engine: YOLO26 NMS-Free (Inference: ~39ms)")
  50. # Synchronized Slider
  51. st.sidebar.slider(
  52. "Confidence Threshold",
  53. 0.1, 1.0,
  54. value=float(current_conf),
  55. key="conf_slider",
  56. on_change=update_confidence
  57. )
  58. # Helper to reset results when files change
  59. def reset_single_results():
  60. st.session_state.last_detection = None
  61. def reset_batch_results():
  62. st.session_state.last_batch_results = None
  63. # MPOB Color Map for Overlays (Global for consistency)
  64. overlay_colors = {
  65. 'Ripe': '#22c55e', # Industrial Green
  66. 'Underripe': '#fbbf24', # Industrial Orange
  67. 'Unripe': '#3b82f6', # Industrial Blue
  68. 'Abnormal': '#dc2626', # Critical Red
  69. 'Empty_Bunch': '#64748b',# Waste Gray
  70. 'Overripe': '#7c2d12' # Dark Brown/Orange
  71. }
  72. def display_interactive_results(image, detections, key=None):
  73. """Renders image with interactive hover-boxes using Plotly."""
  74. img_width, img_height = image.size
  75. fig = go.Figure()
  76. # Add the palm image as the background
  77. fig.add_layout_image(
  78. dict(source=image, x=0, y=img_height, sizex=img_width, sizey=img_height,
  79. sizing="stretch", opacity=1, layer="below", xref="x", yref="y")
  80. )
  81. # Configure axes to match image dimensions
  82. fig.update_xaxes(showgrid=False, range=(0, img_width), zeroline=False, visible=False)
  83. fig.update_yaxes(showgrid=False, range=(0, img_height), zeroline=False, visible=False, scaleanchor="x")
  84. # Add interactive boxes
  85. for i, det in enumerate(detections):
  86. x1, y1, x2, y2 = det['box']
  87. # Plotly y-axis is inverted relative to PIL, so we flip y
  88. y_top, y_bottom = img_height - y1, img_height - y2
  89. color = overlay_colors.get(det['class'], "#ffeb3b")
  90. # The 'Hover' shape
  91. bunch_id = det.get('bunch_id', i+1)
  92. fig.add_trace(go.Scatter(
  93. x=[x1, x2, x2, x1, x1],
  94. y=[y_top, y_top, y_bottom, y_bottom, y_top],
  95. fill="toself",
  96. fillcolor=color,
  97. opacity=0.3, # Semi-transparent until hover
  98. mode='lines',
  99. line=dict(color=color, width=3),
  100. name=f"Bunch #{bunch_id}",
  101. text=f"<b>ID: #{bunch_id}</b><br>Grade: {det['class']}<br>Score: {det['confidence']:.2f}<br>Alert: {det['is_health_alert']}",
  102. hoverinfo="text"
  103. ))
  104. fig.update_layout(width=800, height=600, margin=dict(l=0, r=0, b=0, t=0), showlegend=False)
  105. st.plotly_chart(fig, use_container_width=True, key=key)
  106. def annotate_image(image, detections):
  107. """Draws high-visibility boxes and background-shaded labels."""
  108. from PIL import ImageDraw, ImageFont
  109. draw = ImageDraw.Draw(image)
  110. # Dynamic font size based on image resolution
  111. font_size = max(20, image.width // 40)
  112. try:
  113. font_path = "C:\\Windows\\Fonts\\arial.ttf"
  114. if os.path.exists(font_path):
  115. font = ImageFont.truetype(font_path, font_size)
  116. else:
  117. font = ImageFont.load_default()
  118. except:
  119. font = ImageFont.load_default()
  120. for det in detections:
  121. box = det['box'] # [x1, y1, x2, y2]
  122. cls = det['class']
  123. conf = det['confidence']
  124. bunch_id = det.get('bunch_id', '?')
  125. color = overlay_colors.get(cls, '#ffffff')
  126. # 1. Draw Bold Bounding Box
  127. draw.rectangle(box, outline=color, width=max(4, image.width // 200))
  128. # 2. Draw Label Background (High Contrast)
  129. label = f"#{bunch_id} {cls} {conf:.2f}"
  130. try:
  131. # textbbox provides precise coordinates for background rectangle
  132. l, t, r, b = draw.textbbox((box[0], box[1] - font_size - 10), label, font=font)
  133. draw.rectangle([l-5, t-5, r+5, b+5], fill=color)
  134. draw.text((l, t), label, fill="white", font=font)
  135. except:
  136. # Fallback for basic text drawing
  137. draw.text((box[0], box[1] - 25), label, fill=color)
  138. return image
  139. def generate_batch_report(data, uploaded_files_map=None):
  140. """Generates a professional PDF report for batch results with visual evidence."""
  141. from PIL import ImageDraw
  142. pdf = FPDF()
  143. pdf.add_page()
  144. pdf.set_font("Arial", "B", 16)
  145. pdf.cell(190, 10, "Palm Oil FFB Harvest Quality Report", ln=True, align="C")
  146. pdf.set_font("Arial", "", 12)
  147. pdf.cell(190, 10, f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align="C")
  148. pdf.ln(10)
  149. # 1. Summary Table
  150. pdf.set_font("Arial", "B", 14)
  151. pdf.cell(190, 10, "1. Batch Summary", ln=True)
  152. pdf.set_font("Arial", "", 12)
  153. summary = data.get('industrial_summary', {})
  154. total_bunches = data.get('total_count', 0)
  155. pdf.cell(95, 10, "Metric", border=1)
  156. pdf.cell(95, 10, "Value", border=1, ln=True)
  157. pdf.cell(95, 10, "Total Bunches Detected", border=1)
  158. pdf.cell(95, 10, str(total_bunches), border=1, ln=True)
  159. for grade, count in summary.items():
  160. if count > 0:
  161. pdf.cell(95, 10, f"Grade: {grade}", border=1)
  162. pdf.cell(95, 10, str(count), border=1, ln=True)
  163. pdf.ln(10)
  164. # 2. Strategic Insights
  165. pdf.set_font("Arial", "B", 14)
  166. pdf.cell(190, 10, "2. Strategic Yield Insights", ln=True)
  167. pdf.set_font("Arial", "", 12)
  168. unripe = summary.get('Unripe', 0)
  169. underripe = summary.get('Underripe', 0)
  170. loss = unripe + underripe
  171. if loss > 0:
  172. pdf.multi_cell(190, 10, f"WARNING: {loss} bunches were harvested before peak ripeness. "
  173. "This directly impacts the Oil Extraction Rate (OER) and results in potential yield loss.")
  174. else:
  175. pdf.multi_cell(190, 10, "EXCELLENT: All detected bunches meet prime ripeness standards. Harvest efficiency is 100%.")
  176. # Critical Alerts
  177. abnormal = summary.get('Abnormal', 0)
  178. empty = summary.get('Empty_Bunch', 0)
  179. if abnormal > 0 or empty > 0:
  180. pdf.ln(5)
  181. pdf.set_text_color(220, 0, 0)
  182. pdf.set_font("Arial", "B", 12)
  183. pdf.cell(190, 10, "CRITICAL HEALTH ALERTS:", ln=True)
  184. pdf.set_font("Arial", "", 12)
  185. if abnormal > 0:
  186. pdf.cell(190, 10, f"- {abnormal} Abnormal Bunches detected (Requires immediate field inspection).", ln=True)
  187. if empty > 0:
  188. pdf.cell(190, 10, f"- {empty} Empty Bunches detected (Waste reduction needed).", ln=True)
  189. pdf.set_text_color(0, 0, 0)
  190. # 3. Visual Evidence Section
  191. if 'detailed_results' in data and uploaded_files_map:
  192. pdf.add_page()
  193. pdf.set_font("Arial", "B", 14)
  194. pdf.cell(190, 10, "3. Visual Batch Evidence (AI Overlay)", ln=True)
  195. pdf.ln(5)
  196. # Group detections by filename
  197. results_by_file = {}
  198. for res in data['detailed_results']:
  199. fname = res['filename']
  200. if fname not in results_by_file:
  201. results_by_file[fname] = []
  202. results_by_file[fname].append(res['detection'])
  203. for fname, detections in results_by_file.items():
  204. if fname in uploaded_files_map:
  205. img_bytes = uploaded_files_map[fname]
  206. img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
  207. draw = ImageDraw.Draw(img)
  208. # Drawing annotated boxes for PDF using high-visibility utility
  209. annotate_image(img, detections)
  210. # Save to temp file for PDF
  211. temp_img_path = f"temp_report_{fname}"
  212. img.save(temp_img_path)
  213. # Check if we need a new page based on image height (rough estimate)
  214. if pdf.get_y() > 200:
  215. pdf.add_page()
  216. pdf.image(temp_img_path, x=10, w=150)
  217. pdf.set_font("Arial", "I", 10)
  218. pdf.cell(190, 10, f"Annotated: {fname}", ln=True)
  219. pdf.ln(5)
  220. os.remove(temp_img_path)
  221. # Footer
  222. pdf.set_y(-15)
  223. pdf.set_font("Arial", "I", 8)
  224. pdf.cell(190, 10, "Generated by Palm Oil AI Desktop PoC - YOLO26 Engine", align="C")
  225. return pdf.output(dest='S')
  226. # --- Tabs ---
  227. tab1, tab2, tab3, tab4 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search", "History Vault"])
  228. # --- Tab 1: Single Analysis ---
  229. with tab1:
  230. st.subheader("Analyze Single Bunch")
  231. uploaded_file = st.file_uploader(
  232. "Upload a bunch image...",
  233. type=["jpg", "jpeg", "png"],
  234. key="single",
  235. on_change=reset_single_results
  236. )
  237. if uploaded_file:
  238. # State initialization
  239. if "last_detection" not in st.session_state:
  240. st.session_state.last_detection = None
  241. # 1. Auto-Detection Trigger
  242. if uploaded_file and st.session_state.last_detection is None:
  243. with st.spinner("Processing Detections Locally..."):
  244. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  245. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  246. if res.status_code == 200:
  247. st.session_state.last_detection = res.json()
  248. st.rerun() # Refresh to show results immediately
  249. else:
  250. st.error(f"Detection Failed: {res.text}")
  251. # 2. Results Layout
  252. if st.session_state.last_detection:
  253. st.divider()
  254. # PRIMARY ANNOTATED VIEW
  255. st.write("### 🔍 AI Analytical View")
  256. data = st.session_state.last_detection
  257. img = Image.open(uploaded_file).convert("RGB")
  258. display_interactive_results(img, data['detections'], key="main_viewer")
  259. # Visual Legend
  260. st.write("#### 🎨 Ripeness Legend")
  261. l_cols = st.columns(len(overlay_colors))
  262. for i, (grade, color) in enumerate(overlay_colors.items()):
  263. with l_cols[i]:
  264. st.markdown(f'<div style="background-color:{color}; padding:10px; border-radius:5px; text-align:center; color:white; font-weight:bold;">{grade}</div>', unsafe_allow_html=True)
  265. st.divider()
  266. st.write("### 📈 Manager's Dashboard")
  267. m_col1, m_col2, m_col3 = st.columns(3)
  268. with m_col1:
  269. st.metric("Total Bunches", data.get('total_count', 0))
  270. with m_col2:
  271. st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
  272. with m_col3:
  273. abnormal = data['industrial_summary'].get('Abnormal', 0)
  274. st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
  275. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  276. with col2:
  277. with st.container(border=True):
  278. st.write("### 🏷️ Detection Results")
  279. if not data['detections']:
  280. st.warning("No Fresh Fruit Bunches detected.")
  281. else:
  282. for det in data['detections']:
  283. st.info(f"### Bunch #{det['bunch_id']}: {det['class']} ({det['confidence']:.2%})")
  284. st.write("### 📊 Harvest Quality Mix")
  285. # Convert industrial_summary dictionary to a DataFrame for charting
  286. summary_df = pd.DataFrame(
  287. list(data['industrial_summary'].items()),
  288. columns=['Grade', 'Count']
  289. )
  290. # Filter out classes with 0 count for a cleaner chart
  291. summary_df = summary_df[summary_df['Count'] > 0]
  292. if not summary_df.empty:
  293. # Create a Pie Chart to show the proportion of each grade
  294. fig = px.pie(summary_df, values='Count', names='Grade',
  295. color='Grade',
  296. color_discrete_map={
  297. 'Ripe': '#22c55e', # Industrial Green
  298. 'Underripe': '#fbbf24', # Industrial Orange
  299. 'Unripe': '#3b82f6', # Industrial Blue
  300. 'Abnormal': '#dc2626', # Critical Red
  301. 'Empty_Bunch': '#64748b' # Waste Gray
  302. },
  303. hole=0.4)
  304. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  305. st.plotly_chart(fig, width='stretch', key="single_pie")
  306. # 💡 Strategic R&D Insight: Harvest Efficiency
  307. st.write("---")
  308. st.write("#### 💡 Strategic R&D Insight")
  309. unripe_count = data['industrial_summary'].get('Unripe', 0)
  310. underripe_count = data['industrial_summary'].get('Underripe', 0)
  311. total_non_prime = unripe_count + underripe_count
  312. st.write(f"🌑 **Unripe (Mentah):** {unripe_count}")
  313. st.write(f"🌗 **Underripe (Kurang Masak):** {underripe_count}")
  314. if total_non_prime > 0:
  315. st.warning(f"🚨 **Potential Yield Loss:** {total_non_prime} bunches harvested too early. This will reduce OER (Oil Extraction Rate).")
  316. else:
  317. st.success("✅ **Harvest Efficiency:** 100% Prime Ripeness detected.")
  318. # High-Priority Health Alert
  319. if data['industrial_summary'].get('Abnormal', 0) > 0:
  320. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  321. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  322. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  323. # 3. Cloud Actions (Only if detections found)
  324. st.write("---")
  325. st.write("#### ✨ Cloud Archive")
  326. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  327. with st.spinner("Archiving..."):
  328. import json
  329. primary_det = data['detections'][0]
  330. payload = {"detection_data": json.dumps(primary_det)}
  331. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  332. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  333. if res_cloud.status_code == 200:
  334. res_json = res_cloud.json()
  335. if res_json["status"] == "success":
  336. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  337. else:
  338. st.error(f"Cloud Error: {res_json['message']}")
  339. else:
  340. st.error("Failed to connect to cloud service")
  341. if st.button("🚩 Flag Misclassification", width='stretch', type="secondary"):
  342. # Save to local feedback folder
  343. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  344. feedback_id = f"fb_{timestamp}"
  345. img_path = f"feedback/{feedback_id}.jpg"
  346. json_path = f"feedback/{feedback_id}.json"
  347. # Save image
  348. Image.open(uploaded_file).save(img_path)
  349. # Save metadata
  350. feedback_data = {
  351. "original_filename": uploaded_file.name,
  352. "timestamp": timestamp,
  353. "detections": data['detections'],
  354. "threshold_used": data['current_threshold']
  355. }
  356. with open(json_path, "w") as f:
  357. json.dump(feedback_data, f, indent=4)
  358. st.toast("✅ Feedback saved to local vault!", icon="🚩")
  359. if st.button("💾 Local History Vault (Auto-Saved)", width='stretch', type="secondary", disabled=True):
  360. pass
  361. st.caption("✅ This analysis was automatically archived to the local vault.")
  362. # --- Tab 2: Batch Processing ---
  363. with tab2:
  364. st.subheader("Bulk Analysis")
  365. # 1. Initialize Session State
  366. if "batch_uploader_key" not in st.session_state:
  367. st.session_state.batch_uploader_key = 0
  368. if "last_batch_results" not in st.session_state:
  369. st.session_state.last_batch_results = None
  370. # 2. Display Persisted Results (if any)
  371. if st.session_state.last_batch_results:
  372. res_data = st.session_state.last_batch_results
  373. with st.container(border=True):
  374. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  375. # Batch Summary Dashboard
  376. st.write("### 📈 Batch Quality Overview")
  377. batch_summary = res_data.get('industrial_summary', {})
  378. if batch_summary:
  379. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  380. sum_df = sum_df[sum_df['Count'] > 0]
  381. b_col1, b_col2 = st.columns([1, 1])
  382. with b_col1:
  383. st.dataframe(sum_df, hide_index=True, width='stretch')
  384. with b_col2:
  385. if not sum_df.empty:
  386. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  387. color_discrete_map={
  388. 'Ripe': '#22c55e',
  389. 'Underripe': '#fbbf24',
  390. 'Unripe': '#3b82f6',
  391. 'Abnormal': '#dc2626',
  392. 'Empty_Bunch': '#64748b'
  393. })
  394. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  395. st.plotly_chart(fig_batch, width='stretch', key="batch_bar")
  396. if batch_summary.get('Abnormal', 0) > 0:
  397. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  398. st.write("Generated Record IDs:")
  399. st.code(res_data['record_ids'])
  400. # --- 4. Batch Evidence Gallery ---
  401. st.write("### 🖼️ Detailed Detection Evidence")
  402. if 'detailed_results' in res_data:
  403. # Group results by filename for gallery
  404. gallery_map = {}
  405. for res in res_data['detailed_results']:
  406. fname = res['filename']
  407. if fname not in gallery_map:
  408. gallery_map[fname] = []
  409. gallery_map[fname].append(res['detection'])
  410. # Show images with overlays using consistent utility
  411. for up_file in uploaded_files:
  412. if up_file.name in gallery_map:
  413. with st.container(border=True):
  414. g_img = Image.open(up_file).convert("RGB")
  415. g_annotated = annotate_image(g_img, gallery_map[up_file.name])
  416. st.image(g_annotated, caption=f"Evidence: {up_file.name}", use_container_width=True)
  417. # PDF Export Button (Pass images map)
  418. files_map = {f.name: f.getvalue() for f in uploaded_files}
  419. pdf_bytes = generate_batch_report(res_data, files_map)
  420. st.download_button(
  421. label="📄 Download Executive Batch Report (PDF)",
  422. data=pdf_bytes,
  423. file_name=f"PalmOil_BatchReport_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf",
  424. mime="application/pdf",
  425. width='stretch'
  426. )
  427. if st.button("Clear Results & Start New Batch", width='stretch'):
  428. st.session_state.last_batch_results = None
  429. st.rerun()
  430. st.divider()
  431. # 3. Uploader UI
  432. col_batch1, col_batch2 = st.columns([4, 1])
  433. with col_batch1:
  434. uploaded_files = st.file_uploader(
  435. "Upload multiple images...",
  436. type=["jpg", "jpeg", "png"],
  437. accept_multiple_files=True,
  438. key=f"batch_{st.session_state.batch_uploader_key}",
  439. on_change=reset_batch_results
  440. )
  441. with col_batch2:
  442. st.write("##") # Alignment
  443. if st.session_state.last_batch_results is None and uploaded_files:
  444. if st.button("🔍 Process Batch", type="primary", width='stretch'):
  445. with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
  446. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  447. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  448. if res.status_code == 200:
  449. data = res.json()
  450. if data["status"] == "success":
  451. st.session_state.last_batch_results = data
  452. st.session_state.batch_uploader_key += 1
  453. st.rerun()
  454. elif data["status"] == "partial_success":
  455. st.warning(data["message"])
  456. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  457. else:
  458. st.error(f"Batch Error: {data['message']}")
  459. else:
  460. st.error(f"Batch Processing Failed: {res.text}")
  461. if st.button("🗑️ Reset Uploader"):
  462. st.session_state.batch_uploader_key += 1
  463. st.session_state.last_batch_results = None
  464. st.rerun()
  465. # --- Tab 3: Similarity Search ---
  466. with tab3:
  467. st.subheader("Hybrid Semantic Search")
  468. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  469. with st.form("hybrid_search_form"):
  470. col_input1, col_input2 = st.columns(2)
  471. with col_input1:
  472. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  473. with col_input2:
  474. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  475. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  476. submit_search = st.form_submit_button("Run Semantic Search")
  477. if submit_search:
  478. if not search_file and not text_query:
  479. st.warning("Please provide either an image or a text query.")
  480. else:
  481. with st.spinner("Searching Vector Index..."):
  482. payload = {"limit": top_k}
  483. # If an image is uploaded, it takes precedence for visual search
  484. if search_file:
  485. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  486. # Pass top_k as part of the data
  487. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  488. # Otherwise, use text query
  489. elif text_query:
  490. payload["text_query"] = text_query
  491. # Send as form-data (data=) to match FastAPI's Form(None)
  492. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  493. if res.status_code == 200:
  494. results = res.json().get("results", [])
  495. if not results:
  496. st.warning("No similar records found.")
  497. else:
  498. st.success(f"Found {len(results)} matches.")
  499. for item in results:
  500. with st.container(border=True):
  501. c1, c2 = st.columns([1, 2])
  502. # Fetch the image for this result
  503. rec_id = item["_id"]
  504. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  505. with c1:
  506. if img_res.status_code == 200:
  507. img_b64 = img_res.json().get("image_data")
  508. if img_b64:
  509. st.image(base64.b64decode(img_b64), width=250)
  510. else:
  511. st.write("No image data found.")
  512. else:
  513. st.write("Failed to load image.")
  514. with c2:
  515. st.write(f"**Class:** {item['ripeness_class']}")
  516. st.write(f"**Similarity Score:** {item['score']:.4f}")
  517. st.write(f"**Timestamp:** {item['timestamp']}")
  518. st.write(f"**ID:** `{rec_id}`")
  519. else:
  520. st.error(f"Search failed: {res.text}")
  521. # --- Tab 4: History Vault ---
  522. with tab4:
  523. st.subheader("📜 Local History Vault")
  524. try:
  525. res = requests.get(f"{API_BASE_URL}/get_history")
  526. if res.status_code == 200:
  527. history_data = res.json().get("history", [])
  528. if not history_data:
  529. st.info("No saved records found.")
  530. else:
  531. # Selection table
  532. df_history = pd.DataFrame(history_data)[['id', 'filename', 'timestamp']]
  533. selected_id = st.selectbox("Select a record to review:", df_history['id'])
  534. if selected_id:
  535. record = next(item for item in history_data if item["id"] == selected_id)
  536. detections = json.loads(record['detections'])
  537. # Display Interactive Hover View
  538. if os.path.exists(record['archive_path']):
  539. with open(record['archive_path'], "rb") as f:
  540. hist_img = Image.open(f).convert("RGB")
  541. display_interactive_results(hist_img, detections, key=f"hist_{record['id']}")
  542. st.write("### 📈 Archived Summary")
  543. summary = json.loads(record['summary'])
  544. s_col1, s_col2, s_col3 = st.columns(3)
  545. with s_col1:
  546. st.metric("Total Bunches", sum(summary.values()))
  547. with s_col2:
  548. st.metric("Healthy (Ripe)", summary.get('Ripe', 0))
  549. with s_col3:
  550. abnormal = summary.get('Abnormal', 0)
  551. st.metric("Abnormal Alerts", abnormal)
  552. else:
  553. st.error(f"Archive file not found: {record['archive_path']}")
  554. else:
  555. st.error(f"Failed to fetch history: {res.text}")
  556. except Exception as e:
  557. st.error(f"Error loading history: {str(e)}")