|
|
@@ -0,0 +1,110 @@
|
|
|
+import { Injectable, OnModuleInit } from '@nestjs/common';
|
|
|
+import * as onnx from 'onnxruntime-node';
|
|
|
+import * as sharp from 'sharp';
|
|
|
+import * as path from 'path';
|
|
|
+import { MPOB_CLASSES, HEALTH_ALERT_CLASSES } from '../constants/mpob-standards';
|
|
|
+import { DetectionResult } from '../interfaces/palm-analysis.interface';
|
|
|
+
|
|
|
+@Injectable()
|
|
|
+export class ScannerProvider implements OnModuleInit {
|
|
|
+ private session: onnx.InferenceSession;
|
|
|
+ private readonly modelPath = path.join(process.cwd(), 'best.onnx');
|
|
|
+
|
|
|
+ async onModuleInit() {
|
|
|
+ try {
|
|
|
+ this.session = await onnx.InferenceSession.create(this.modelPath);
|
|
|
+ console.log('✅ ONNX Inference Session initialized from:', this.modelPath);
|
|
|
+ } catch (error) {
|
|
|
+ console.error('❌ Failed to initialize ONNX Inference Session:', error);
|
|
|
+ throw error;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Preprocesses the image buffer: resize to 640x640, transpose HWC to CHW, and normalize.
|
|
|
+ */
|
|
|
+ async preprocess(imageBuffer: Buffer): Promise<onnx.Tensor> {
|
|
|
+ // Proper Sharp RGB extraction
|
|
|
+ const resized = await sharp(imageBuffer)
|
|
|
+ .resize(640, 640, { fit: 'fill' })
|
|
|
+ .removeAlpha()
|
|
|
+ .raw()
|
|
|
+ .toBuffer({ resolveWithObject: true });
|
|
|
+
|
|
|
+ const { width, height, channels } = resized.info;
|
|
|
+ const pixels = resized.data; // Uint8Array [R, G, B, R, G, B...]
|
|
|
+
|
|
|
+ const imageSize = width * height;
|
|
|
+ const floatData = new Float32Array(3 * imageSize);
|
|
|
+
|
|
|
+ // HWC to CHW Transposition
|
|
|
+ // pixels: [R1, G1, B1, R2, G2, B2...]
|
|
|
+ // floatData: [R1, R2, ..., G1, G2, ..., B1, B2, ...]
|
|
|
+ for (let i = 0; i < imageSize; i++) {
|
|
|
+ floatData[i] = pixels[i * 3] / 255.0; // R
|
|
|
+ floatData[i + imageSize] = pixels[i * 3 + 1] / 255.0; // G
|
|
|
+ floatData[i + 2 * imageSize] = pixels[i * 3 + 2] / 255.0; // B
|
|
|
+ }
|
|
|
+
|
|
|
+ return new onnx.Tensor('float32', floatData, [1, 3, 640, 640]);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Executes the ONNX session with the preprocessed tensor.
|
|
|
+ */
|
|
|
+ async inference(tensor: onnx.Tensor): Promise<onnx.Tensor> {
|
|
|
+ const inputs = { images: tensor };
|
|
|
+ const outputs = await this.session.run(inputs);
|
|
|
+
|
|
|
+ // The model typically returns the output under a generic name like 'output0' or 'outputs'
|
|
|
+ // We'll take the first output key available
|
|
|
+ const outputKey = Object.keys(outputs)[0];
|
|
|
+ return outputs[outputKey];
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Post-processes the model output: filtering, scaling, and mapping to MPOB standards.
|
|
|
+ */
|
|
|
+ async postprocess(
|
|
|
+ outputTensor: onnx.Tensor,
|
|
|
+ originalWidth: number,
|
|
|
+ originalHeight: number,
|
|
|
+ threshold: number = 0.25,
|
|
|
+ ): Promise<DetectionResult[]> {
|
|
|
+ const data = outputTensor.data as Float32Array;
|
|
|
+ // Expected shape: [1, 300, 6]
|
|
|
+ // Each candidate: [x1, y1, x2, y2, confidence, class_index]
|
|
|
+
|
|
|
+ const results: DetectionResult[] = [];
|
|
|
+ const numCandidates = outputTensor.dims[1];
|
|
|
+
|
|
|
+ for (let i = 0; i < numCandidates; i++) {
|
|
|
+ const offset = i * 6;
|
|
|
+ const x1 = data[offset];
|
|
|
+ const y1 = data[offset + 1];
|
|
|
+ const x2 = data[offset + 2];
|
|
|
+ const y2 = data[offset + 3];
|
|
|
+ const confidence = data[offset + 4];
|
|
|
+ const classIndex = data[offset + 5];
|
|
|
+
|
|
|
+ if (confidence >= threshold) {
|
|
|
+ const className = MPOB_CLASSES[Math.round(classIndex)] || 'Unknown';
|
|
|
+ results.push({
|
|
|
+ bunch_id: results.length + 1,
|
|
|
+ class: className,
|
|
|
+ confidence: parseFloat(confidence.toFixed(4)),
|
|
|
+ is_health_alert: HEALTH_ALERT_CLASSES.includes(className),
|
|
|
+ // Normalize by dividing by 640 (the model input size)
|
|
|
+ box: [
|
|
|
+ x1 / 640,
|
|
|
+ y1 / 640,
|
|
|
+ x2 / 640,
|
|
|
+ y2 / 640
|
|
|
+ ],
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return results;
|
|
|
+ }
|
|
|
+}
|