main.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from typing import List, Optional
  2. import uuid
  3. import os
  4. import shutil
  5. from fastapi import FastAPI, File, UploadFile, Body, Form, BackgroundTasks
  6. from ultralytics import YOLO
  7. from dotenv import load_dotenv
  8. import io
  9. from PIL import Image
  10. from src.infrastructure.vision_service import VertexVisionService
  11. from src.infrastructure.repository import MongoPalmOilRepository
  12. from src.application.analyze_bunch import AnalyzeBunchUseCase, AnalyzeBatchUseCase, SearchSimilarUseCase
  13. # Load environment variables
  14. load_dotenv()
  15. # Set Google Cloud credentials globally
  16. os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.json"
  17. app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
  18. # Initialize YOLO model
  19. yolo_model = YOLO('best.pt')
  20. # Global state for the confidence threshold
  21. current_conf = 0.25
  22. # Initialize DDD Components
  23. vision_service = VertexVisionService(
  24. project_id=os.getenv("PROJECT_ID", "your-project-id"),
  25. location=os.getenv("LOCATION", "us-central1")
  26. )
  27. repo = MongoPalmOilRepository(
  28. uri=os.getenv("MONGO_URI"),
  29. db_name=os.getenv("DB_NAME", "palm_oil_db"),
  30. collection_name=os.getenv("COLLECTION_NAME", "ffb_records")
  31. )
  32. analyze_use_case = AnalyzeBunchUseCase(vision_service, repo)
  33. analyze_batch_use_case = AnalyzeBatchUseCase(vision_service, repo)
  34. search_use_case = SearchSimilarUseCase(vision_service, repo)
  35. @app.get("/get_confidence")
  36. # ... (rest of the code remains same until analyze)
  37. async def get_confidence():
  38. """Returns the current confidence threshold used by the model."""
  39. return {
  40. "status": "success",
  41. "current_confidence": current_conf,
  42. "model_version": "best.pt"
  43. }
  44. @app.post("/set_confidence")
  45. async def set_confidence(threshold: float = Body(..., embed=True)):
  46. """Updates the confidence threshold globally."""
  47. global current_conf
  48. if 0.0 <= threshold <= 1.0:
  49. current_conf = threshold
  50. return {"status": "success", "new_confidence": current_conf}
  51. else:
  52. return {"status": "error", "message": "Threshold must be between 0.0 and 1.0"}
  53. @app.post("/detect")
  54. async def detect_ripeness(file: UploadFile = File(...)):
  55. """Simple YOLO detection only. No archival or vectorization."""
  56. image_bytes = await file.read()
  57. img = Image.open(io.BytesIO(image_bytes))
  58. results = yolo_model(img, conf=current_conf)
  59. detections = []
  60. for r in results:
  61. for box in r.boxes:
  62. detections.append({
  63. "class": yolo_model.names[int(box.cls)],
  64. "confidence": round(float(box.conf), 2),
  65. "box": box.xyxy.tolist()[0]
  66. })
  67. return {
  68. "status": "success",
  69. "current_threshold": current_conf,
  70. "data": detections
  71. }
  72. @app.post("/analyze")
  73. async def analyze_ripeness(file: UploadFile = File(...)):
  74. """Full analysis: Detection + Vertex Vectorization + MongoDB Archival."""
  75. # 1. Save file temporarily for YOLO and Vertex
  76. unique_id = uuid.uuid4().hex
  77. temp_path = f"temp_{unique_id}_{file.filename}"
  78. with open(temp_path, "wb") as buffer:
  79. shutil.copyfileobj(file.file, buffer)
  80. try:
  81. # 2. Run YOLO detection
  82. img = Image.open(temp_path)
  83. results = yolo_model(img, conf=current_conf)
  84. detections = []
  85. for r in results:
  86. for box in r.boxes:
  87. detections.append({
  88. "class": yolo_model.names[int(box.cls)],
  89. "confidence": round(float(box.conf), 2),
  90. "box": box.xyxy.tolist()[0]
  91. })
  92. # 3. If detections found, analyze the first one (primary) for deeper insights
  93. if detections:
  94. primary_detection = detections[0]
  95. record_id = analyze_use_case.execute(temp_path, primary_detection)
  96. return {
  97. "status": "success",
  98. "record_id": record_id,
  99. "detections": detections,
  100. "message": "FFB analyzed, vectorized, and archived successfully"
  101. }
  102. return {"status": "no_detection", "message": "No palm oil FFB detected"}
  103. finally:
  104. # Clean up temp file
  105. if os.path.exists(temp_path):
  106. os.remove(temp_path)
  107. @app.post("/analyze_batch")
  108. async def analyze_batch(files: List[UploadFile] = File(...)):
  109. """Handles multiple images: Detect -> Vectorize -> Store."""
  110. batch_results = []
  111. temp_files = []
  112. try:
  113. for file in files:
  114. # 1. Save Temp
  115. unique_id = uuid.uuid4().hex
  116. path = f"temp_{unique_id}_{file.filename}"
  117. with open(path, "wb") as f:
  118. shutil.copyfileobj(file.file, f)
  119. temp_files.append(path)
  120. # 2. YOLO Detect
  121. img = Image.open(path)
  122. yolo_res = yolo_model(img, conf=current_conf)
  123. # 3. Take the primary detection per image
  124. if yolo_res and yolo_res[0].boxes:
  125. box = yolo_res[0].boxes[0]
  126. batch_results.append({
  127. "path": path,
  128. "yolo": {
  129. "class": yolo_model.names[int(box.cls)],
  130. "confidence": float(box.conf),
  131. "box": box.xyxy.tolist()[0]
  132. }
  133. })
  134. # 4. Process Batch Use Case
  135. record_ids = analyze_batch_use_case.execute(batch_results)
  136. return {
  137. "status": "success",
  138. "processed_count": len(record_ids),
  139. "record_ids": record_ids,
  140. "message": f"Successfully processed {len(record_ids)} bunches"
  141. }
  142. finally:
  143. # 5. Clean up all temp files
  144. for path in temp_files:
  145. if os.path.exists(path):
  146. os.remove(path)
  147. @app.post("/search_hybrid")
  148. async def search_hybrid(
  149. file: Optional[UploadFile] = File(None),
  150. text_query: Optional[str] = Form(None),
  151. limit: int = Form(3)
  152. ):
  153. """Hybrid Search: Supports Visual Similarity and Natural Language Search."""
  154. temp_path = None
  155. try:
  156. if file:
  157. unique_id = uuid.uuid4().hex
  158. temp_path = f"temp_search_{unique_id}_{file.filename}"
  159. with open(temp_path, "wb") as buffer:
  160. shutil.copyfileobj(file.file, buffer)
  161. results = search_use_case.execute(image_path=temp_path, limit=limit)
  162. elif text_query:
  163. results = search_use_case.execute(text_query=text_query, limit=limit)
  164. else:
  165. return {"status": "error", "message": "No search input provided"}
  166. return {"status": "success", "results": results}
  167. finally:
  168. if temp_path and os.path.exists(temp_path):
  169. os.remove(temp_path)
  170. @app.get("/get_image/{record_id}")
  171. async def get_image(record_id: str):
  172. """Retrieve the Base64 image data for a specific record."""
  173. record = repo.get_by_id(record_id)
  174. if not record:
  175. return {"status": "error", "message": "Record not found"}
  176. return {
  177. "status": "success",
  178. "image_data": record.get("image_data")
  179. }
  180. if __name__ == "__main__":
  181. import uvicorn
  182. uvicorn.run(app, host="0.0.0.0", port=8000)