tflite_service.dart 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import 'dart:io';
  2. import 'dart:math';
  3. import 'dart:ui';
  4. import 'dart:isolate';
  5. import 'dart:async';
  6. import 'package:flutter/services.dart';
  7. import 'package:flutter/foundation.dart';
  8. import 'package:image/image.dart' as img;
  9. import 'package:image_picker/image_picker.dart';
  10. import 'package:tflite_flutter/tflite_flutter.dart';
  11. import 'package:camera/camera.dart';
  12. class DetectionResult {
  13. final String className;
  14. final int classIndex;
  15. final double confidence;
  16. final Rect normalizedBox;
  17. const DetectionResult({
  18. required this.className,
  19. required this.classIndex,
  20. required this.confidence,
  21. required this.normalizedBox,
  22. });
  23. Color getStatusColor() {
  24. if (className == 'Empty_Bunch' || className == 'Abnormal') return const Color(0xFFF44336);
  25. if (className == 'Ripe' || className == 'Overripe') return const Color(0xFF4CAF50);
  26. return const Color(0xFFFF9800);
  27. }
  28. }
  29. class TfliteService {
  30. static const _modelAsset = 'best.tflite';
  31. static const _labelsAsset = 'labels.txt';
  32. static const int _inputSize = 640;
  33. static const double _confidenceThreshold = 0.25;
  34. Isolate? _isolate;
  35. SendPort? _sendPort;
  36. ReceivePort? _receivePort;
  37. List<String> _labels = [];
  38. final ImagePicker _picker = ImagePicker();
  39. bool _isInitialized = false;
  40. bool _isIsolateBusy = false;
  41. bool get isInitialized => _isInitialized;
  42. Future<void> initModel() async {
  43. try {
  44. final labelData = await rootBundle.loadString('assets/$_labelsAsset');
  45. _labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  46. final modelData = await rootBundle.load('assets/$_modelAsset');
  47. final modelBytes = modelData.buffer.asUint8List();
  48. _receivePort = ReceivePort();
  49. _isolate = await Isolate.spawn(_isolateEntry, _receivePort!.sendPort);
  50. final completer = Completer<SendPort>();
  51. StreamSubscription? sub;
  52. sub = _receivePort!.listen((message) {
  53. if (message is SendPort) {
  54. completer.complete(message);
  55. sub?.cancel();
  56. }
  57. });
  58. _sendPort = await completer.future;
  59. final initCompleter = Completer<void>();
  60. final initReplyPort = ReceivePort();
  61. _sendPort!.send({
  62. 'command': 'init',
  63. 'modelBytes': modelBytes,
  64. 'labelData': labelData,
  65. 'replyPort': initReplyPort.sendPort,
  66. });
  67. StreamSubscription? initSub;
  68. initSub = initReplyPort.listen((message) {
  69. if (message == 'init_done') {
  70. initCompleter.complete();
  71. initSub?.cancel();
  72. initReplyPort.close();
  73. }
  74. });
  75. await initCompleter.future;
  76. _isInitialized = true;
  77. print('TfliteService: Model loaded via persistent isolate.');
  78. } catch (e) {
  79. print('TfliteService init error: $e');
  80. rethrow;
  81. }
  82. }
  83. Future<XFile?> pickImage() async {
  84. return await _picker.pickImage(
  85. source: ImageSource.gallery,
  86. maxWidth: _inputSize.toDouble(),
  87. maxHeight: _inputSize.toDouble(),
  88. );
  89. }
  90. Future<List<DetectionResult>> runInference(String imagePath) async {
  91. if (!_isInitialized) await initModel();
  92. final imageBytes = await File(imagePath).readAsBytes();
  93. final replyPort = ReceivePort();
  94. _sendPort!.send({
  95. 'command': 'inference_static',
  96. 'imageBytes': imageBytes,
  97. 'replyPort': replyPort.sendPort,
  98. });
  99. final detections = await replyPort.first;
  100. replyPort.close();
  101. return detections as List<DetectionResult>;
  102. }
  103. Future<List<DetectionResult>> runInferenceOnStream(CameraImage image) async {
  104. if (!_isInitialized) await initModel();
  105. if (_isIsolateBusy) return <DetectionResult>[];
  106. _isIsolateBusy = true;
  107. final replyPort = ReceivePort();
  108. _sendPort!.send({
  109. 'command': 'inference_stream',
  110. 'planes': image.planes.map((p) => {
  111. 'bytes': p.bytes,
  112. 'bytesPerRow': p.bytesPerRow,
  113. 'bytesPerPixel': p.bytesPerPixel,
  114. }).toList(),
  115. 'width': image.width,
  116. 'height': image.height,
  117. 'format': image.format.group,
  118. 'replyPort': replyPort.sendPort,
  119. });
  120. final detections = await replyPort.first;
  121. replyPort.close();
  122. _isIsolateBusy = false;
  123. return detections as List<DetectionResult>;
  124. }
  125. static void _isolateEntry(SendPort sendPort) {
  126. final receivePort = ReceivePort();
  127. sendPort.send(receivePort.sendPort);
  128. Interpreter? interpreter;
  129. List<String> labels = [];
  130. receivePort.listen((message) {
  131. if (message is Map) {
  132. final command = message['command'];
  133. final replyPort = message['replyPort'] as SendPort;
  134. if (command == 'init') {
  135. final modelBytes = message['modelBytes'] as Uint8List;
  136. final labelData = message['labelData'] as String;
  137. final interpreterOptions = InterpreterOptions()..threads = 4;
  138. interpreter = Interpreter.fromBuffer(modelBytes, options: interpreterOptions);
  139. labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
  140. replyPort.send('init_done');
  141. } else if (command == 'inference_static') {
  142. if (interpreter == null) {
  143. replyPort.send(<DetectionResult>[]);
  144. return;
  145. }
  146. final imageBytes = message['imageBytes'] as Uint8List;
  147. final results = _inferenceStaticTask(imageBytes, interpreter!, labels);
  148. replyPort.send(results);
  149. } else if (command == 'inference_stream') {
  150. if (interpreter == null) {
  151. replyPort.send(<DetectionResult>[]);
  152. return;
  153. }
  154. final planes = message['planes'] as List<dynamic>;
  155. final width = message['width'] as int;
  156. final height = message['height'] as int;
  157. final format = message['format'];
  158. final results = _inferenceStreamTask(planes, width, height, format, interpreter!, labels);
  159. replyPort.send(results);
  160. }
  161. }
  162. });
  163. }
  164. static List<DetectionResult> _inferenceStaticTask(Uint8List imageBytes, Interpreter interpreter, List<String> labels) {
  165. try {
  166. final decoded = img.decodeImage(imageBytes);
  167. if (decoded == null) throw Exception('Could not decode image');
  168. final int width = decoded.width;
  169. final int height = decoded.height;
  170. final int size = width < height ? width : height;
  171. final int offsetX = (width - size) ~/ 2;
  172. final int offsetY = (height - size) ~/ 2;
  173. final cropped = img.copyCrop(decoded, x: offsetX, y: offsetY, width: size, height: size);
  174. final resized = img.copyResize(cropped, width: _inputSize, height: _inputSize, interpolation: img.Interpolation.linear);
  175. final inputTensor = List.generate(1, (_) =>
  176. List.generate(_inputSize, (y) =>
  177. List.generate(_inputSize, (x) {
  178. final pixel = resized.getPixel(x, y);
  179. return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
  180. })
  181. )
  182. );
  183. final outputShape = interpreter.getOutputTensors()[0].shape;
  184. final outputTensor = List.generate(1, (_) =>
  185. List.generate(outputShape[1], (_) =>
  186. List<double>.filled(outputShape[2], 0.0)
  187. )
  188. );
  189. interpreter.run(inputTensor, outputTensor);
  190. return _decodeDetections(
  191. outputTensor[0],
  192. labels,
  193. cropSize: size,
  194. offsetX: offsetX,
  195. offsetY: offsetY,
  196. fullWidth: width,
  197. fullHeight: height
  198. );
  199. } catch (e) {
  200. print('Isolate static inference error: $e');
  201. return <DetectionResult>[];
  202. }
  203. }
  204. static List<DetectionResult> _inferenceStreamTask(
  205. List<dynamic> planes, int width, int height, dynamic format,
  206. Interpreter interpreter, List<String> labels
  207. ) {
  208. try {
  209. final size = width < height ? width : height;
  210. final offsetX = (width - size) ~/ 2;
  211. final offsetY = (height - size) ~/ 2;
  212. img.Image? image;
  213. if (format == ImageFormatGroup.bgra8888) {
  214. final fullImage = img.Image.fromBytes(
  215. width: width,
  216. height: height,
  217. bytes: planes[0]['bytes'].buffer,
  218. format: img.Format.uint8,
  219. numChannels: 4,
  220. order: img.ChannelOrder.bgra,
  221. );
  222. image = img.copyCrop(fullImage, x: offsetX, y: offsetY, width: size, height: size);
  223. } else if (format == ImageFormatGroup.yuv420) {
  224. image = _convertYUV420ToImage(
  225. planes: planes,
  226. width: width,
  227. height: height,
  228. cropSize: size,
  229. offsetX: offsetX,
  230. offsetY: offsetY,
  231. );
  232. } else {
  233. print("TfliteService: Unsupported format: $format. Ensure platform correctly requests YUV420 or BGRA.");
  234. return <DetectionResult>[];
  235. }
  236. final resized = img.copyResize(image, width: _inputSize, height: _inputSize);
  237. final inputTensor = List.generate(1, (_) =>
  238. List.generate(_inputSize, (y) =>
  239. List.generate(_inputSize, (x) {
  240. final pixel = resized.getPixel(x, y);
  241. return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
  242. })
  243. )
  244. );
  245. final outputShape = interpreter.getOutputTensors()[0].shape;
  246. final outputTensor = List.generate(1, (_) =>
  247. List.generate(outputShape[1], (_) =>
  248. List<double>.filled(outputShape[2], 0.0)
  249. )
  250. );
  251. interpreter.run(inputTensor, outputTensor);
  252. return _decodeDetections(
  253. outputTensor[0],
  254. labels,
  255. cropSize: size,
  256. offsetX: offsetX,
  257. offsetY: offsetY,
  258. fullWidth: width,
  259. fullHeight: height
  260. );
  261. } catch (e) {
  262. print('Isolate stream inference error: $e');
  263. return <DetectionResult>[];
  264. }
  265. }
  266. static img.Image _convertYUV420ToImage({
  267. required List<dynamic> planes,
  268. required int width,
  269. required int height,
  270. required int cropSize,
  271. required int offsetX,
  272. required int offsetY,
  273. }) {
  274. final yPlane = planes[0];
  275. final uPlane = planes[1];
  276. final vPlane = planes[2];
  277. final yBytes = yPlane['bytes'] as Uint8List;
  278. final uBytes = uPlane['bytes'] as Uint8List;
  279. final vBytes = vPlane['bytes'] as Uint8List;
  280. final yRowStride = yPlane['bytesPerRow'] as int;
  281. final uvRowStride = uPlane['bytesPerRow'] as int;
  282. final uvPixelStride = uPlane['bytesPerPixel'] as int;
  283. // Use a flat Uint8List buffer for fast native-style memory writing
  284. // 3 channels: R, G, B
  285. final Uint8List rgbBytes = Uint8List(cropSize * cropSize * 3);
  286. int bufferIndex = 0;
  287. for (int y = 0; y < cropSize; y++) {
  288. for (int x = 0; x < cropSize; x++) {
  289. final int actualX = x + offsetX;
  290. final int actualY = y + offsetY;
  291. // Mathematical offset matching
  292. final int uvIndex = (uvRowStride * (actualY >> 1)) + (uvPixelStride * (actualX >> 1));
  293. final int yIndex = (actualY * yRowStride) + actualX;
  294. // Skip if out of bounds (should not happen mathematically if offsets are valid,
  295. // but kept as safety check for corrupted frames)
  296. if (yIndex >= yBytes.length || uvIndex >= uBytes.length || uvIndex >= vBytes.length) {
  297. bufferIndex += 3;
  298. continue;
  299. }
  300. final int yp = yBytes[yIndex];
  301. final int up = uBytes[uvIndex];
  302. final int vp = vBytes[uvIndex];
  303. // Standard YUV to RGB conversion
  304. int r = (yp + (1.370705 * (vp - 128))).toInt();
  305. int g = (yp - (0.337633 * (up - 128)) - (0.698001 * (vp - 128))).toInt();
  306. int b = (yp + (1.732446 * (up - 128))).toInt();
  307. // Write directly to sequential memory with inline clamping
  308. rgbBytes[bufferIndex++] = r < 0 ? 0 : (r > 255 ? 255 : r);
  309. rgbBytes[bufferIndex++] = g < 0 ? 0 : (g > 255 ? 255 : g);
  310. rgbBytes[bufferIndex++] = b < 0 ? 0 : (b > 255 ? 255 : b);
  311. }
  312. }
  313. // Construct image mapping directly from the fast buffer
  314. return img.Image.fromBytes(
  315. width: cropSize,
  316. height: cropSize,
  317. bytes: rgbBytes.buffer,
  318. format: img.Format.uint8,
  319. numChannels: 3,
  320. order: img.ChannelOrder.rgb,
  321. );
  322. }
  323. static List<DetectionResult> _decodeDetections(
  324. List<List<double>> rawDetections,
  325. List<String> labels, {
  326. int? cropSize,
  327. int? offsetX,
  328. int? offsetY,
  329. int? fullWidth,
  330. int? fullHeight,
  331. }) {
  332. final detections = <DetectionResult>[];
  333. for (final det in rawDetections) {
  334. if (det.length < 6) continue;
  335. final conf = det[4];
  336. if (conf < _confidenceThreshold) continue;
  337. double x1 = det[0].clamp(0.0, 1.0);
  338. double y1 = det[1].clamp(0.0, 1.0);
  339. double x2 = det[2].clamp(0.0, 1.0);
  340. double y2 = det[3].clamp(0.0, 1.0);
  341. // If crop info is provided, map back to full frame
  342. if (cropSize != null && offsetX != null && offsetY != null && fullWidth != null && fullHeight != null) {
  343. x1 = (x1 * cropSize + offsetX) / fullWidth;
  344. x2 = (x2 * cropSize + offsetX) / fullWidth;
  345. y1 = (y1 * cropSize + offsetY) / fullHeight;
  346. y2 = (y2 * cropSize + offsetY) / fullHeight;
  347. }
  348. final classId = det[5].round();
  349. if (x2 <= x1 || y2 <= y1) continue;
  350. final label = (classId >= 0 && classId < labels.length) ? labels[classId] : 'Unknown';
  351. detections.add(DetectionResult(
  352. className: label,
  353. classIndex: classId,
  354. confidence: conf,
  355. normalizedBox: Rect.fromLTRB(x1, y1, x2, y2),
  356. ));
  357. }
  358. detections.sort((a, b) => b.confidence.compareTo(a.confidence));
  359. return detections;
  360. }
  361. void dispose() {
  362. _receivePort?.close();
  363. if (_isolate != null) {
  364. _isolate!.kill(priority: Isolate.immediate);
  365. _isolate = null;
  366. }
  367. _isInitialized = false;
  368. }
  369. }