scanner.provider.ts 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import { Injectable, OnModuleInit } from '@nestjs/common';
  2. import * as onnx from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. import * as path from 'path';
  5. import { MPOB_CLASSES, HEALTH_ALERT_CLASSES } from '../constants/mpob-standards';
  6. import { DetectionResult } from '../interfaces/palm-analysis.interface';
  7. @Injectable()
  8. export class ScannerProvider implements OnModuleInit {
  9. private session!: onnx.InferenceSession;
  10. private readonly modelPath = path.join(process.cwd(), 'best.onnx');
  11. async onModuleInit() {
  12. try {
  13. this.session = await onnx.InferenceSession.create(this.modelPath);
  14. console.log('✅ ONNX Inference Session initialized from:', this.modelPath);
  15. } catch (error) {
  16. console.error('❌ Failed to initialize ONNX Inference Session:', error);
  17. throw error;
  18. }
  19. }
  20. /**
  21. * Preprocesses the image buffer: resize to 640x640, transpose HWC to CHW, and normalize.
  22. */
  23. async preprocess(imageBuffer: Buffer): Promise<onnx.Tensor> {
  24. // Proper Sharp RGB extraction
  25. const resized = await sharp(imageBuffer)
  26. .resize(640, 640, { fit: 'fill' })
  27. .removeAlpha()
  28. .raw()
  29. .toBuffer({ resolveWithObject: true });
  30. const { width, height, channels } = resized.info;
  31. const pixels = resized.data; // Uint8Array [R, G, B, R, G, B...]
  32. const imageSize = width * height;
  33. const floatData = new Float32Array(3 * imageSize);
  34. // HWC to CHW Transposition
  35. // pixels: [R1, G1, B1, R2, G2, B2...]
  36. // floatData: [R1, R2, ..., G1, G2, ..., B1, B2, ...]
  37. for (let i = 0; i < imageSize; i++) {
  38. floatData[i] = pixels[i * 3] / 255.0; // R
  39. floatData[i + imageSize] = pixels[i * 3 + 1] / 255.0; // G
  40. floatData[i + 2 * imageSize] = pixels[i * 3 + 2] / 255.0; // B
  41. }
  42. return new onnx.Tensor('float32', floatData, [1, 3, 640, 640]);
  43. }
  44. /**
  45. * Executes the ONNX session with the preprocessed tensor.
  46. */
  47. async inference(tensor: onnx.Tensor): Promise<onnx.Tensor> {
  48. const inputs = { images: tensor };
  49. const outputs = await this.session.run(inputs);
  50. // The model typically returns the output under a generic name like 'output0' or 'outputs'
  51. // We'll take the first output key available
  52. const outputKey = Object.keys(outputs)[0];
  53. return outputs[outputKey];
  54. }
  55. /**
  56. * Post-processes the model output: filtering, scaling, and mapping to MPOB standards.
  57. */
  58. async postprocess(
  59. outputTensor: onnx.Tensor,
  60. originalWidth: number,
  61. originalHeight: number,
  62. threshold: number = 0.25,
  63. ): Promise<DetectionResult[]> {
  64. const data = outputTensor.data as Float32Array;
  65. // Expected shape: [1, 300, 6]
  66. // Each candidate: [x1, y1, x2, y2, confidence, class_index]
  67. const results: DetectionResult[] = [];
  68. const numCandidates = outputTensor.dims[1];
  69. for (let i = 0; i < numCandidates; i++) {
  70. const offset = i * 6;
  71. const x1 = data[offset];
  72. const y1 = data[offset + 1];
  73. const x2 = data[offset + 2];
  74. const y2 = data[offset + 3];
  75. const confidence = data[offset + 4];
  76. const classIndex = data[offset + 5];
  77. if (confidence >= threshold) {
  78. const className = MPOB_CLASSES[Math.round(classIndex)] || 'Unknown';
  79. results.push({
  80. bunch_id: results.length + 1,
  81. class: className,
  82. confidence: parseFloat(confidence.toFixed(4)),
  83. is_health_alert: HEALTH_ALERT_CLASSES.includes(className),
  84. // HEAVY LIFTING: Multiply ratio (0.0-1.0) by original pixels
  85. box: [
  86. data[offset] * originalWidth, // x1
  87. data[offset + 1] * originalHeight, // y1
  88. data[offset + 2] * originalWidth, // x2
  89. data[offset + 3] * originalHeight // y2
  90. ],
  91. });
  92. }
  93. }
  94. return results;
  95. }
  96. }