| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- from typing import List, Optional
- import uuid
- import os
- import shutil
- from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
- from ultralytics import YOLO
- from dotenv import load_dotenv
- import io
- from PIL import Image
- from src.infrastructure.vision_service import VertexVisionService
- from src.infrastructure.repository import MongoPalmOilRepository
- from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
- # Load environment variables
- load_dotenv()
- # Set Google Cloud credentials globally
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
- app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
- # Initialize YOLO model
- yolo_model = YOLO('best.pt')
- # Global state for the confidence threshold
- current_conf = 0.25
- # Initialize DDD Components
- vision_service = VertexVisionService(
- project_id=os.getenv("PROJECT_ID", "your-project-id"),
- location=os.getenv("LOCATION", "us-central1")
- )
- repo = MongoPalmOilRepository(
- uri=os.getenv("MONGO_URI"),
- db_name=os.getenv("DB_NAME", "palm_oil_db"),
- collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
- )
- analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
- analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
- search_use_case = SearchSimilarUseCase(vision_service, repo)
- @app.get("/get_confidence")
- # ... (rest of the code remains same until analyze)
- async def get_confidence():
- """Returns the current confidence threshold used by the model."""
- return {
- "status": "success",
- "current_confidence": current_conf,
- "model_version": "best.pt"
- }
- @app.post("/set_confidence")
- async def set_confidence(threshold: float = Body(..., embed=True)):
- """Updates the confidence threshold globally."""
- global current_conf
- if 0.0 <= threshold <= 1.0:
- current_conf = threshold
- return {"status": "success", "new_confidence": current_conf}
- else:
- return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
- @app.post("/detect")
- async def detect_ripeness(file: UploadFile = File(...)):
- """Simple YOLO detection only. No archival or vectorization."""
- image_bytes = await file.read()
- img = Image.open(io.BytesIO(image_bytes))
-
- results = yolo_model(img, conf=current_conf)
-
- detections = []
- for r in results:
- for box in r.boxes:
- detections.append({
- "class": yolo_model.names[int(box.cls)],
- "confidence": round(float(box.conf), 2),
- "box": box.xyxy.tolist()[0]
- })
-
- return {
- "status": "success",
- "current_threshold": current_conf,
- "data": detections
- }
- @app.post("/analyze")
- async def analyze_ripeness(file: UploadFile = File(...)):
- """Full analysis: Detection + Vertex Vectorization + MongoDB Archival."""
- # 1. Save file temporarily for YOLO and Vertex
- unique_id = uuid.uuid4().hex
- temp_path = f"temp_{unique_id}_{file.filename}"
- with open(temp_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- try:
- # 2. Run YOLO detection
- img = Image.open(temp_path)
- results = yolo_model(img, conf=current_conf)
-
- detections = []
- for r in results:
- for box in r.boxes:
- detections.append({
- "class": yolo_model.names[int(box.cls)],
- "confidence": round(float(box.conf), 2),
- "box": box.xyxy.tolist()[0]
- })
- # 3. If detections found, analyze the first one (primary) for deeper insights
- if detections:
- primary_detection = detections[0]
- record_id = analyze_use_case.execute(temp_path, primary_detection)
-
- return {
- "status": "success",
- "record_id": record_id,
- "detections": detections,
- "message": "FFB analyzed, vectorized, and archived successfully"
- }
-
- return {"status": "no_detection", "message": "No palm oil FFB detected"}
- finally:
- # Clean up temp file
- if os.path.exists(temp_path):
- os.remove(temp_path)
- @app.post("/analyze_batch")
- async def analyze_batch(files: List[UploadFile] = File(...)):
- """Handles multiple images: Detect -> Vectorize -> Store."""
- batch_results = []
- temp_files = []
- try:
- for file in files:
- # 1. Save Temp
- unique_id = uuid.uuid4().hex
- path = f"temp_{unique_id}_{file.filename}"
- with open(path, "wb") as f:
- shutil.copyfileobj(file.file, f)
- temp_files.append(path)
- # 2. YOLO Detect
- img = Image.open(path)
- yolo_res = yolo_model(img, conf=current_conf)
-
- # 3. Take the primary detection per image
- if yolo_res and yolo_res[0].boxes:
- box = yolo_res[0].boxes[0]
- batch_results.append({
- "path": path,
- "yolo": {
- "class": yolo_model.names[int(box.cls)],
- "confidence": float(box.conf),
- "box": box.xyxy.tolist()[0]
- }
- })
- # 4. Process Batch Use Case
- record_ids = analyze_batch_use_case.execute(batch_results)
- return {
- "status": "success",
- "processed_count": len(record_ids),
- "record_ids": record_ids,
- "message": f"Successfully processed {len(record_ids)} bunches"
- }
- finally:
- # 5. Clean up all temp files
- for path in temp_files:
- if os.path.exists(path):
- os.remove(path)
- @app.post("/search_hybrid")
- async def search_hybrid(
- file: Optional[UploadFile] = File(None),
- text_query: Optional[str] = Form(None),
- limit: int = Form(3)
- ):
- """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
- temp_path = None
- try:
- if file:
- unique_id = uuid.uuid4().hex
- temp_path = f"temp_search_{unique_id}_{file.filename}"
- with open(temp_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
-
- results = search_use_case.execute(image_path=temp_path, limit=limit)
- elif text_query:
- results = search_use_case.execute(text_query=text_query, limit=limit)
- else:
- return {"status": "error", "message": "No search input provided"}
- return {"status": "success", "results": results}
-
- finally:
- if temp_path and os.path.exists(temp_path):
- os.remove(temp_path)
- @app.get("/get_image/{record_id}")
- async def get_image(record_id: str):
- """Retrieve the Base64 image data for a specific record."""
- record = repo.get_by_id(record_id)
- if not record:
- return {"status": "error", "message": "Record not found"}
- return {
- "status": "success",
- "image_data": record.get("image_data")
- }
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
|