tflite_service.dart 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import 'dart:io';
  2. import 'dart:math';
  3. import 'dart:ui';
  4. import 'package:flutter/services.dart';
  5. import 'package:flutter/foundation.dart';
  6. import 'package:image/image.dart' as img;
  7. import 'package:image_picker/image_picker.dart';
  8. import 'package:tflite_flutter/tflite_flutter.dart';
  9. import 'package:camera/camera.dart';
  10. /// A detection result parsed from the model's end-to-end output.
  11. class DetectionResult {
  12. final String className;
  13. final int classIndex;
  14. final double confidence;
  15. /// Normalized bounding box (0.0 - 1.0)
  16. final Rect normalizedBox;
  17. const DetectionResult({
  18. required this.className,
  19. required this.classIndex,
  20. required this.confidence,
  21. required this.normalizedBox,
  22. });
  23. Color getStatusColor() {
  24. if (className == 'Empty_Bunch' || className == 'Abnormal') return const Color(0xFFF44336); // Colors.red
  25. if (className == 'Ripe' || className == 'Overripe') return const Color(0xFF4CAF50); // Colors.green
  26. return const Color(0xFFFF9800); // Colors.orange
  27. }
  28. }
  29. /// Custom TFLite inference service that correctly decodes the end-to-end
  30. /// YOLO model output format [1, N, 6] = [batch, detections, (x1,y1,x2,y2,conf,class_id)].
  31. class TfliteService {
  32. static const _modelAsset = 'best.tflite';
  33. static const _labelsAsset = 'labels.txt';
  34. static const int _inputSize = 640;
  35. static const double _confidenceThreshold = 0.25;
  36. Interpreter? _interpreter;
  37. List<String> _labels = [];
  38. final ImagePicker _picker = ImagePicker();
  39. bool _isInitialized = false;
  40. bool get isInitialized => _isInitialized;
  41. Future<void> initModel() async {
  42. try {
  43. // Load labels
  44. final labelData = await rootBundle.loadString('assets/$_labelsAsset');
  45. _labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  46. // Load model
  47. final interpreterOptions = InterpreterOptions()..threads = 4;
  48. _interpreter = await Interpreter.fromAsset(
  49. 'assets/$_modelAsset',
  50. options: interpreterOptions,
  51. );
  52. _isInitialized = true;
  53. print('TfliteService: Model loaded. Labels: $_labels');
  54. print('TfliteService: Input: ${_interpreter!.getInputTensors().map((t) => t.shape)}');
  55. print('TfliteService: Output: ${_interpreter!.getOutputTensors().map((t) => t.shape)}');
  56. } catch (e) {
  57. print('TfliteService init error: $e');
  58. rethrow;
  59. }
  60. }
  61. Future<XFile?> pickImage() async {
  62. return await _picker.pickImage(
  63. source: ImageSource.gallery,
  64. maxWidth: _inputSize.toDouble(),
  65. maxHeight: _inputSize.toDouble(),
  66. );
  67. }
  68. /// Run inference on the image at [imagePath].
  69. /// Returns a list of [DetectionResult] sorted by confidence descending.
  70. /// Offloaded to a background isolate to keep UI smooth.
  71. Future<List<DetectionResult>> runInference(String imagePath) async {
  72. if (!_isInitialized) await initModel();
  73. final imageBytes = await File(imagePath).readAsBytes();
  74. // We pass the raw bytes and asset paths to the isolate.
  75. // The isolate will handle decoding, resizing, and inference.
  76. return await _runInferenceInIsolate(imageBytes);
  77. }
  78. /// Run inference on a [CameraImage] from the stream.
  79. /// Throttled by the caller.
  80. Future<List<DetectionResult>> runInferenceOnStream(CameraImage image) async {
  81. if (!_isInitialized) await initModel();
  82. // We pass the CameraImage planes to the isolate for conversion and inference.
  83. return await compute(_inferenceStreamTaskWrapper, {
  84. 'planes': image.planes.map((p) => {
  85. 'bytes': p.bytes,
  86. 'bytesPerRow': p.bytesPerRow,
  87. 'bytesPerPixel': p.bytesPerPixel,
  88. }).toList(),
  89. 'width': image.width,
  90. 'height': image.height,
  91. 'format': image.format.group,
  92. 'modelBytes': (await rootBundle.load('assets/$_modelAsset')).buffer.asUint8List(),
  93. 'labelData': await rootBundle.loadString('assets/$_labelsAsset'),
  94. });
  95. }
  96. static List<DetectionResult> _inferenceStreamTaskWrapper(Map<String, dynamic> args) {
  97. final modelBytes = args['modelBytes'] as Uint8List;
  98. final labelData = args['labelData'] as String;
  99. final planes = args['planes'] as List<dynamic>;
  100. final width = args['width'] as int;
  101. final height = args['height'] as int;
  102. final interpreter = Interpreter.fromBuffer(modelBytes);
  103. final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  104. try {
  105. final size = width < height ? width : height;
  106. final offsetX = (width - size) ~/ 2;
  107. final offsetY = (height - size) ~/ 2;
  108. img.Image? image;
  109. if (args['format'] == ImageFormatGroup.yuv420) {
  110. image = _convertYUV420ToImage(
  111. planes: planes,
  112. width: width,
  113. height: height,
  114. cropSize: size,
  115. offsetX: offsetX,
  116. offsetY: offsetY,
  117. );
  118. } else if (args['format'] == ImageFormatGroup.bgra8888) {
  119. final fullImage = img.Image.fromBytes(
  120. width: width,
  121. height: height,
  122. bytes: planes[0]['bytes'].buffer,
  123. format: img.Format.uint8,
  124. numChannels: 4,
  125. order: img.ChannelOrder.bgra,
  126. );
  127. image = img.copyCrop(fullImage, x: offsetX, y: offsetY, width: size, height: size);
  128. }
  129. if (image == null) return [];
  130. // Resize and Run
  131. final resized = img.copyResize(image, width: _inputSize, height: _inputSize);
  132. final inputTensor = List.generate(1, (_) =>
  133. List.generate(_inputSize, (y) =>
  134. List.generate(_inputSize, (x) {
  135. final pixel = resized.getPixel(x, y);
  136. return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
  137. })
  138. )
  139. );
  140. final outputShape = interpreter.getOutputTensors()[0].shape;
  141. final outputTensor = List.generate(1, (_) =>
  142. List.generate(outputShape[1], (_) =>
  143. List<double>.filled(outputShape[2], 0.0)
  144. )
  145. );
  146. interpreter.run(inputTensor, outputTensor);
  147. // Map detections back to full frame
  148. return _decodeDetections(
  149. outputTensor[0],
  150. labels,
  151. cropSize: size,
  152. offsetX: offsetX,
  153. offsetY: offsetY,
  154. fullWidth: width,
  155. fullHeight: height
  156. );
  157. } finally {
  158. interpreter.close();
  159. }
  160. }
  161. static img.Image _convertYUV420ToImage({
  162. required List<dynamic> planes,
  163. required int width,
  164. required int height,
  165. required int cropSize,
  166. required int offsetX,
  167. required int offsetY,
  168. }) {
  169. final yPlane = planes[0];
  170. final uPlane = planes[1];
  171. final vPlane = planes[2];
  172. final yBytes = yPlane['bytes'] as Uint8List;
  173. final uBytes = uPlane['bytes'] as Uint8List;
  174. final vBytes = vPlane['bytes'] as Uint8List;
  175. final yRowStride = yPlane['bytesPerRow'] as int;
  176. final uvRowStride = uPlane['bytesPerRow'] as int;
  177. final uvPixelStride = uPlane['bytesPerPixel'] as int;
  178. final image = img.Image(width: cropSize, height: cropSize);
  179. for (int y = 0; y < cropSize; y++) {
  180. for (int x = 0; x < cropSize; x++) {
  181. final int actualX = x + offsetX;
  182. final int actualY = y + offsetY;
  183. final int uvIndex = (uvRowStride * (actualY / 2).floor()) + (uvPixelStride * (actualX / 2).floor());
  184. final int yIndex = (actualY * yRowStride) + actualX;
  185. // Ensure we don't go out of bounds
  186. if (yIndex >= yBytes.length || uvIndex >= uBytes.length || uvIndex >= vBytes.length) continue;
  187. final int yp = yBytes[yIndex];
  188. final int up = uBytes[uvIndex];
  189. final int vp = vBytes[uvIndex];
  190. // Standard YUV to RGB conversion
  191. int r = (yp + (1.370705 * (vp - 128))).toInt().clamp(0, 255);
  192. int g = (yp - (0.337633 * (up - 128)) - (0.698001 * (vp - 128))).toInt().clamp(0, 255);
  193. int b = (yp + (1.732446 * (up - 128))).toInt().clamp(0, 255);
  194. image.setPixelRgb(x, y, r, g, b);
  195. }
  196. }
  197. return image;
  198. }
  199. static List<DetectionResult> _decodeDetections(
  200. List<List<double>> rawDetections,
  201. List<String> labels, {
  202. int? cropSize,
  203. int? offsetX,
  204. int? offsetY,
  205. int? fullWidth,
  206. int? fullHeight,
  207. }) {
  208. final detections = <DetectionResult>[];
  209. for (final det in rawDetections) {
  210. if (det.length < 6) continue;
  211. final conf = det[4];
  212. if (conf < _confidenceThreshold) continue;
  213. double x1 = det[0].clamp(0.0, 1.0);
  214. double y1 = det[1].clamp(0.0, 1.0);
  215. double x2 = det[2].clamp(0.0, 1.0);
  216. double y2 = det[3].clamp(0.0, 1.0);
  217. // If crop info is provided, map back to full frame
  218. if (cropSize != null && offsetX != null && offsetY != null && fullWidth != null && fullHeight != null) {
  219. x1 = (x1 * cropSize + offsetX) / fullWidth;
  220. x2 = (x2 * cropSize + offsetX) / fullWidth;
  221. y1 = (y1 * cropSize + offsetY) / fullHeight;
  222. y2 = (y2 * cropSize + offsetY) / fullHeight;
  223. }
  224. final classId = det[5].round();
  225. if (x2 <= x1 || y2 <= y1) continue;
  226. final label = (classId >= 0 && classId < labels.length) ? labels[classId] : 'Unknown';
  227. detections.add(DetectionResult(
  228. className: label,
  229. classIndex: classId,
  230. confidence: conf,
  231. normalizedBox: Rect.fromLTRB(x1, y1, x2, y2),
  232. ));
  233. }
  234. detections.sort((a, b) => b.confidence.compareTo(a.confidence));
  235. return detections;
  236. }
  237. Future<List<DetectionResult>> _runInferenceInIsolate(Uint8List imageBytes) async {
  238. // We need the model and labels passed as data
  239. final modelData = await rootBundle.load('assets/$_modelAsset');
  240. final labelData = await rootBundle.loadString('assets/$_labelsAsset');
  241. // Use compute to run in a real isolate
  242. return await compute(_inferenceTaskWrapper, {
  243. 'imageBytes': imageBytes,
  244. 'modelBytes': modelData.buffer.asUint8List(),
  245. 'labelData': labelData,
  246. });
  247. }
  248. static List<DetectionResult> _inferenceTaskWrapper(Map<String, dynamic> args) {
  249. return _inferenceTask(
  250. args['imageBytes'] as Uint8List,
  251. args['modelBytes'] as Uint8List,
  252. args['labelData'] as String,
  253. );
  254. }
  255. /// The static task that runs in the background isolate
  256. static List<DetectionResult> _inferenceTask(Uint8List imageBytes, Uint8List modelBytes, String labelData) {
  257. // 1. Initialize Interpreter inside the isolate
  258. final interpreter = Interpreter.fromBuffer(modelBytes);
  259. final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  260. try {
  261. // 2. Preprocess image
  262. final decoded = img.decodeImage(imageBytes);
  263. if (decoded == null) throw Exception('Could not decode image');
  264. // Center-Square Crop
  265. final int width = decoded.width;
  266. final int height = decoded.height;
  267. final int size = width < height ? width : height;
  268. final int offsetX = (width - size) ~/ 2;
  269. final int offsetY = (height - size) ~/ 2;
  270. final cropped = img.copyCrop(decoded, x: offsetX, y: offsetY, width: size, height: size);
  271. final resized = img.copyResize(cropped, width: _inputSize, height: _inputSize, interpolation: img.Interpolation.linear);
  272. final inputTensor = List.generate(1, (_) =>
  273. List.generate(_inputSize, (y) =>
  274. List.generate(_inputSize, (x) {
  275. final pixel = resized.getPixel(x, y);
  276. return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
  277. })
  278. )
  279. );
  280. // 3. Prepare output
  281. final outputShape = interpreter.getOutputTensors()[0].shape;
  282. final outputTensor = List.generate(1, (_) =>
  283. List.generate(outputShape[1], (_) =>
  284. List<double>.filled(outputShape[2], 0.0)
  285. )
  286. );
  287. // 4. Run
  288. interpreter.run(inputTensor, outputTensor);
  289. // Map detections back to full frame
  290. return _decodeDetections(
  291. outputTensor[0],
  292. labels,
  293. cropSize: size,
  294. offsetX: offsetX,
  295. offsetY: offsetY,
  296. fullWidth: width,
  297. fullHeight: height
  298. );
  299. } finally {
  300. interpreter.close();
  301. }
  302. }
  303. void dispose() {
  304. _interpreter?.close();
  305. _interpreter = null;
  306. _isInitialized = false;
  307. }
  308. }