demo_app.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import streamlit as st
  2. import requests
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import base64
  8. import pandas as pd
  9. import plotly.express as px
  10. # --- 1. Global Backend Check ---
  11. API_BASE_URL = "http://localhost:8000"
  12. def check_backend():
  13. try:
  14. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  15. return res.status_code == 200
  16. except:
  17. return False
  18. backend_active = check_backend()
  19. # Load YOLO model locally for Analytical View
  20. @st.cache_resource
  21. def load_yolo():
  22. return YOLO('best.pt')
  23. yolo_model = load_yolo()
  24. if not backend_active:
  25. st.error("⚠️ Backend API is offline!")
  26. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  27. if st.button("🔄 Retry Connection"):
  28. st.rerun()
  29. st.stop() # Stops execution here, effectively disabling the app
  30. # --- 2. Main Page Config (Only rendered if backend is active) ---
  31. st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
  32. st.title("🌴 Palm Oil FFB Management System")
  33. st.markdown("### Production-Ready AI Analysis & Archival")
  34. # --- Sidebar ---
  35. st.sidebar.header("Backend Controls")
  36. def update_confidence():
  37. new_conf = st.session_state.conf_slider
  38. try:
  39. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  40. st.toast(f"Threshold updated to {new_conf}")
  41. except:
  42. st.sidebar.error("Failed to update threshold")
  43. # We already know backend is up here
  44. response = requests.get(f"{API_BASE_URL}/get_confidence")
  45. current_conf = response.json().get("current_confidence", 0.25)
  46. st.sidebar.success(f"Connected to API")
  47. # Synchronized Slider
  48. st.sidebar.slider(
  49. "Confidence Threshold",
  50. 0.1, 1.0,
  51. value=float(current_conf),
  52. key="conf_slider",
  53. on_change=update_confidence
  54. )
  55. # Helper to reset results when files change
  56. def reset_single_results():
  57. st.session_state.last_detection = None
  58. def reset_batch_results():
  59. st.session_state.last_batch_results = None
  60. # --- Tabs ---
  61. tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
  62. # --- Tab 1: Single Analysis ---
  63. with tab1:
  64. st.subheader("Analyze Single Bunch")
  65. uploaded_file = st.file_uploader(
  66. "Upload a bunch image...",
  67. type=["jpg", "jpeg", "png"],
  68. key="single",
  69. on_change=reset_single_results
  70. )
  71. if uploaded_file:
  72. # State initialization
  73. if "last_detection" not in st.session_state:
  74. st.session_state.last_detection = None
  75. # 1. Auto-Detection Trigger
  76. if uploaded_file and st.session_state.last_detection is None:
  77. with st.spinner("Processing Detections Locally..."):
  78. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  79. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  80. if res.status_code == 200:
  81. st.session_state.last_detection = res.json()
  82. st.rerun() # Refresh to show results immediately
  83. else:
  84. st.error(f"Detection Failed: {res.text}")
  85. # 2. Results Layout
  86. if st.session_state.last_detection:
  87. st.divider()
  88. # SIDE-BY-SIDE ANALYTICAL VIEW
  89. col_left, col_right = st.columns(2)
  90. # Fetch data once
  91. data = st.session_state.last_detection
  92. with col_left:
  93. st.image(uploaded_file, caption="Original Photo", width='stretch')
  94. with col_right:
  95. # Use the local model to plot the boxes directly
  96. img = Image.open(uploaded_file)
  97. results = yolo_model(img, conf=current_conf, agnostic_nms=True, iou=0.4)
  98. annotated_img = results[0].plot() # Draws boxes/labels
  99. # Convert BGR (OpenCV format) to RGB for Streamlit
  100. annotated_img_rgb = annotated_img[:, :, ::-1]
  101. st.image(annotated_img_rgb, caption="AI Analytical View (X-Ray)", width='stretch')
  102. st.write("### 📈 Manager's Dashboard")
  103. m_col1, m_col2, m_col3 = st.columns(3)
  104. with m_col1:
  105. st.metric("Total Bunches", data.get('total_count', 0))
  106. with m_col2:
  107. st.metric("Healthy (Ripe)", data['industrial_summary'].get('Ripe', 0))
  108. with m_col3:
  109. abnormal = data['industrial_summary'].get('Abnormal', 0)
  110. st.metric("Abnormal Alerts", abnormal, delta=-abnormal, delta_color="inverse")
  111. col1, col2 = st.columns([1.5, 1]) # Keep original col structure for summary below
  112. with col2:
  113. with st.container(border=True):
  114. st.write("### 🏷️ Detection Results")
  115. if not data['detections']:
  116. st.warning("No Fresh Fruit Bunches detected.")
  117. else:
  118. for det in data['detections']:
  119. st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
  120. st.write("### 📊 Harvest Quality Mix")
  121. # Convert industrial_summary dictionary to a DataFrame for charting
  122. summary_df = pd.DataFrame(
  123. list(data['industrial_summary'].items()),
  124. columns=['Grade', 'Count']
  125. )
  126. # Filter out classes with 0 count for a cleaner chart
  127. summary_df = summary_df[summary_df['Count'] > 0]
  128. if not summary_df.empty:
  129. # Create a Pie Chart to show the proportion of each grade
  130. fig = px.pie(summary_df, values='Count', names='Grade',
  131. color='Grade',
  132. color_discrete_map={
  133. 'Abnormal': '#ef4444', # Red
  134. 'Empty_Bunch': '#94a3b8', # Gray
  135. 'Ripe': '#22c55e', # Green
  136. 'Underripe': '#eab308', # Yellow
  137. 'Unripe': '#3b82f6', # Blue
  138. 'Overripe': '#a855f7' # Purple
  139. },
  140. hole=0.4)
  141. fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
  142. st.plotly_chart(fig, width='stretch')
  143. # 💡 Strategic R&D Insight: Harvest Efficiency
  144. st.write("---")
  145. st.write("#### 💡 Strategic R&D Insight")
  146. unripe_count = data['industrial_summary'].get('Unripe', 0)
  147. underripe_count = data['industrial_summary'].get('Underripe', 0)
  148. total_non_prime = unripe_count + underripe_count
  149. st.write(f"🌑 **Unripe (Mentah):** {unripe_count}")
  150. st.write(f"🌗 **Underripe (Kurang Masak):** {underripe_count}")
  151. if total_non_prime > 0:
  152. st.warning(f"🚨 **Potential Yield Loss:** {total_non_prime} bunches harvested too early. This will reduce OER (Oil Extraction Rate).")
  153. else:
  154. st.success("✅ **Harvest Efficiency:** 100% Prime Ripeness detected.")
  155. # High-Priority Health Alert
  156. if data['industrial_summary'].get('Abnormal', 0) > 0:
  157. st.error(f"🚨 CRITICAL: {data['industrial_summary']['Abnormal']} Abnormal Bunches Detected!")
  158. if data['industrial_summary'].get('Empty_Bunch', 0) > 0:
  159. st.warning(f"⚠️ ALERT: {data['industrial_summary']['Empty_Bunch']} Empty Bunches Detected.")
  160. # 3. Cloud Actions (Only if detections found)
  161. st.write("---")
  162. st.write("#### ✨ Cloud Archive")
  163. if st.button("🚀 Save to Atlas (Vectorize)", width='stretch'):
  164. with st.spinner("Archiving..."):
  165. import json
  166. primary_det = data['detections'][0]
  167. payload = {"detection_data": json.dumps(primary_det)}
  168. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  169. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  170. if res_cloud.status_code == 200:
  171. res_json = res_cloud.json()
  172. if res_json["status"] == "success":
  173. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  174. else:
  175. st.error(f"Cloud Error: {res_json['message']}")
  176. else:
  177. st.error("Failed to connect to cloud service")
  178. # --- Tab 2: Batch Processing ---
  179. with tab2:
  180. st.subheader("Bulk Analysis")
  181. # 1. Initialize Session State
  182. if "batch_uploader_key" not in st.session_state:
  183. st.session_state.batch_uploader_key = 0
  184. if "last_batch_results" not in st.session_state:
  185. st.session_state.last_batch_results = None
  186. # 2. Display Persisted Results (if any)
  187. if st.session_state.last_batch_results:
  188. res_data = st.session_state.last_batch_results
  189. with st.container(border=True):
  190. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  191. # Batch Summary Dashboard
  192. st.write("### 📈 Batch Quality Overview")
  193. batch_summary = res_data.get('industrial_summary', {})
  194. if batch_summary:
  195. sum_df = pd.DataFrame(list(batch_summary.items()), columns=['Grade', 'Count'])
  196. sum_df = sum_df[sum_df['Count'] > 0]
  197. b_col1, b_col2 = st.columns([1, 1])
  198. with b_col1:
  199. st.dataframe(sum_df, hide_index=True, width='stretch')
  200. with b_col2:
  201. if not sum_df.empty:
  202. fig_batch = px.bar(sum_df, x='Grade', y='Count', color='Grade',
  203. color_discrete_map={
  204. 'Abnormal': '#ef4444',
  205. 'Empty_Bunch': '#94a3b8',
  206. 'Ripe': '#22c55e'
  207. })
  208. fig_batch.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=200, showlegend=False)
  209. st.plotly_chart(fig_batch, width='stretch')
  210. if batch_summary.get('Abnormal', 0) > 0:
  211. st.error(f"🚨 BATCH CRITICAL: {batch_summary['Abnormal']} Abnormal Bunches found in this batch!")
  212. st.write("Generated Record IDs:")
  213. st.code(res_data['record_ids'])
  214. if st.button("Clear Results & Start New Batch"):
  215. st.session_state.last_batch_results = None
  216. st.rerun()
  217. st.divider()
  218. # 3. Uploader UI
  219. col_batch1, col_batch2 = st.columns([4, 1])
  220. with col_batch1:
  221. uploaded_files = st.file_uploader(
  222. "Upload multiple images...",
  223. type=["jpg", "jpeg", "png"],
  224. accept_multiple_files=True,
  225. key=f"batch_{st.session_state.batch_uploader_key}",
  226. on_change=reset_batch_results
  227. )
  228. with col_batch2:
  229. st.write("##") # Alignment
  230. if st.session_state.last_batch_results is None and uploaded_files:
  231. if st.button("🔍 Process Batch", type="primary", width='stretch'):
  232. with st.spinner(f"Analyzing {len(uploaded_files)} images..."):
  233. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  234. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  235. if res.status_code == 200:
  236. data = res.json()
  237. if data["status"] == "success":
  238. st.session_state.last_batch_results = data
  239. st.session_state.batch_uploader_key += 1
  240. st.rerun()
  241. elif data["status"] == "partial_success":
  242. st.warning(data["message"])
  243. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  244. else:
  245. st.error(f"Batch Error: {data['message']}")
  246. else:
  247. st.error(f"Batch Processing Failed: {res.text}")
  248. if st.button("🗑️ Reset Uploader"):
  249. st.session_state.batch_uploader_key += 1
  250. st.session_state.last_batch_results = None
  251. st.rerun()
  252. # --- Tab 3: Similarity Search ---
  253. with tab3:
  254. st.subheader("Hybrid Semantic Search")
  255. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  256. with st.form("hybrid_search_form"):
  257. col_input1, col_input2 = st.columns(2)
  258. with col_input1:
  259. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  260. with col_input2:
  261. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  262. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  263. submit_search = st.form_submit_button("Run Semantic Search")
  264. if submit_search:
  265. if not search_file and not text_query:
  266. st.warning("Please provide either an image or a text query.")
  267. else:
  268. with st.spinner("Searching Vector Index..."):
  269. payload = {"limit": top_k}
  270. # If an image is uploaded, it takes precedence for visual search
  271. if search_file:
  272. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  273. # Pass top_k as part of the data
  274. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  275. # Otherwise, use text query
  276. elif text_query:
  277. payload["text_query"] = text_query
  278. # Send as form-data (data=) to match FastAPI's Form(None)
  279. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  280. if res.status_code == 200:
  281. results = res.json().get("results", [])
  282. if not results:
  283. st.warning("No similar records found.")
  284. else:
  285. st.success(f"Found {len(results)} matches.")
  286. for item in results:
  287. with st.container(border=True):
  288. c1, c2 = st.columns([1, 2])
  289. # Fetch the image for this result
  290. rec_id = item["_id"]
  291. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  292. with c1:
  293. if img_res.status_code == 200:
  294. img_b64 = img_res.json().get("image_data")
  295. if img_b64:
  296. st.image(base64.b64decode(img_b64), width=250)
  297. else:
  298. st.write("No image data found.")
  299. else:
  300. st.write("Failed to load image.")
  301. with c2:
  302. st.write(f"**Class:** {item['ripeness_class']}")
  303. st.write(f"**Similarity Score:** {item['score']:.4f}")
  304. st.write(f"**Timestamp:** {item['timestamp']}")
  305. st.write(f"**ID:** `{rec_id}`")
  306. else:
  307. st.error(f"Search failed: {res.text}")