271 lines
8.9 KiB
TypeScript
271 lines
8.9 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { resolveBackend, SessionHandlerType } from './backend';
|
|
import { ExecutionPlan } from './execution-plan';
|
|
import { Graph } from './graph';
|
|
import { Profiler } from './instrument';
|
|
import { Model } from './model';
|
|
import { Operator } from './operators';
|
|
import { Tensor } from './tensor';
|
|
|
|
export declare namespace Session {
|
|
export interface Config {
|
|
backendHint?: string;
|
|
profiler?: Profiler.Config;
|
|
}
|
|
|
|
export interface Context {
|
|
profiler: Readonly<Profiler>;
|
|
graphInputTypes?: Tensor.DataType[];
|
|
graphInputDims?: Array<readonly number[]>;
|
|
}
|
|
}
|
|
|
|
export class Session {
|
|
constructor(config: Session.Config = {}) {
|
|
this._initialized = false;
|
|
this.backendHint = config.backendHint;
|
|
this.profiler = Profiler.create(config.profiler);
|
|
this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] };
|
|
}
|
|
|
|
get inputNames(): readonly string[] {
|
|
return this._model.graph.getInputNames();
|
|
}
|
|
get outputNames(): readonly string[] {
|
|
return this._model.graph.getOutputNames();
|
|
}
|
|
|
|
startProfiling() {
|
|
this.profiler.start();
|
|
}
|
|
|
|
endProfiling() {
|
|
this.profiler.stop();
|
|
}
|
|
|
|
async loadModel(uri: string): Promise<void>;
|
|
async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise<void>;
|
|
async loadModel(buffer: Uint8Array): Promise<void>;
|
|
async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise<void> {
|
|
await this.profiler.event('session', 'Session.loadModel', async () => {
|
|
// resolve backend and session handler
|
|
const backend = await resolveBackend(this.backendHint);
|
|
this.sessionHandler = backend.createSessionHandler(this.context);
|
|
|
|
this._model = new Model();
|
|
if (typeof arg === 'string') {
|
|
const isOrtFormat = arg.endsWith('.ort');
|
|
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
|
|
// node
|
|
const { readFile } = require('node:fs/promises');
|
|
const buf = await readFile(arg);
|
|
this.initialize(buf, isOrtFormat);
|
|
} else {
|
|
// browser
|
|
const response = await fetch(arg);
|
|
const buf = await response.arrayBuffer();
|
|
this.initialize(new Uint8Array(buf), isOrtFormat);
|
|
}
|
|
} else if (!ArrayBuffer.isView(arg)) {
|
|
// load model from ArrayBuffer
|
|
const arr = new Uint8Array(arg, byteOffset || 0, length || arg.byteLength);
|
|
this.initialize(arr);
|
|
} else {
|
|
// load model from Uint8array
|
|
this.initialize(arg);
|
|
}
|
|
});
|
|
}
|
|
|
|
private initialize(modelProtoBlob: Uint8Array, isOrtFormat?: boolean): void {
|
|
if (this._initialized) {
|
|
throw new Error('already initialized');
|
|
}
|
|
|
|
this.profiler.event('session', 'Session.initialize', () => {
|
|
// load graph
|
|
const graphInitializer = this.sessionHandler.transformGraph
|
|
? (this.sessionHandler as Graph.Initializer)
|
|
: undefined;
|
|
this._model.load(modelProtoBlob, graphInitializer, isOrtFormat);
|
|
|
|
// graph is completely initialzied at this stage , let the interested handlers know
|
|
if (this.sessionHandler.onGraphInitialized) {
|
|
this.sessionHandler.onGraphInitialized(this._model.graph);
|
|
}
|
|
// initialize each operator in the graph
|
|
this.initializeOps(this._model.graph);
|
|
|
|
// instantiate an ExecutionPlan object to be used by the Session object
|
|
this._executionPlan = new ExecutionPlan(this._model.graph, this._ops, this.profiler);
|
|
});
|
|
|
|
this._initialized = true;
|
|
}
|
|
|
|
async run(inputs: Map<string, Tensor> | Tensor[]): Promise<Map<string, Tensor>> {
|
|
if (!this._initialized) {
|
|
throw new Error('session not initialized yet');
|
|
}
|
|
|
|
return this.profiler.event('session', 'Session.run', async () => {
|
|
const inputTensors = this.normalizeAndValidateInputs(inputs);
|
|
|
|
const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);
|
|
|
|
return this.createOutput(outputTensors);
|
|
});
|
|
}
|
|
|
|
private normalizeAndValidateInputs(inputs: Map<string, Tensor> | Tensor[]): Tensor[] {
|
|
const modelInputNames = this._model.graph.getInputNames();
|
|
|
|
// normalize inputs
|
|
// inputs: Tensor[]
|
|
if (Array.isArray(inputs)) {
|
|
if (inputs.length !== modelInputNames.length) {
|
|
throw new Error(`incorrect input array length: expected ${modelInputNames.length} but got ${inputs.length}`);
|
|
}
|
|
}
|
|
// convert map to array
|
|
// inputs: Map<string, Tensor>
|
|
else {
|
|
if (inputs.size !== modelInputNames.length) {
|
|
throw new Error(`incorrect input map size: expected ${modelInputNames.length} but got ${inputs.size}`);
|
|
}
|
|
|
|
const sortedInputs = new Array<Tensor>(inputs.size);
|
|
let sortedInputsIndex = 0;
|
|
for (let i = 0; i < modelInputNames.length; ++i) {
|
|
const tensor = inputs.get(modelInputNames[i]);
|
|
if (!tensor) {
|
|
throw new Error(`missing input tensor for: '${name}'`);
|
|
}
|
|
sortedInputs[sortedInputsIndex++] = tensor;
|
|
}
|
|
|
|
inputs = sortedInputs;
|
|
}
|
|
|
|
// validate dims requirements
|
|
// First session run - graph input data is not cached for the session
|
|
if (
|
|
!this.context.graphInputTypes ||
|
|
this.context.graphInputTypes.length === 0 ||
|
|
!this.context.graphInputDims ||
|
|
this.context.graphInputDims.length === 0
|
|
) {
|
|
const modelInputIndices = this._model.graph.getInputIndices();
|
|
const modelValues = this._model.graph.getValues();
|
|
|
|
const graphInputDims = new Array<readonly number[]>(modelInputIndices.length);
|
|
|
|
for (let i = 0; i < modelInputIndices.length; ++i) {
|
|
const graphInput = modelValues[modelInputIndices[i]];
|
|
graphInputDims[i] = graphInput.type!.shape.dims;
|
|
|
|
// cached for second and subsequent runs.
|
|
// Some parts of the framework works on the assumption that the graph and types and shapes are static
|
|
this.context.graphInputTypes!.push(graphInput.type!.tensorType);
|
|
this.context.graphInputDims!.push(inputs[i].dims);
|
|
}
|
|
|
|
this.validateInputTensorDims(graphInputDims, inputs, true);
|
|
}
|
|
|
|
// Second and subsequent session runs - graph input data is cached for the session
|
|
else {
|
|
this.validateInputTensorDims(this.context.graphInputDims, inputs, false);
|
|
}
|
|
|
|
// validate types requirement
|
|
this.validateInputTensorTypes(this.context.graphInputTypes!, inputs);
|
|
|
|
return inputs;
|
|
}
|
|
|
|
private validateInputTensorTypes(graphInputTypes: Tensor.DataType[], givenInputs: Tensor[]) {
|
|
for (let i = 0; i < givenInputs.length; i++) {
|
|
const expectedType = graphInputTypes[i];
|
|
const actualType = givenInputs[i].type;
|
|
if (expectedType !== actualType) {
|
|
throw new Error(`input tensor[${i}] check failed: expected type '${expectedType}' but got ${actualType}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
private validateInputTensorDims(
|
|
graphInputDims: Array<readonly number[]>,
|
|
givenInputs: Tensor[],
|
|
noneDimSupported: boolean,
|
|
) {
|
|
for (let i = 0; i < givenInputs.length; i++) {
|
|
const expectedDims = graphInputDims[i];
|
|
const actualDims = givenInputs[i].dims;
|
|
if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) {
|
|
throw new Error(
|
|
`input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join(
|
|
',',
|
|
)}]`,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
private compareTensorDims(
|
|
expectedDims: readonly number[],
|
|
actualDims: readonly number[],
|
|
noneDimSupported: boolean,
|
|
): boolean {
|
|
if (expectedDims.length !== actualDims.length) {
|
|
return false;
|
|
}
|
|
|
|
for (let i = 0; i < expectedDims.length; ++i) {
|
|
if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) {
|
|
// data shape mis-match AND not a 'None' dimension.
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
private createOutput(outputTensors: Tensor[]): Map<string, Tensor> {
|
|
const modelOutputNames = this._model.graph.getOutputNames();
|
|
if (outputTensors.length !== modelOutputNames.length) {
|
|
throw new Error('expected number of outputs do not match number of generated outputs');
|
|
}
|
|
|
|
const output = new Map<string, Tensor>();
|
|
for (let i = 0; i < modelOutputNames.length; ++i) {
|
|
output.set(modelOutputNames[i], outputTensors[i]);
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
private initializeOps(graph: Graph): void {
|
|
const nodes = graph.getNodes();
|
|
this._ops = new Array(nodes.length);
|
|
|
|
for (let i = 0; i < nodes.length; i++) {
|
|
this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph);
|
|
}
|
|
}
|
|
|
|
private _model: Model;
|
|
private _initialized: boolean;
|
|
|
|
private _ops: Operator[];
|
|
private _executionPlan: ExecutionPlan;
|
|
|
|
private backendHint?: string;
|
|
|
|
private sessionHandler: SessionHandlerType;
|
|
private context: Session.Context;
|
|
private profiler: Readonly<Profiler>;
|
|
}
|