Przeglądaj źródła

update to include mutliple endpoints

Dr-Swopt 3 dni temu
rodzic
commit
6702db105a

+ 106 - 46
demo_app.py

@@ -1,60 +1,120 @@
 import streamlit as st
 import streamlit as st
-from ultralytics import YOLO
+import requests
 from PIL import Image
 from PIL import Image
-import numpy as np
 import io
 import io
+import base64
+
+# --- API Configuration ---
+API_BASE_URL = "http://localhost:8000"
 
 
 # --- Page Config ---
 # --- Page Config ---
 st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
 st.set_page_config(page_title="Palm Oil Ripeness AI", layout="wide")
-st.title("🌴 Palm Oil FFB Ripeness Detector")
-st.markdown("### R&D Proof of Concept: Automated Maturity Grading")
+st.title("🌴 Palm Oil FFB Management System")
+st.markdown("### Production-Ready AI Analysis & Archival")
 
 
-# --- Load Model (Cached for performance) ---
-@st.cache_resource
-def load_model():
-    return YOLO("best.pt")
+# --- Sidebar ---
+st.sidebar.header("Backend Controls")
 
 
-model = load_model()
+def update_confidence():
+    new_conf = st.session_state.conf_slider
+    try:
+        requests.post(f"{API_BASE_URL}/set_confidence", json={"threshold": new_conf})
+        st.toast(f"Threshold updated to {new_conf}")
+    except:
+        st.sidebar.error("Failed to update threshold")
 
 
-# --- Sidebar Controls ---
-st.sidebar.header("Settings")
-conf_threshold = st.sidebar.slider("Confidence Threshold", 0.1, 1.0, 0.5)
+try:
+    response = requests.get(f"{API_BASE_URL}/get_confidence")
+    if response.status_code == 200:
+        current_conf = response.json().get("current_confidence", 0.25)
+        st.sidebar.success(f"Connected to API")
+        
+        # Synchronized Slider
+        st.sidebar.slider(
+            "Confidence Threshold", 
+            0.1, 1.0, 
+            value=float(current_conf),
+            key="conf_slider",
+            on_change=update_confidence
+        )
+    else:
+        st.sidebar.error("API Error")
+except:
+    st.sidebar.error("Could not connect to Backend API. Please ensure it is running.")
 
 
-# --- Image Upload ---
-uploaded_file = st.file_uploader("Drag and drop a Palm Oil FFB image here...", type=["jpg", "jpeg", "png"])
+# --- Tabs ---
+tab1, tab2, tab3 = st.tabs(["Single Analysis", "Batch Processing", "Similarity Search"])
 
 
-if uploaded_file is not None:
-    # Convert uploaded file to PIL Image
-    image = Image.open(uploaded_file)
-    
-    # Layout: Original vs Predicted
-    col1, col2 = st.columns(2)
+# --- Tab 1: Single Analysis ---
+with tab1:
+    st.subheader("Analyze Single Bunch")
+    uploaded_file = st.file_uploader("Upload a bunch image...", type=["jpg", "jpeg", "png"], key="single")
     
     
-    with col1:
-        st.image(image, caption="Uploaded Image", use_container_width=True)
+    if uploaded_file:
+        col1, col2 = st.columns(2)
+        with col1:
+            st.image(uploaded_file, caption="Input", use_container_width=True)
         
         
-    with col2:
-        with st.spinner('Analyzing ripeness...'):
-            # Run Inference
-            results = model.predict(source=image, conf=conf_threshold)
-            
-            # The .plot() method returns a BGR numpy array with boxes drawn
-            annotated_img = results[0].plot()
-            
-            # Convert BGR (OpenCV format) to RGB (Streamlit/PIL format)
-            annotated_img_rgb = annotated_img[:, :, ::-1]
-            
-            st.image(annotated_img_rgb, caption="AI Analysis Result", use_container_width=True)
-
-    # --- Metrics Section ---
-    st.divider()
-    st.subheader("Analysis Summary")
+        with col2:
+            if st.button("Run Full Analysis"):
+                with st.spinner("Processing... (Detecting + Vectorizing + Archiving)"):
+                    files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
+                    res = requests.post(f"{API_BASE_URL}/analyze", files=files)
+                    
+                    if res.status_code == 200:
+                        data = res.json()
+                        st.success(f"✅ Record Archived! ID: {data['record_id']}")
+                        
+                        for det in data['detections']:
+                            st.info(f"**{det['class']}** - {det['confidence']:.2%} confidence")
+                    else:
+                        st.error(f"Analysis Failed: {res.text}")
+
+# --- Tab 2: Batch Processing ---
+with tab2:
+    st.subheader("Bulk Analysis")
+    uploaded_files = st.file_uploader("Upload multiple images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True, key="batch")
     
     
-    detections = results[0].boxes
-    if len(detections) > 0:
-        for box in detections:
-            label = model.names[int(box.cls)]
-            conf = float(box.conf)
-            st.success(f"**Detected:** {label} | **Confidence:** {conf:.2%}")
-    else:
-        st.warning("No fruit bunches detected. Try adjusting the confidence slider.")
+    if uploaded_files:
+        if st.button(f"Process {len(uploaded_files)} Images"):
+            with st.spinner("Batch Processing in progress..."):
+                files = [("files", (f.name, f.getvalue(), f.type)) for f in uploaded_files]
+                res = requests.post(f"{API_BASE_URL}/analyze_batch", files=files)
+                
+                if res.status_code == 200:
+                    data = res.json()
+                    st.success(f"Successfully processed {data['processed_count']} images.")
+                    st.write("Generated Record IDs:")
+                    st.code(data['record_ids'])
+                else:
+                    st.error("Batch Failed")
+
+# --- Tab 3: Similarity Search ---
+with tab3:
+    st.subheader("Atlas Vector Search")
+    st.markdown("Upload an image to find the most similar historical records in the database.")
+    search_file = st.file_uploader("Search Image...", type=["jpg", "jpeg", "png"], key="search")
+    
+    if search_file:
+        st.image(search_file, width=300)
+        if st.button("Find Similar Bunches"):
+            with st.spinner("Searching Vector Index..."):
+                files = {"file": (search_file.name, search_file.getvalue(), search_file.type)}
+                res = requests.post(f"{API_BASE_URL}/search", files=files)
+                
+                if res.status_code == 200:
+                    results = res.json().get("results", [])
+                    if not results:
+                        st.warning("No similar records found.")
+                    else:
+                        for item in results:
+                            with st.container(border=True):
+                                c1, c2 = st.columns([1, 2])
+                                # Note: Actual prod app would fetch image_data by id here
+                                # For demo, we show the textual metadata
+                                with c2:
+                                    st.write(f"**Class:** {item['ripeness_class']}")
+                                    st.write(f"**Similarity Score:** {item['score']:.4f}")
+                                    st.write(f"**Timestamp:** {item['timestamp']}")
+                else:
+                    st.error("Search failed")

+ 1 - 0
main.py

@@ -5,4 +5,5 @@ if __name__ == "__main__":
     # This file serves as a root-level wrapper for the DDD transition.
     # This file serves as a root-level wrapper for the DDD transition.
     # It redirects execution to the new API entry point in src/api/main.py.
     # It redirects execution to the new API entry point in src/api/main.py.
     print("Redirecting to DDD Architecture Entry Point (src.api.main)...")
     print("Redirecting to DDD Architecture Entry Point (src.api.main)...")
+    print("Starting server... http://localhost:8000/docs")
     uvicorn.run(app, host="0.0.0.0", port=8000)
     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 77 - 3
src/api/main.py

@@ -1,18 +1,23 @@
+from typing import List, Optional
+import uuid
 import os
 import os
+import shutil
 from fastapi import FastAPI, File, UploadFile, Body
 from fastapi import FastAPI, File, UploadFile, Body
 from ultralytics import YOLO
 from ultralytics import YOLO
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 import io
 import io
-import shutil
 from PIL import Image
 from PIL import Image
 
 
 from src.infrastructure.vision_service import VertexVisionService
 from src.infrastructure.vision_service import VertexVisionService
 from src.infrastructure.repository import MongoPalmOilRepository
 from src.infrastructure.repository import MongoPalmOilRepository
-from src.application.analyze_bunch import AnalyzeBunchUseCase
+from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
 
 
 # Load environment variables
 # Load environment variables
 load_dotenv()
 load_dotenv()
 
 
+# Set Google Cloud credentials globally
+os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
+
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 
 
 # Initialize YOLO model
 # Initialize YOLO model
@@ -32,8 +37,11 @@ repo = MongoPalmOilRepository(
     collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
     collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
 )
 )
 analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
 analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
+analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
+search_use_case = SearchSimilarUseCase(vision_service, repo)
 
 
 @app.get("/get_confidence")
 @app.get("/get_confidence")
+# ... (rest of the code remains same until analyze)
 async def get_confidence():
 async def get_confidence():
     """Returns the current confidence threshold used by the model."""
     """Returns the current confidence threshold used by the model."""
     return {
     return {
@@ -79,7 +87,8 @@ async def detect_ripeness(file: UploadFile = File(...)):
 async def analyze_ripeness(file: UploadFile = File(...)):
 async def analyze_ripeness(file: UploadFile = File(...)):
     """Full analysis: Detection + Vertex Vectorization + MongoDB Archival."""
     """Full analysis: Detection + Vertex Vectorization + MongoDB Archival."""
     # 1. Save file temporarily for YOLO and Vertex
     # 1. Save file temporarily for YOLO and Vertex
-    temp_path = f"temp_{file.filename}"
+    unique_id = uuid.uuid4().hex
+    temp_path = f"temp_{unique_id}_{file.filename}"
     with open(temp_path, "wb") as buffer:
     with open(temp_path, "wb") as buffer:
         shutil.copyfileobj(file.file, buffer)
         shutil.copyfileobj(file.file, buffer)
 
 
@@ -116,6 +125,71 @@ async def analyze_ripeness(file: UploadFile = File(...)):
         if os.path.exists(temp_path):
         if os.path.exists(temp_path):
             os.remove(temp_path)
             os.remove(temp_path)
 
 
+@app.post("/analyze_batch")
+async def analyze_batch(files: List[UploadFile] = File(...)):
+    """Handles multiple images: Detect -> Vectorize -> Store."""
+    batch_results = []
+    temp_files = []
+
+    try:
+        for file in files:
+            # 1. Save Temp
+            unique_id = uuid.uuid4().hex
+            path = f"temp_{unique_id}_{file.filename}"
+            with open(path, "wb") as f:
+                shutil.copyfileobj(file.file, f)
+            temp_files.append(path)
+
+            # 2. YOLO Detect
+            img = Image.open(path)
+            yolo_res = yolo_model(img, conf=current_conf)
+            
+            # 3. Take the primary detection per image
+            if yolo_res and yolo_res[0].boxes:
+                box = yolo_res[0].boxes[0]
+                batch_results.append({
+                    "path": path,
+                    "yolo": {
+                        "class": yolo_model.names[int(box.cls)],
+                        "confidence": float(box.conf),
+                        "box": box.xyxy.tolist()[0]
+                    }
+                })
+
+        # 4. Process Batch Use Case
+        record_ids = analyze_batch_use_case.execute(batch_results)
+        return {
+            "status": "success", 
+            "processed_count": len(record_ids), 
+            "record_ids": record_ids,
+            "message": f"Successfully processed {len(record_ids)} bunches"
+        }
+
+    finally:
+        # 5. Clean up all temp files
+        for path in temp_files:
+            if os.path.exists(path):
+                os.remove(path)
+
+@app.post("/search")
+async def search_similar(file: UploadFile = File(...), limit: int = 3):
+    """Atlas Vector Search: Find similar palm oil bunches by image."""
+    unique_id = uuid.uuid4().hex
+    temp_path = f"temp_search_{unique_id}_{file.filename}"
+    
+    with open(temp_path, "wb") as buffer:
+        shutil.copyfileobj(file.file, buffer)
+
+    try:
+        results = search_use_case.execute(temp_path, limit=limit)
+        return {
+            "status": "success",
+            "results": results
+        }
+    finally:
+        if os.path.exists(temp_path):
+            os.remove(temp_path)
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     import uvicorn
     import uvicorn
     uvicorn.run(app, host="0.0.0.0", port=8000)
     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 49 - 3
src/application/analyze_bunch.py

@@ -11,14 +11,60 @@ class AnalyzeBunchUseCase:
         # 1. Get the visual fingerprint
         # 1. Get the visual fingerprint
         vector = self.vision_service.get_image_embedding(image_path)
         vector = self.vision_service.get_image_embedding(image_path)
 
 
-        # 2. Create the Domain Entity
+        # 2. Resize and Convert Image to Base64 String via Vision Service
+        encoded_string = self.vision_service.encode_image_to_base64(image_path)
+
+        # 3. Create the Domain Entity with Image Data
         bunch = PalmOilBunch(
         bunch = PalmOilBunch(
             ripeness_class=yolo_result['class'],
             ripeness_class=yolo_result['class'],
             confidence=yolo_result['confidence'],
             confidence=yolo_result['confidence'],
             embedding=vector,
             embedding=vector,
-            box=yolo_result['box']
+            box=yolo_result['box'],
+            image_data=encoded_string
         )
         )
 
 
-        # 3. Persist to "Memory"
+        # 4. Persist to MongoDB
         record_id = self.repo.save(bunch)
         record_id = self.repo.save(bunch)
         return record_id
         return record_id
+
+class AnalyzeBatchUseCase:
+    def __init__(self, vision_service: VertexVisionService, repo: MongoPalmOilRepository):
+        self.vision_service = vision_service
+        self.repo = repo
+
+    def execute(self, image_results: list):
+        """
+        image_results: List of dicts {'path': str, 'yolo': dict}
+        """
+        processed_bunches = []
+        
+        for item in image_results:
+            # 1. Vectorize
+            vector = self.vision_service.get_image_embedding(item['path'])
+            # 2. Encode Image
+            img_b64 = self.vision_service.encode_image_to_base64(item['path'])
+            
+            # 3. Create Domain Model
+            bunch = PalmOilBunch(
+                ripeness_class=item['yolo']['class'],
+                confidence=item['yolo']['confidence'],
+                embedding=vector,
+                box=item['yolo']['box'],
+                image_data=img_b64
+            )
+            processed_bunches.append(bunch)
+
+        # 4. Bulk Save
+        return self.repo.save_many(processed_bunches)
+
+class SearchSimilarUseCase:
+    def __init__(self, vision_service: VertexVisionService, repo: MongoPalmOilRepository):
+        self.vision_service = vision_service
+        self.repo = repo
+
+    def execute(self, image_path: str, limit: int = 3):
+        # 1. Vectorize the query image
+        query_vector = self.vision_service.get_image_embedding(image_path)
+        
+        # 2. Perform vector search in repository
+        return self.repo.vector_search(query_vector, limit)

+ 1 - 0
src/domain/models.py

@@ -8,5 +8,6 @@ class PalmOilBunch:
     confidence: float
     confidence: float
     embedding: List[float]
     embedding: List[float]
     box: List[float]
     box: List[float]
+    image_data: str
     timestamp: datetime = field(default_factory=datetime.now)
     timestamp: datetime = field(default_factory=datetime.now)
     id: Optional[str] = None
     id: Optional[str] = None

+ 36 - 0
src/infrastructure/repository.py

@@ -15,3 +15,39 @@ class MongoPalmOilRepository:
         
         
         result = self.collection.insert_one(doc)
         result = self.collection.insert_one(doc)
         return str(result.inserted_id)
         return str(result.inserted_id)
+
+    def save_many(self, bunches: list):
+        """Bulk insert palm oil records."""
+        docs = []
+        for bunch in bunches:
+            doc = bunch.__dict__.copy()
+            if doc.get('id') is None:
+                doc.pop('id')
+            docs.append(doc)
+        
+        if docs:
+            result = self.collection.insert_many(docs)
+            return [str(i) for i in result.inserted_ids]
+        return []
+
+    def vector_search(self, query_vector: list, limit: int = 3):
+        """Atlas Vector Search using the 1408-D index."""
+        pipeline = [
+            {
+                "$vectorSearch": {
+                    "index": "vector_index", 
+                    "path": "embedding",
+                    "queryVector": query_vector,
+                    "numCandidates": limit * 10,
+                    "limit": limit
+                }
+            },
+            {
+                "$project": {
+                    "embedding": 0, 
+                    "image_data": 0, # Exclude for speed; fetch by ID if needed
+                    "score": {"$meta": "vectorSearchScore"}
+                }
+            }
+        ]
+        return list(self.collection.aggregate(pipeline))

+ 14 - 1
src/infrastructure/vision_service.py

@@ -1,17 +1,30 @@
 import os
 import os
 import vertexai
 import vertexai
 from vertexai.vision_models import Image, MultiModalEmbeddingModel
 from vertexai.vision_models import Image, MultiModalEmbeddingModel
+import base64
+import io
+from PIL import Image as PILImage
 from typing import List
 from typing import List
 
 
 class VertexVisionService:
 class VertexVisionService:
     def __init__(self, project_id: str, location: str):
     def __init__(self, project_id: str, location: str):
         # Ensure credentials are set before init if using service account key
         # Ensure credentials are set before init if using service account key
-        # os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-service-key.json" 
+        # (This is now handled globally in main.py)
         vertexai.init(project=project_id, location=location)
         vertexai.init(project=project_id, location=location)
         self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
         self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
 
 
     def get_image_embedding(self, image_path: str) -> List[float]:
     def get_image_embedding(self, image_path: str) -> List[float]:
+        if not os.path.exists(image_path):
+            raise FileNotFoundError(f"Image not found at {image_path}")
         image = Image.load_from_file(image_path)
         image = Image.load_from_file(image_path)
         # Standardizing to 1408 dimensions for consistency
         # Standardizing to 1408 dimensions for consistency
         embeddings = self.model.get_embeddings(image=image, dimension=1408)
         embeddings = self.model.get_embeddings(image=image, dimension=1408)
         return embeddings.image_embedding
         return embeddings.image_embedding
+
+    def encode_image_to_base64(self, image_path: str) -> str:
+        """Resizes image to 640x640 and encodes to Base64."""
+        with PILImage.open(image_path) as img:
+            img.thumbnail((640, 640))
+            buffered = io.BytesIO()
+            img.save(buffered, format="JPEG", quality=85)
+            return base64.b64encode(buffered.getvalue()).decode('utf-8')