Prechádzať zdrojové kódy

feat: migrate inference engine to onnxruntime-web for ARM/Termux compatibility

Dr-Swopt 4 dní pred
rodič
commit
6e9a56ecb2

+ 0 - 1
package.json

@@ -30,7 +30,6 @@
     "class-validator": "^0.15.1",
     "find-process": "^2.1.1",
     "jimp": "^1.6.1",
-    "onnxruntime-node": "^1.24.3",
     "onnxruntime-web": "^1.25.1",
     "pidusage": "^4.0.1",
     "reflect-metadata": "^0.2.2",

+ 2 - 6
src/palm-oil/palm-oil.module.ts

@@ -6,13 +6,9 @@ import { VisionGateway } from './vision.gateway';
 import { History } from './entities/history.entity';
 import { SurveillanceModule } from '../surveillance/surveillance.module';
 import { SCANNER_TOKEN } from './providers/scanner.interface';
-import { OnnxNativeProvider } from './providers/onnx-native.provider';
 import { OnnxWasmProvider } from './providers/onnx-wasm.provider';
 
-const backend = process.env.INFERENCE_BACKEND ?? 'onnx-native';
-const ScannerImpl = backend === 'onnx-wasm' ? OnnxWasmProvider : OnnxNativeProvider;
-
-console.log(`🔧 Inference backend: ${backend} → ${ScannerImpl.name}`);
+console.log('🔧 Inference backend: onnx-wasm (Android/Termux)');
 
 @Module({
   imports: [TypeOrmModule.forFeature([History]), SurveillanceModule],
@@ -20,7 +16,7 @@ console.log(`🔧 Inference backend: ${backend} → ${ScannerImpl.name}`);
   providers: [
     PalmOilService,
     VisionGateway,
-    { provide: SCANNER_TOKEN, useClass: ScannerImpl },
+    { provide: SCANNER_TOKEN, useClass: OnnxWasmProvider },
   ],
   exports: [PalmOilService],
 })

+ 0 - 106
src/palm-oil/providers/onnx-native.provider.ts

@@ -1,106 +0,0 @@
-import { Injectable, OnModuleInit } from '@nestjs/common';
-import * as onnx from 'onnxruntime-node';
-import { Jimp } from 'jimp';
-import * as path from 'path';
-import { MPOB_CLASSES, HEALTH_ALERT_CLASSES } from '../constants/mpob-standards';
-import { DetectionResult } from '../interfaces/palm-analysis.interface';
-import { IScannerProvider, InferenceTensor, ScanResult } from './scanner.interface';
-
-@Injectable()
-export class OnnxNativeProvider implements IScannerProvider, OnModuleInit {
-  private session!: onnx.InferenceSession;
-  private readonly modelPath = path.join(process.cwd(), 'best.onnx');
-
-  async onModuleInit() {
-    try {
-      this.session = await onnx.InferenceSession.create(this.modelPath);
-      console.log('✅ [onnx-native] Inference session initialized:', this.modelPath);
-    } catch (error) {
-      console.error('❌ [onnx-native] Failed to initialize:', error);
-      throw error;
-    }
-  }
-
-  async preprocess(imageBuffer: Buffer): Promise<InferenceTensor> {
-    const img = await Jimp.read(imageBuffer);
-    img.resize({ w: 640, h: 640 });
-
-    const pixels = img.bitmap.data;
-    const imageSize = 640 * 640;
-    const floatData = new Float32Array(3 * imageSize);
-
-    for (let i = 0; i < imageSize; i++) {
-      floatData[i] = pixels[i * 4] / 255.0;
-      floatData[i + imageSize] = pixels[i * 4 + 1] / 255.0;
-      floatData[i + 2 * imageSize] = pixels[i * 4 + 2] / 255.0;
-    }
-
-    const tensor = new onnx.Tensor('float32', floatData, [1, 3, 640, 640]);
-    return { data: tensor.data as Float32Array, dims: tensor.dims };
-  }
-
-  async inference(tensor: InferenceTensor): Promise<InferenceTensor> {
-    const onnxTensor = new onnx.Tensor('float32', tensor.data, [1, 3, 640, 640]);
-    const outputs = await this.session.run({ images: onnxTensor });
-    const out = outputs[Object.keys(outputs)[0]];
-    return { data: out.data as Float32Array, dims: out.dims };
-  }
-
-  async postprocess(
-    tensor: InferenceTensor,
-    originalWidth: number,
-    originalHeight: number,
-    threshold = 0.25,
-  ): Promise<ScanResult> {
-    return postprocessShared(tensor, originalWidth, originalHeight, threshold);
-  }
-}
-
-/** Shared postprocess logic — identical between native and WASM providers. */
-export function postprocessShared(
-  outputTensor: InferenceTensor,
-  originalWidth: number,
-  originalHeight: number,
-  threshold: number,
-): ScanResult {
-  const data = outputTensor.data;
-
-  const sampleRows = Math.min(5, outputTensor.dims[1]);
-  const raw_tensor_sample: number[][] = [];
-  for (let i = 0; i < sampleRows; i++) {
-    const offset = i * 6;
-    raw_tensor_sample.push([
-      parseFloat(data[offset].toFixed(6)),
-      parseFloat(data[offset + 1].toFixed(6)),
-      parseFloat(data[offset + 2].toFixed(6)),
-      parseFloat(data[offset + 3].toFixed(6)),
-      parseFloat(data[offset + 4].toFixed(6)),
-      parseFloat(data[offset + 5].toFixed(6)),
-    ]);
-  }
-
-  const results: DetectionResult[] = [];
-  const numCandidates = outputTensor.dims[1];
-
-  for (let i = 0; i < numCandidates; i++) {
-    const offset = i * 6;
-    const confidence = data[offset + 4];
-    if (confidence < threshold) continue;
-
-    const className = MPOB_CLASSES[Math.round(data[offset + 5])] || 'Unknown';
-    results.push({
-      bunch_id: results.length + 1,
-      class: className,
-      confidence: parseFloat(confidence.toFixed(4)),
-      is_health_alert: HEALTH_ALERT_CLASSES.includes(className),
-      box: [
-        data[offset] * originalWidth,
-        data[offset + 1] * originalHeight,
-        data[offset + 2] * originalWidth,
-        data[offset + 3] * originalHeight,
-      ],
-    });
-  }
-
-  return { detections: results, raw_tensor_sample };
-}

+ 51 - 2
src/palm-oil/providers/onnx-wasm.provider.ts

@@ -3,9 +3,10 @@ import * as ort from 'onnxruntime-web';
 import { Jimp } from 'jimp';
 import * as path from 'path';
 import { IScannerProvider, InferenceTensor, ScanResult } from './scanner.interface';
-import { postprocessShared } from './onnx-native.provider';
+import { MPOB_CLASSES, HEALTH_ALERT_CLASSES } from '../constants/mpob-standards';
+import { DetectionResult } from '../interfaces/palm-analysis.interface';
 
-// Single-threaded WASM — safer on low-resource / ARM environments (Android/Termux)
+// Single-threaded WASM — ARM/Termux safe
 ort.env.wasm.numThreads = 1;
 
 @Injectable()
@@ -58,3 +59,51 @@ export class OnnxWasmProvider implements IScannerProvider, OnModuleInit {
     return postprocessShared(tensor, originalWidth, originalHeight, threshold);
   }
 }
+
+export function postprocessShared(
+  outputTensor: InferenceTensor,
+  originalWidth: number,
+  originalHeight: number,
+  threshold: number,
+): ScanResult {
+  const data = outputTensor.data;
+
+  const sampleRows = Math.min(5, outputTensor.dims[1]);
+  const raw_tensor_sample: number[][] = [];
+  for (let i = 0; i < sampleRows; i++) {
+    const offset = i * 6;
+    raw_tensor_sample.push([
+      parseFloat(data[offset].toFixed(6)),
+      parseFloat(data[offset + 1].toFixed(6)),
+      parseFloat(data[offset + 2].toFixed(6)),
+      parseFloat(data[offset + 3].toFixed(6)),
+      parseFloat(data[offset + 4].toFixed(6)),
+      parseFloat(data[offset + 5].toFixed(6)),
+    ]);
+  }
+
+  const results: DetectionResult[] = [];
+  const numCandidates = outputTensor.dims[1];
+
+  for (let i = 0; i < numCandidates; i++) {
+    const offset = i * 6;
+    const confidence = data[offset + 4];
+    if (confidence < threshold) continue;
+
+    const className = MPOB_CLASSES[Math.round(data[offset + 5])] || 'Unknown';
+    results.push({
+      bunch_id: results.length + 1,
+      class: className,
+      confidence: parseFloat(confidence.toFixed(4)),
+      is_health_alert: HEALTH_ALERT_CLASSES.includes(className),
+      box: [
+        data[offset] * originalWidth,
+        data[offset + 1] * originalHeight,
+        data[offset + 2] * originalWidth,
+        data[offset + 3] * originalHeight,
+      ],
+    });
+  }
+
+  return { detections: results, raw_tensor_sample };
+}