Files

420 lines
14 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="webnn/webnn.d.ts" />
import { Env, Tensor } from 'onnxruntime-common';
import { DataType, tensorDataTypeStringToEnum } from '../wasm-common';
import { getInstance } from '../wasm-factory';
import { createView } from './tensor-view';
import { TensorId, createTensorManager, convertDataToInt32 } from './webnn/tensor-manager';
import { configureLogger, LOG_DEBUG } from './log';
/*
* TensorProto::data_type to WebNN OperandType mapping.
*/
const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.float, 'float32'],
[DataType.float16, 'float16'],
[DataType.int32, 'int32'],
[DataType.uint32, 'uint32'],
[DataType.int64, 'int64'],
[DataType.uint64, 'uint64'],
[DataType.int4, 'int4'],
[DataType.uint4, 'uint4'],
[DataType.int8, 'int8'],
[DataType.uint8, 'uint8'],
[DataType.bool, 'uint8'],
]);
type MLContextEntry = {
gpuDevice?: GPUDevice;
options?: MLContextOptions;
mlContext: MLContext;
};
const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => {
if (a === b) {
return true;
}
if (a === undefined || b === undefined) {
return false;
}
const aKeys = Object.keys(a).sort() as Array<keyof typeof a>;
const bKeys = Object.keys(b).sort() as Array<keyof typeof b>;
return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]);
};
/**
* WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
* of the current MLContext being used by the sessions.
*/
export class WebNNBackend {
/**
* Tensor managers for each session.
*/
private tensorManager = createTensorManager(this);
/**
* Maps from session id to MLContexts.
*/
private mlContextBySessionId = new Map<number, MLContext>();
/**
* Maps from MLContext to session ids.
*/
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
/**
* Cache of MLContexts.
*/
private mlContextCache: MLContextEntry[] = [];
/**
* Current session id.
*/
private activeSessionId?: number;
/**
* Maps from session id to list of graph inputs.
*/
private sessionGraphInputs: Map<number, string[]> = new Map();
/**
* Maps from session id to list of graph outputs.
*/
private sessionGraphOutputs: Map<number, string[]> = new Map();
/**
* Temporary graph inputs for the current session.
* These inputs will be registered when the session is created.
*/
private temporaryGraphInputs: string[] = [];
/**
* Temporary graph outputs for the current session.
* These outputs will be registered when the session is created.
*/
private temporaryGraphOutputs: string[] = [];
/**
* Temporary tensors for the current session.
*/
private temporarySessionTensorIds: Map<number, TensorId[]> = new Map();
constructor(env: Env) {
configureLogger(env.logLevel!, !!env.debug);
}
public get currentSessionId(): number {
if (this.activeSessionId === undefined) {
throw new Error('No active session');
}
return this.activeSessionId;
}
public onRunStart(sessionId: number): void {
LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`);
this.activeSessionId = sessionId;
}
public onRunEnd(sessionId: number): void {
LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`);
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
if (!tensorIds) {
return;
}
for (const tensorId of tensorIds) {
LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`);
this.tensorManager.releaseTensorId(tensorId);
}
this.temporarySessionTensorIds.delete(sessionId);
this.activeSessionId = undefined;
}
public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
if (optionsOrDevice instanceof GPUDevice) {
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext });
return mlContext;
}
} else if (optionsOrDevice === undefined) {
const mlContextIndex = this.mlContextCache.findIndex(
(entry) => entry.options === undefined && entry.gpuDevice === undefined,
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext();
this.mlContextCache.push({ mlContext });
return mlContext;
}
}
const mlContextIndex = this.mlContextCache.findIndex((entry) =>
compareMLContextOptions(entry.options, optionsOrDevice),
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ options: optionsOrDevice, mlContext });
return mlContext;
}
}
public registerMLContext(sessionId: number, mlContext: MLContext): void {
this.mlContextBySessionId.set(sessionId, mlContext);
let sessionIds = this.sessionIdsByMLContext.get(mlContext);
if (!sessionIds) {
sessionIds = new Set();
this.sessionIdsByMLContext.set(mlContext, sessionIds);
}
sessionIds.add(sessionId);
if (this.temporaryGraphInputs.length > 0) {
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
this.temporaryGraphInputs = [];
}
if (this.temporaryGraphOutputs.length > 0) {
this.sessionGraphOutputs.set(sessionId, this.temporaryGraphOutputs);
this.temporaryGraphOutputs = [];
}
}
public onReleaseSession(sessionId: number): void {
this.sessionGraphInputs.delete(sessionId);
this.sessionGraphOutputs.delete(sessionId);
const mlContext = this.mlContextBySessionId.get(sessionId)!;
if (!mlContext) {
// Current session is not a WebNN session.
return;
}
this.tensorManager.releaseTensorsForSession(sessionId);
this.mlContextBySessionId.delete(sessionId);
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
this.sessionIdsByMLContext.delete(mlContext);
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext);
if (mlContextIndex !== -1) {
this.mlContextCache.splice(mlContextIndex, 1);
}
}
}
public getMLContext(sessionId: number): MLContext | undefined {
return this.mlContextBySessionId.get(sessionId);
}
public reserveTensorId(): TensorId {
return this.tensorManager.reserveTensorId();
}
public releaseTensorId(tensorId: TensorId): void {
LOG_DEBUG('verbose', () => `[WebNN] releaseTensorId {tensorId: ${tensorId}}`);
this.tensorManager.releaseTensorId(tensorId);
}
public async ensureTensor(
sessionId: number | undefined,
tensorId: TensorId,
onnxDataType: DataType,
dimensions: number[],
copyOld: boolean,
): Promise<MLTensor> {
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
if (!webnnDataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}
return this.tensorManager.ensureTensor(
sessionId ?? this.currentSessionId,
tensorId,
webnnDataType,
dimensions,
copyOld,
);
}
public async createTemporaryTensor(
sessionId: number,
onnxDataType: DataType,
shape: readonly number[],
): Promise<TensorId> {
LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`);
const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
if (!dataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}
const tensorId = this.tensorManager.reserveTensorId();
await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false);
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
if (!tensorIds) {
this.temporarySessionTensorIds.set(sessionId, [tensorId]);
} else {
tensorIds.push(tensorId);
}
return tensorId;
}
public uploadTensor(tensorId: TensorId, data: Uint8Array): void {
const wasm = getInstance();
if (!wasm.shouldTransferToMLTensor) {
throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false');
}
LOG_DEBUG('verbose', () => `[WebNN] uploadTensor {tensorId: ${tensorId}, data: ${data.byteLength}}`);
this.tensorManager.upload(tensorId, data);
}
public async downloadTensor(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise<undefined> {
return this.tensorManager.download(tensorId, dstBuffer);
}
public createMLTensorDownloader(tensorId: TensorId, type: Tensor.MLTensorDataTypes): () => Promise<Tensor.DataType> {
return async () => {
const data = await this.tensorManager.download(tensorId);
return createView(data, type);
};
}
public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
if (!webnnDataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}
const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions);
LOG_DEBUG(
'verbose',
() =>
`[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${
dimensions
}} -> {tensorId: ${id}}`,
);
return id;
}
// Register a WebNN Constant operand from external data.
public registerMLConstant(
externalFilePath: string,
dataOffset: number,
dataLength: number,
builder: MLGraphBuilder,
desc: MLOperandDescriptor,
mountedFiles: Map<string, Uint8Array> | undefined,
shouldConvertInt64ToInt32 = false,
): MLOperand {
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (!mountedFiles) {
throw new Error('External mounted files are not available.');
}
let filePath = externalFilePath;
if (externalFilePath.startsWith('./')) {
filePath = externalFilePath.substring(2);
}
const fileData = mountedFiles.get(filePath);
if (!fileData) {
throw new Error(`File with name ${filePath} not found in preloaded files.`);
}
if (dataOffset + dataLength > fileData.byteLength) {
throw new Error('Out of bounds: data offset and length exceed the external file data size.');
}
const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
let bufferView: ArrayBufferView;
switch (desc.dataType) {
case 'float32':
bufferView = new Float32Array(buffer);
break;
case 'float16':
bufferView =
typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer);
break;
case 'int32':
bufferView = new Int32Array(buffer);
break;
case 'uint32':
bufferView = new Uint32Array(buffer);
break;
case 'int64':
if (shouldConvertInt64ToInt32) {
// Int64 is not supported by current context, use int32 instead.
const int32Buffer = convertDataToInt32(new Uint8Array(buffer), 'int64');
bufferView = new Int32Array(int32Buffer.buffer);
desc.dataType = 'int32';
} else {
bufferView = new BigInt64Array(buffer);
}
break;
case 'uint64':
bufferView = new BigUint64Array(buffer);
break;
case 'int8':
bufferView = new Int8Array(buffer);
break;
case 'int4':
case 'uint4':
case 'uint8':
bufferView = new Uint8Array(buffer);
break;
default:
throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
}
LOG_DEBUG(
'verbose',
() =>
`[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}} ${
shouldConvertInt64ToInt32 ? '(Note: it was int64 data type and registered to int32 as workaround)' : ''
}`,
);
return builder.constant(desc, bufferView);
}
public registerGraphInput(inputName: string): void {
this.temporaryGraphInputs.push(inputName);
}
public registerGraphOutput(outputName: string): void {
this.temporaryGraphOutputs.push(outputName);
}
public isGraphInput(sessionId: number, inputName: string): boolean {
const inputNames = this.sessionGraphInputs.get(sessionId);
if (!inputNames) {
return false;
}
return inputNames.includes(inputName);
}
public isGraphOutput(sessionId: number, outputName: string): boolean {
const outputNames = this.sessionGraphOutputs.get(sessionId);
if (!outputNames) {
return false;
}
return outputNames.includes(outputName);
}
public isGraphInputOutputTypeSupported(sessionId: number, type: Tensor.Type, isInput = true): boolean {
const context = this.mlContextBySessionId.get(sessionId);
const dataType = onnxDataTypeToWebnnDataType.get(tensorDataTypeStringToEnum(type));
if (typeof dataType === 'undefined') {
return false;
}
if (isInput) {
return !!context?.opSupportLimits().input.dataTypes.includes(dataType);
} else {
return !!context?.opSupportLimits().output.dataTypes.includes(dataType);
}
}
public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
}