Files
voice_recognition/whispervad/node_modules/onnxruntime-web/lib/onnxjs/graph.ts

812 lines
27 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Attribute } from './attribute';
import * as ortFbs from './ort-schema/flatbuffers/ort-generated';
import { onnx } from './ort-schema/protobuf/onnx';
import { Tensor } from './tensor';
import { LongUtil, MAX_CLIP, MIN_CLIP, ProtoUtil } from './util';
export declare namespace Graph {
export interface Shape {
readonly dims: readonly number[];
}
export interface ValueType {
readonly tensorType: Tensor.DataType;
readonly shape: Shape;
}
export interface Value {
// the tensor data. empty for non-initialized inputs
readonly tensor?: Tensor;
// index to the Node where the value comes from. -1 for initializer.
readonly from: number;
// indices to the Nodes where the values go to.
readonly to: readonly number[];
// value type specification. empty for non-input values.
readonly type?: ValueType;
}
export interface Node {
// name of the node
readonly name: string;
// the operator type
readonly opType: string;
// indices to the Values where the inputs come from.
readonly inputs: readonly number[];
// indices to the Values where the outpus go to.
readonly outputs: readonly number[];
// the attributes that used by the operator
readonly attributes: Attribute;
}
/**
* a Transformer is an instance that allows all possible transformation operations that applied to a graph
*/
export interface Transformer {
removeAllIdentityNodes(): void;
removeAllDropoutNodes(): void;
fuseConvActivationNodes(): void;
// TODO: add generic functions to manipulate the graph
}
// an initializer can use transformer to transform the graph
export interface Initializer {
transformGraph(transformer: Transformer): void;
}
}
// eslint-disable-next-line @typescript-eslint/no-redeclare
export interface Graph {
getInputIndices(): readonly number[];
getInputNames(): readonly string[];
getOutputIndices(): readonly number[];
getOutputNames(): readonly string[];
getValues(): readonly Graph.Value[];
getNodes(): readonly Graph.Node[];
}
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare
export const Graph = {
/**
* construct a graph from a graph protobuf type
*/
from: (graphProto: onnx.IGraphProto | ortFbs.Graph, initializer?: Graph.Initializer) =>
new GraphImpl(graphProto, initializer),
};
class Value implements Graph.Value {
constructor(valueInfo?: onnx.IValueInfoProto) {
this._from = undefined;
this._to = [];
this.tensor = undefined;
this.type = undefined;
if (valueInfo) {
this.type = ProtoUtil.tensorValueTypeFromProto(valueInfo.type!.tensorType!);
}
}
_from?: number; // -1 represent from initializer
get from() {
return this._from!;
}
_to: number[];
get to() {
return this._to;
}
type?: Graph.ValueType;
tensor?: Tensor;
}
class Node implements Graph.Node {
constructor(_nodeProto: onnx.INodeProto | ortFbs.Node, name?: string) {
if (_nodeProto instanceof onnx.NodeProto) {
this.name = _nodeProto.name;
this.opType = _nodeProto.opType;
this.attributes = new Attribute(_nodeProto.attribute);
} else if (_nodeProto instanceof ortFbs.Node) {
this.name = name ?? _nodeProto.name()!;
this.opType = _nodeProto.opType()!;
this.attributes = new Attribute(ProtoUtil.tensorAttributesFromORTFormat(_nodeProto));
}
this.inputs = [];
this.outputs = [];
this.executeNode = true;
}
name: string;
opType: string;
inputs: number[];
outputs: number[];
attributes: Attribute;
executeNode: boolean;
}
class GraphImpl implements Graph, Graph.Transformer {
private _allData: Value[];
private _allInputIndices: number[];
private _allInputNames: string[];
private _allOutputIndices: number[];
private _allOutputNames: string[];
private _nodes: Node[];
constructor(graph: onnx.IGraphProto | ortFbs.Graph, graphInitializer?: Graph.Initializer) {
if (!graph) {
throw new TypeError('graph is empty');
}
// build the graph - will throw exceptions if something fatal is detected
this.buildGraph(graph);
// execute any transformation logic for the graph (if applicable)
this.transformGraph(graphInitializer);
// check for cycles and other inconsistencies - will throw exceptions if something fatal is detected
this.checkIsAcyclic();
}
getInputIndices(): readonly number[] {
return this._allInputIndices;
}
getInputNames(): readonly string[] {
return this._allInputNames;
}
getOutputIndices(): readonly number[] {
return this._allOutputIndices;
}
getOutputNames(): readonly string[] {
return this._allOutputNames;
}
getValues(): readonly Graph.Value[] {
return this._allData;
}
getNodes(): readonly Graph.Node[] {
return this._nodes;
}
private buildGraph(graph: onnx.IGraphProto | ortFbs.Graph) {
// build the graph - will throw exceptions if something fatal is detected
if (graph instanceof onnx.GraphProto) {
this.buildGraphFromOnnxFormat(graph);
} else if (graph instanceof ortFbs.Graph) {
this.buildGraphFromOrtFormat(graph);
} else {
throw new TypeError('Graph type is not supported.');
}
}
private buildGraphFromOnnxFormat(graph: onnx.IGraphProto) {
const dataIndices = new Map<string, number>();
this._allData = [];
this._allInputIndices = [];
this._allInputNames = [];
this._allOutputIndices = [];
this._allOutputNames = [];
this._nodes = [];
const nodesIndices = new Map<string, number>();
// scan all inputs
if (!graph.input) {
throw new Error('missing information in graph: input');
}
const inputValueNames = [];
for (const i of graph.input) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated input name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
inputValueNames.push(i.name!);
}
// scan all initializers
if (!graph.initializer) {
throw new Error('missing information in graph: initializer');
}
for (const i of graph.initializer) {
let index = dataIndices.get(i.name!);
if (index === undefined) {
const value = new Value();
value.type = {
shape: { dims: ProtoUtil.tensorDimsFromProto(i.dims!) },
tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!),
};
index = this._allData.push(value) - 1;
dataIndices.set(i.name!, index);
}
this._allData[index]._from = -1;
this._allData[index].tensor = Tensor.fromProto(i);
}
// filter out input indices
for (let i = 0; i < this._allData.length; i++) {
if (!this._allData[i].tensor) {
this._allInputIndices.push(i);
this._allInputNames.push(inputValueNames[i]);
}
}
// scan all outputs
if (!graph.output) {
throw new Error('missing information in graph: output');
}
for (const i of graph.output) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated output name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
this._allOutputIndices.push(currentIndex);
this._allOutputNames.push(i.name!);
}
// scan all nodes
if (!graph.node) {
throw new Error('missing information in graph: node');
}
for (const nodeProto of graph.node) {
if (!nodeProto.name) {
// assign a name to the node if it doesn't have one
for (let pick = 0; ; pick++) {
const name = `unnamed_${nodeProto.opType}_${pick}`;
if (!nodesIndices.has(name)) {
nodeProto.name = name;
break;
}
}
}
if (nodesIndices.has(nodeProto.name)) {
throw new Error(`duplicated node name: ${nodeProto.name}`);
}
const currentIndex = this._nodes.push(new Node(nodeProto)) - 1;
nodesIndices.set(nodeProto.name, currentIndex);
}
// scan node's outputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.output) {
throw new Error(`missing output for node: ${nodeProto.name}`);
}
for (const output of nodeProto.output) {
let dataIndex = dataIndices.get(output);
if (typeof dataIndex === 'undefined') {
dataIndex = this._allData.push(new Value()) - 1;
dataIndices.set(output, dataIndex);
}
node.outputs.push(dataIndex);
if (this._allData[dataIndex]._from !== undefined) {
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
}
this._allData[dataIndex]._from = i;
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
// operator and ignore the node from the graph
if (nodeProto.opType === 'Constant') {
if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) {
throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
}
if (!nodeProto.output || nodeProto.output.length !== 1) {
throw new Error('missing output or incorrect number of outputs for this Constant operator');
}
node.outputs.pop();
node.executeNode = false;
this._allData[dataIndex]._from = -1;
this._allData[dataIndex].tensor = Tensor.fromProto(nodeProto.attribute[0].t);
}
}
}
// scan node's inputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.input) {
throw new Error(`missing input for node: ${nodeProto.name}`);
}
for (const input of nodeProto.input) {
const dataIndex = dataIndices.get(input);
if (typeof dataIndex === 'undefined') {
// handle exception when opset > 9 and roi / scales not given
if (
input === '' &&
(nodeProto.input.length === 3 || nodeProto.input.length === 4) &&
nodeProto.opType === 'Resize'
) {
continue;
}
throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`);
}
node.inputs.push(dataIndex);
this._allData[dataIndex]._to.push(i);
}
}
return true;
}
private buildGraphFromOrtFormat(graph: ortFbs.Graph) {
const dataIndices = new Map<string, number>();
this._allData = [];
this._allInputIndices = [];
this._allInputNames = [];
this._allOutputIndices = [];
this._allOutputNames = [];
this._nodes = [];
const nodesIndices = new Map<string, number>();
// scan all inputs
const inputValueNames = [];
for (let i = 0; i < graph.inputsLength(); i++) {
const inputName = graph.inputs(i);
if (dataIndices.has(inputName)) {
throw new Error(`duplicated input name: ${inputName}`);
}
// Find the input typeInfo from nodeargs
for (let j = 0; j < graph.nodeArgsLength(); j++) {
if (graph.nodeArgs(j)?.name() === inputName) {
const value = new Value();
const valueType = graph.nodeArgs(j)?.type()?.valueType();
if (valueType !== ortFbs.TypeInfoValue.tensor_type) {
throw new Error('Unexpected value type for the nodeArg.');
}
const valueInfo = graph.nodeArgs(j)!.type()!.value(new ortFbs.TensorTypeAndShape())!;
const type = ProtoUtil.tensorDataTypeFromProto(valueInfo.elemType());
const shape = valueInfo.shape()!;
const dims = [];
for (let k = 0; k < shape.dimLength()!; k++) {
dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!));
}
value.type = { shape: { dims }, tensorType: type };
const currentIndex = this._allData.push(value) - 1;
dataIndices.set(inputName, currentIndex);
inputValueNames.push(inputName);
}
}
}
// check initializers
for (let i = 0; i < graph.initializersLength(); i++) {
const initializer = graph.initializers(i)!;
let index = dataIndices.get(initializer.name()!);
if (index === undefined) {
const value = new Value();
const dims = ProtoUtil.tensorDimsFromORTFormat(initializer);
const type = ProtoUtil.tensorDataTypeFromProto(initializer.dataType());
value.type = { shape: { dims }, tensorType: type };
index = this._allData.push(value) - 1;
dataIndices.set(initializer.name()!, index);
}
this._allData[index]._from = -1;
this._allData[index].tensor = Tensor.fromOrtTensor(initializer);
}
// filter out input indices
for (let i = 0; i < this._allData.length; i++) {
if (!this._allData[i].tensor) {
this._allInputIndices.push(i);
this._allInputNames.push(inputValueNames[i]);
}
}
// scan all outputs
for (let i = 0; i < graph.outputsLength(); i++) {
const outputName = graph.outputs(i);
if (dataIndices.has(outputName)) {
throw new Error(`duplicated output name: ${outputName}`);
}
const currentIndex = this._allData.push(new Value()) - 1;
dataIndices.set(outputName, currentIndex);
this._allOutputIndices.push(currentIndex);
this._allOutputNames.push(outputName);
}
// scan all nodes
if (!graph.nodes) {
throw new Error('missing information in graph: node');
}
for (let i = 0; i < graph.nodesLength(); i++) {
const nodeProto = graph.nodes(i);
let name = nodeProto!.name();
if (!name) {
// assign a name to the node if it doesn't have one
for (let pick = 0; ; pick++) {
name = `unnamed_${nodeProto!.opType()}_${pick}`;
if (!nodesIndices.has(name)) {
// an unique name is found. break.
break;
}
}
}
if (nodesIndices.has(name)) {
throw new Error(`duplicated node name: ${name}`);
}
const currentIndex = this._nodes.push(new Node(nodeProto!, name)) - 1;
nodesIndices.set(name, currentIndex);
}
// scan node's outputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.nodes(i);
if (nodeProto == null) {
throw new Error(`No node exists at index ${i}`);
}
if (nodeProto?.outputsLength() === 0) {
throw new Error(`missing output for node: ${nodeProto.name}`);
}
for (let j = 0; j < nodeProto?.outputsLength(); j++) {
const output = nodeProto?.outputs(j);
let dataIndex = dataIndices.get(output);
if (typeof dataIndex === 'undefined') {
dataIndex = this._allData.push(new Value()) - 1;
dataIndices.set(output, dataIndex);
}
node.outputs.push(dataIndex);
if (this._allData[dataIndex]._from !== undefined) {
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
}
this._allData[dataIndex]._from = i;
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
// operator and ignore the node from the graph
if (nodeProto.opType() === 'Constant') {
if (nodeProto.attributesLength() !== 1 || !nodeProto.attributes(0)!.t()) {
throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
}
if (nodeProto.outputsLength() !== 1) {
throw new Error('missing output or incorrect number of outputs for this Constant operator');
}
node.outputs.pop();
node.executeNode = false;
this._allData[dataIndex]._from = -1;
this._allData[dataIndex].tensor = Tensor.fromOrtTensor(nodeProto.attributes(0)!.t()!);
}
}
}
// scan node's inputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.nodes(i)!;
if (nodeProto.inputsLength() === 0) {
throw new Error(`missing input for node: ${nodeProto.name}`);
}
for (let j = 0; j < nodeProto.inputsLength()!; j++) {
const input = nodeProto.inputs(j)!;
const dataIndex = dataIndices.get(input);
if (typeof dataIndex === 'undefined') {
throw new Error(`unrecognized input '${input}' for node: ${nodeProto!.name()}`);
}
node.inputs.push(dataIndex);
this._allData[dataIndex]._to.push(i);
}
}
}
private checkIsAcyclic() {
// go through the graph and check for cycles or other fatal inconsistencies
const starters: Set<number> = new Set<number>();
this._allInputIndices.forEach((i) => {
const data = this._allData[i];
data._to.forEach((j) => {
starters.add(j);
});
});
// Iterative DFS to check for cycles
const nodesStack = Array.from(starters);
const nodesState = new Array<string>(this._nodes.length).fill('white');
while (nodesStack.length > 0) {
const nodeIndex = nodesStack.pop()!;
// this node has now been processed completely. Mark this node 'black' to denote this.
if (nodesState[nodeIndex] === 'gray') {
nodesState[nodeIndex] = 'black';
} else {
// this node is under processing stage. mark this node 'gray' to denote this.
nodesStack.push(nodeIndex);
nodesState[nodeIndex] = 'gray';
this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => {
const data = this._allData[outgoingEdgeIndex];
if (typeof data.tensor !== 'undefined') {
throw new Error('node outputs should not be initialized');
}
if (data._from !== nodeIndex) {
throw new Error("from property of the Value object doesn't match index of Node being processed");
}
data._to.forEach((downstreamNodeIndex) => {
// back edge found - cyclic
if (nodesState[downstreamNodeIndex] === 'gray') {
throw new Error('model graph is cyclic');
}
// tree edge found - continue processing by adding it to stack
else if (nodesState[downstreamNodeIndex] === 'white') {
nodesStack.push(downstreamNodeIndex);
}
});
});
}
}
}
private transformGraph(graphInitializer?: Graph.Initializer): void {
// apply common transform
this.removeAllIdentityNodes();
this.removeAllDropoutNodes();
this.fuseConvActivationNodes();
// apply initializer specific transform
if (graphInitializer) {
graphInitializer.transformGraph(this);
}
// finalize graph
this.finalizeGraph();
}
/**
* finalize the graph.
*
* this function should be called after all the transformation completed.
* this function removes all unnecessary nodes and values from the graph
*/
finalizeGraph() {
let offset = 0;
// delete all nodes that are not being executed
// The graph is represented using these two arrays
// this._nodes - Array holding the kernels to execute - each entry is a kernel pointing to this._allData
// this._allData - hold 2 fields - to [] & from - these feileds hold the graph map for inputs and outputs per node
// newIndices - remapping the graph after reading the flag 'executeNode'
const newIndices = new Array<number>(this._nodes.length, 0);
let nodePossition = 0;
for (let i = 0; i < this._nodes.length; i++) {
// giving new indexes to the nodes based on execution flag
newIndices[i] = nodePossition;
if (this._nodes[i].executeNode) {
if (nodePossition !== i) {
this._nodes[nodePossition] = this._nodes[i];
}
nodePossition++;
} else {
// delete all output values
this._nodes[i].outputs.forEach((ind) => {
this._allData[ind]._from = -2;
});
}
}
// removing the unused nodes
this._nodes.splice(nodePossition, this._nodes.length - nodePossition);
// Updating this._allData according to the new this._nodes
for (let i = 0; i < this._allData.length; i++) {
const currentData = this._allData[i];
if (currentData._from !== undefined && currentData._from !== -1 && currentData._from !== -2) {
currentData._from = newIndices[currentData._from];
}
for (let j = 0; j < currentData._to.length; j++) {
if (currentData._to[j] >= 0) {
currentData._to[j] = newIndices[currentData._to[j]];
} else {
throw new Error('Trying to update a removed node');
}
}
}
offset = 0;
// delete all values that are not being referenced
for (let i = 0; i < this._allData.length; i++) {
// if current value is neither linked to next node, nor an output value, remove it.
if (this._allData[i].from === -2 && this._allOutputIndices.indexOf(i + offset) === -1) {
offset++;
this._allData.splice(i, 1);
i--;
continue;
}
if (offset > 0) {
let ind = -1;
// if current value is neither an input value nor an initializer, find the node it's
// coming from and update the corresponding node output
if (this._allData[i].from !== undefined && this._allData[i].from !== -1) {
ind = this._nodes[this._allData[i].from].outputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[this._allData[i].from].outputs[ind] = i;
}
} else {
// if current value is an input value, update its reference in inputIndices
ind = this._allInputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allInputIndices[ind] = i;
}
}
// find the node that the current value is linking to and update its input reference
this._allData[i].to.forEach((node) => {
ind = this._nodes[node].inputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[node].inputs[ind] = i;
}
});
if (this._allData[i].to.length === 0) {
// if current value is a graph output, update its reference in outputIndices
ind = this._allOutputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allOutputIndices[ind] = i;
}
}
}
}
}
/**
* Delete the specified node. Assume the node has one incoming input and the first output connected to other nodes.
* An input validation must be done before calling this function.
* @param nodeIndex The index of node to be deleted
*/
private deleteNode(nodeIndex: number) {
const node = this._nodes[nodeIndex];
if (node.outputs.length > 1) {
for (let i = 1; i < node.outputs.length; i++) {
if (this._allData[node.outputs[i]].to.length > 0) {
throw new Error('Node deletion with more than one output connected to other nodes is not supported. ');
}
}
}
// this node wil not be executed
node.executeNode = false;
const inputValueIndex = node.inputs[0];
const outputValueIndex = node.outputs[0];
const nodesConsumingOutput = this._allData[outputValueIndex].to;
// remove this node from the to property of the input Value
for (let i = 0; i < node.inputs.length; i++) {
const delIndex = this._allData[node.inputs[i]].to.indexOf(nodeIndex);
// should not happen
if (delIndex === -1) {
throw new Error("The Value object doesn't have the current Node in it's 'to' property ");
}
this._allData[node.inputs[i]].to.splice(delIndex, 1);
}
// clear node indices consuming this output Value
this._allData[outputValueIndex]._to = [];
// if the output of this node is a graph output, adjust the index appropriately
const index = this._allOutputIndices.indexOf(outputValueIndex);
if (index !== -1) {
this._allOutputIndices[index] = inputValueIndex;
}
// override the inputs for nodes consuming this node's output with the input to this node
if (nodesConsumingOutput && nodesConsumingOutput.length > 0) {
for (const nodeIndex of nodesConsumingOutput) {
const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex);
// should not happen
if (replaceIndex === -1) {
throw new Error("The Node object doesn't have the output Value in it's 'inputs' property ");
}
this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex;
this._allData[inputValueIndex].to.push(nodeIndex);
}
}
}
removeAllDropoutNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Dropout' nodes so that no time is wasted in execution
if (node.opType === 'Dropout') {
// the node should have exactly 1 input and 1 or 2 outputs
if (node.inputs.length !== 1) {
throw new Error('Dropout nodes should only contain one input. ');
}
if (node.outputs.length !== 1 && node.outputs.length !== 2) {
throw new Error('Dropout nodes should contain either 1 or 2 output(s)');
}
// the second output should not be referenced by any other node
if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) {
throw new Error("Dropout nodes's second output should not be referenced by other nodes");
}
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
removeAllIdentityNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Identity' nodes so that no time is wasted in execution
if (node.opType === 'Identity') {
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
case 'Clip':
return true;
default:
return false;
}
}
fuseConvActivationNodes() {
for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
const child = this._nodes[next[0]];
if (child.opType === 'Clip') {
if (child.inputs.length === 1) {
try {
node.attributes.set('activation_params', 'floats', [
child.attributes.getFloat('min'),
child.attributes.getFloat('max'),
]);
} catch (e) {
node.attributes.set('activation_params', 'floats', [MIN_CLIP, MAX_CLIP]);
}
} else if (
child.inputs.length >= 3 &&
this._allData[child.inputs[1]].tensor !== undefined &&
this._allData[child.inputs[2]].tensor !== undefined
) {
node.attributes.set('activation_params', 'floats', [
this._allData[child.inputs[1]].tensor!.floatData[0],
this._allData[child.inputs[2]].tensor!.floatData[0],
]);
} else {
// Skip fusion with clip node since clip min and clip max are not coming from initializer
continue;
}
}
node.attributes.set('activation', 'string', child.opType);
this.deleteNode(next[0]);
}
}
}
}
}