Parcourir la source

Refactor: Redirect root main.py to src.api.main and migrate confidence features

Dr-Swopt il y a 3 jours
Parent
commit
14213a73a9
2 fichiers modifiés avec 29 ajouts et 58 suppressions
  1. 5 56
      main.py
  2. 24 2
      src/api/main.py

+ 5 - 56
main.py

@@ -1,59 +1,8 @@
-from fastapi import FastAPI, File, UploadFile, Body
-from ultralytics import YOLO
-import io
-from PIL import Image
-
-app = FastAPI()
-
-# 1. Load your custom trained model
-model = YOLO('best.pt') 
-
-# 2. Global state for the confidence threshold
-# Defaulting to 0.25 (YOLO's internal default)
-current_conf = 0.25
-
-@app.get("/get_confidence")
-async def get_confidence():
-    """Returns the current confidence threshold used by the model."""
-    return {
-        "status": "success",
-        "current_confidence": current_conf,
-        "model_version": "best.pt"
-    }
-
-@app.post("/set_confidence")
-async def set_confidence(threshold: float = Body(..., embed=True)):
-    """Updates the confidence threshold globally."""
-    global current_conf
-    if 0.0 <= threshold <= 1.0:
-        current_conf = threshold
-        return {"status": "success", "new_confidence": current_conf}
-    else:
-        return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
-
-@app.post("/detect")
-async def detect_ripeness(file: UploadFile = File(...)):
-    image_bytes = await file.read()
-    img = Image.open(io.BytesIO(image_bytes))
-
-    # 3. Apply the dynamic threshold to the inference
-    results = model(img, conf=current_conf)
-
-    detections = []
-    for r in results:
-        for box in r.boxes:
-            detections.append({
-                "class": model.names[int(box.cls)],
-                "confidence": round(float(box.conf), 2),
-                "box": box.xyxy.tolist()[0]
-            })
-
-    return {
-        "status": "success", 
-        "current_threshold": current_conf,
-        "data": detections
-    }
+import uvicorn
+from src.api.main import app
 
 if __name__ == "__main__":
-    import uvicorn
+    # This file serves as a root-level wrapper for the DDD transition.
+    # It redirects execution to the new API entry point in src/api/main.py.
+    print("Redirecting to DDD Architecture Entry Point (src.api.main)...")
     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 24 - 2
src/api/main.py

@@ -1,5 +1,5 @@
 import os
-from fastapi import FastAPI, File, UploadFile
+from fastapi import FastAPI, File, UploadFile, Body
 from ultralytics import YOLO
 from dotenv import load_dotenv
 import io
@@ -18,6 +18,9 @@ app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
 # Initialize YOLO model
 yolo_model = YOLO('best.pt')
 
+# Global state for the confidence threshold
+current_conf = 0.25
+
 # Initialize DDD Components
 vision_service = VertexVisionService(
     project_id=os.getenv("PROJECT_ID", "your-project-id"),
@@ -30,6 +33,25 @@ repo = MongoPalmOilRepository(
 )
 analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
 
+@app.get("/get_confidence")
+async def get_confidence():
+    """Returns the current confidence threshold used by the model."""
+    return {
+        "status": "success",
+        "current_confidence": current_conf,
+        "model_version": "best.pt"
+    }
+
+@app.post("/set_confidence")
+async def set_confidence(threshold: float = Body(..., embed=True)):
+    """Updates the confidence threshold globally."""
+    global current_conf
+    if 0.0 <= threshold <= 1.0:
+        current_conf = threshold
+        return {"status": "success", "new_confidence": current_conf}
+    else:
+        return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
+
 @app.post("/detect")
 async def detect_ripeness(file: UploadFile = File(...)):
     # 1. Save file temporarily for YOLO and Vertex
@@ -40,7 +62,7 @@ async def detect_ripeness(file: UploadFile = File(...)):
     try:
         # 2. Run YOLO detection
         img = Image.open(temp_path)
-        results = yolo_model(img)
+        results = yolo_model(img, conf=current_conf)
         
         detections = []
         for r in results: