import 'dart:io'; import 'dart:math'; import 'dart:ui'; import 'package:flutter/services.dart'; import 'package:flutter/foundation.dart'; import 'package:image/image.dart' as img; import 'package:image_picker/image_picker.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; /// A detection result parsed from the model's end-to-end output. class DetectionResult { final String className; final int classIndex; final double confidence; /// Normalized bounding box (0.0 - 1.0) final Rect normalizedBox; const DetectionResult({ required this.className, required this.classIndex, required this.confidence, required this.normalizedBox, }); } /// Custom TFLite inference service that correctly decodes the end-to-end /// YOLO model output format [1, N, 6] = [batch, detections, (x1,y1,x2,y2,conf,class_id)]. class TfliteService { static const _modelAsset = 'best.tflite'; static const _labelsAsset = 'labels.txt'; static const int _inputSize = 640; static const double _confidenceThreshold = 0.25; Interpreter? _interpreter; List _labels = []; final ImagePicker _picker = ImagePicker(); bool _isInitialized = false; bool get isInitialized => _isInitialized; Future initModel() async { try { // Load labels final labelData = await rootBundle.loadString('assets/$_labelsAsset'); _labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList(); // Load model final interpreterOptions = InterpreterOptions()..threads = 4; _interpreter = await Interpreter.fromAsset( 'assets/$_modelAsset', options: interpreterOptions, ); _isInitialized = true; print('TfliteService: Model loaded. Labels: $_labels'); print('TfliteService: Input: ${_interpreter!.getInputTensors().map((t) => t.shape)}'); print('TfliteService: Output: ${_interpreter!.getOutputTensors().map((t) => t.shape)}'); } catch (e) { print('TfliteService init error: $e'); rethrow; } } Future pickImage() async { return await _picker.pickImage( source: ImageSource.gallery, maxWidth: _inputSize.toDouble(), maxHeight: _inputSize.toDouble(), ); } /// Run inference on the image at [imagePath]. /// Returns a list of [DetectionResult] sorted by confidence descending. /// Offloaded to a background isolate to keep UI smooth. Future> runInference(String imagePath) async { if (!_isInitialized) await initModel(); final imageBytes = await File(imagePath).readAsBytes(); // We pass the raw bytes and asset paths to the isolate. // The isolate will handle decoding, resizing, and inference. return await _runInferenceInIsolate(imageBytes); } Future> _runInferenceInIsolate(Uint8List imageBytes) async { // We need the model and labels passed as data final modelData = await rootBundle.load('assets/$_modelAsset'); final labelData = await rootBundle.loadString('assets/$_labelsAsset'); // Use compute to run in a real isolate return await compute(_inferenceTaskWrapper, { 'imageBytes': imageBytes, 'modelBytes': modelData.buffer.asUint8List(), 'labelData': labelData, }); } static List _inferenceTaskWrapper(Map args) { return _inferenceTask( args['imageBytes'] as Uint8List, args['modelBytes'] as Uint8List, args['labelData'] as String, ); } /// The static task that runs in the background isolate static List _inferenceTask(Uint8List imageBytes, Uint8List modelBytes, String labelData) { // 1. Initialize Interpreter inside the isolate final interpreter = Interpreter.fromBuffer(modelBytes); final labels = labelData.split('\n').where((l) => l.trim().isNotEmpty).map((l) => l.trim()).toList(); try { // 2. Preprocess image final decoded = img.decodeImage(imageBytes); if (decoded == null) throw Exception('Could not decode image'); final resized = img.copyResize(decoded, width: _inputSize, height: _inputSize, interpolation: img.Interpolation.linear); final inputTensor = List.generate(1, (_) => List.generate(_inputSize, (y) => List.generate(_inputSize, (x) { final pixel = resized.getPixel(x, y); return [pixel.r / 255.0, pixel.g / 255.0, pixel.b / 255.0]; }) ) ); // 3. Prepare output final outputShape = interpreter.getOutputTensors()[0].shape; final numDetections = outputShape[1]; final numFields = outputShape[2]; final outputTensor = List.generate(1, (_) => List.generate(numDetections, (_) => List.filled(numFields, 0.0) ) ); // 4. Run interpreter.run(inputTensor, outputTensor); // 5. Decode final detections = []; final rawDetections = outputTensor[0]; for (final det in rawDetections) { if (det.length < 6) continue; final conf = det[4]; if (conf < _confidenceThreshold) continue; final x1 = det[0].clamp(0.0, 1.0); final y1 = det[1].clamp(0.0, 1.0); final x2 = det[2].clamp(0.0, 1.0); final y2 = det[3].clamp(0.0, 1.0); final classId = det[5].round(); if (x2 <= x1 || y2 <= y1) continue; final label = (classId >= 0 && classId < labels.length) ? labels[classId] : 'Unknown'; detections.add(DetectionResult( className: label, classIndex: classId, confidence: conf, normalizedBox: Rect.fromLTRB(x1, y1, x2, y2), )); } detections.sort((a, b) => b.confidence.compareTo(a.confidence)); return detections; } finally { interpreter.close(); } } void dispose() { _interpreter?.close(); _interpreter = null; _isInitialized = false; } }