Browse Source

separate out vectorization

Dr-Swopt 15 hours ago
parent
commit
a5840add94
4 changed files with 162 additions and 95 deletions
  1. 4 2
      README.md
  2. 64 21
      demo_app.py
  3. 72 66
      src/api/main.py
  4. 22 6
      src/infrastructure/vision_service.py

+ 4 - 2
README.md

@@ -89,8 +89,10 @@ streamlit run demo_app.py
 
 | Endpoint | Method | Description |
 | :--- | :--- | :--- |
-| `/detect` | `POST` | **Fast Inference**: Returns YOLO detection results (JSON). |
-| `/analyze` | `POST` | **Full Process**: Detection + Vertex AI Vectorization + MongoDB Archival. |
+| `/analyze` | `POST` | **Local Detection**: Returns YOLO results only. Guaranteed to work without Cloud Billing. |
+| `/vectorize_and_store` | `POST` | **Cloud Archival**: Vectorizes a detection and saves to MongoDB Atlas. Requires GCP Billing. |
+| `/process_batch` | `POST` | **Bulk Processor**: Handles multiple images. Detects locally; archives to cloud if available. |
+| `/search_hybrid` | `POST` | **Semantic Search**: Visual similarity or natural language search via Vertex AI. |
 | `/get_confidence` | `GET` | Retrieve the current AI confidence threshold. |
 | `/set_confidence` | `POST` | Update the AI confidence threshold globally. |
 

+ 64 - 21
demo_app.py

@@ -62,24 +62,61 @@ with tab1:
     uploaded_file = st.file_uploader("Upload a bunch image...", type=["jpg", "jpeg", "png"], key="single")
     
     if uploaded_file:
-        col1, col2 = st.columns(2)
-        with col1:
-            st.image(uploaded_file, caption="Input", width=500)
-        
-        with col2:
-            if st.button("Run Full Analysis"):
-                with st.spinner("Processing... (Detecting + Vectorizing + Archiving)"):
-                    files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
-                    res = requests.post(f"{API_BASE_URL}/analyze", files=files)
-                    
-                    if res.status_code == 200:
-                        data = res.json()
-                        st.success(f"✅ Record Archived! ID: {data['record_id']}")
-                        
+        # State initialization
+        if "last_detection" not in st.session_state:
+            st.session_state.last_detection = None
+
+        # 1. Action Button (Centered and Prominent)
+        st.write("##")
+        _, col_btn, _ = st.columns([1, 2, 1])
+        if col_btn.button("🔍 Run Ripeness Detection", type="primary", use_container_width=True):
+            with st.spinner("Processing Detections Locally..."):
+                files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
+                res = requests.post(f"{API_BASE_URL}/analyze", files=files)
+                
+                if res.status_code == 200:
+                    st.session_state.last_detection = res.json()
+                else:
+                    st.error(f"Detection Failed: {res.text}")
+
+        # 2. Results Layout
+        if st.session_state.last_detection:
+            st.divider()
+            col1, col2 = st.columns([1.5, 1])
+            
+            with col1:
+                st.image(uploaded_file, caption="Analyzed Image", use_container_width=True)
+            
+            with col2:
+                data = st.session_state.last_detection
+                with st.container(border=True):
+                    st.write("### 🏷️ Detection Results")
+                    if not data['detections']:
+                        st.warning("No Fresh Fruit Bunches detected.")
+                    else:
                         for det in data['detections']:
                             st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
-                    else:
-                        st.error(f"Analysis Failed: {res.text}")
+                        
+                        # 3. Cloud Actions (Only if detections found)
+                        st.write("---")
+                        st.write("#### ✨ Cloud Archive")
+                        if st.button("🚀 Save to Atlas (Vectorize)", use_container_width=True):
+                            with st.spinner("Archiving..."):
+                                import json
+                                primary_det = data['detections'][0]
+                                payload = {"detection_data": json.dumps(primary_det)}
+                                files_cloud = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
+                                
+                                res_cloud = requests.post(f"{API_BASE_URL}/vectorize_and_store", files=files_cloud, data=payload)
+                                
+                                if res_cloud.status_code == 200:
+                                    res_json = res_cloud.json()
+                                    if res_json["status"] == "success":
+                                        st.success(f"Archived! ID: `{res_json['record_id'][:8]}...`")
+                                    else:
+                                        st.error(f"Cloud Error: {res_json['message']}")
+                                else:
+                                    st.error("Failed to connect to cloud service")
 
 # --- Tab 2: Batch Processing ---
 with tab2:
@@ -123,13 +160,19 @@ with tab2:
         if st.button(f"🚀 Process {len(uploaded_files)} Images"):
             with st.spinner("Batch Processing in progress..."):
                 files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
-                res = requests.post(f"{API_BASE_URL}/analyze_batch", files=files)
+                res = requests.post(f"{API_BASE_URL}/process_batch", files=files)
                 
                 if res.status_code == 200:
-                    # 4. Success: Store results and Clear Uploader automatically
-                    st.session_state.last_batch_results = res.json()
-                    st.session_state.batch_uploader_key += 1
-                    st.rerun()
+                    data = res.json()
+                    if data["status"] == "success":
+                        st.session_state.last_batch_results = data
+                        st.session_state.batch_uploader_key += 1
+                        st.rerun()
+                    elif data["status"] == "partial_success":
+                        st.warning(data["message"])
+                        st.info(f"Successfully detected {data['detections_count']} bunches locally.")
+                    else:
+                        st.error(f"Batch Error: {data['message']}")
                 else:
                     st.error(f"Batch Failed: {res.text}")
 

+ 72 - 66
src/api/main.py

@@ -60,12 +60,12 @@ async def set_confidence(threshold: float = Body(..., embed=True)):
     else:
         return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
 
-@app.post("/detect")
-async def detect_ripeness(file: UploadFile = File(...)):
-    """Simple YOLO detection only. No archival or vectorization."""
+
+@app.post("/analyze")
+async def analyze_only(file: UploadFile = File(...)):
+    """Local YOLO detection only. Guaranteed to work without Billing."""
     image_bytes = await file.read()
     img = Image.open(io.BytesIO(image_bytes))
-    
     results = yolo_model(img, conf=current_conf)
     
     detections = []
@@ -78,56 +78,46 @@ async def detect_ripeness(file: UploadFile = File(...)):
             })
             
     return {
-        "status": "success",
+        "status": "success", 
         "current_threshold": current_conf,
-        "data": detections
+        "detections": detections
     }
 
-@app.post("/analyze")
-async def analyze_ripeness(file: UploadFile = File(...)):
-    """Full analysis: Detection + Vertex Vectorization + MongoDB Archival."""
-    # 1. Save file temporarily for YOLO and Vertex
+@app.post("/vectorize_and_store")
+async def vectorize_and_store(file: UploadFile = File(...), detection_data: str = Form(...)):
+    """Cloud-dependent. Requires active billing."""
+    import json
+    try:
+        primary_detection = json.loads(detection_data)
+    except Exception:
+        return {"status": "error", "message": "Invalid detection_data format"}
+
     unique_id = uuid.uuid4().hex
-    temp_path = f"temp_{unique_id}_{file.filename}"
+    temp_path = f"temp_vec_{unique_id}_{file.filename}"
+    
+    # Reset file pointer since it might have been read (though here it's a new request)
+    # Actually, in a new request, we read it for the first time.
     with open(temp_path, "wb") as buffer:
         shutil.copyfileobj(file.file, buffer)
 
     try:
-        # 2. Run YOLO detection
-        img = Image.open(temp_path)
-        results = yolo_model(img, conf=current_conf)
-        
-        detections = []
-        for r in results:
-            for box in r.boxes:
-                detections.append({
-                    "class": yolo_model.names[int(box.cls)],
-                    "confidence": round(float(box.conf), 2),
-                    "box": box.xyxy.tolist()[0]
-                })
-
-        # 3. If detections found, analyze the first one (primary) for deeper insights
-        if detections:
-            primary_detection = detections[0]
-            record_id = analyze_use_case.execute(temp_path, primary_detection)
-            
-            return {
-                "status": "success",
-                "record_id": record_id,
-                "detections": detections,
-                "message": "FFB analyzed, vectorized, and archived successfully"
-            }
-        
-        return {"status": "no_detection", "message": "No palm oil FFB detected"}
-
+        record_id = analyze_use_case.execute(temp_path, primary_detection)
+        return {
+            "status": "success",
+            "record_id": record_id,
+            "message": "FFB vectorized and archived successfully"
+        }
+    except RuntimeError as e:
+        return {"status": "error", "message": str(e)}
+    except Exception as e:
+        return {"status": "error", "message": f"An unexpected error occurred: {str(e)}"}
     finally:
-        # Clean up temp file
         if os.path.exists(temp_path):
             os.remove(temp_path)
 
-@app.post("/analyze_batch")
-async def analyze_batch(files: List[UploadFile] = File(...)):
-    """Handles multiple images: Detect -> Vectorize -> Store."""
+@app.post("/process_batch")
+async def process_batch(files: List[UploadFile] = File(...)):
+    """Handles multiple images: Detect -> Vectorize -> Store. Graceful handling of cloud errors."""
     batch_results = []
     temp_files = []
 
@@ -135,9 +125,9 @@ async def analyze_batch(files: List[UploadFile] = File(...)):
         for file in files:
             # 1. Save Temp
             unique_id = uuid.uuid4().hex
-            path = f"temp_{unique_id}_{file.filename}"
-            with open(path, "wb") as f:
-                shutil.copyfileobj(file.file, f)
+            path = f"temp_batch_{unique_id}_{file.filename}"
+            with open(path, "wb") as f_out:
+                shutil.copyfileobj(file.file, f_out)
             temp_files.append(path)
 
             # 2. YOLO Detect
@@ -156,14 +146,27 @@ async def analyze_batch(files: List[UploadFile] = File(...)):
                     }
                 })
 
-        # 4. Process Batch Use Case
-        record_ids = analyze_batch_use_case.execute(batch_results)
-        return {
-            "status": "success", 
-            "processed_count": len(record_ids), 
-            "record_ids": record_ids,
-            "message": f"Successfully processed {len(record_ids)} bunches"
-        }
+        if not batch_results:
+            return {"status": "no_detection", "message": "No bunches detected in batch"}
+
+        # 4. Process Batch Use Case with error handling for cloud services
+        try:
+            record_ids = analyze_batch_use_case.execute(batch_results)
+            return {
+                "status": "success", 
+                "processed_count": len(record_ids), 
+                "record_ids": record_ids,
+                "message": f"Successfully processed {len(record_ids)} bunches"
+            }
+        except RuntimeError as e:
+            return {
+                "status": "partial_success",
+                "message": f"Detections completed, but cloud archival failed: {str(e)}",
+                "detections_count": len(batch_results)
+            }
+
+    except Exception as e:
+        return {"status": "error", "message": f"Batch processing failed: {str(e)}"}
 
     finally:
         # 5. Clean up all temp files
@@ -179,19 +182,22 @@ async def search_hybrid(
     """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
     temp_path = None
     try:
-        if file:
-            unique_id = uuid.uuid4().hex
-            temp_path = f"temp_search_{unique_id}_{file.filename}"
-            with open(temp_path, "wb") as buffer:
-                shutil.copyfileobj(file.file, buffer)
-            
-            results = search_use_case.execute(image_path=temp_path, limit=limit)
-        elif text_query:
-            results = search_use_case.execute(text_query=text_query, limit=limit)
-        else:
-            return {"status": "error", "message": "No search input provided"}
-
-        return {"status": "success", "results": results}
+        try:
+            if file:
+                unique_id = uuid.uuid4().hex
+                temp_path = f"temp_search_{unique_id}_{file.filename}"
+                with open(temp_path, "wb") as buffer:
+                    shutil.copyfileobj(file.file, buffer)
+                
+                results = search_use_case.execute(image_path=temp_path, limit=limit)
+            elif text_query:
+                results = search_use_case.execute(text_query=text_query, limit=limit)
+            else:
+                return {"status": "error", "message": "No search input provided"}
+
+            return {"status": "success", "results": results}
+        except RuntimeError as e:
+            return {"status": "error", "message": f"Search unavailable: {str(e)}"}
         
     finally:
         if temp_path and os.path.exists(temp_path):

+ 22 - 6
src/infrastructure/vision_service.py

@@ -8,22 +8,38 @@ from typing import List
 
 class VertexVisionService:
     def __init__(self, project_id: str, location: str):
-        # Ensure credentials are set before init if using service account key
-        # (This is now handled globally in main.py)
-        vertexai.init(project=project_id, location=location)
-        self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
+        self.project_id = project_id
+        self.location = location
+        self._model = None
+
+    def _get_model(self):
+        """Lazy load the model and catch billing/connection errors."""
+        if self._model is None:
+            try:
+                # Ensure credentials are set before init if using service account key
+                # (This is now handled globally in main.py)
+                vertexai.init(project=self.project_id, location=self.location)
+                self._model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
+            except Exception as e:
+                # Log the specific error (e.g., Billing Disabled)
+                print(f"Vertex AI Initialization Failed: {e}")
+                raise RuntimeError("Cloud services (Vectorization/Search) are currently unavailable.")
+        return self._model
 
     def get_image_embedding(self, image_path: str) -> List[float]:
         if not os.path.exists(image_path):
             raise FileNotFoundError(f"Image not found at {image_path}")
+        
+        model = self._get_model() # This will raise the RuntimeError if billing is down
         image = Image.load_from_file(image_path)
         # Standardizing to 1408 dimensions for consistency
-        embeddings = self.model.get_embeddings(image=image, dimension=1408)
+        embeddings = model.get_embeddings(image=image, dimension=1408)
         return embeddings.image_embedding
 
     def get_text_embedding(self, text: str) -> List[float]:
         """Converts text query into a 1408-d vector."""
-        embeddings = self.model.get_embeddings(
+        model = self._get_model()
+        embeddings = model.get_embeddings(
             contextual_text=text, 
             dimension=1408
         )