| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- import 'dart:io';
- import 'dart:math';
- import 'dart:ui';
- import 'package:flutter/services.dart';
- import 'package:flutter/foundation.dart';
- import 'package:image/image.dart' as img;
- import 'package:image_picker/image_picker.dart';
- import 'package:tflite_flutter/tflite_flutter.dart';
- import 'package:camera/camera.dart';
- /// A detection result parsed from the model's end-to-end output.
- class DetectionResult {
- final String className;
- final int classIndex;
- final double confidence;
- /// Normalized bounding box (0.0 - 1.0)
- final Rect normalizedBox;
- const DetectionResult({
- required this.className,
- required this.classIndex,
- required this.confidence,
- required this.normalizedBox,
- });
- Color getStatusColor() {
- if (className == 'Empty_Bunch' || className == 'Abnormal') return const Color(0xFFF44336); // Colors.red
- if (className == 'Ripe' || className == 'Overripe') return const Color(0xFF4CAF50); // Colors.green
- return const Color(0xFFFF9800); // Colors.orange
- }
- }
- /// Custom TFLite inference service that correctly decodes the end-to-end
- /// YOLO model output format [1, N, 6] = [batch, detections, (x1,y1,x2,y2,conf,class_id)].
- class TfliteService {
- static const _modelAsset = 'best.tflite';
- static const _labelsAsset = 'labels.txt';
- static const int _inputSize = 640;
- static const double _confidenceThreshold = 0.25;
- Interpreter? _interpreter;
- List<String> _labels = [];
- final ImagePicker _picker = ImagePicker();
- bool _isInitialized = false;
- bool get isInitialized => _isInitialized;
- Future<void> initModel() async {
- try {
- // Load labels
- final labelData = await rootBundle.loadString('assets/$_labelsAsset');
- _labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
- // Load model
- final interpreterOptions = InterpreterOptions()..threads = 4;
- _interpreter = await Interpreter.fromAsset(
- 'assets/$_modelAsset',
- options: interpreterOptions,
- );
- _isInitialized = true;
- print('TfliteService: Model loaded. Labels: $_labels');
- print('TfliteService: Input: ${_interpreter!.getInputTensors().map((t) => t.shape)}');
- print('TfliteService: Output: ${_interpreter!.getOutputTensors().map((t) => t.shape)}');
- } catch (e) {
- print('TfliteService init error: $e');
- rethrow;
- }
- }
- Future<XFile?> pickImage() async {
- return await _picker.pickImage(
- source: ImageSource.gallery,
- maxWidth: _inputSize.toDouble(),
- maxHeight: _inputSize.toDouble(),
- );
- }
- /// Run inference on the image at [imagePath].
- /// Returns a list of [DetectionResult] sorted by confidence descending.
- /// Offloaded to a background isolate to keep UI smooth.
- Future<List<DetectionResult>> runInference(String imagePath) async {
- if (!_isInitialized) await initModel();
- final imageBytes = await File(imagePath).readAsBytes();
-
- // We pass the raw bytes and asset paths to the isolate.
- // The isolate will handle decoding, resizing, and inference.
- return await _runInferenceInIsolate(imageBytes);
- }
- /// Run inference on a [CameraImage] from the stream.
- /// Throttled by the caller.
- Future<List<DetectionResult>> runInferenceOnStream(CameraImage image) async {
- if (!_isInitialized) await initModel();
- // We pass the CameraImage planes to the isolate for conversion and inference.
- return await compute(_inferenceStreamTaskWrapper, {
- 'planes': image.planes.map((p) => {
- 'bytes': p.bytes,
- 'bytesPerRow': p.bytesPerRow,
- 'bytesPerPixel': p.bytesPerPixel,
- }).toList(),
- 'width': image.width,
- 'height': image.height,
- 'format': image.format.group,
- 'modelBytes': (await rootBundle.load('assets/$_modelAsset')).buffer.asUint8List(),
- 'labelData': await rootBundle.loadString('assets/$_labelsAsset'),
- });
- }
- static List<DetectionResult> _inferenceStreamTaskWrapper(Map<String, dynamic> args) {
- // This is a simplified wrapper for stream inference in isolate
- final modelBytes = args['modelBytes'] as Uint8List;
- final labelData = args['labelData'] as String;
- final planes = args['planes'] as List<dynamic>;
- final width = args['width'] as int;
- final height = args['height'] as int;
-
- final interpreter = Interpreter.fromBuffer(modelBytes);
- final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
- try {
- // Manual YUV to RGB conversion if needed, or use image package if possible
- // For speed in stream, we might want a more optimized conversion.
- // But for now, let's use a basic one or the image package.
-
- img.Image? image;
- if (args['format'] == ImageFormatGroup.yuv420) {
- // Simple YUV420 to RGB (this is slow in Dart, but better in isolate)
- image = _convertYUV420ToImage(planes, width, height);
- } else if (args['format'] == ImageFormatGroup.bgra8888) {
- image = img.Image.fromBytes(
- width: width,
- height: height,
- bytes: planes[0]['bytes'].buffer,
- format: img.Format.uint8,
- numChannels: 4,
- order: img.ChannelOrder.bgra,
- );
- }
- if (image == null) return [];
- // Resize and Run
- final resized = img.copyResize(image, width: _inputSize, height: _inputSize);
-
- final inputTensor = List.generate(1, (_) =>
- List.generate(_inputSize, (y) =>
- List.generate(_inputSize, (x) {
- final pixel = resized.getPixel(x, y);
- return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
- })
- )
- );
- final outputShape = interpreter.getOutputTensors()[0].shape;
- final outputTensor = List.generate(1, (_) =>
- List.generate(outputShape[1], (_) =>
- List<double>.filled(outputShape[2], 0.0)
- )
- );
- interpreter.run(inputTensor, outputTensor);
- return _decodeDetections(outputTensor[0], labels);
- } finally {
- interpreter.close();
- }
- }
- static img.Image _convertYUV420ToImage(List<dynamic> planes, int width, int height) {
- final yPlane = planes[0];
- final uPlane = planes[1];
- final vPlane = planes[2];
- final yBytes = yPlane['bytes'] as Uint8List;
- final uBytes = uPlane['bytes'] as Uint8List;
- final vBytes = vPlane['bytes'] as Uint8List;
- final yRowStride = yPlane['bytesPerRow'] as int;
- final uvRowStride = uPlane['bytesPerRow'] as int;
- final uvPixelStride = uPlane['bytesPerPixel'] as int;
- final image = img.Image(width: width, height: height);
- for (int y = 0; y < height; y++) {
- for (int x = 0; x < width; x++) {
- final int uvIndex = (uvRowStride * (y / 2).floor()) + (uvPixelStride * (x / 2).floor());
- final int yIndex = (y * yRowStride) + x;
- final int yp = yBytes[yIndex];
- final int up = uBytes[uvIndex];
- final int vp = vBytes[uvIndex];
- // Standard YUV to RGB conversion
- int r = (yp + (1.370705 * (vp - 128))).toInt().clamp(0, 255);
- int g = (yp - (0.337633 * (up - 128)) - (0.698001 * (vp - 128))).toInt().clamp(0, 255);
- int b = (yp + (1.732446 * (up - 128))).toInt().clamp(0, 255);
- image.setPixelRgb(x, y, r, g, b);
- }
- }
- return image;
- }
- static List<DetectionResult> _decodeDetections(List<List<double>> rawDetections, List<String> labels) {
- final detections = <DetectionResult>[];
- for (final det in rawDetections) {
- if (det.length < 6) continue;
- final conf = det[4];
- if (conf < _confidenceThreshold) continue;
- final x1 = det[0].clamp(0.0, 1.0);
- final y1 = det[1].clamp(0.0, 1.0);
- final x2 = det[2].clamp(0.0, 1.0);
- final y2 = det[3].clamp(0.0, 1.0);
- final classId = det[5].round();
- if (x2 <= x1 || y2 <= y1) continue;
- final label = (classId >= 0 && classId < labels.length) ? labels[classId] : 'Unknown';
- detections.add(DetectionResult(
- className: label,
- classIndex: classId,
- confidence: conf,
- normalizedBox: Rect.fromLTRB(x1, y1, x2, y2),
- ));
- }
- detections.sort((a, b) => b.confidence.compareTo(a.confidence));
- return detections;
- }
- Future<List<DetectionResult>> _runInferenceInIsolate(Uint8List imageBytes) async {
- // We need the model and labels passed as data
- final modelData = await rootBundle.load('assets/$_modelAsset');
- final labelData = await rootBundle.loadString('assets/$_labelsAsset');
-
- // Use compute to run in a real isolate
- return await compute(_inferenceTaskWrapper, {
- 'imageBytes': imageBytes,
- 'modelBytes': modelData.buffer.asUint8List(),
- 'labelData': labelData,
- });
- }
- static List<DetectionResult> _inferenceTaskWrapper(Map<String, dynamic> args) {
- return _inferenceTask(
- args['imageBytes'] as Uint8List,
- args['modelBytes'] as Uint8List,
- args['labelData'] as String,
- );
- }
- /// The static task that runs in the background isolate
- static List<DetectionResult> _inferenceTask(Uint8List imageBytes, Uint8List modelBytes, String labelData) {
- // 1. Initialize Interpreter inside the isolate
- final interpreter = Interpreter.fromBuffer(modelBytes);
- final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList();
- try {
- // 2. Preprocess image
- final decoded = img.decodeImage(imageBytes);
- if (decoded == null) throw Exception('Could not decode image');
- final resized = img.copyResize(decoded, width: _inputSize, height: _inputSize, interpolation: img.Interpolation.linear);
- final inputTensor = List.generate(1, (_) =>
- List.generate(_inputSize, (y) =>
- List.generate(_inputSize, (x) {
- final pixel = resized.getPixel(x, y);
- return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0];
- })
- )
- );
- // 3. Prepare output
- final outputShape = interpreter.getOutputTensors()[0].shape;
- final numDetections = outputShape[1];
- final numFields = outputShape[2];
- final outputTensor = List.generate(1, (_) =>
- List.generate(numDetections, (_) =>
- List<double>.filled(numFields, 0.0)
- )
- );
- // 4. Run
- interpreter.run(inputTensor, outputTensor);
- return _decodeDetections(outputTensor[0], labels);
- } finally {
- interpreter.close();
- }
- }
- void dispose() {
- _interpreter?.close();
- _interpreter = null;
- _isInitialized = false;
- }
- }
|