test_benchmark.py 1.1 KB

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. import sys
  3. from PIL import Image
  4. import io
  5. import torch
  6. # Add the project root to sys.path to import src
  7. sys.path.append(os.getcwd())
  8. from src.api.main import ModelManager
  9. def test_inference():
  10. print("Testing ModelManager initialization...")
  11. manager = ModelManager(onnx_path='best.onnx', pt_path='best.pt', benchmark_path='sawit_tbs.pt')
  12. print("ModelManager initialized successfully.")
  13. # Create a dummy image for testing
  14. img = Image.new('RGB', (640, 640), color = (73, 109, 137))
  15. print("\nTesting PyTorch inference (Native)...")
  16. detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="pytorch")
  17. print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
  18. print("\nTesting Benchmark inference (YOLOv8-Sawit)...")
  19. detections, raw, ms = manager.run_pytorch_inference(img, 0.25, engine_type="yolov8_sawit")
  20. print(f"Detections: {len(detections)}, Inference: {ms:.2f}ms")
  21. print(f"Benchmark Class Names: {manager.benchmark_class_names}")
  22. print("\nVerification Complete.")
  23. if __name__ == "__main__":
  24. test_inference()