demo_app.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import streamlit as st
  2. import requests
  3. from PIL import Image
  4. import io
  5. import base64
  6. # --- 1. Global Backend Check ---
  7. API_BASE_URL = "http://localhost:8000"
  8. def check_backend():
  9. try:
  10. res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
  11. return res.status_code == 200
  12. except:
  13. return False
  14. backend_active = check_backend()
  15. if not backend_active:
  16. st.error("⚠️ Backend API is offline!")
  17. st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
  18. if st.button("🔄 Retry Connection"):
  19. st.rerun()
  20. st.stop() # Stops execution here, effectively disabling the app
  21. # --- 2. Main Page Config (Only rendered if backend is active) ---
  22. st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
  23. st.title("🌴 Palm Oil FFB Management System")
  24. st.markdown("### Production-Ready AI Analysis & Archival")
  25. # --- Sidebar ---
  26. st.sidebar.header("Backend Controls")
  27. def update_confidence():
  28. new_conf = st.session_state.conf_slider
  29. try:
  30. requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
  31. st.toast(f"Threshold updated to {new_conf}")
  32. except:
  33. st.sidebar.error("Failed to update threshold")
  34. # We already know backend is up here
  35. response = requests.get(f"{API_BASE_URL}/get_confidence")
  36. current_conf = response.json().get("current_confidence", 0.25)
  37. st.sidebar.success(f"Connected to API")
  38. # Synchronized Slider
  39. st.sidebar.slider(
  40. "Confidence Threshold",
  41. 0.1, 1.0,
  42. value=float(current_conf),
  43. key="conf_slider",
  44. on_change=update_confidence
  45. )
  46. # --- Tabs ---
  47. tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
  48. # --- Tab 1: Single Analysis ---
  49. with tab1:
  50. st.subheader("Analyze Single Bunch")
  51. uploaded_file = st.file_uploader("Upload a bunch image...", type=["jpg", "jpeg", "png"], key="single")
  52. if uploaded_file:
  53. # State initialization
  54. if "last_detection" not in st.session_state:
  55. st.session_state.last_detection = None
  56. # 1. Action Button (Centered and Prominent)
  57. st.write("##")
  58. _, col_btn, _ = st.columns([1, 2, 1])
  59. if col_btn.button("🔍 Run Ripeness Detection", type="primary", use_container_width=True):
  60. with st.spinner("Processing Detections Locally..."):
  61. files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  62. res = requests.post(f"{API_BASE_URL}/analyze", files=files)
  63. if res.status_code == 200:
  64. st.session_state.last_detection = res.json()
  65. else:
  66. st.error(f"Detection Failed: {res.text}")
  67. # 2. Results Layout
  68. if st.session_state.last_detection:
  69. st.divider()
  70. col1, col2 = st.columns([1.5, 1])
  71. with col1:
  72. st.image(uploaded_file, caption="Analyzed Image", use_container_width=True)
  73. with col2:
  74. data = st.session_state.last_detection
  75. with st.container(border=True):
  76. st.write("### 🏷️ Detection Results")
  77. if not data['detections']:
  78. st.warning("No Fresh Fruit Bunches detected.")
  79. else:
  80. for det in data['detections']:
  81. st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
  82. # 3. Cloud Actions (Only if detections found)
  83. st.write("---")
  84. st.write("#### ✨ Cloud Archive")
  85. if st.button("🚀 Save to Atlas (Vectorize)", use_container_width=True):
  86. with st.spinner("Archiving..."):
  87. import json
  88. primary_det = data['detections'][0]
  89. payload = {"detection_data": json.dumps(primary_det)}
  90. files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
  91. res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
  92. if res_cloud.status_code == 200:
  93. res_json = res_cloud.json()
  94. if res_json["status"] == "success":
  95. st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
  96. else:
  97. st.error(f"Cloud Error: {res_json['message']}")
  98. else:
  99. st.error("Failed to connect to cloud service")
  100. # --- Tab 2: Batch Processing ---
  101. with tab2:
  102. st.subheader("Bulk Analysis")
  103. # 1. Initialize Session State
  104. if "batch_uploader_key" not in st.session_state:
  105. st.session_state.batch_uploader_key = 0
  106. if "last_batch_results" not in st.session_state:
  107. st.session_state.last_batch_results = None
  108. # 2. Display Persisted Results (if any)
  109. if st.session_state.last_batch_results:
  110. res_data = st.session_state.last_batch_results
  111. with st.container(border=True):
  112. st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
  113. st.write("Generated Record IDs:")
  114. st.code(res_data['record_ids'])
  115. if st.button("Clear Results & Start New Batch"):
  116. st.session_state.last_batch_results = None
  117. st.rerun()
  118. st.divider()
  119. # 3. Uploader UI
  120. col_batch1, col_batch2 = st.columns([4, 1])
  121. with col_batch1:
  122. uploaded_files = st.file_uploader(
  123. "Upload multiple images...",
  124. type=["jpg", "jpeg", "png"],
  125. accept_multiple_files=True,
  126. key=f"batch_{st.session_state.batch_uploader_key}"
  127. )
  128. with col_batch2:
  129. st.write("##") # Alignment
  130. if st.button("🗑️ Reset Uploader"):
  131. st.session_state.batch_uploader_key += 1
  132. st.rerun()
  133. if uploaded_files:
  134. if st.button(f"🚀 Process {len(uploaded_files)} Images"):
  135. with st.spinner("Batch Processing in progress..."):
  136. files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
  137. res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
  138. if res.status_code == 200:
  139. data = res.json()
  140. if data["status"] == "success":
  141. st.session_state.last_batch_results = data
  142. st.session_state.batch_uploader_key += 1
  143. st.rerun()
  144. elif data["status"] == "partial_success":
  145. st.warning(data["message"])
  146. st.info(f"Successfully detected {data['detections_count']} bunches locally.")
  147. else:
  148. st.error(f"Batch Error: {data['message']}")
  149. else:
  150. st.error(f"Batch Failed: {res.text}")
  151. # --- Tab 3: Similarity Search ---
  152. with tab3:
  153. st.subheader("Hybrid Semantic Search")
  154. st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
  155. with st.form("hybrid_search_form"):
  156. col_input1, col_input2 = st.columns(2)
  157. with col_input1:
  158. search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
  159. with col_input2:
  160. text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
  161. top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
  162. submit_search = st.form_submit_button("Run Semantic Search")
  163. if submit_search:
  164. if not search_file and not text_query:
  165. st.warning("Please provide either an image or a text query.")
  166. else:
  167. with st.spinner("Searching Vector Index..."):
  168. payload = {"limit": top_k}
  169. # If an image is uploaded, it takes precedence for visual search
  170. if search_file:
  171. files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
  172. # Pass top_k as part of the data
  173. res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
  174. # Otherwise, use text query
  175. elif text_query:
  176. payload["text_query"] = text_query
  177. # Send as form-data (data=) to match FastAPI's Form(None)
  178. res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
  179. if res.status_code == 200:
  180. results = res.json().get("results", [])
  181. if not results:
  182. st.warning("No similar records found.")
  183. else:
  184. st.success(f"Found {len(results)} matches.")
  185. for item in results:
  186. with st.container(border=True):
  187. c1, c2 = st.columns([1, 2])
  188. # Fetch the image for this result
  189. rec_id = item["_id"]
  190. img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
  191. with c1:
  192. if img_res.status_code == 200:
  193. img_b64 = img_res.json().get("image_data")
  194. if img_b64:
  195. st.image(base64.b64decode(img_b64), width=250)
  196. else:
  197. st.write("No image data found.")
  198. else:
  199. st.write("Failed to load image.")
  200. with c2:
  201. st.write(f"**Class:** {item['ripeness_class']}")
  202. st.write(f"**Similarity Score:** {item['score']:.4f}")
  203. st.write(f"**Timestamp:** {item['timestamp']}")
  204. st.write(f"**ID:** `{rec_id}`")
  205. else:
  206. st.error(f"Search failed: {res.text}")