demo_app.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. # --- 1. Global Backend Check ---
  11. API_BASE_URL = "http://localhost:8000"
  12. def check_backend():
  13. try:
  14. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  15. return res.status_code == 200
  16. except:
  17. return False
  18. backend_active = check_backend()
  19. # LOCAL MODEL LOADING REMOVED (YOLO26 Clean Sweep)
  20. # UI now relies entirely on Backend API for NMS-Free inference.
  21. if not backend_active:
  22. st.error("⚠️ Backend API is offline!")
  23. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  24. if st.button("🔄 Retry Connection"):
  25. st.rerun()
  26. st.stop() # Stops execution here, effectively disabling the app
  27. # --- 2. Main Page Config (Only rendered if backend is active) ---
  28. st.set_page_config(page_title="Palm Oil Ripeness AI (YOLO26)", layout="wide")
  29. st.title("🌴 Palm Oil FFB Management System")
  30. st.markdown("### Production-Ready AI Analysis & Archival")
  31. # --- Sidebar ---
  32. st.sidebar.header("Backend Controls")
  33. def update_confidence():
  34. new_conf = st.session_state.conf_slider
  35. try:
  36. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  37. st.toast(f"Threshold updated to {new_conf}")
  38. except:
  39. st.sidebar.error("Failed to update threshold")
  40. # We already know backend is up here
  41. response = requests.get(f"{API_BASE_URL}/get_confidence")
  42. current_conf = response.json().get("current_confidence", 0.25)
  43. st.sidebar.success(f"Connected to API")
  44. st.sidebar.info("Engine: YOLO26 NMS-Free (Inference: ~39ms)")
  45. # Synchronized Slider
  46. st.sidebar.slider(
  47. "Confidence Threshold",
  48. 0.1, 1.0,
  49. value=float(current_conf),
  50. key="conf_slider",
  51. on_change=update_confidence
  52. )
  53. # Helper to reset results when files change
  54. def reset_single_results():
  55. st.session_state.last_detection = None
  56. def reset_batch_results():
  57. st.session_state.last_batch_results = None
  58. # --- Tabs ---
  59. tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
  60. # --- Tab 1: Single Analysis ---
  61. with tab1:
  62. st.subheader("Analyze Single Bunch")
  63. uploaded_file = st.file_uploader(
  64. "Upload a bunch image...",
  65. type=["jpg", "jpeg", "png"],
  66. key="single",
  67. on_change=reset_single_results
  68. )
  69. if uploaded_file:
  70. # State initialization
  71. if "last_detection" not in st.session_state:
  72. st.session_state.last_detection = None
  73. # 1. Auto-Detection Trigger
  74. if uploaded_file and st.session_state.last_detection is None:
  75. with st.spinner("Processing Detections Locally..."):
  76. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  77. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  78. if res.status_code == 200:
  79. st.session_state.last_detection = res.json()
  80. st.rerun() # Refresh to show results immediately
  81. else:
  82. st.error(f"Detection Failed: {res.text}")
  83. # 2. Results Layout
  84. if st.session_state.last_detection:
  85. st.divider()
  86. # SIDE-BY-SIDE ANALYTICAL VIEW
  87. col_left, col_right = st.columns(2)
  88. # Fetch data once
  89. data = st.session_state.last_detection
  90. with col_left:
  91. st.image(uploaded_file, caption="Original Photo", width='stretch')
  92. with col_right:
  93. # MANUAL OVERLAY DRAWING (NMS-Free Output from API)
  94. img = Image.open(uploaded_file).convert("RGB")
  95. from PIL import ImageDraw, ImageFont
  96. draw = ImageDraw.Draw(img)
  97. # MPOB Color Map for Overlays
  98. overlay_colors = {
  99. 'Ripe': '#22c55e', # Industrial Green
  100. 'Underripe': '#fbbf24', # Industrial Orange
  101. 'Unripe': '#3b82f6', # Industrial Blue
  102. 'Abnormal': '#dc2626', # Critical Red
  103. 'Empty_Bunch': '#64748b' # Waste Gray
  104. }
  105. for det in data['detections']:
  106. box = det['box'] # [x1, y1, x2, y2]
  107. cls = det['class']
  108. color = overlay_colors.get(cls, '#ffffff')
  109. # Draw Box
  110. draw.rectangle(box, outline=color, width=4)
  111. # Draw Label Background
  112. label = f"{cls} {det['confidence']:.2f}"
  113. draw.text((box[0], box[1] - 15), label, fill=color)
  114. st.image(img, caption="AI Analytical View (NMS-Free Native)", width='stretch')
  115. st.write("### 📈 Manager's Dashboard")
  116. m_col1, m_col2, m_col3 = st.columns(3)
  117. with m_col1:
  118. st.metric("Total Bunches", data.get('total_count', 0))
  119. with m_col2:
  120. st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
  121. with m_col3:
  122. abnormal = data['industrial_summary'].get('Abnormal', 0)
  123. st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
  124. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  125. with col2:
  126. with st.container(border=True):
  127. st.write("### 🏷️ Detection Results")
  128. if not data['detections']:
  129. st.warning("No Fresh Fruit Bunches detected.")
  130. else:
  131. for det in data['detections']:
  132. st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
  133. st.write("### 📊 Harvest Quality Mix")
  134. # Convert industrial_summary dictionary to a DataFrame for charting
  135. summary_df = pd.DataFrame(
  136. list(data['industrial_summary'].items()),
  137. columns=['Grade', 'Count']
  138. )
  139. # Filter out classes with 0 count for a cleaner chart
  140. summary_df = summary_df[summary_df['Count'] > 0]
  141. if not summary_df.empty:
  142. # Create a Pie Chart to show the proportion of each grade
  143. fig = px.pie(summary_df, values='Count', names='Grade',
  144. color='Grade',
  145. color_discrete_map={
  146. 'Ripe': '#22c55e', # Industrial Green
  147. 'Underripe': '#fbbf24', # Industrial Orange
  148. 'Unripe': '#3b82f6', # Industrial Blue
  149. 'Abnormal': '#dc2626', # Critical Red
  150. 'Empty_Bunch': '#64748b' # Waste Gray
  151. },
  152. hole=0.4)
  153. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  154. st.plotly_chart(fig, width='stretch')
  155. # 💡 Strategic R&D Insight: Harvest Efficiency
  156. st.write("---")
  157. st.write("#### 💡 Strategic R&D Insight")
  158. unripe_count = data['industrial_summary'].get('Unripe', 0)
  159. underripe_count = data['industrial_summary'].get('Underripe', 0)
  160. total_non_prime = unripe_count + underripe_count
  161. st.write(f"🌑 **Unripe (Mentah):** {unripe_count}")
  162. st.write(f"🌗 **Underripe (Kurang Masak):** {underripe_count}")
  163. if total_non_prime > 0:
  164. st.warning(f"🚨 **Potential Yield Loss:** {total_non_prime} bunches harvested too early. This will reduce OER (Oil Extraction Rate).")
  165. else:
  166. st.success("✅ **Harvest Efficiency:** 100% Prime Ripeness detected.")
  167. # High-Priority Health Alert
  168. if data['industrial_summary'].get('Abnormal', 0) > 0:
  169. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  170. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  171. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  172. # 3. Cloud Actions (Only if detections found)
  173. st.write("---")
  174. st.write("#### ✨ Cloud Archive")
  175. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  176. with st.spinner("Archiving..."):
  177. import json
  178. primary_det = data['detections'][0]
  179. payload = {"detection_data": json.dumps(primary_det)}
  180. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  181. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  182. if res_cloud.status_code == 200:
  183. res_json = res_cloud.json()
  184. if res_json["status"] == "success":
  185. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  186. else:
  187. st.error(f"Cloud Error: {res_json['message']}")
  188. else:
  189. st.error("Failed to connect to cloud service")
  190. # --- Tab 2: Batch Processing ---
  191. with tab2:
  192. st.subheader("Bulk Analysis")
  193. # 1. Initialize Session State
  194. if "batch_uploader_key" not in st.session_state:
  195. st.session_state.batch_uploader_key = 0
  196. if "last_batch_results" not in st.session_state:
  197. st.session_state.last_batch_results = None
  198. # 2. Display Persisted Results (if any)
  199. if st.session_state.last_batch_results:
  200. res_data = st.session_state.last_batch_results
  201. with st.container(border=True):
  202. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  203. # Batch Summary Dashboard
  204. st.write("### 📈 Batch Quality Overview")
  205. batch_summary = res_data.get('industrial_summary', {})
  206. if batch_summary:
  207. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  208. sum_df = sum_df[sum_df['Count'] > 0]
  209. b_col1, b_col2 = st.columns([1, 1])
  210. with b_col1:
  211. st.dataframe(sum_df, hide_index=True, width='stretch')
  212. with b_col2:
  213. if not sum_df.empty:
  214. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  215. color_discrete_map={
  216. 'Ripe': '#22c55e',
  217. 'Underripe': '#fbbf24',
  218. 'Unripe': '#3b82f6',
  219. 'Abnormal': '#dc2626',
  220. 'Empty_Bunch': '#64748b'
  221. })
  222. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  223. st.plotly_chart(fig_batch, width='stretch')
  224. if batch_summary.get('Abnormal', 0) > 0:
  225. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  226. st.write("Generated Record IDs:")
  227. st.code(res_data['record_ids'])
  228. if st.button("Clear Results & Start New Batch"):
  229. st.session_state.last_batch_results = None
  230. st.rerun()
  231. st.divider()
  232. # 3. Uploader UI
  233. col_batch1, col_batch2 = st.columns([4, 1])
  234. with col_batch1:
  235. uploaded_files = st.file_uploader(
  236. "Upload multiple images...",
  237. type=["jpg", "jpeg", "png"],
  238. accept_multiple_files=True,
  239. key=f"batch_{st.session_state.batch_uploader_key}",
  240. on_change=reset_batch_results
  241. )
  242. with col_batch2:
  243. st.write("##") # Alignment
  244. if st.session_state.last_batch_results is None and uploaded_files:
  245. if st.button("🔍 Process Batch", type="primary", width='stretch'):
  246. with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
  247. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  248. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  249. if res.status_code == 200:
  250. data = res.json()
  251. if data["status"] == "success":
  252. st.session_state.last_batch_results = data
  253. st.session_state.batch_uploader_key += 1
  254. st.rerun()
  255. elif data["status"] == "partial_success":
  256. st.warning(data["message"])
  257. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  258. else:
  259. st.error(f"Batch Error: {data['message']}")
  260. else:
  261. st.error(f"Batch Processing Failed: {res.text}")
  262. if st.button("🗑️ Reset Uploader"):
  263. st.session_state.batch_uploader_key += 1
  264. st.session_state.last_batch_results = None
  265. st.rerun()
  266. # --- Tab 3: Similarity Search ---
  267. with tab3:
  268. st.subheader("Hybrid Semantic Search")
  269. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  270. with st.form("hybrid_search_form"):
  271. col_input1, col_input2 = st.columns(2)
  272. with col_input1:
  273. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  274. with col_input2:
  275. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  276. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  277. submit_search = st.form_submit_button("Run Semantic Search")
  278. if submit_search:
  279. if not search_file and not text_query:
  280. st.warning("Please provide either an image or a text query.")
  281. else:
  282. with st.spinner("Searching Vector Index..."):
  283. payload = {"limit": top_k}
  284. # If an image is uploaded, it takes precedence for visual search
  285. if search_file:
  286. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  287. # Pass top_k as part of the data
  288. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  289. # Otherwise, use text query
  290. elif text_query:
  291. payload["text_query"] = text_query
  292. # Send as form-data (data=) to match FastAPI's Form(None)
  293. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  294. if res.status_code == 200:
  295. results = res.json().get("results", [])
  296. if not results:
  297. st.warning("No similar records found.")
  298. else:
  299. st.success(f"Found {len(results)} matches.")
  300. for item in results:
  301. with st.container(border=True):
  302. c1, c2 = st.columns([1, 2])
  303. # Fetch the image for this result
  304. rec_id = item["_id"]
  305. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  306. with c1:
  307. if img_res.status_code == 200:
  308. img_b64 = img_res.json().get("image_data")
  309. if img_b64:
  310. st.image(base64.b64decode(img_b64), width=250)
  311. else:
  312. st.write("No image data found.")
  313. else:
  314. st.write("Failed to load image.")
  315. with c2:
  316. st.write(f"**Class:** {item['ripeness_class']}")
  317. st.write(f"**Similarity Score:** {item['score']:.4f}")
  318. st.write(f"**Timestamp:** {item['timestamp']}")
  319. st.write(f"**ID:** `{rec_id}`")
  320. else:
  321. st.error(f"Search failed: {res.text}")