demo_app.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. # --- 1. Global Backend Check ---
  11. API_BASE_URL = "http://localhost:8000"
  12. def check_backend():
  13. try:
  14. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  15. return res.status_code == 200
  16. except:
  17. return False
  18. backend_active = check_backend()
  19. # Load YOLO model locally for Analytical View
  20. @st.cache_resource
  21. def load_yolo():
  22. return YOLO('best.pt')
  23. yolo_model = load_yolo()
  24. if not backend_active:
  25. st.error("⚠️ Backend API is offline!")
  26. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  27. if st.button("🔄 Retry Connection"):
  28. st.rerun()
  29. st.stop() # Stops execution here, effectively disabling the app
  30. # --- 2. Main Page Config (Only rendered if backend is active) ---
  31. st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
  32. st.title("🌴 Palm Oil FFB Management System")
  33. st.markdown("### Production-Ready AI Analysis & Archival")
  34. # --- Sidebar ---
  35. st.sidebar.header("Backend Controls")
  36. def update_confidence():
  37. new_conf = st.session_state.conf_slider
  38. try:
  39. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  40. st.toast(f"Threshold updated to {new_conf}")
  41. except:
  42. st.sidebar.error("Failed to update threshold")
  43. # We already know backend is up here
  44. response = requests.get(f"{API_BASE_URL}/get_confidence")
  45. current_conf = response.json().get("current_confidence", 0.25)
  46. st.sidebar.success(f"Connected to API")
  47. # Synchronized Slider
  48. st.sidebar.slider(
  49. "Confidence Threshold",
  50. 0.1, 1.0,
  51. value=float(current_conf),
  52. key="conf_slider",
  53. on_change=update_confidence
  54. )
  55. # --- Tabs ---
  56. tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
  57. # --- Tab 1: Single Analysis ---
  58. with tab1:
  59. st.subheader("Analyze Single Bunch")
  60. uploaded_file = st.file_uploader("Upload a bunch image...", type=["jpg", "jpeg", "png"], key="single")
  61. if uploaded_file:
  62. # State initialization
  63. if "last_detection" not in st.session_state:
  64. st.session_state.last_detection = None
  65. # 1. Action Button (Centered and Prominent)
  66. st.write("##")
  67. _, col_btn, _ = st.columns([1, 2, 1])
  68. if col_btn.button("🔍 Run Ripeness Detection", type="primary", width='stretch'):
  69. with st.spinner("Processing Detections Locally..."):
  70. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  71. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  72. if res.status_code == 200:
  73. st.session_state.last_detection = res.json()
  74. else:
  75. st.error(f"Detection Failed: {res.text}")
  76. # 2. Results Layout
  77. if st.session_state.last_detection:
  78. st.divider()
  79. # SIDE-BY-SIDE ANALYTICAL VIEW
  80. col_left, col_right = st.columns(2)
  81. with col_left:
  82. st.image(uploaded_file, caption="Original Photo", width='stretch')
  83. with col_right:
  84. # Use the local model to plot the boxes directly
  85. img = Image.open(uploaded_file)
  86. results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
  87. annotated_img = results[0].plot() # Draws boxes/labels
  88. # Convert BGR (OpenCV format) to RGB for Streamlit
  89. annotated_img_rgb = annotated_img[:, :, ::-1]
  90. st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
  91. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  92. with col2:
  93. data = st.session_state.last_detection
  94. with st.container(border=True):
  95. st.write("### 🏷️ Detection Results")
  96. if not data['detections']:
  97. st.warning("No Fresh Fruit Bunches detected.")
  98. else:
  99. for det in data['detections']:
  100. st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
  101. st.write("### 📊 Harvest Quality Mix")
  102. # Convert industrial_summary dictionary to a DataFrame for charting
  103. summary_df = pd.DataFrame(
  104. list(data['industrial_summary'].items()),
  105. columns=['Grade', 'Count']
  106. )
  107. # Filter out classes with 0 count for a cleaner chart
  108. summary_df = summary_df[summary_df['Count'] > 0]
  109. if not summary_df.empty:
  110. # Create a Pie Chart to show the proportion of each grade
  111. fig = px.pie(summary_df, values='Count', names='Grade',
  112. color='Grade',
  113. color_discrete_map={
  114. 'Abnormal': '#ef4444', # Red
  115. 'Empty_Bunch': '#94a3b8', # Gray
  116. 'Ripe': '#22c55e', # Green
  117. 'Underripe': '#eab308', # Yellow
  118. 'Unripe': '#3b82f6', # Blue
  119. 'Overripe': '#a855f7' # Purple
  120. },
  121. hole=0.4)
  122. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  123. st.plotly_chart(fig, width='stretch')
  124. # High-Priority Health Alert
  125. if data['industrial_summary'].get('Abnormal', 0) > 0:
  126. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  127. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  128. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  129. # 3. Cloud Actions (Only if detections found)
  130. st.write("---")
  131. st.write("#### ✨ Cloud Archive")
  132. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  133. with st.spinner("Archiving..."):
  134. import json
  135. primary_det = data['detections'][0]
  136. payload = {"detection_data": json.dumps(primary_det)}
  137. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  138. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  139. if res_cloud.status_code == 200:
  140. res_json = res_cloud.json()
  141. if res_json["status"] == "success":
  142. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  143. else:
  144. st.error(f"Cloud Error: {res_json['message']}")
  145. else:
  146. st.error("Failed to connect to cloud service")
  147. # --- Tab 2: Batch Processing ---
  148. with tab2:
  149. st.subheader("Bulk Analysis")
  150. # 1. Initialize Session State
  151. if "batch_uploader_key" not in st.session_state:
  152. st.session_state.batch_uploader_key = 0
  153. if "last_batch_results" not in st.session_state:
  154. st.session_state.last_batch_results = None
  155. # 2. Display Persisted Results (if any)
  156. if st.session_state.last_batch_results:
  157. res_data = st.session_state.last_batch_results
  158. with st.container(border=True):
  159. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  160. # Batch Summary Dashboard
  161. st.write("### 📈 Batch Quality Overview")
  162. batch_summary = res_data.get('industrial_summary', {})
  163. if batch_summary:
  164. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  165. sum_df = sum_df[sum_df['Count'] > 0]
  166. b_col1, b_col2 = st.columns([1, 1])
  167. with b_col1:
  168. st.dataframe(sum_df, hide_index=True, width='stretch')
  169. with b_col2:
  170. if not sum_df.empty:
  171. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  172. color_discrete_map={
  173. 'Abnormal': '#ef4444',
  174. 'Empty_Bunch': '#94a3b8',
  175. 'Ripe': '#22c55e'
  176. })
  177. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  178. st.plotly_chart(fig_batch, width='stretch')
  179. if batch_summary.get('Abnormal', 0) > 0:
  180. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  181. st.write("Generated Record IDs:")
  182. st.code(res_data['record_ids'])
  183. if st.button("Clear Results & Start New Batch"):
  184. st.session_state.last_batch_results = None
  185. st.rerun()
  186. st.divider()
  187. # 3. Uploader UI
  188. col_batch1, col_batch2 = st.columns([4, 1])
  189. with col_batch1:
  190. uploaded_files = st.file_uploader(
  191. "Upload multiple images...",
  192. type=["jpg", "jpeg", "png"],
  193. accept_multiple_files=True,
  194. key=f"batch_{st.session_state.batch_uploader_key}"
  195. )
  196. with col_batch2:
  197. st.write("##") # Alignment
  198. if st.button("🗑️ Reset Uploader"):
  199. st.session_state.batch_uploader_key += 1
  200. st.rerun()
  201. if uploaded_files:
  202. if st.button(f"🚀 Process {len(uploaded_files)} Images"):
  203. with st.spinner("Batch Processing in progress..."):
  204. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  205. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  206. if res.status_code == 200:
  207. data = res.json()
  208. if data["status"] == "success":
  209. st.session_state.last_batch_results = data
  210. st.session_state.batch_uploader_key += 1
  211. st.rerun()
  212. elif data["status"] == "partial_success":
  213. st.warning(data["message"])
  214. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  215. else:
  216. st.error(f"Batch Error: {data['message']}")
  217. else:
  218. st.error(f"Batch Failed: {res.text}")
  219. # --- Tab 3: Similarity Search ---
  220. with tab3:
  221. st.subheader("Hybrid Semantic Search")
  222. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  223. with st.form("hybrid_search_form"):
  224. col_input1, col_input2 = st.columns(2)
  225. with col_input1:
  226. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  227. with col_input2:
  228. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  229. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  230. submit_search = st.form_submit_button("Run Semantic Search")
  231. if submit_search:
  232. if not search_file and not text_query:
  233. st.warning("Please provide either an image or a text query.")
  234. else:
  235. with st.spinner("Searching Vector Index..."):
  236. payload = {"limit": top_k}
  237. # If an image is uploaded, it takes precedence for visual search
  238. if search_file:
  239. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  240. # Pass top_k as part of the data
  241. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  242. # Otherwise, use text query
  243. elif text_query:
  244. payload["text_query"] = text_query
  245. # Send as form-data (data=) to match FastAPI's Form(None)
  246. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  247. if res.status_code == 200:
  248. results = res.json().get("results", [])
  249. if not results:
  250. st.warning("No similar records found.")
  251. else:
  252. st.success(f"Found {len(results)} matches.")
  253. for item in results:
  254. with st.container(border=True):
  255. c1, c2 = st.columns([1, 2])
  256. # Fetch the image for this result
  257. rec_id = item["_id"]
  258. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  259. with c1:
  260. if img_res.status_code == 200:
  261. img_b64 = img_res.json().get("image_data")
  262. if img_b64:
  263. st.image(base64.b64decode(img_b64), width=250)
  264. else:
  265. st.write("No image data found.")
  266. else:
  267. st.write("Failed to load image.")
  268. with c2:
  269. st.write(f"**Class:** {item['ripeness_class']}")
  270. st.write(f"**Similarity Score:** {item['score']:.4f}")
  271. st.write(f"**Timestamp:** {item['timestamp']}")
  272. st.write(f"**ID:** `{rec_id}`")
  273. else:
  274. st.error(f"Search failed: {res.text}")