|
|
@@ -0,0 +1,89 @@
|
|
|
+"""
|
|
|
+Export best.pt to TFLite with RAW output format (no end-to-end NMS).
|
|
|
+The ultralytics_yolo plugin's JNI postprocessor expects [batch, C+4, N] format
|
|
|
+where rows 0-3 are cx,cy,w,h and rows 4+ are class scores.
|
|
|
+
|
|
|
+This script:
|
|
|
+1. Exports best.pt -> best.onnx (without NMS)
|
|
|
+2. Converts ONNX -> TFLite using onnx_tf
|
|
|
+"""
|
|
|
+import subprocess
|
|
|
+import sys
|
|
|
+import os
|
|
|
+
|
|
|
+def run(cmd):
|
|
|
+ print(f"Running: {cmd}")
|
|
|
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
+ print(result.stdout)
|
|
|
+ if result.returncode != 0:
|
|
|
+ print("STDERR:", result.stderr)
|
|
|
+ return result.returncode
|
|
|
+
|
|
|
+# Step 1: Export ONNX without NMS using ultralytics
|
|
|
+print("=== Step 1: Export ONNX (no NMS) ===")
|
|
|
+code = """
|
|
|
+from ultralytics import YOLO
|
|
|
+model = YOLO('best.pt')
|
|
|
+model.export(format='onnx', imgsz=640, simplify=True, nms=False, opset=12)
|
|
|
+print('ONNX exported')
|
|
|
+"""
|
|
|
+result = subprocess.run([sys.executable, '-c', code], capture_output=True, text=True, cwd=os.getcwd())
|
|
|
+print(result.stdout)
|
|
|
+if result.returncode != 0:
|
|
|
+ print("Error:", result.stderr)
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+# Check ONNX output shape
|
|
|
+print("=== Step 2: Verify ONNX output shape ===")
|
|
|
+check_code = """
|
|
|
+import onnx
|
|
|
+m = onnx.load('best.onnx')
|
|
|
+for out in m.graph.output:
|
|
|
+ print('Output:', out.name, [d.dim_value for d in out.type.tensor_type.shape.dim])
|
|
|
+"""
|
|
|
+subprocess.run([sys.executable, '-c', check_code], cwd=os.getcwd())
|
|
|
+
|
|
|
+# Step 3: Convert ONNX to TFLite using onnx2tf with disable_strict_mode
|
|
|
+print("=== Step 3: Convert ONNX -> TFLite ===")
|
|
|
+ret = run(
|
|
|
+ f'python -m onnx2tf -i best.onnx -o best_raw_saved_model '
|
|
|
+ f'--not_use_onnx_optimization '
|
|
|
+ f'--output_tfv1_signaturedefs '
|
|
|
+ f'--non_verbose '
|
|
|
+ f'--disable_strict_mode'
|
|
|
+)
|
|
|
+
|
|
|
+# Step 4: Convert SavedModel -> TFLite (float32)
|
|
|
+print("=== Step 4: Convert SavedModel -> TFLite Float32 ===")
|
|
|
+tflite_code = """
|
|
|
+import tensorflow as tf
|
|
|
+import os
|
|
|
+
|
|
|
+# Find saved_model
|
|
|
+base = 'best_raw_saved_model'
|
|
|
+for root, dirs, files in os.walk(base):
|
|
|
+ if 'saved_model.pb' in files:
|
|
|
+ saved_model_dir = root
|
|
|
+ break
|
|
|
+
|
|
|
+print(f'Using SavedModel: {saved_model_dir}')
|
|
|
+
|
|
|
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
|
|
+converter.target_spec.supported_types = [tf.float32]
|
|
|
+tflite_model = converter.convert()
|
|
|
+
|
|
|
+with open('best_raw_float32.tflite', 'wb') as f:
|
|
|
+ f.write(tflite_model)
|
|
|
+print('Saved: best_raw_float32.tflite')
|
|
|
+
|
|
|
+# Verify output shape
|
|
|
+interp = tf.lite.Interpreter(model_path='best_raw_float32.tflite')
|
|
|
+interp.allocate_tensors()
|
|
|
+inp = interp.get_input_details()
|
|
|
+out = interp.get_output_details()
|
|
|
+print('Input shape:', inp[0]['shape'])
|
|
|
+for i, o in enumerate(out):
|
|
|
+ print(f'Output[{i}]:', o['shape'], o['dtype'])
|
|
|
+"""
|
|
|
+subprocess.run([sys.executable, '-c', tflite_code], cwd=os.getcwd())
|
|
|
+print("Done. Check best_raw_float32.tflite")
|