113 lines
3.6 KiB
JavaScript
113 lines
3.6 KiB
JavaScript
'use strict';
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
var __createBinding =
|
|
(this && this.__createBinding) ||
|
|
(Object.create
|
|
? function (o, m, k, k2) {
|
|
if (k2 === undefined) k2 = k;
|
|
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
if (!desc || ('get' in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
desc = {
|
|
enumerable: true,
|
|
get: function () {
|
|
return m[k];
|
|
},
|
|
};
|
|
}
|
|
Object.defineProperty(o, k2, desc);
|
|
}
|
|
: function (o, m, k, k2) {
|
|
if (k2 === undefined) k2 = k;
|
|
o[k2] = m[k];
|
|
});
|
|
var __setModuleDefault =
|
|
(this && this.__setModuleDefault) ||
|
|
(Object.create
|
|
? function (o, v) {
|
|
Object.defineProperty(o, 'default', { enumerable: true, value: v });
|
|
}
|
|
: function (o, v) {
|
|
o['default'] = v;
|
|
});
|
|
var __importStar =
|
|
(this && this.__importStar) ||
|
|
function (mod) {
|
|
if (mod && mod.__esModule) return mod;
|
|
var result = {};
|
|
if (mod != null)
|
|
for (var k in mod)
|
|
if (k !== 'default' && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
|
|
__setModuleDefault(result, mod);
|
|
return result;
|
|
};
|
|
Object.defineProperty(exports, '__esModule', { value: true });
|
|
exports.Model = void 0;
|
|
const flatbuffers = __importStar(require('flatbuffers'));
|
|
const graph_1 = require('./graph');
|
|
const ortFbs = __importStar(require('./ort-schema/flatbuffers/ort-generated'));
|
|
const onnx_1 = require('./ort-schema/protobuf/onnx');
|
|
const util_1 = require('./util');
|
|
class Model {
|
|
// empty model
|
|
constructor() {}
|
|
load(buf, graphInitializer, isOrtFormat) {
|
|
let onnxError;
|
|
if (!isOrtFormat) {
|
|
// isOrtFormat === false || isOrtFormat === undefined
|
|
try {
|
|
this.loadFromOnnxFormat(buf, graphInitializer);
|
|
return;
|
|
} catch (e) {
|
|
if (isOrtFormat !== undefined) {
|
|
throw e;
|
|
}
|
|
onnxError = e;
|
|
}
|
|
}
|
|
try {
|
|
this.loadFromOrtFormat(buf, graphInitializer);
|
|
} catch (e) {
|
|
if (isOrtFormat !== undefined) {
|
|
throw e;
|
|
}
|
|
// Tried both formats and failed (when isOrtFormat === undefined)
|
|
throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);
|
|
}
|
|
}
|
|
loadFromOnnxFormat(buf, graphInitializer) {
|
|
const modelProto = onnx_1.onnx.ModelProto.decode(buf);
|
|
const irVersion = util_1.LongUtil.longToNumber(modelProto.irVersion);
|
|
if (irVersion < 3) {
|
|
throw new Error('only support ONNX model with IR_VERSION>=3');
|
|
}
|
|
this._opsets = modelProto.opsetImport.map((i) => ({
|
|
domain: i.domain,
|
|
version: util_1.LongUtil.longToNumber(i.version),
|
|
}));
|
|
this._graph = graph_1.Graph.from(modelProto.graph, graphInitializer);
|
|
}
|
|
loadFromOrtFormat(buf, graphInitializer) {
|
|
const fb = new flatbuffers.ByteBuffer(buf);
|
|
const ortModel = ortFbs.InferenceSession.getRootAsInferenceSession(fb).model();
|
|
const irVersion = util_1.LongUtil.longToNumber(ortModel.irVersion());
|
|
if (irVersion < 3) {
|
|
throw new Error('only support ONNX model with IR_VERSION>=3');
|
|
}
|
|
this._opsets = [];
|
|
for (let i = 0; i < ortModel.opsetImportLength(); i++) {
|
|
const opsetId = ortModel.opsetImport(i);
|
|
this._opsets.push({ domain: opsetId?.domain(), version: util_1.LongUtil.longToNumber(opsetId.version()) });
|
|
}
|
|
this._graph = graph_1.Graph.from(ortModel.graph(), graphInitializer);
|
|
}
|
|
get graph() {
|
|
return this._graph;
|
|
}
|
|
get opsets() {
|
|
return this._opsets;
|
|
}
|
|
}
|
|
exports.Model = Model;
|
|
//# sourceMappingURL=model.js.map
|