|
|
@@ -61,11 +61,13 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gemini-embedding-service-key.jso
|
|
|
app = FastAPI(title="Palm Oil Ripeness Service (DDD)")
|
|
|
|
|
|
class ModelManager:
|
|
|
- def __init__(self, onnx_path: str, pt_path: str):
|
|
|
+ def __init__(self, onnx_path: str, pt_path: str, benchmark_path: str = 'sawit_tbs.pt'):
|
|
|
self.onnx_session = ort.InferenceSession(onnx_path)
|
|
|
self.onnx_input_name = self.onnx_session.get_inputs()[0].name
|
|
|
self.pt_model = YOLO(pt_path)
|
|
|
self.class_names = self.pt_model.names
|
|
|
+ self.benchmark_model = YOLO(benchmark_path)
|
|
|
+ self.benchmark_class_names = self.benchmark_model.names
|
|
|
|
|
|
def preprocess_onnx(self, img: Image.Image):
|
|
|
img = img.convert("RGB")
|
|
|
@@ -118,17 +120,22 @@ class ModelManager:
|
|
|
raw_sample = detections_batch[0, :5].tolist()
|
|
|
return detections, raw_sample, inference_ms
|
|
|
|
|
|
- def run_pytorch_inference(self, img: Image.Image, conf_threshold: float):
|
|
|
+ def run_pytorch_inference(self, img: Image.Image, conf_threshold: float, engine_type: str = "pytorch"):
|
|
|
import time
|
|
|
start_inf = time.perf_counter()
|
|
|
- results = self.pt_model(img, conf=conf_threshold, verbose=False)
|
|
|
+
|
|
|
+ # Selection Logic for Third Engine
|
|
|
+ model = self.pt_model if engine_type == "pytorch" else self.benchmark_model
|
|
|
+ names = self.class_names if engine_type == "pytorch" else self.benchmark_class_names
|
|
|
+
|
|
|
+ results = model(img, conf=conf_threshold, verbose=False)
|
|
|
end_inf = time.perf_counter()
|
|
|
inference_ms = (end_inf - start_inf) * 1000
|
|
|
|
|
|
detections = []
|
|
|
for i, box in enumerate(results[0].boxes):
|
|
|
class_id = int(box.cls)
|
|
|
- class_name = self.class_names.get(class_id, "Unknown")
|
|
|
+ class_name = names.get(class_id, "Unknown")
|
|
|
detections.append({
|
|
|
"bunch_id": i + 1,
|
|
|
"class": class_name,
|
|
|
@@ -203,7 +210,9 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
|
|
|
start_total = time.perf_counter()
|
|
|
# Select Inference Engine
|
|
|
if model_type == "pytorch":
|
|
|
- detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf)
|
|
|
+ detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "pytorch")
|
|
|
+ elif model_type == "benchmark":
|
|
|
+ detections, raw_sample, inference_ms = model_manager.run_pytorch_inference(img, current_conf, "benchmark")
|
|
|
else:
|
|
|
detections, raw_sample, inference_ms = model_manager.run_onnx_inference(img, current_conf)
|
|
|
|
|
|
@@ -212,7 +221,8 @@ async def analyze_with_health_metrics(file: UploadFile = File(...), model_type:
|
|
|
processing_ms = total_ms - inference_ms
|
|
|
|
|
|
# Initialize summary
|
|
|
- summary = {name: 0 for name in model_manager.class_names.values()}
|
|
|
+ active_names = model_manager.class_names if model_type != "benchmark" else model_manager.benchmark_class_names
|
|
|
+ summary = {name: 0 for name in active_names.values()}
|
|
|
for det in detections:
|
|
|
summary[det['class']] += 1
|
|
|
|