소스 검색

Refactor: Implement Domain-Driven Design (DDD) with Vertex AI and MongoDB integration

Dr-Swopt 3 일 전
부모
커밋
930fbc5999

+ 9 - 0
.env

@@ -0,0 +1,9 @@
+GOOGLE_API_KEY=AIzaSyDCnA7edktfcz77N-NobqVmV03XJvBRb1A
+MONGO_URI=mongodb+srv://devai_db_user:UBPOIDS6L6Baqe6o@cluster0.znkvqif.mongodb.net/
+PROJECT_ID=gen-lang-client-0904742181
+LOCATION=us-central1
+DB_NAME=VectorDB
+COLLECTION_NAME=Palm Oil
+GOOGLE_APPLICATION_CREDENTIALS=gemini-embedding-service-key.json
+GEMINI_EMBEDDING_MODEL=multimodalembedding@001
+GOOGLE_APPLICATION_CREDENTIALS=gemini-embedding-service-key.json

+ 11 - 17
README.md

@@ -46,29 +46,23 @@ python train_script.py
 
 4. Copy the resulting `best.pt` from `runs/detect/train/weights/` to the project root.
 
-### 3. Running the Demo (Streamlit)
+### Running the API Server (DDD Structure)
 
-To show the interactive dashboard to colleagues:
-
-```bash
-streamlit run demo_app.py
+The new architecture decouples the vision logic from the API entry point.
 
+```powershell
+# Run the FastAPI server from the src directory
+python -m src.api.main
 ```
+By default, the server runs on `http://localhost:8000`.
 
-* **Local URL:** `http://localhost:8501`
-
-### 4. Running the API for n8n
-
-To connect your AI to n8n workflows:
+### Running the Streamlit Dashboard
 
-```bash
-python main.py
-
-```
-
-* **Endpoint:** `POST http://localhost:8000/detect`
-* **Payload:** Form-data with key `file`.
+The Streamlit app still provides the user interface for manual testing.
 
+```powershell
+# Run the Streamlit app
+streamlit run demo_app.py
 ---
 
 ## 📂 Repository Structure

+ 0 - 0
src/__init__.py


+ 0 - 0
src/api/__init__.py


+ 76 - 0
src/api/main.py

@@ -0,0 +1,76 @@
+import os
+from fastapi import FastAPI, File, UploadFile
+from ultralytics import YOLO
+from dotenv import load_dotenv
+import io
+import shutil
+from PIL import Image
+
+from src.infrastructure.vision_service import VertexVisionService
+from src.infrastructure.repository import MongoPalmOilRepository
+from src.application.analyze_bunch import AnalyzeBunchUseCase
+
+# Load environment variables
+load_dotenv()
+
+app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
+
+# Initialize YOLO model
+yolo_model = YOLO('best.pt')
+
+# Initialize DDD Components
+vision_service = VertexVisionService(
+    project_id=os.getenv("PROJECT_ID", "your-project-id"),
+    location=os.getenv("LOCATION", "us-central1")
+)
+repo = MongoPalmOilRepository(
+    uri=os.getenv("MONGO_URI"),
+    db_name=os.getenv("DB_NAME", "palm_oil_db")
+)
+analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
+
+@app.post("/detect")
+async def detect_ripeness(file: UploadFile = File(...)):
+    # 1. Save file temporarily for YOLO and Vertex
+    temp_path = f"temp_{file.filename}"
+    with open(temp_path, "wb") as buffer:
+        shutil.copyfileobj(file.file, buffer)
+
+    try:
+        # 2. Run YOLO detection
+        img = Image.open(temp_path)
+        results = yolo_model(img)
+        
+        detections = []
+        for r in results:
+            for box in r.boxes:
+                detections.append({
+                    "class": yolo_model.names[int(box.cls)],
+                    "confidence": round(float(box.conf), 2),
+                    "box": box.xyxy.tolist()[0]
+                })
+
+        # 3. If detections found, analyze the first one (or all) for deeper insights
+        results_summary = []
+        if detections:
+            # For this MVP, we analyze the primary detection
+            primary_detection = detections[0]
+            record_id = analyze_use_case.execute(temp_path, primary_detection)
+            
+            return {
+                "status": "success",
+                "record_id": record_id,
+                "detections": detections,
+                "message": "FFB processed and archived successfully"
+            }
+        
+        return {"status": "no_detection", "message": "No palm oil FFB detected"}
+
+    finally:
+        # Clean up temp file
+        if os.path.exists(temp_path):
+            os.remove(temp_path)
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8000)

+ 0 - 0
src/application/__init__.py


+ 24 - 0
src/application/analyze_bunch.py

@@ -0,0 +1,24 @@
+from src.infrastructure.vision_service import VertexVisionService
+from src.infrastructure.repository import MongoPalmOilRepository
+from src.domain.models import PalmOilBunch
+
+class AnalyzeBunchUseCase:
+    def __init__(self, vision_service: VertexVisionService, repo: MongoPalmOilRepository):
+        self.vision_service = vision_service
+        self.repo = repo
+
+    def execute(self, image_path: str, yolo_result: dict):
+        # 1. Get the visual fingerprint
+        vector = self.vision_service.get_image_embedding(image_path)
+
+        # 2. Create the Domain Entity
+        bunch = PalmOilBunch(
+            ripeness_class=yolo_result['class'],
+            confidence=yolo_result['confidence'],
+            embedding=vector,
+            box=yolo_result['box']
+        )
+
+        # 3. Persist to "Memory"
+        record_id = self.repo.save(bunch)
+        return record_id

+ 0 - 0
src/domain/__init__.py


+ 12 - 0
src/domain/models.py

@@ -0,0 +1,12 @@
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import List, Optional
+
+@dataclass
+class PalmOilBunch:
+    ripeness_class: str
+    confidence: float
+    embedding: List[float]
+    box: List[float]
+    timestamp: datetime = field(default_factory=datetime.now)
+    id: Optional[str] = None

+ 0 - 0
src/infrastructure/__init__.py


+ 17 - 0
src/infrastructure/repository.py

@@ -0,0 +1,17 @@
+from pymongo import MongoClient
+from src.domain.models import PalmOilBunch
+
+class MongoPalmOilRepository:
+    def __init__(self, uri: str, db_name: str):
+        self.client = MongoClient(uri)
+        self.collection = self.client[db_name]["ffb_records"]
+
+    def save(self, bunch: PalmOilBunch):
+        # Convert dataclass to dict for MongoDB
+        doc = bunch.__dict__.copy()
+        # Remove id if it's None to let Mongo generate it
+        if doc.get('id') is None:
+            doc.pop('id')
+        
+        result = self.collection.insert_one(doc)
+        return str(result.inserted_id)

+ 17 - 0
src/infrastructure/vision_service.py

@@ -0,0 +1,17 @@
+import os
+import vertexai
+from vertexai.vision_models import Image, MultiModalEmbeddingModel
+from typing import List
+
+class VertexVisionService:
+    def __init__(self, project_id: str, location: str):
+        # Ensure credentials are set before init if using service account key
+        # os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-service-key.json" 
+        vertexai.init(project=project_id, location=location)
+        self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
+
+    def get_image_embedding(self, image_path: str) -> List[float]:
+        image = Image.load_from_file(image_path)
+        # Standardizing to 1408 dimensions for consistency
+        embeddings = self.model.get_embeddings(image=image, dimension=1408)
+        return embeddings.image_embedding