demo_app.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. import plotly.graph_objects as go
  11. import json
  12. import os
  13. from datetime import datetime
  14. from fpdf import FPDF
  15. @st.dialog("📘 AI Interpretation Guide")
  16. def show_tech_guide():
  17. st.write("### 🎯 What does 'Confidence' mean?")
  18. st.write("""
  19. This is a probability score from **0.0 to 1.0**.
  20. - **0.90+**: The AI is nearly certain this is a bunch of this grade.
  21. - **0.25 (Threshold)**: We ignore anything below this to filter out 'ghost' detections or background noise.
  22. """)
  23. st.write("### 🛠️ The Raw Mathematical Tensor")
  24. st.write("The AI returns a raw array of shape `[1, 300, 6]`. Here is the key:")
  25. st.table({
  26. "Index": ["0-3", "4", "5"],
  27. "Meaning": ["Coordinates (x1, y1, x2, y2)", "Confidence Score", "Class ID (0-5)"],
  28. "Reality": ["The 'Box' in the image.", "The AI's certainty.", "The Ripeness Grade."]
  29. })
  30. st.write("### ⚡ Inference vs. Processing Time")
  31. st.write("""
  32. - **Inference Speed**: The time the AI model took to 'think' about the pixels.
  33. - **Total Time**: Includes image uploading and database saving overhead.
  34. """)
  35. st.info("💡 **Engine Note**: ONNX is optimized for latency (~39ms), while PyTorch offers native indicator flexibility.")
  36. # --- 1. Global Backend Check ---
  37. API_BASE_URL = "http://localhost:8000"
  38. def check_backend():
  39. try:
  40. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  41. return res.status_code == 200
  42. except:
  43. return False
  44. backend_active = check_backend()
  45. # LOCAL MODEL LOADING REMOVED (YOLO26 Clean Sweep)
  46. # UI now relies entirely on Backend API for NMS-Free inference.
  47. if not backend_active:
  48. st.error("⚠️ Backend API is offline!")
  49. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  50. if st.button("🔄 Retry Connection"):
  51. st.rerun()
  52. st.stop() # Stops execution here, effectively disabling the app
  53. # --- 2. Main Page Config (Only rendered if backend is active) ---
  54. st.set_page_config(page_title="Palm Oil Ripeness AI (YOLO26)", layout="wide")
  55. st.title("🌴 Palm Oil FFB Management System")
  56. st.markdown("### Production-Ready AI Analysis & Archival")
  57. # --- Sidebar ---
  58. st.sidebar.header("Backend Controls")
  59. def update_confidence():
  60. new_conf = st.session_state.conf_slider
  61. try:
  62. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  63. st.toast(f"Threshold updated to {new_conf}")
  64. except:
  65. st.sidebar.error("Failed to update threshold")
  66. # We already know backend is up here
  67. response = requests.get(f"{API_BASE_URL}/get_confidence")
  68. current_conf = response.json().get("current_confidence", 0.25)
  69. st.sidebar.success(f"Connected to API")
  70. st.sidebar.info("Engine: YOLO26 NMS-Free (Inference: ~39ms)")
  71. # Synchronized Slider
  72. st.sidebar.slider(
  73. "Confidence Threshold",
  74. 0.1, 1.0,
  75. value=float(current_conf),
  76. key="conf_slider",
  77. on_change=update_confidence
  78. )
  79. st.sidebar.markdown("---")
  80. st.sidebar.subheader("Inference Engine")
  81. engine_choice = st.sidebar.selectbox(
  82. "Select Model Engine",
  83. ["YOLO26 (PyTorch - Native)", "YOLO26 (ONNX - High Speed)"],
  84. index=0,
  85. help="ONNX is optimized for latency. PyTorch provides native object handling."
  86. )
  87. st.sidebar.markdown("---")
  88. st.sidebar.subheader("🛠️ Technical Controls")
  89. show_trace = st.sidebar.toggle("🔬 Show Technical Trace", value=False, help="Enable to see raw mathematical tensor data alongside AI labels.")
  90. st.session_state.tech_trace = show_trace
  91. model_type = "onnx" if "ONNX" in engine_choice else "pytorch"
  92. if model_type == "pytorch":
  93. st.sidebar.warning("PyTorch Engine: Higher Memory Usage")
  94. else:
  95. st.sidebar.info("ONNX Engine: ~39ms Latency")
  96. st.sidebar.markdown("---")
  97. if st.sidebar.button("❓ How to read results?", icon="📘", width='stretch'):
  98. show_tech_guide()
  99. # Helper to reset results when files change
  100. def reset_single_results():
  101. st.session_state.last_detection = None
  102. def reset_batch_results():
  103. st.session_state.last_batch_results = None
  104. # MPOB Color Map for Overlays (Global for consistency)
  105. overlay_colors = {
  106. 'Ripe': '#22c55e', # Industrial Green
  107. 'Underripe': '#fbbf24', # Industrial Orange
  108. 'Unripe': '#3b82f6', # Industrial Blue
  109. 'Abnormal': '#dc2626', # Critical Red
  110. 'Empty_Bunch': '#64748b',# Waste Gray
  111. 'Overripe': '#7c2d12' # Dark Brown/Orange
  112. }
  113. def display_interactive_results(image, detections, key=None):
  114. """Renders image with interactive hover-boxes using Plotly."""
  115. img_width, img_height = image.size
  116. fig = go.Figure()
  117. # Add the palm image as the background
  118. fig.add_layout_image(
  119. dict(source=image, x=0, y=img_height, sizex=img_width, sizey=img_height,
  120. sizing="stretch", opacity=1, layer="below", xref="x", yref="y")
  121. )
  122. # Configure axes to match image dimensions
  123. fig.update_xaxes(showgrid=False, range=(0, img_width), zeroline=False, visible=False)
  124. fig.update_yaxes(showgrid=False, range=(0, img_height), zeroline=False, visible=False, scaleanchor="x")
  125. # Add interactive boxes
  126. for i, det in enumerate(detections):
  127. x1, y1, x2, y2 = det['box']
  128. # Plotly y-axis is inverted relative to PIL, so we flip y
  129. y_top, y_bottom = img_height - y1, img_height - y2
  130. color = overlay_colors.get(det['class'], "#ffeb3b")
  131. # The 'Hover' shape
  132. bunch_id = det.get('bunch_id', i+1)
  133. fig.add_trace(go.Scatter(
  134. x=[x1, x2, x2, x1, x1],
  135. y=[y_top, y_top, y_bottom, y_bottom, y_top],
  136. fill="toself",
  137. fillcolor=color,
  138. opacity=0.3, # Semi-transparent until hover
  139. mode='lines',
  140. line=dict(color=color, width=3),
  141. name=f"ID: #{bunch_id}", # Unified ID Tag
  142. text=f"<b>ID: #{bunch_id}</b><br>Grade: {det['class']}<br>Score: {det['confidence']:.2f}<br>Alert: {det['is_health_alert']}",
  143. hoverinfo="text"
  144. ))
  145. fig.update_layout(width=800, height=600, margin=dict(l=0, r=0, b=0, t=0), showlegend=False)
  146. st.plotly_chart(fig, use_container_width=True, key=key)
  147. def annotate_image(image, detections):
  148. """Draws high-visibility 'Plated Labels' and boxes on the image."""
  149. from PIL import ImageDraw, ImageFont
  150. draw = ImageDraw.Draw(image)
  151. # 1. Dynamic Font Scaling (width // 40 as requested)
  152. font_size = max(20, image.width // 40)
  153. try:
  154. # standard Windows font paths for agent environment
  155. font_path = "C:\\Windows\\Fonts\\arialbd.ttf" # Bold for higher visibility
  156. if not os.path.exists(font_path):
  157. font_path = "C:\\Windows\\Fonts\\arial.ttf"
  158. if os.path.exists(font_path):
  159. font = ImageFont.truetype(font_path, font_size)
  160. else:
  161. font = ImageFont.load_default()
  162. except:
  163. font = ImageFont.load_default()
  164. for det in detections:
  165. box = det['box'] # [x1, y1, x2, y2]
  166. cls = det['class']
  167. conf = det['confidence']
  168. bunch_id = det.get('bunch_id', '?')
  169. color = overlay_colors.get(cls, '#ffffff')
  170. # 2. Draw Heavy-Duty Bounding Box
  171. line_width = max(4, image.width // 150)
  172. draw.rectangle(box, outline=color, width=line_width)
  173. # 3. Draw 'Plated Label' (Background Shaded)
  174. label = f"#{bunch_id} {cls} {conf:.2f}"
  175. try:
  176. # Precise background calculation using textbbox
  177. l, t, r, b = draw.textbbox((box[0], box[1]), label, font=font)
  178. # Shift background up so it doesn't obscure the fruit
  179. bg_rect = [l - 2, t - (b - t) - 10, r + 2, t - 6]
  180. draw.rectangle(bg_rect, fill=color)
  181. # Draw text inside the plate
  182. draw.text((l, t - (b - t) - 8), label, fill="white", font=font)
  183. except:
  184. # Simple fallback
  185. draw.text((box[0], box[1] - font_size), label, fill=color)
  186. return image
  187. def generate_batch_report(data, uploaded_files_map=None):
  188. """Generates a professional PDF report for batch results with visual evidence."""
  189. from PIL import ImageDraw
  190. pdf = FPDF()
  191. pdf.add_page()
  192. pdf.set_font("Arial", "B", 16)
  193. pdf.cell(190, 10, "Palm Oil FFB Harvest Quality Report", ln=True, align="C")
  194. pdf.set_font("Arial", "", 12)
  195. pdf.cell(190, 10, f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align="C")
  196. pdf.ln(10)
  197. # 1. Summary Table
  198. pdf.set_font("Arial", "B", 14)
  199. pdf.cell(190, 10, "1. Batch Summary", ln=True)
  200. pdf.set_font("Arial", "", 12)
  201. summary = data.get('industrial_summary', {})
  202. total_bunches = data.get('total_count', 0)
  203. pdf.cell(95, 10, "Metric", border=1)
  204. pdf.cell(95, 10, "Value", border=1, ln=True)
  205. pdf.cell(95, 10, "Total Bunches Detected", border=1)
  206. pdf.cell(95, 10, str(total_bunches), border=1, ln=True)
  207. for grade, count in summary.items():
  208. if count > 0:
  209. pdf.cell(95, 10, f"Grade: {grade}", border=1)
  210. pdf.cell(95, 10, str(count), border=1, ln=True)
  211. pdf.ln(10)
  212. # 2. Strategic Insights
  213. pdf.set_font("Arial", "B", 14)
  214. pdf.cell(190, 10, "2. Strategic Yield Insights", ln=True)
  215. pdf.set_font("Arial", "", 12)
  216. unripe = summary.get('Unripe', 0)
  217. underripe = summary.get('Underripe', 0)
  218. loss = unripe + underripe
  219. if loss > 0:
  220. pdf.multi_cell(190, 10, f"WARNING: {loss} bunches were harvested before peak ripeness. "
  221. "This directly impacts the Oil Extraction Rate (OER) and results in potential yield loss.")
  222. else:
  223. pdf.multi_cell(190, 10, "EXCELLENT: All detected bunches meet prime ripeness standards. Harvest efficiency is 100%.")
  224. # Critical Alerts
  225. abnormal = summary.get('Abnormal', 0)
  226. empty = summary.get('Empty_Bunch', 0)
  227. if abnormal > 0 or empty > 0:
  228. pdf.ln(5)
  229. pdf.set_text_color(220, 0, 0)
  230. pdf.set_font("Arial", "B", 12)
  231. pdf.cell(190, 10, "CRITICAL HEALTH ALERTS:", ln=True)
  232. pdf.set_font("Arial", "", 12)
  233. if abnormal > 0:
  234. pdf.cell(190, 10, f"- {abnormal} Abnormal Bunches detected (Requires immediate field inspection).", ln=True)
  235. if empty > 0:
  236. pdf.cell(190, 10, f"- {empty} Empty Bunches detected (Waste reduction needed).", ln=True)
  237. pdf.set_text_color(0, 0, 0)
  238. # 3. Visual Evidence Section
  239. if 'detailed_results' in data and uploaded_files_map:
  240. pdf.add_page()
  241. pdf.set_font("Arial", "B", 14)
  242. pdf.cell(190, 10, "3. Visual Batch Evidence (AI Overlay)", ln=True)
  243. pdf.ln(5)
  244. # Group detections by filename
  245. results_by_file = {}
  246. for res in data['detailed_results']:
  247. fname = res['filename']
  248. if fname not in results_by_file:
  249. results_by_file[fname] = []
  250. results_by_file[fname].append(res['detection'])
  251. for fname, detections in results_by_file.items():
  252. if fname in uploaded_files_map:
  253. img_bytes = uploaded_files_map[fname]
  254. img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
  255. draw = ImageDraw.Draw(img)
  256. # Drawing annotated boxes for PDF using high-visibility utility
  257. annotate_image(img, detections)
  258. # Save to temp file for PDF
  259. temp_img_path = f"temp_report_{fname}"
  260. img.save(temp_img_path)
  261. # Check if we need a new page based on image height (rough estimate)
  262. if pdf.get_y() > 200:
  263. pdf.add_page()
  264. pdf.image(temp_img_path, x=10, w=150)
  265. pdf.set_font("Arial", "I", 10)
  266. pdf.cell(190, 10, f"Annotated: {fname}", ln=True)
  267. pdf.ln(5)
  268. os.remove(temp_img_path)
  269. # Footer
  270. pdf.set_y(-15)
  271. pdf.set_font("Arial", "I", 8)
  272. pdf.cell(190, 10, "Generated by Palm Oil AI Desktop PoC - YOLO26 Engine", align="C")
  273. return pdf.output(dest='S')
  274. # --- Tabs ---
  275. tab1, tab2, tab3, tab4 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search", "History Vault"])
  276. # --- Tab 1: Single Analysis ---
  277. with tab1:
  278. st.subheader("Analyze Single Bunch")
  279. uploaded_file = st.file_uploader(
  280. "Upload a bunch image...",
  281. type=["jpg", "jpeg", "png"],
  282. key="single",
  283. on_change=reset_single_results
  284. )
  285. if uploaded_file:
  286. # State initialization
  287. if "last_detection" not in st.session_state:
  288. st.session_state.last_detection = None
  289. # 1. Auto-Detection Trigger
  290. if uploaded_file and st.session_state.last_detection is None:
  291. with st.spinner(f"Processing with {model_type.upper()} Engine..."):
  292. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  293. payload = {"model_type": model_type}
  294. res = requests.post(f"{API_BASE_URL}/analyze", files=files, data=payload)
  295. if res.status_code == 200:
  296. st.session_state.last_detection = res.json()
  297. st.rerun() # Refresh to show results immediately
  298. else:
  299. st.error(f"Detection Failed: {res.text}")
  300. # 2. Results Layout
  301. if st.session_state.last_detection:
  302. data = st.session_state.last_detection
  303. st.divider()
  304. st.write("### 📈 Manager's Dashboard")
  305. m_col1, m_col2, m_col3, m_col4 = st.columns(4)
  306. with m_col1:
  307. st.metric("Total Bunches", data.get('total_count', 0))
  308. with m_col2:
  309. st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
  310. with m_col3:
  311. # Refined speed label based on engine
  312. speed_label = "Raw Speed (Unlabeled)" if model_type == "onnx" else "Wrapped Speed (Auto-Labeled)"
  313. st.metric("Inference Speed", f"{data.get('inference_ms', 0):.1f} ms", help=speed_label)
  314. with m_col4:
  315. st.metric("Post-Processing", f"{data.get('processing_ms', 0):.1f} ms", help="Labeling/Scaling overhead")
  316. st.divider()
  317. # Side-by-Side View (Technical Trace)
  318. img = Image.open(uploaded_file).convert("RGB")
  319. if st.session_state.get('tech_trace', False):
  320. t_col1, t_col2 = st.columns(2)
  321. with t_col1:
  322. st.subheader("🔢 Raw Output Tensor (The Math)")
  323. st.caption("First 5 rows of the 1x300x6 detection tensor.")
  324. st.json(data.get('raw_array_sample', []))
  325. with t_col2:
  326. st.subheader("🎨 AI Interpretation")
  327. img_annotated = annotate_image(img.copy(), data['detections'])
  328. st.image(img_annotated, width='stretch')
  329. else:
  330. # Regular View
  331. st.write("### 🔍 AI Analytical View")
  332. display_interactive_results(img, data['detections'], key="main_viewer")
  333. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  334. with col1:
  335. col_tech_h1, col_tech_h2 = st.columns([4, 1])
  336. with col_tech_h1:
  337. st.write("#### 🛠️ Technical Evidence")
  338. with col_tech_h2:
  339. if st.button("❓ Guide", key="guide_tab1"):
  340. show_tech_guide()
  341. with st.expander("Raw Output Tensor (NMS-Free)", expanded=False):
  342. st.caption("See the Interpretation Guide for a breakdown of these numbers.")
  343. st.json(data.get('raw_array_sample', []))
  344. with st.container(border=True):
  345. st.write("### 🏷️ Detection Results")
  346. if not data['detections']:
  347. st.warning("No Fresh Fruit Bunches detected.")
  348. else:
  349. for det in data['detections']:
  350. st.info(f"### Bunch #{det['bunch_id']}: {det['class']} ({det['confidence']:.2%})")
  351. st.write("### 📊 Harvest Quality Mix")
  352. # Convert industrial_summary dictionary to a DataFrame for charting
  353. summary_df = pd.DataFrame(
  354. list(data['industrial_summary'].items()),
  355. columns=['Grade', 'Count']
  356. )
  357. # Filter out classes with 0 count for a cleaner chart
  358. summary_df = summary_df[summary_df['Count'] > 0]
  359. if not summary_df.empty:
  360. # Create a Pie Chart to show the proportion of each grade
  361. fig = px.pie(summary_df, values='Count', names='Grade',
  362. color='Grade',
  363. color_discrete_map={
  364. 'Ripe': '#22c55e', # Industrial Green
  365. 'Underripe': '#fbbf24', # Industrial Orange
  366. 'Unripe': '#3b82f6', # Industrial Blue
  367. 'Abnormal': '#dc2626', # Critical Red
  368. 'Empty_Bunch': '#64748b' # Waste Gray
  369. },
  370. hole=0.4)
  371. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  372. st.plotly_chart(fig, width='stretch', key="single_pie")
  373. # 💡 Strategic R&D Insight: Harvest Efficiency
  374. st.write("---")
  375. st.write("#### 💡 Strategic R&D Insight")
  376. unripe_count = data['industrial_summary'].get('Unripe', 0)
  377. underripe_count = data['industrial_summary'].get('Underripe', 0)
  378. total_non_prime = unripe_count + underripe_count
  379. st.write(f"🌑 **Unripe (Mentah):** {unripe_count}")
  380. st.write(f"🌗 **Underripe (Kurang Masak):** {underripe_count}")
  381. if total_non_prime > 0:
  382. st.warning(f"🚨 **Potential Yield Loss:** {total_non_prime} bunches harvested too early. This will reduce OER (Oil Extraction Rate).")
  383. else:
  384. st.success("✅ **Harvest Efficiency:** 100% Prime Ripeness detected.")
  385. # High-Priority Health Alert
  386. if data['industrial_summary'].get('Abnormal', 0) > 0:
  387. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  388. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  389. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  390. # 3. Cloud Actions (Only if detections found)
  391. st.write("---")
  392. st.write("#### ✨ Cloud Archive")
  393. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  394. with st.spinner("Archiving..."):
  395. import json
  396. primary_det = data['detections'][0]
  397. payload = {"detection_data": json.dumps(primary_det)}
  398. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  399. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  400. if res_cloud.status_code == 200:
  401. res_json = res_cloud.json()
  402. if res_json["status"] == "success":
  403. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  404. else:
  405. st.error(f"Cloud Error: {res_json['message']}")
  406. else:
  407. st.error("Failed to connect to cloud service")
  408. if st.button("🚩 Flag Misclassification", width='stretch', type="secondary"):
  409. # Save to local feedback folder
  410. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  411. feedback_id = f"fb_{timestamp}"
  412. img_path = f"feedback/{feedback_id}.jpg"
  413. json_path = f"feedback/{feedback_id}.json"
  414. # Save image
  415. Image.open(uploaded_file).save(img_path)
  416. # Save metadata
  417. feedback_data = {
  418. "original_filename": uploaded_file.name,
  419. "timestamp": timestamp,
  420. "detections": data['detections'],
  421. "threshold_used": data['current_threshold']
  422. }
  423. with open(json_path, "w") as f:
  424. json.dump(feedback_data, f, indent=4)
  425. st.toast("✅ Feedback saved to local vault!", icon="🚩")
  426. if st.button("💾 Local History Vault (Auto-Saved)", width='stretch', type="secondary", disabled=True):
  427. pass
  428. st.caption("✅ This analysis was automatically archived to the local vault.")
  429. # --- Tab 2: Batch Processing ---
  430. with tab2:
  431. st.subheader("Bulk Analysis")
  432. # 1. Initialize Session State
  433. if "batch_uploader_key" not in st.session_state:
  434. st.session_state.batch_uploader_key = 0
  435. if "last_batch_results" not in st.session_state:
  436. st.session_state.last_batch_results = None
  437. # 2. Display Persisted Results (if any)
  438. if st.session_state.last_batch_results:
  439. res_data = st.session_state.last_batch_results
  440. with st.container(border=True):
  441. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  442. # Batch Summary Dashboard
  443. st.write("### 📈 Batch Quality Overview")
  444. batch_summary = res_data.get('industrial_summary', {})
  445. if batch_summary:
  446. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  447. sum_df = sum_df[sum_df['Count'] > 0]
  448. b_col1, b_col2 = st.columns([1, 1])
  449. with b_col1:
  450. st.dataframe(sum_df, hide_index=True, width='stretch')
  451. with b_col2:
  452. if not sum_df.empty:
  453. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  454. color_discrete_map={
  455. 'Ripe': '#22c55e',
  456. 'Underripe': '#fbbf24',
  457. 'Unripe': '#3b82f6',
  458. 'Abnormal': '#dc2626',
  459. 'Empty_Bunch': '#64748b'
  460. })
  461. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  462. st.plotly_chart(fig_batch, width='stretch', key="batch_bar")
  463. if batch_summary.get('Abnormal', 0) > 0:
  464. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  465. st.write("Generated Record IDs:")
  466. st.code(res_data['record_ids'])
  467. # --- 4. Batch Evidence Gallery ---
  468. st.write("### 🖼️ Detailed Detection Evidence")
  469. if 'detailed_results' in res_data:
  470. # Group results by filename for gallery
  471. gallery_map = {}
  472. for res in res_data['detailed_results']:
  473. fname = res['filename']
  474. if fname not in gallery_map:
  475. gallery_map[fname] = []
  476. gallery_map[fname].append(res['detection'])
  477. # Show images with overlays using consistent utility
  478. for up_file in uploaded_files:
  479. if up_file.name in gallery_map:
  480. with st.container(border=True):
  481. g_img = Image.open(up_file).convert("RGB")
  482. g_annotated = annotate_image(g_img, gallery_map[up_file.name])
  483. st.image(g_annotated, caption=f"Evidence: {up_file.name}", width='stretch')
  484. # PDF Export Button (Pass images map)
  485. files_map = {f.name: f.getvalue() for f in uploaded_files}
  486. pdf_bytes = generate_batch_report(res_data, files_map)
  487. st.download_button(
  488. label="📄 Download Executive Batch Report (PDF)",
  489. data=pdf_bytes,
  490. file_name=f"PalmOil_BatchReport_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf",
  491. mime="application/pdf",
  492. width='stretch'
  493. )
  494. if st.button("Clear Results & Start New Batch", width='stretch'):
  495. st.session_state.last_batch_results = None
  496. st.rerun()
  497. st.divider()
  498. # 3. Uploader UI
  499. col_batch1, col_batch2 = st.columns([4, 1])
  500. with col_batch1:
  501. uploaded_files = st.file_uploader(
  502. "Upload multiple images...",
  503. type=["jpg", "jpeg", "png"],
  504. accept_multiple_files=True,
  505. key=f"batch_{st.session_state.batch_uploader_key}",
  506. on_change=reset_batch_results
  507. )
  508. with col_batch2:
  509. st.write("##") # Alignment
  510. if st.session_state.last_batch_results is None and uploaded_files:
  511. if st.button("🔍 Process Batch", type="primary", width='stretch'):
  512. with st.spinner(f"Analyzing {len(uploaded_files)} images with {model_type.upper()}..."):
  513. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  514. payload = {"model_type": model_type}
  515. res = requests.post(f"{API_BASE_URL}/process_batch", files=files, data=payload)
  516. if res.status_code == 200:
  517. data = res.json()
  518. if data["status"] == "success":
  519. st.session_state.last_batch_results = data
  520. st.session_state.batch_uploader_key += 1
  521. st.rerun()
  522. elif data["status"] == "partial_success":
  523. st.warning(data["message"])
  524. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  525. else:
  526. st.error(f"Batch Error: {data['message']}")
  527. else:
  528. st.error(f"Batch Processing Failed: {res.text}")
  529. if st.button("🗑️ Reset Uploader"):
  530. st.session_state.batch_uploader_key += 1
  531. st.session_state.last_batch_results = None
  532. st.rerun()
  533. # --- Tab 3: Similarity Search ---
  534. with tab3:
  535. st.subheader("Hybrid Semantic Search")
  536. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  537. with st.form("hybrid_search_form"):
  538. col_input1, col_input2 = st.columns(2)
  539. with col_input1:
  540. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  541. with col_input2:
  542. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  543. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  544. submit_search = st.form_submit_button("Run Semantic Search")
  545. if submit_search:
  546. if not search_file and not text_query:
  547. st.warning("Please provide either an image or a text query.")
  548. else:
  549. with st.spinner("Searching Vector Index..."):
  550. payload = {"limit": top_k}
  551. # If an image is uploaded, it takes precedence for visual search
  552. if search_file:
  553. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  554. # Pass top_k as part of the data
  555. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  556. # Otherwise, use text query
  557. elif text_query:
  558. payload["text_query"] = text_query
  559. # Send as form-data (data=) to match FastAPI's Form(None)
  560. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  561. if res.status_code == 200:
  562. results = res.json().get("results", [])
  563. if not results:
  564. st.warning("No similar records found.")
  565. else:
  566. st.success(f"Found {len(results)} matches.")
  567. for item in results:
  568. with st.container(border=True):
  569. c1, c2 = st.columns([1, 2])
  570. # Fetch the image for this result
  571. rec_id = item["_id"]
  572. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  573. with c1:
  574. if img_res.status_code == 200:
  575. img_b64 = img_res.json().get("image_data")
  576. if img_b64:
  577. st.image(base64.b64decode(img_b64), width=250)
  578. else:
  579. st.write("No image data found.")
  580. else:
  581. st.write("Failed to load image.")
  582. with c2:
  583. st.write(f"**Class:** {item['ripeness_class']}")
  584. st.write(f"**Similarity Score:** {item['score']:.4f}")
  585. st.write(f"**Timestamp:** {item['timestamp']}")
  586. st.write(f"**ID:** `{rec_id}`")
  587. else:
  588. st.error(f"Search failed: {res.text}")
  589. # --- Tab 4: History Vault ---
  590. with tab4:
  591. st.subheader("📜 Local History Vault")
  592. st.caption("Industrial-grade audit log of all past AI harvest scans.")
  593. if "selected_history_id" not in st.session_state:
  594. st.session_state.selected_history_id = None
  595. try:
  596. res = requests.get(f"{API_BASE_URL}/get_history")
  597. if res.status_code == 200:
  598. history_data = res.json().get("history", [])
  599. if not history_data:
  600. st.info("No saved records found in the vault.")
  601. else:
  602. if st.session_state.selected_history_id is None:
  603. # --- 1. ListView Mode (Management Dashboard) ---
  604. st.write("### 📋 Audit Log")
  605. # Prepare searchable dataframe
  606. df_history = pd.DataFrame(history_data)
  607. # Clean up for display
  608. display_df = df_history[['id', 'timestamp', 'filename', 'inference_ms']].copy()
  609. display_df.columns = ['ID', 'Date/Time', 'Filename', 'Inference (ms)']
  610. st.dataframe(
  611. display_df,
  612. hide_index=True,
  613. use_container_width=True,
  614. column_config={
  615. "ID": st.column_config.NumberColumn(width="small"),
  616. "Inference (ms)": st.column_config.NumberColumn(format="%.1f ms")
  617. }
  618. )
  619. # Industrial Selection UI
  620. hist_col1, hist_col2 = st.columns([3, 1])
  621. with hist_col1:
  622. target_id = st.selectbox(
  623. "Select Record for Deep Dive Analysis",
  624. options=df_history['id'].tolist(),
  625. format_func=lambda x: f"Record #{x} - {df_history[df_history['id']==x]['filename'].values[0]}"
  626. )
  627. with hist_col2:
  628. st.write("##") # Alignment
  629. if st.button("🔬 Start Deep Dive", type="primary", use_container_width=True):
  630. st.session_state.selected_history_id = target_id
  631. st.rerun()
  632. else:
  633. # --- 2. Detail View Mode (Technical Auditor) ---
  634. record = next((item for item in history_data if item["id"] == st.session_state.selected_history_id), None)
  635. if not record:
  636. st.error("Audit record not found.")
  637. if st.button("Back to List"):
  638. st.session_state.selected_history_id = None
  639. st.rerun()
  640. else:
  641. st.button("⬅️ Back to Audit Log", on_click=lambda: st.session_state.update({"selected_history_id": None}))
  642. st.divider()
  643. st.write(f"## 🔍 Deep Dive: Record #{record['id']}")
  644. st.caption(f"Original Filename: `{record['filename']}` | Processed: `{record['timestamp']}`")
  645. detections = json.loads(record['detections'])
  646. summary = json.loads(record['summary'])
  647. # Metrics Executive Summary
  648. h_col1, h_col2, h_col3, h_col4 = st.columns(4)
  649. with h_col1:
  650. st.metric("Total Bunches", sum(summary.values()))
  651. with h_col2:
  652. st.metric("Healthy (Ripe)", summary.get('Ripe', 0))
  653. with h_col3:
  654. st.metric("Engine Performance", f"{record.get('inference_ms', 0) or 0:.1f} ms")
  655. with h_col4:
  656. st.metric("Labeling Overhead", f"{record.get('processing_ms', 0) or 0:.1f} ms")
  657. # Re-Annotate Archived Image
  658. if os.path.exists(record['archive_path']):
  659. with open(record['archive_path'], "rb") as f:
  660. hist_img = Image.open(f).convert("RGB")
  661. # Side-by-Side: Interactive vs Static Plate
  662. v_tab1, v_tab2 = st.tabs(["Interactive Plotly View", "Static Annotated Evidence"])
  663. with v_tab1:
  664. display_interactive_results(hist_img, detections, key=f"hist_plotly_{record['id']}")
  665. with v_tab2:
  666. img_plate = annotate_image(hist_img.copy(), detections)
  667. st.image(img_plate, use_container_width=True, caption="Point-of-Harvest AI Interpretation")
  668. else:
  669. st.warning(f"Technical Error: Archive file missing at `{record['archive_path']}`")
  670. # Technical Evidence Expander (Mathematical Audit)
  671. st.divider()
  672. st.write("### 🛠️ Technical Audit Trail")
  673. with st.expander("🔬 View Raw Mathematical Tensor", expanded=False):
  674. st.info("This is the exact numerical output from the AI engine prior to human-readable transformation.")
  675. raw_data = record.get('raw_tensor')
  676. if raw_data:
  677. try:
  678. st.json(json.loads(raw_data))
  679. except:
  680. st.code(raw_data)
  681. else:
  682. st.warning("No raw tensor trace was archived for this legacy record.")
  683. else:
  684. st.error(f"Vault Connection Failed: {res.text}")
  685. except Exception as e:
  686. st.error(f"Audit System Error: {str(e)}")