Parcourir la source

semantic search fixes and enhancements

Dr-Swopt il y a 3 jours
Parent
commit
54f96e9219

+ 119 - 39
demo_app.py

@@ -4,10 +4,26 @@ from PIL import Image
 import io
 import base64
 
-# --- API Configuration ---
+# --- 1. Global Backend Check ---
 API_BASE_URL = "http://localhost:8000"
 
-# --- Page Config ---
+def check_backend():
+    try:
+        res = requests.get(f"{API_BASE_URL}/get_confidence", timeout=2)
+        return res.status_code == 200
+    except:
+        return False
+
+backend_active = check_backend()
+
+if not backend_active:
+    st.error("⚠️ Backend API is offline!")
+    st.info("Please start the backend server first (e.g., `python main.py`) to unlock AI features.")
+    if st.button("🔄 Retry Connection"):
+        st.rerun()
+    st.stop() # Stops execution here, effectively disabling the app
+
+# --- 2. Main Page Config (Only rendered if backend is active) ---
 st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
 st.title("🌴 Palm Oil FFB Management System")
 st.markdown("### Production-Ready AI Analysis & Archival")
@@ -23,24 +39,19 @@ def update_confidence():
     except:
         st.sidebar.error("Failed to update threshold")
 
-try:
-    response = requests.get(f"{API_BASE_URL}/get_confidence")
-    if response.status_code == 200:
-        current_conf = response.json().get("current_confidence", 0.25)
-        st.sidebar.success(f"Connected to API")
-        
-        # Synchronized Slider
-        st.sidebar.slider(
-            "Confidence Threshold", 
-            0.1, 1.0, 
-            value=float(current_conf),
-            key="conf_slider",
-            on_change=update_confidence
-        )
-    else:
-        st.sidebar.error("API Error")
-except:
-    st.sidebar.error("Could not connect to Backend API. Please ensure it is running.")
+# We already know backend is up here
+response = requests.get(f"{API_BASE_URL}/get_confidence")
+current_conf = response.json().get("current_confidence", 0.25)
+st.sidebar.success(f"Connected to API")
+
+# Synchronized Slider
+st.sidebar.slider(
+    "Confidence Threshold", 
+    0.1, 1.0, 
+    value=float(current_conf),
+    key="conf_slider",
+    on_change=update_confidence
+)
 
 # --- Tabs ---
 tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
@@ -53,7 +64,7 @@ with tab1:
     if uploaded_file:
         col1, col2 = st.columns(2)
         with col1:
-            st.image(uploaded_file, caption="Input", use_container_width=True)
+            st.image(uploaded_file, caption="Input", width=500)
         
         with col2:
             if st.button("Run Full Analysis"):
@@ -73,48 +84,117 @@ with tab1:
 # --- Tab 2: Batch Processing ---
 with tab2:
     st.subheader("Bulk Analysis")
-    uploaded_files = st.file_uploader("Upload multiple images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True, key="batch")
+    
+    # 1. Initialize Session State
+    if "batch_uploader_key" not in st.session_state:
+        st.session_state.batch_uploader_key = 0
+    if "last_batch_results" not in st.session_state:
+        st.session_state.last_batch_results = None
+
+    # 2. Display Persisted Results (if any)
+    if st.session_state.last_batch_results:
+        res_data = st.session_state.last_batch_results
+        with st.container(border=True):
+            st.success(f"✅ Successfully processed {res_data['processed_count']} images.")
+            st.write("Generated Record IDs:")
+            st.code(res_data['record_ids'])
+            if st.button("Clear Results & Start New Batch"):
+                st.session_state.last_batch_results = None
+                st.rerun()
+        st.divider()
+
+    # 3. Uploader UI
+    col_batch1, col_batch2 = st.columns([4, 1])
+    with col_batch1:
+        uploaded_files = st.file_uploader(
+            "Upload multiple images...", 
+            type=["jpg", "jpeg", "png"], 
+            accept_multiple_files=True, 
+            key=f"batch_{st.session_state.batch_uploader_key}"
+        )
+    
+    with col_batch2:
+        st.write("##") # Alignment
+        if st.button("🗑️ Reset Uploader"):
+            st.session_state.batch_uploader_key += 1
+            st.rerun()
     
     if uploaded_files:
-        if st.button(f"Process {len(uploaded_files)} Images"):
+        if st.button(f"🚀 Process {len(uploaded_files)} Images"):
             with st.spinner("Batch Processing in progress..."):
                 files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
                 res = requests.post(f"{API_BASE_URL}/analyze_batch", files=files)
                 
                 if res.status_code == 200:
-                    data = res.json()
-                    st.success(f"Successfully processed {data['processed_count']} images.")
-                    st.write("Generated Record IDs:")
-                    st.code(data['record_ids'])
+                    # 4. Success: Store results and Clear Uploader automatically
+                    st.session_state.last_batch_results = res.json()
+                    st.session_state.batch_uploader_key += 1
+                    st.rerun()
                 else:
-                    st.error("Batch Failed")
+                    st.error(f"Batch Failed: {res.text}")
 
 # --- Tab 3: Similarity Search ---
 with tab3:
-    st.subheader("Atlas Vector Search")
-    st.markdown("Upload an image to find the most similar historical records in the database.")
-    search_file = st.file_uploader("Search Image...", type=["jpg", "jpeg", "png"], key="search")
+    st.subheader("Hybrid Semantic Search")
+    st.markdown("Search records by either **Image Similarity** or **Natural Language Query**.")
     
-    if search_file:
-        st.image(search_file, width=300)
-        if st.button("Find Similar Bunches"):
+    with st.form("hybrid_search_form"):
+        col_input1, col_input2 = st.columns(2)
+        
+        with col_input1:
+            search_file = st.file_uploader("Option A: Search Image...", type=["jpg", "jpeg", "png"], key="search")
+        
+        with col_input2:
+            text_query = st.text_input("Option B: Natural Language Query", placeholder="e.g., 'ripe bunches with dark spots' or 'unripe fruit'")
+            top_k = st.slider("Results Limit (Top K)", 1, 20, 3)
+
+        submit_search = st.form_submit_button("Run Semantic Search")
+
+    if submit_search:
+        if not search_file and not text_query:
+            st.warning("Please provide either an image or a text query.")
+        else:
             with st.spinner("Searching Vector Index..."):
-                files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
-                res = requests.post(f"{API_BASE_URL}/search", files=files)
+                payload = {"limit": top_k}
+                
+                # If an image is uploaded, it takes precedence for visual search
+                if search_file:
+                    files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
+                    # Pass top_k as part of the data
+                    res = requests.post(f"{API_BASE_URL}/search_hybrid", files=files, data=payload)
+                # Otherwise, use text query
+                elif text_query:
+                    payload["text_query"] = text_query
+                    # Send as form-data (data=) to match FastAPI's Form(None)
+                    res = requests.post(f"{API_BASE_URL}/search_hybrid", data=payload)
                 
                 if res.status_code == 200:
                     results = res.json().get("results", [])
                     if not results:
                         st.warning("No similar records found.")
                     else:
+                        st.success(f"Found {len(results)} matches.")
                         for item in results:
                             with st.container(border=True):
                                 c1, c2 = st.columns([1, 2])
-                                # Note: Actual prod app would fetch image_data by id here
-                                # For demo, we show the textual metadata
+                                # Fetch the image for this result
+                                rec_id = item["_id"]
+                                img_res = requests.get(f"{API_BASE_URL}/get_image/{rec_id}")
+                                
+                                with c1:
+                                    if img_res.status_code == 200:
+                                        img_b64 = img_res.json().get("image_data")
+                                        if img_b64:
+                                            st.image(base64.b64decode(img_b64), width=250)
+                                        else:
+                                            st.write("No image data found.")
+                                    else:
+                                        st.write("Failed to load image.")
+
                                 with c2:
                                     st.write(f"**Class:** {item['ripeness_class']}")
                                     st.write(f"**Similarity Score:** {item['score']:.4f}")
                                     st.write(f"**Timestamp:** {item['timestamp']}")
+                                    st.write(f"**ID:** `{rec_id}`")
                 else:
-                    st.error("Search failed")
+                    st.error(f"Search failed: {res.text}")

+ 35 - 17
src/api/main.py

@@ -2,7 +2,7 @@ from typing import List, Optional
 import uuid
 import os
 import shutil
-from fastapi import FastAPI, File, UploadFile, Body
+from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
 from ultralytics import YOLO
 from dotenv import load_dotenv
 import io
@@ -170,26 +170,44 @@ async def analyze_batch(files: List[UploadFile] = File(...)):
         for path in temp_files:
             if os.path.exists(path):
                 os.remove(path)
-
-@app.post("/search")
-async def search_similar(file: UploadFile = File(...), limit: int = 3):
-    """Atlas Vector Search: Find similar palm oil bunches by image."""
-    unique_id = uuid.uuid4().hex
-    temp_path = f"temp_search_{unique_id}_{file.filename}"
-    
-    with open(temp_path, "wb") as buffer:
-        shutil.copyfileobj(file.file, buffer)
-
+@app.post("/search_hybrid")
+async def search_hybrid(
+    file: Optional[UploadFile] = File(None), 
+    text_query: Optional[str] = Form(None), 
+    limit: int = Form(3)
+):
+    """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
+    temp_path = None
     try:
-        results = search_use_case.execute(temp_path, limit=limit)
-        return {
-            "status": "success",
-            "results": results
-        }
+        if file:
+            unique_id = uuid.uuid4().hex
+            temp_path = f"temp_search_{unique_id}_{file.filename}"
+            with open(temp_path, "wb") as buffer:
+                shutil.copyfileobj(file.file, buffer)
+            
+            results = search_use_case.execute(image_path=temp_path, limit=limit)
+        elif text_query:
+            results = search_use_case.execute(text_query=text_query, limit=limit)
+        else:
+            return {"status": "error", "message": "No search input provided"}
+
+        return {"status": "success", "results": results}
+        
     finally:
-        if os.path.exists(temp_path):
+        if temp_path and os.path.exists(temp_path):
             os.remove(temp_path)
 
+@app.get("/get_image/{record_id}")
+async def get_image(record_id: str):
+    """Retrieve the Base64 image data for a specific record."""
+    record = repo.get_by_id(record_id)
+    if not record:
+        return {"status": "error", "message": "Record not found"}
+    return {
+        "status": "success",
+        "image_data": record.get("image_data")
+    }
+
 if __name__ == "__main__":
     import uvicorn
     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 11 - 4
src/application/analyze_bunch.py

@@ -62,9 +62,16 @@ class SearchSimilarUseCase:
         self.vision_service = vision_service
         self.repo = repo
 
-    def execute(self, image_path: str, limit: int = 3):
-        # 1. Vectorize the query image
-        query_vector = self.vision_service.get_image_embedding(image_path)
+    def execute(self, image_path: str = None, text_query: str = None, limit: int = 3):
+        """Supports both visual similarity and natural language search."""
+        query_vector = None
         
-        # 2. Perform vector search in repository
+        if image_path:
+            query_vector = self.vision_service.get_image_embedding(image_path)
+        elif text_query:
+            query_vector = self.vision_service.get_text_embedding(text_query)
+        
+        if not query_vector:
+            raise ValueError("Must provide either an image or a text query.")
+
         return self.repo.vector_search(query_vector, limit)

+ 10 - 2
src/infrastructure/repository.py

@@ -1,3 +1,4 @@
+from bson import ObjectId
 from pymongo import MongoClient
 from src.domain.models import PalmOilBunch
 
@@ -6,6 +7,10 @@ class MongoPalmOilRepository:
         self.client = MongoClient(uri)
         self.collection = self.client[db_name][collection_name]
 
+    def get_by_id(self, record_id: str):
+        """Retrieve a specific record by its ID."""
+        return self.collection.find_one({"_id": ObjectId(record_id)})
+
     def save(self, bunch: PalmOilBunch):
         # Convert dataclass to dict for MongoDB
         doc = bunch.__dict__.copy()
@@ -45,9 +50,12 @@ class MongoPalmOilRepository:
             {
                 "$project": {
                     "embedding": 0, 
-                    "image_data": 0, # Exclude for speed; fetch by ID if needed
+                    "image_data": 0,
                     "score": {"$meta": "vectorSearchScore"}
                 }
             }
         ]
-        return list(self.collection.aggregate(pipeline))
+        results = list(self.collection.aggregate(pipeline))
+        for res in results:
+            res["_id"] = str(res["_id"])
+        return results

+ 8 - 0
src/infrastructure/vision_service.py

@@ -21,6 +21,14 @@ class VertexVisionService:
         embeddings = self.model.get_embeddings(image=image, dimension=1408)
         return embeddings.image_embedding
 
+    def get_text_embedding(self, text: str) -> List[float]:
+        """Converts text query into a 1408-d vector."""
+        embeddings = self.model.get_embeddings(
+            contextual_text=text, 
+            dimension=1408
+        )
+        return embeddings.text_embedding
+
     def encode_image_to_base64(self, image_path: str) -> str:
         """Resizes image to 640x640 and encodes to Base64."""
         with PILImage.open(image_path) as img: