|
|
@@ -1,104 +1,93 @@
|
|
|
import { Injectable } from '@angular/core';
|
|
|
import * as ort from 'onnxruntime-web';
|
|
|
-import * as tflite from '@tensorflow/tfjs-tflite/dist/tf-tflite.fesm.js';
|
|
|
+import * as tf from '@tensorflow/tfjs';
|
|
|
|
|
|
@Injectable({ providedIn: 'root' })
|
|
|
export class LocalInferenceService {
|
|
|
private onnxSession: ort.InferenceSession | null = null;
|
|
|
private tfliteModel: any | null = null;
|
|
|
|
|
|
- // Standardized Industrial Color Scheme
|
|
|
private readonly GRADE_COLORS: { [key: string]: string } = {
|
|
|
- 'Empty_Bunch': '#6C757D', // Gray
|
|
|
- 'Underripe': '#F9A825', // Amber
|
|
|
- 'Abnormal': '#DC3545', // Red
|
|
|
- 'Ripe': '#00A651', // Green
|
|
|
- 'Unripe': '#9E9D24', // Lime
|
|
|
- 'Overripe': '#5D4037' // Brown
|
|
|
+ 'Empty_Bunch': '#6C757D', 'Underripe': '#F9A825', 'Abnormal': '#DC3545',
|
|
|
+ 'Ripe': '#00A651', 'Unripe': '#9E9D24', 'Overripe': '#5D4037'
|
|
|
};
|
|
|
|
|
|
async loadModel(modelPath: string) {
|
|
|
+ await tf.ready(); // Ensure TFJS core is initialized
|
|
|
+
|
|
|
if (modelPath.endsWith('.onnx')) {
|
|
|
- // Explicitly set the WASM path for ONNX
|
|
|
ort.env.wasm.wasmPaths = '/assets/wasm/';
|
|
|
-
|
|
|
- this.onnxSession = await ort.InferenceSession.create(modelPath, {
|
|
|
- executionProviders: ['wasm'], // Start with WASM for stability
|
|
|
- graphOptimizationLevel: 'all'
|
|
|
- });
|
|
|
+ this.onnxSession = await ort.InferenceSession.create(modelPath, { executionProviders: ['wasm'] });
|
|
|
this.tfliteModel = null;
|
|
|
- console.log('ONNX Model loaded successfully');
|
|
|
+ console.log('ONNX Engine Ready');
|
|
|
} else {
|
|
|
- // CRITICAL: Set this BEFORE loading the model
|
|
|
- tflite.setWasmPath('/assets/tflite-wasm/');
|
|
|
- this.tfliteModel = await tflite.loadTFLiteModel(modelPath);
|
|
|
- this.onnxSession = null;
|
|
|
- console.log('TFLite Model loaded successfully');
|
|
|
+ // DYNAMIC IMPORT: This fixes the "setWasmPath undefined" error
|
|
|
+ const tflite = await import('@tensorflow/tfjs-tflite');
|
|
|
+
|
|
|
+ if (tflite && tflite.setWasmPath) {
|
|
|
+ tflite.setWasmPath('/assets/tflite-wasm/');
|
|
|
+ this.tfliteModel = await tflite.loadTFLiteModel(modelPath);
|
|
|
+ this.onnxSession = null;
|
|
|
+ console.log('TFLite Engine Ready');
|
|
|
+ } else {
|
|
|
+ throw new Error('TFLite module resolution failed');
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async runInference(input: Float32Array): Promise<any> {
|
|
|
- if (this.onnxSession) {
|
|
|
- // Create the tensor: [BatchSize, Channels, Height, Width]
|
|
|
- const tensor = new ort.Tensor('float32', input, [1, 3, 640, 640]);
|
|
|
-
|
|
|
- // Execute the model logic
|
|
|
- const feeds: any = {};
|
|
|
- feeds[this.onnxSession.inputNames[0]] = tensor;
|
|
|
- const output = await this.onnxSession.run(feeds);
|
|
|
-
|
|
|
- return output[this.onnxSession.outputNames[0]].data;
|
|
|
- } else if (this.tfliteModel) {
|
|
|
- // TFLite expects [1, 3, 640, 640] usually in this PoC export
|
|
|
- const output = await this.tfliteModel.predict(input);
|
|
|
- return output.data();
|
|
|
+ async runInference(input: Float32Array): Promise<Float32Array | null> {
|
|
|
+ try {
|
|
|
+ if (this.onnxSession) {
|
|
|
+ const tensor = new ort.Tensor('float32', input, [1, 3, 640, 640]);
|
|
|
+ const output = await this.onnxSession.run({ [this.onnxSession.inputNames[0]]: tensor });
|
|
|
+ return output[this.onnxSession.outputNames[0]].data as Float32Array;
|
|
|
+ } else if (this.tfliteModel) {
|
|
|
+ // Create tensor and predict
|
|
|
+ const inputTensor = tf.tensor(input, [1, 3, 640, 640]);
|
|
|
+ const result = await this.tfliteModel.predict(inputTensor);
|
|
|
+ const data = await result.data();
|
|
|
+ tf.dispose([inputTensor, result]); // Cleanup memory
|
|
|
+ return data as Float32Array;
|
|
|
+ }
|
|
|
+ } catch (err) {
|
|
|
+ console.error('Inference Error:', err);
|
|
|
}
|
|
|
- throw new Error('No model loaded');
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
parseDetections(rawData: any, threshold: number, imgWidth: number, imgHeight: number): any[] {
|
|
|
if (!rawData) return [];
|
|
|
-
|
|
|
- // Ensure we are working with a Float32Array
|
|
|
const tensorData = rawData instanceof Float32Array ? rawData : new Float32Array(rawData);
|
|
|
const detections = [];
|
|
|
- const numBoxes = tensorData.length / 6; // Dynamically calculate based on data size
|
|
|
+
|
|
|
+ // Safety check for data length
|
|
|
+ const numBoxes = Math.floor(tensorData.length / 6);
|
|
|
|
|
|
for (let i = 0; i < numBoxes; i++) {
|
|
|
- const offset = i * 6;
|
|
|
- const confidence = tensorData[offset + 4];
|
|
|
-
|
|
|
- if (confidence >= threshold) {
|
|
|
- const classId = Math.round(tensorData[offset + 5]);
|
|
|
- const className = this.getClassName(classId);
|
|
|
+ const offset = i * 6;
|
|
|
+ const confidence = tensorData[offset + 4];
|
|
|
|
|
|
- detections.push({
|
|
|
- box: [
|
|
|
- tensorData[offset + 0] * imgWidth,
|
|
|
- tensorData[offset + 1] * imgHeight,
|
|
|
- tensorData[offset + 2] * imgWidth,
|
|
|
- tensorData[offset + 3] * imgHeight
|
|
|
- ],
|
|
|
- confidence: confidence,
|
|
|
- class: className,
|
|
|
- color: this.GRADE_COLORS[className] || '#000000',
|
|
|
- is_health_alert: className === 'Abnormal' || className === 'Empty_Bunch'
|
|
|
- });
|
|
|
- }
|
|
|
+ if (confidence >= threshold) {
|
|
|
+ const classId = Math.round(tensorData[offset + 5]);
|
|
|
+ const className = this.getClassName(classId);
|
|
|
+ detections.push({
|
|
|
+ box: [
|
|
|
+ tensorData[offset + 0] * imgWidth,
|
|
|
+ tensorData[offset + 1] * imgHeight,
|
|
|
+ tensorData[offset + 2] * imgWidth,
|
|
|
+ tensorData[offset + 3] * imgHeight
|
|
|
+ ],
|
|
|
+ confidence,
|
|
|
+ class: className,
|
|
|
+ color: this.GRADE_COLORS[className] || '#000000'
|
|
|
+ });
|
|
|
+ }
|
|
|
}
|
|
|
return detections;
|
|
|
}
|
|
|
|
|
|
private getClassName(id: number): string {
|
|
|
- // STRICT ORDER matching your data.yaml
|
|
|
- const classes = [
|
|
|
- 'Empty_Bunch', // 0
|
|
|
- 'Underripe', // 1
|
|
|
- 'Abnormal', // 2
|
|
|
- 'Ripe', // 3
|
|
|
- 'Unripe', // 4
|
|
|
- 'Overripe' // 5
|
|
|
- ];
|
|
|
+ const classes = ['Empty_Bunch', 'Underripe', 'Abnormal', 'Ripe', 'Unripe', 'Overripe'];
|
|
|
return classes[id] || 'Unknown';
|
|
|
}
|
|
|
}
|