174 lines
5.3 KiB
TypeScript
174 lines
5.3 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { SessionHandler } from './backend';
|
|
import { Graph } from './graph';
|
|
import { Logger, Profiler } from './instrument';
|
|
import { Operator } from './operators';
|
|
import { Tensor } from './tensor';
|
|
|
|
class KernelOp {
|
|
constructor(
|
|
public op: Operator,
|
|
public node: Graph.Node,
|
|
) {}
|
|
}
|
|
|
|
export class ExecutionPlan {
|
|
constructor(
|
|
private graph: Graph,
|
|
ops: Operator[],
|
|
private profiler: Readonly<Profiler>,
|
|
) {
|
|
this.initialize(ops);
|
|
}
|
|
|
|
initialize(ops: Operator[]) {
|
|
this.profiler.event('session', 'ExecutionPlan.initialize', () => {
|
|
const graphNodes = this.graph.getNodes();
|
|
if (graphNodes.length !== ops.length) {
|
|
throw new Error('The size of nodes and OPs do not match.');
|
|
}
|
|
|
|
this._ops = ops.map((op, i) => new KernelOp(op, graphNodes[i]));
|
|
this.reset();
|
|
|
|
// look for starter node(s)
|
|
this._starter = [];
|
|
this._ops.forEach((op, i) => {
|
|
let resolved = true;
|
|
for (const input of op.node.inputs) {
|
|
if (
|
|
!this._values[input] && // not an initialized input
|
|
this.graph.getInputIndices().indexOf(input) === -1 // not model input
|
|
) {
|
|
resolved = false;
|
|
break;
|
|
}
|
|
}
|
|
if (resolved) {
|
|
this._starter.push(i);
|
|
}
|
|
});
|
|
});
|
|
}
|
|
|
|
reset() {
|
|
this._values = this.graph.getValues().map((i) => i.tensor);
|
|
}
|
|
|
|
async execute(sessionHandler: SessionHandler, modelInputs: Tensor[]): Promise<Tensor[]> {
|
|
return this.profiler.event('session', 'ExecutionPlan.execute', async () => {
|
|
// reset mediem result
|
|
this.reset();
|
|
|
|
// create inference handler
|
|
const inferenceHandler = sessionHandler.createInferenceHandler();
|
|
|
|
// populate inputs value
|
|
const graphInputs = this.graph.getInputIndices();
|
|
if (modelInputs.length !== graphInputs.length) {
|
|
throw new Error(
|
|
`number of input tensors don't match the number of inputs to the model: actual: ${
|
|
modelInputs.length
|
|
} expected: ${graphInputs.length}`,
|
|
);
|
|
}
|
|
|
|
modelInputs.forEach((input, i) => {
|
|
const index = graphInputs[i];
|
|
this._values[index] = input;
|
|
});
|
|
|
|
// prepare running sequence
|
|
const sequence: number[] = this._starter.slice(0);
|
|
|
|
// execution iterations
|
|
const graphValues = this.graph.getValues();
|
|
const graphNodes = this.graph.getNodes();
|
|
|
|
let rear = 0;
|
|
while (rear < sequence.length) {
|
|
const thisOpIndex = sequence[rear++];
|
|
const thisOp = this._ops[thisOpIndex];
|
|
|
|
// check input
|
|
const inputList = thisOp.node.inputs.map((i) => this._values[i]);
|
|
if (inputList.indexOf(undefined) !== -1) {
|
|
throw new Error(`unresolved input detected: op: ${thisOp.node}`);
|
|
}
|
|
|
|
// run
|
|
const inputTensors = inputList as Tensor[];
|
|
Logger.verbose(
|
|
'ExecPlan',
|
|
`Running op:${thisOp.node.name} (${inputTensors
|
|
.map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`)
|
|
.join(', ')})`,
|
|
);
|
|
|
|
const outputList = await this.profiler.event('node', thisOp.node.name, async () =>
|
|
thisOp.op.impl(inferenceHandler, inputTensors, thisOp.op.context),
|
|
);
|
|
|
|
// check output
|
|
if (outputList.length !== thisOp.node.outputs.length) {
|
|
throw new Error('the size of output does not match model definition.');
|
|
}
|
|
|
|
// fill value
|
|
outputList.forEach((output, i) => {
|
|
const j = thisOp.node.outputs[i];
|
|
if (this._values[j]) {
|
|
throw new Error(`output [${j}] already has value: op:${thisOp.node.name}`);
|
|
}
|
|
this._values[j] = output;
|
|
});
|
|
|
|
// resolve downstream nodes
|
|
const downstreamNodes = new Set<number>();
|
|
outputList.forEach((_output, i) => {
|
|
const j = thisOp.node.outputs[i];
|
|
for (const currentDownstreamNodeIndex of graphValues[j].to) {
|
|
const currentDownstreamNode = graphNodes[currentDownstreamNodeIndex];
|
|
let resolved = true;
|
|
for (const k of currentDownstreamNode.inputs) {
|
|
if (!this._values[k]) {
|
|
resolved = false;
|
|
break;
|
|
}
|
|
}
|
|
if (resolved) {
|
|
downstreamNodes.add(currentDownstreamNodeIndex);
|
|
}
|
|
}
|
|
});
|
|
sequence.push(...downstreamNodes);
|
|
}
|
|
|
|
const output: Tensor[] = [];
|
|
for (let i = 0; i < this.graph.getOutputIndices().length; i++) {
|
|
const outputIndex = this.graph.getOutputIndices()[i];
|
|
const outputTensor = this._values[outputIndex];
|
|
if (outputTensor === undefined) {
|
|
throw new Error(`required output [${outputIndex}] does not have value`);
|
|
}
|
|
if (outputIndex === 0) {
|
|
await outputTensor.getData();
|
|
} else {
|
|
// eslint-disable-next-line no-unused-expressions
|
|
outputTensor.data;
|
|
}
|
|
output.push(outputTensor);
|
|
}
|
|
Logger.verbose('ExecPlan', 'disposing of inferenceHandler');
|
|
inferenceHandler.dispose();
|
|
return output;
|
|
});
|
|
}
|
|
|
|
_values: Array<Tensor | undefined>;
|
|
_ops: KernelOp[];
|
|
_starter: number[];
|
|
}
|