demo_app.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. # --- 1. Global Backend Check ---
  11. API_BASE_URL = "http://localhost:8000"
  12. def check_backend():
  13. try:
  14. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  15. return res.status_code == 200
  16. except:
  17. return False
  18. backend_active = check_backend()
  19. # Load YOLO model locally for Analytical View
  20. @st.cache_resource
  21. def load_yolo():
  22. return YOLO('best.pt')
  23. yolo_model = load_yolo()
  24. if not backend_active:
  25. st.error("⚠️ Backend API is offline!")
  26. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  27. if st.button("🔄 Retry Connection"):
  28. st.rerun()
  29. st.stop() # Stops execution here, effectively disabling the app
  30. # --- 2. Main Page Config (Only rendered if backend is active) ---
  31. st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
  32. st.title("🌴 Palm Oil FFB Management System")
  33. st.markdown("### Production-Ready AI Analysis & Archival")
  34. # --- Sidebar ---
  35. st.sidebar.header("Backend Controls")
  36. def update_confidence():
  37. new_conf = st.session_state.conf_slider
  38. try:
  39. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  40. st.toast(f"Threshold updated to {new_conf}")
  41. except:
  42. st.sidebar.error("Failed to update threshold")
  43. # We already know backend is up here
  44. response = requests.get(f"{API_BASE_URL}/get_confidence")
  45. current_conf = response.json().get("current_confidence", 0.25)
  46. st.sidebar.success(f"Connected to API")
  47. # Synchronized Slider
  48. st.sidebar.slider(
  49. "Confidence Threshold",
  50. 0.1, 1.0,
  51. value=float(current_conf),
  52. key="conf_slider",
  53. on_change=update_confidence
  54. )
  55. # Helper to reset results when files change
  56. def reset_single_results():
  57. st.session_state.last_detection = None
  58. def reset_batch_results():
  59. st.session_state.last_batch_results = None
  60. # --- Tabs ---
  61. tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
  62. # --- Tab 1: Single Analysis ---
  63. with tab1:
  64. st.subheader("Analyze Single Bunch")
  65. uploaded_file = st.file_uploader(
  66. "Upload a bunch image...",
  67. type=["jpg", "jpeg", "png"],
  68. key="single",
  69. on_change=reset_single_results
  70. )
  71. if uploaded_file:
  72. # State initialization
  73. if "last_detection" not in st.session_state:
  74. st.session_state.last_detection = None
  75. # 1. Action Button (Centered and Prominent)
  76. st.write("##")
  77. _, col_btn, _ = st.columns([1, 2, 1])
  78. if col_btn.button("🔍 Run Ripeness Detection", type="primary", width='stretch'):
  79. with st.spinner("Processing Detections Locally..."):
  80. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  81. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  82. if res.status_code == 200:
  83. st.session_state.last_detection = res.json()
  84. else:
  85. st.error(f"Detection Failed: {res.text}")
  86. # 2. Results Layout
  87. if st.session_state.last_detection:
  88. st.divider()
  89. # SIDE-BY-SIDE ANALYTICAL VIEW
  90. col_left, col_right = st.columns(2)
  91. # Fetch data once
  92. data = st.session_state.last_detection
  93. with col_left:
  94. st.image(uploaded_file, caption="Original Photo", width='stretch')
  95. with col_right:
  96. # Use the local model to plot the boxes directly
  97. img = Image.open(uploaded_file)
  98. results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
  99. annotated_img = results[0].plot() # Draws boxes/labels
  100. # Convert BGR (OpenCV format) to RGB for Streamlit
  101. annotated_img_rgb = annotated_img[:, :, ::-1]
  102. st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
  103. st.write("### 📈 Manager's Dashboard")
  104. m_col1, m_col2, m_col3 = st.columns(3)
  105. with m_col1:
  106. st.metric("Total Bunches", data.get('total_count', 0))
  107. with m_col2:
  108. st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
  109. with m_col3:
  110. abnormal = data['industrial_summary'].get('Abnormal', 0)
  111. st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
  112. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  113. with col2:
  114. with st.container(border=True):
  115. st.write("### 🏷️ Detection Results")
  116. if not data['detections']:
  117. st.warning("No Fresh Fruit Bunches detected.")
  118. else:
  119. for det in data['detections']:
  120. st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
  121. st.write("### 📊 Harvest Quality Mix")
  122. # Convert industrial_summary dictionary to a DataFrame for charting
  123. summary_df = pd.DataFrame(
  124. list(data['industrial_summary'].items()),
  125. columns=['Grade', 'Count']
  126. )
  127. # Filter out classes with 0 count for a cleaner chart
  128. summary_df = summary_df[summary_df['Count'] > 0]
  129. if not summary_df.empty:
  130. # Create a Pie Chart to show the proportion of each grade
  131. fig = px.pie(summary_df, values='Count', names='Grade',
  132. color='Grade',
  133. color_discrete_map={
  134. 'Abnormal': '#ef4444', # Red
  135. 'Empty_Bunch': '#94a3b8', # Gray
  136. 'Ripe': '#22c55e', # Green
  137. 'Underripe': '#eab308', # Yellow
  138. 'Unripe': '#3b82f6', # Blue
  139. 'Overripe': '#a855f7' # Purple
  140. },
  141. hole=0.4)
  142. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  143. st.plotly_chart(fig, width='stretch')
  144. # High-Priority Health Alert
  145. if data['industrial_summary'].get('Abnormal', 0) > 0:
  146. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  147. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  148. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  149. # 3. Cloud Actions (Only if detections found)
  150. st.write("---")
  151. st.write("#### ✨ Cloud Archive")
  152. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  153. with st.spinner("Archiving..."):
  154. import json
  155. primary_det = data['detections'][0]
  156. payload = {"detection_data": json.dumps(primary_det)}
  157. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  158. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  159. if res_cloud.status_code == 200:
  160. res_json = res_cloud.json()
  161. if res_json["status"] == "success":
  162. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  163. else:
  164. st.error(f"Cloud Error: {res_json['message']}")
  165. else:
  166. st.error("Failed to connect to cloud service")
  167. # --- Tab 2: Batch Processing ---
  168. with tab2:
  169. st.subheader("Bulk Analysis")
  170. # 1. Initialize Session State
  171. if "batch_uploader_key" not in st.session_state:
  172. st.session_state.batch_uploader_key = 0
  173. if "last_batch_results" not in st.session_state:
  174. st.session_state.last_batch_results = None
  175. # 2. Display Persisted Results (if any)
  176. if st.session_state.last_batch_results:
  177. res_data = st.session_state.last_batch_results
  178. with st.container(border=True):
  179. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  180. # Batch Summary Dashboard
  181. st.write("### 📈 Batch Quality Overview")
  182. batch_summary = res_data.get('industrial_summary', {})
  183. if batch_summary:
  184. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  185. sum_df = sum_df[sum_df['Count'] > 0]
  186. b_col1, b_col2 = st.columns([1, 1])
  187. with b_col1:
  188. st.dataframe(sum_df, hide_index=True, width='stretch')
  189. with b_col2:
  190. if not sum_df.empty:
  191. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  192. color_discrete_map={
  193. 'Abnormal': '#ef4444',
  194. 'Empty_Bunch': '#94a3b8',
  195. 'Ripe': '#22c55e'
  196. })
  197. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  198. st.plotly_chart(fig_batch, width='stretch')
  199. if batch_summary.get('Abnormal', 0) > 0:
  200. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  201. st.write("Generated Record IDs:")
  202. st.code(res_data['record_ids'])
  203. if st.button("Clear Results & Start New Batch"):
  204. st.session_state.last_batch_results = None
  205. st.rerun()
  206. st.divider()
  207. # 3. Uploader UI
  208. col_batch1, col_batch2 = st.columns([4, 1])
  209. with col_batch1:
  210. uploaded_files = st.file_uploader(
  211. "Upload multiple images...",
  212. type=["jpg", "jpeg", "png"],
  213. accept_multiple_files=True,
  214. key=f"batch_{st.session_state.batch_uploader_key}",
  215. on_change=reset_batch_results
  216. )
  217. with col_batch2:
  218. st.write("##") # Alignment
  219. if st.button("🗑️ Reset Uploader"):
  220. st.session_state.batch_uploader_key += 1
  221. st.rerun()
  222. if uploaded_files:
  223. if st.button(f"🚀 Process {len(uploaded_files)} Images"):
  224. with st.spinner("Batch Processing in progress..."):
  225. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  226. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  227. if res.status_code == 200:
  228. data = res.json()
  229. if data["status"] == "success":
  230. st.session_state.last_batch_results = data
  231. st.session_state.batch_uploader_key += 1
  232. st.rerun()
  233. elif data["status"] == "partial_success":
  234. st.warning(data["message"])
  235. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  236. else:
  237. st.error(f"Batch Error: {data['message']}")
  238. else:
  239. st.error(f"Batch Failed: {res.text}")
  240. # --- Tab 3: Similarity Search ---
  241. with tab3:
  242. st.subheader("Hybrid Semantic Search")
  243. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  244. with st.form("hybrid_search_form"):
  245. col_input1, col_input2 = st.columns(2)
  246. with col_input1:
  247. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  248. with col_input2:
  249. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  250. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  251. submit_search = st.form_submit_button("Run Semantic Search")
  252. if submit_search:
  253. if not search_file and not text_query:
  254. st.warning("Please provide either an image or a text query.")
  255. else:
  256. with st.spinner("Searching Vector Index..."):
  257. payload = {"limit": top_k}
  258. # If an image is uploaded, it takes precedence for visual search
  259. if search_file:
  260. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  261. # Pass top_k as part of the data
  262. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  263. # Otherwise, use text query
  264. elif text_query:
  265. payload["text_query"] = text_query
  266. # Send as form-data (data=) to match FastAPI's Form(None)
  267. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  268. if res.status_code == 200:
  269. results = res.json().get("results", [])
  270. if not results:
  271. st.warning("No similar records found.")
  272. else:
  273. st.success(f"Found {len(results)} matches.")
  274. for item in results:
  275. with st.container(border=True):
  276. c1, c2 = st.columns([1, 2])
  277. # Fetch the image for this result
  278. rec_id = item["_id"]
  279. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  280. with c1:
  281. if img_res.status_code == 200:
  282. img_b64 = img_res.json().get("image_data")
  283. if img_b64:
  284. st.image(base64.b64decode(img_b64), width=250)
  285. else:
  286. st.write("No image data found.")
  287. else:
  288. st.write("Failed to load image.")
  289. with c2:
  290. st.write(f"**Class:** {item['ripeness_class']}")
  291. st.write(f"**Similarity Score:** {item['score']:.4f}")
  292. st.write(f"**Timestamp:** {item['timestamp']}")
  293. st.write(f"**ID:** `{rec_id}`")
  294. else:
  295. st.error(f"Search failed: {res.text}")