71 lines
2.0 KiB
TypeScript
71 lines
2.0 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {
|
|
CpuPinnedConstructorParameters,
|
|
GpuBufferConstructorParameters,
|
|
MLTensorConstructorParameters,
|
|
TextureConstructorParameters,
|
|
} from './tensor-factory.js';
|
|
import { Tensor } from './tensor-impl.js';
|
|
|
|
/**
|
|
* calculate size from dims.
|
|
*
|
|
* @param dims the dims array. May be an illegal input.
|
|
*/
|
|
export const calculateSize = (dims: readonly unknown[]): number => {
|
|
let size = 1;
|
|
for (let i = 0; i < dims.length; i++) {
|
|
const dim = dims[i];
|
|
if (typeof dim !== 'number' || !Number.isSafeInteger(dim)) {
|
|
throw new TypeError(`dims[${i}] must be an integer, got: ${dim}`);
|
|
}
|
|
if (dim < 0) {
|
|
throw new RangeError(`dims[${i}] must be a non-negative integer, got: ${dim}`);
|
|
}
|
|
size *= dim;
|
|
}
|
|
return size;
|
|
};
|
|
|
|
/**
|
|
* implementation of Tensor.reshape()
|
|
*/
|
|
export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor => {
|
|
switch (tensor.location) {
|
|
case 'cpu':
|
|
return new Tensor(tensor.type, tensor.data, dims);
|
|
case 'cpu-pinned':
|
|
return new Tensor({
|
|
location: 'cpu-pinned',
|
|
data: tensor.data as CpuPinnedConstructorParameters['data'],
|
|
type: tensor.type as CpuPinnedConstructorParameters['type'],
|
|
dims,
|
|
});
|
|
case 'texture':
|
|
return new Tensor({
|
|
location: 'texture',
|
|
texture: tensor.texture,
|
|
type: tensor.type as TextureConstructorParameters['type'],
|
|
dims,
|
|
});
|
|
case 'gpu-buffer':
|
|
return new Tensor({
|
|
location: 'gpu-buffer',
|
|
gpuBuffer: tensor.gpuBuffer,
|
|
type: tensor.type as GpuBufferConstructorParameters['type'],
|
|
dims,
|
|
});
|
|
case 'ml-tensor':
|
|
return new Tensor({
|
|
location: 'ml-tensor',
|
|
mlTensor: tensor.mlTensor,
|
|
type: tensor.type as MLTensorConstructorParameters['type'],
|
|
dims,
|
|
});
|
|
default:
|
|
throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`);
|
|
}
|
|
};
|