| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- """
- Export best.pt to TFLite with RAW output format (no end-to-end NMS).
- The ultralytics_yolo plugin's JNI postprocessor expects [batch, C+4, N] format
- where rows 0-3 are cx,cy,w,h and rows 4+ are class scores.
- This script:
- 1. Exports best.pt -> best.onnx (without NMS)
- 2. Converts ONNX -> TFLite using onnx_tf
- """
- import subprocess
- import sys
- import os
- def run(cmd):
- print(f"Running: {cmd}")
- result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
- print(result.stdout)
- if result.returncode != 0:
- print("STDERR:", result.stderr)
- return result.returncode
- # Step 1: Export ONNX without NMS using ultralytics
- print("=== Step 1: Export ONNX (no NMS) ===")
- code = """
- from ultralytics import YOLO
- model = YOLO('best.pt')
- model.export(format='onnx', imgsz=640, simplify=True, nms=False, opset=12)
- print('ONNX exported')
- """
- result = subprocess.run([sys.executable, '-c', code], capture_output=True, text=True, cwd=os.getcwd())
- print(result.stdout)
- if result.returncode != 0:
- print("Error:", result.stderr)
- sys.exit(1)
- # Check ONNX output shape
- print("=== Step 2: Verify ONNX output shape ===")
- check_code = """
- import onnx
- m = onnx.load('best.onnx')
- for out in m.graph.output:
- print('Output:', out.name, [d.dim_value for d in out.type.tensor_type.shape.dim])
- """
- subprocess.run([sys.executable, '-c', check_code], cwd=os.getcwd())
- # Step 3: Convert ONNX to TFLite using onnx2tf with disable_strict_mode
- print("=== Step 3: Convert ONNX -> TFLite ===")
- ret = run(
- f'python -m onnx2tf -i best.onnx -o best_raw_saved_model '
- f'--not_use_onnx_optimization '
- f'--output_tfv1_signaturedefs '
- f'--non_verbose '
- f'--disable_strict_mode'
- )
- # Step 4: Convert SavedModel -> TFLite (float32)
- print("=== Step 4: Convert SavedModel -> TFLite Float32 ===")
- tflite_code = """
- import tensorflow as tf
- import os
- # Find saved_model
- base = 'best_raw_saved_model'
- for root, dirs, files in os.walk(base):
- if 'saved_model.pb' in files:
- saved_model_dir = root
- break
- print(f'Using SavedModel: {saved_model_dir}')
- converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
- converter.target_spec.supported_types = [tf.float32]
- tflite_model = converter.convert()
- with open('best_raw_float32.tflite', 'wb') as f:
- f.write(tflite_model)
- print('Saved: best_raw_float32.tflite')
- # Verify output shape
- interp = tf.lite.Interpreter(model_path='best_raw_float32.tflite')
- interp.allocate_tensors()
- inp = interp.get_input_details()
- out = interp.get_output_details()
- print('Input shape:', inp[0]['shape'])
- for i, o in enumerate(out):
- print(f'Output[{i}]:', o['shape'], o['dtype'])
- """
- subprocess.run([sys.executable, '-c', tflite_code], cwd=os.getcwd())
- print("Done. Check best_raw_float32.tflite")
|