|
|
@@ -1,5 +1,5 @@
|
|
|
import os
|
|
|
-from fastapi import FastAPI, File, UploadFile
|
|
|
+from fastapi import FastAPI, File, UploadFile, Body
|
|
|
from ultralytics import YOLO
|
|
|
from dotenv import load_dotenv
|
|
|
import io
|
|
|
@@ -18,6 +18,9 @@ app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
|
|
|
# Initialize YOLO model
|
|
|
yolo_model = YOLO('best.pt')
|
|
|
|
|
|
+# Global state for the confidence threshold
|
|
|
+current_conf = 0.25
|
|
|
+
|
|
|
# Initialize DDD Components
|
|
|
vision_service = VertexVisionService(
|
|
|
project_id=os.getenv("PROJECT_ID", "your-project-id"),
|
|
|
@@ -30,6 +33,25 @@ repo = MongoPalmOilRepository(
|
|
|
)
|
|
|
analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
|
|
|
|
|
|
+@app.get("/get_confidence")
|
|
|
+async def get_confidence():
|
|
|
+ """Returns the current confidence threshold used by the model."""
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "current_confidence": current_conf,
|
|
|
+ "model_version": "best.pt"
|
|
|
+ }
|
|
|
+
|
|
|
+@app.post("/set_confidence")
|
|
|
+async def set_confidence(threshold: float = Body(..., embed=True)):
|
|
|
+ """Updates the confidence threshold globally."""
|
|
|
+ global current_conf
|
|
|
+ if 0.0 <= threshold <= 1.0:
|
|
|
+ current_conf = threshold
|
|
|
+ return {"status": "success", "new_confidence": current_conf}
|
|
|
+ else:
|
|
|
+ return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
|
|
|
+
|
|
|
@app.post("/detect")
|
|
|
async def detect_ripeness(file: UploadFile = File(...)):
|
|
|
# 1. Save file temporarily for YOLO and Vertex
|
|
|
@@ -40,7 +62,7 @@ async def detect_ripeness(file: UploadFile = File(...)):
|
|
|
try:
|
|
|
# 2. Run YOLO detection
|
|
|
img = Image.open(temp_path)
|
|
|
- results = yolo_model(img)
|
|
|
+ results = yolo_model(img, conf=current_conf)
|
|
|
|
|
|
detections = []
|
|
|
for r in results:
|