tflite_service.dart 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import 'dart:io';
  2. import 'dart:math';
  3. import 'dart:ui';
  4. import 'package:flutter/services.dart';
  5. import 'package:flutter/foundation.dart';
  6. import 'package:image/image.dart' as img;
  7. import 'package:image_picker/image_picker.dart';
  8. import 'package:tflite_flutter/tflite_flutter.dart';
  9. /// A detection result parsed from the model's end-to-end output.
  10. class DetectionResult {
  11. final String className;
  12. final int classIndex;
  13. final double confidence;
  14. /// Normalized bounding box (0.0 - 1.0)
  15. final Rect normalizedBox;
  16. const DetectionResult({
  17. required this.className,
  18. required this.classIndex,
  19. required this.confidence,
  20. required this.normalizedBox,
  21. });
  22. }
  23. /// Custom TFLite inference service that correctly decodes the end-to-end
  24. /// YOLO model output format [1, N, 6] = [batch, detections, (x1,y1,x2,y2,conf,class_id)].
  25. class TfliteService {
  26. static const _modelAsset = 'best.tflite';
  27. static const _labelsAsset = 'labels.txt';
  28. static const int _inputSize = 640;
  29. static const double _confidenceThreshold = 0.25;
  30. Interpreter? _interpreter;
  31. List<String> _labels = [];
  32. final ImagePicker _picker = ImagePicker();
  33. bool _isInitialized = false;
  34. bool get isInitialized => _isInitialized;
  35. Future<void> initModel() async {
  36. try {
  37. // Load labels
  38. final labelData = await rootBundle.loadString('assets/$_labelsAsset');
  39. _labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  40. // Load model
  41. final interpreterOptions = InterpreterOptions()..threads = 4;
  42. _interpreter = await Interpreter.fromAsset(
  43. 'assets/$_modelAsset',
  44. options: interpreterOptions,
  45. );
  46. _isInitialized = true;
  47. print('TfliteService: Model loaded. Labels: $_labels');
  48. print('TfliteService: Input: ${_interpreter!.getInputTensors().map((t) => t.shape)}');
  49. print('TfliteService: Output: ${_interpreter!.getOutputTensors().map((t) => t.shape)}');
  50. } catch (e) {
  51. print('TfliteService init error: $e');
  52. rethrow;
  53. }
  54. }
  55. Future<XFile?> pickImage() async {
  56. return await _picker.pickImage(
  57. source: ImageSource.gallery,
  58. maxWidth: _inputSize.toDouble(),
  59. maxHeight: _inputSize.toDouble(),
  60. );
  61. }
  62. /// Run inference on the image at [imagePath].
  63. /// Returns a list of [DetectionResult] sorted by confidence descending.
  64. /// Offloaded to a background isolate to keep UI smooth.
  65. Future<List<DetectionResult>> runInference(String imagePath) async {
  66. if (!_isInitialized) await initModel();
  67. final imageBytes = await File(imagePath).readAsBytes();
  68. // We pass the raw bytes and asset paths to the isolate.
  69. // The isolate will handle decoding, resizing, and inference.
  70. return await _runInferenceInIsolate(imageBytes);
  71. }
  72. Future<List<DetectionResult>> _runInferenceInIsolate(Uint8List imageBytes) async {
  73. // We need the model and labels passed as data
  74. final modelData = await rootBundle.load('assets/$_modelAsset');
  75. final labelData = await rootBundle.loadString('assets/$_labelsAsset');
  76. // Use compute to run in a real isolate
  77. return await compute(_inferenceTaskWrapper, {
  78. 'imageBytes': imageBytes,
  79. 'modelBytes': modelData.buffer.asUint8List(),
  80. 'labelData': labelData,
  81. });
  82. }
  83. static List<DetectionResult> _inferenceTaskWrapper(Map<String, dynamic> args) {
  84. return _inferenceTask(
  85. args['imageBytes'] as Uint8List,
  86. args['modelBytes'] as Uint8List,
  87. args['labelData'] as String,
  88. );
  89. }
  90. /// The static task that runs in the background isolate
  91. static List<DetectionResult> _inferenceTask(Uint8List imageBytes, Uint8List modelBytes, String labelData) {
  92. // 1. Initialize Interpreter inside the isolate
  93. final interpreter = Interpreter.fromBuffer(modelBytes);
  94. final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  95. try {
  96. // 2. Preprocess image
  97. final decoded = img.decodeImage(imageBytes);
  98. if (decoded == null) throw Exception('Could not decode image');
  99. final resized = img.copyResize(decoded, width: _inputSize, height: _inputSize, interpolation: img.Interpolation.linear);
  100. final inputTensor = List.generate(1, (_) =>
  101. List.generate(_inputSize, (y) =>
  102. List.generate(_inputSize, (x) {
  103. final pixel = resized.getPixel(x, y);
  104. return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
  105. })
  106. )
  107. );
  108. // 3. Prepare output
  109. final outputShape = interpreter.getOutputTensors()[0].shape;
  110. final numDetections = outputShape[1];
  111. final numFields = outputShape[2];
  112. final outputTensor = List.generate(1, (_) =>
  113. List.generate(numDetections, (_) =>
  114. List<double>.filled(numFields, 0.0)
  115. )
  116. );
  117. // 4. Run
  118. interpreter.run(inputTensor, outputTensor);
  119. // 5. Decode
  120. final detections = <DetectionResult>[];
  121. final rawDetections = outputTensor[0];
  122. for (final det in rawDetections) {
  123. if (det.length < 6) continue;
  124. final conf = det[4];
  125. if (conf < _confidenceThreshold) continue;
  126. final x1 = det[0].clamp(0.0, 1.0);
  127. final y1 = det[1].clamp(0.0, 1.0);
  128. final x2 = det[2].clamp(0.0, 1.0);
  129. final y2 = det[3].clamp(0.0, 1.0);
  130. final classId = det[5].round();
  131. if (x2 <= x1 || y2 <= y1) continue;
  132. final label = (classId >= 0 && classId < labels.length) ? labels[classId] : 'Unknown';
  133. detections.add(DetectionResult(
  134. className: label,
  135. classIndex: classId,
  136. confidence: conf,
  137. normalizedBox: Rect.fromLTRB(x1, y1, x2, y2),
  138. ));
  139. }
  140. detections.sort((a, b) => b.confidence.compareTo(a.confidence));
  141. return detections;
  142. } finally {
  143. interpreter.close();
  144. }
  145. }
  146. void dispose() {
  147. _interpreter?.close();
  148. _interpreter = null;
  149. _isInitialized = false;
  150. }
  151. }