onnx-wasm.provider.ts 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import { Injectable, OnModuleInit } from '@nestjs/common';
  2. import * as ort from 'onnxruntime-web';
  3. import { Jimp } from 'jimp';
  4. import * as path from 'path';
  5. import { IScannerProvider, InferenceTensor, ScanResult } from './scanner.interface';
  6. import { postprocessShared } from './onnx-native.provider';
  7. // Single-threaded WASM — safer on low-resource / ARM environments (Android/Termux)
  8. ort.env.wasm.numThreads = 1;
  9. @Injectable()
  10. export class OnnxWasmProvider implements IScannerProvider, OnModuleInit {
  11. private session!: ort.InferenceSession;
  12. private readonly modelPath = path.join(process.cwd(), 'best.onnx');
  13. async onModuleInit() {
  14. try {
  15. this.session = await ort.InferenceSession.create(this.modelPath, {
  16. executionProviders: ['wasm'],
  17. });
  18. console.log('✅ [onnx-wasm] Inference session initialized:', this.modelPath);
  19. } catch (error) {
  20. console.error('❌ [onnx-wasm] Failed to initialize:', error);
  21. throw error;
  22. }
  23. }
  24. async preprocess(imageBuffer: Buffer): Promise<InferenceTensor> {
  25. const img = await Jimp.read(imageBuffer);
  26. img.resize({ w: 640, h: 640 });
  27. const pixels = img.bitmap.data;
  28. const imageSize = 640 * 640;
  29. const floatData = new Float32Array(3 * imageSize);
  30. for (let i = 0; i < imageSize; i++) {
  31. floatData[i] = pixels[i * 4] / 255.0;
  32. floatData[i + imageSize] = pixels[i * 4 + 1] / 255.0;
  33. floatData[i + 2 * imageSize] = pixels[i * 4 + 2] / 255.0;
  34. }
  35. return { data: floatData, dims: [1, 3, 640, 640] };
  36. }
  37. async inference(tensor: InferenceTensor): Promise<InferenceTensor> {
  38. const ortTensor = new ort.Tensor('float32', tensor.data, [1, 3, 640, 640]);
  39. const outputs = await this.session.run({ images: ortTensor });
  40. const out = outputs[Object.keys(outputs)[0]];
  41. return { data: out.data as Float32Array, dims: out.dims };
  42. }
  43. async postprocess(
  44. tensor: InferenceTensor,
  45. originalWidth: number,
  46. originalHeight: number,
  47. threshold = 0.25,
  48. ): Promise<ScanResult> {
  49. return postprocessShared(tensor, originalWidth, originalHeight, threshold);
  50. }
  51. }