export_raw_tflite.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """
  2. Export best.pt to TFLite with RAW output format (no end-to-end NMS).
  3. The ultralytics_yolo plugin's JNI postprocessor expects [batch, C+4, N] format
  4. where rows 0-3 are cx,cy,w,h and rows 4+ are class scores.
  5. This script:
  6. 1. Exports best.pt -> best.onnx (without NMS)
  7. 2. Converts ONNX -> TFLite using onnx_tf
  8. """
  9. import subprocess
  10. import sys
  11. import os
  12. def run(cmd):
  13. print(f"Running: {cmd}")
  14. result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
  15. print(result.stdout)
  16. if result.returncode != 0:
  17. print("STDERR:", result.stderr)
  18. return result.returncode
  19. # Step 1: Export ONNX without NMS using ultralytics
  20. print("=== Step 1: Export ONNX (no NMS) ===")
  21. code = """
  22. from ultralytics import YOLO
  23. model = YOLO('best.pt')
  24. model.export(format='onnx', imgsz=640, simplify=True, nms=False, opset=12)
  25. print('ONNX exported')
  26. """
  27. result = subprocess.run([sys.executable, '-c', code], capture_output=True, text=True, cwd=os.getcwd())
  28. print(result.stdout)
  29. if result.returncode != 0:
  30. print("Error:", result.stderr)
  31. sys.exit(1)
  32. # Check ONNX output shape
  33. print("=== Step 2: Verify ONNX output shape ===")
  34. check_code = """
  35. import onnx
  36. m = onnx.load('best.onnx')
  37. for out in m.graph.output:
  38. print('Output:', out.name, [d.dim_value for d in out.type.tensor_type.shape.dim])
  39. """
  40. subprocess.run([sys.executable, '-c', check_code], cwd=os.getcwd())
  41. # Step 3: Convert ONNX to TFLite using onnx2tf with disable_strict_mode
  42. print("=== Step 3: Convert ONNX -> TFLite ===")
  43. ret = run(
  44. f'python -m onnx2tf -i best.onnx -o best_raw_saved_model '
  45. f'--not_use_onnx_optimization '
  46. f'--output_tfv1_signaturedefs '
  47. f'--non_verbose '
  48. f'--disable_strict_mode'
  49. )
  50. # Step 4: Convert SavedModel -> TFLite (float32)
  51. print("=== Step 4: Convert SavedModel -> TFLite Float32 ===")
  52. tflite_code = """
  53. import tensorflow as tf
  54. import os
  55. # Find saved_model
  56. base = 'best_raw_saved_model'
  57. for root, dirs, files in os.walk(base):
  58. if 'saved_model.pb' in files:
  59. saved_model_dir = root
  60. break
  61. print(f'Using SavedModel: {saved_model_dir}')
  62. converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  63. converter.target_spec.supported_types = [tf.float32]
  64. tflite_model = converter.convert()
  65. with open('best_raw_float32.tflite', 'wb') as f:
  66. f.write(tflite_model)
  67. print('Saved: best_raw_float32.tflite')
  68. # Verify output shape
  69. interp = tf.lite.Interpreter(model_path='best_raw_float32.tflite')
  70. interp.allocate_tensors()
  71. inp = interp.get_input_details()
  72. out = interp.get_output_details()
  73. print('Input shape:', inp[0]['shape'])
  74. for i, o in enumerate(out):
  75. print(f'Output[{i}]:', o['shape'], o['dtype'])
  76. """
  77. subprocess.run([sys.executable, '-c', tflite_code], cwd=os.getcwd())
  78. print("Done. Check best_raw_float32.tflite")