Files
voice_recognition/whisper_server/deps/exla/lib/exla/mlir/value.ex
2025-07-15 14:39:51 +00:00

1110 lines
35 KiB
Elixir

defmodule EXLA.MLIR.Value do
@moduledoc false
# Representation of an MLIR Value.
#
# MLIR Values are SSA and generally are either operations or
# block arguments. This module is used to construct most of the
# MLIR operations.
#
# See the full specification of the stablehlo MLIR dialect [1]. Note
# that the URL points to the exact stablehlo revision that we depend
# on via elixir-nx/xla.
#
# [1]: https://github.com/openxla/stablehlo/blob/04291aea6b50d9573e6f4de184938d83b9564cd0/docs/spec.md
defstruct [:ref, :function]
alias __MODULE__
alias EXLA.Typespec
alias EXLA.MLIR.Region
alias EXLA.MLIR.Function
@bin_ops %{
add: "stablehlo.add",
subtract: "stablehlo.subtract",
multiply: "stablehlo.multiply",
divide: "stablehlo.divide",
pow: "stablehlo.power",
min: "stablehlo.minimum",
max: "stablehlo.maximum",
remainder: "stablehlo.remainder",
atan2: "stablehlo.atan2",
bitwise_and: "stablehlo.and",
bitwise_or: "stablehlo.or",
bitwise_xor: "stablehlo.xor",
left_shift: "stablehlo.shift_left",
right_shift_arithmetic: "stablehlo.shift_right_arithmetic",
right_shift_logical: "stablehlo.shift_right_logical"
}
for {op, op_name} <- @bin_ops do
def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, unquote(op_name), [lhs, rhs], result_types) |> one!()
end
end
@bin_comparison_ops %{
equal: :eq,
less: :lt,
less_equal: :le,
greater: :gt,
greater_equal: :ge,
not_equal: :ne
}
for {op, direction} <- @bin_comparison_ops do
def unquote(op)(
%Value{function: func} = lhs,
%Value{function: func} = rhs,
typespec,
opts \\ []
) do
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order])
end
end
defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)
comparison_type =
cond do
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
[compare_type: attr_comparison_type(:float)]
Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
attr =
if total_order? do
attr_comparison_type(:totalorder)
else
attr_comparison_type(:float)
end
[compare_type: attr]
true ->
[]
end
attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type
result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])
op(func, "stablehlo.compare", [lhs, rhs], result_types, attributes: attributes) |> one!()
end
@unary_ops %{
abs: "stablehlo.abs",
exp: "stablehlo.exponential",
expm1: "stablehlo.exponential_minus_one",
floor: "stablehlo.floor",
ceil: "stablehlo.ceil",
round: "stablehlo.round_nearest_afz",
log: "stablehlo.log",
log1p: "stablehlo.log_plus_one",
sigmoid: "stablehlo.logistic",
sign: "stablehlo.sign",
cos: "stablehlo.cosine",
sin: "stablehlo.sine",
tan: "chlo.tan",
acos: "chlo.acos",
asin: "chlo.asin",
atan: "chlo.atan",
cosh: "chlo.cosh",
sinh: "chlo.sinh",
tanh: "stablehlo.tanh",
acosh: "chlo.acosh",
asinh: "chlo.asinh",
atanh: "chlo.atanh",
sqrt: "stablehlo.sqrt",
cbrt: "stablehlo.cbrt",
bitwise_not: "stablehlo.not",
erf: "chlo.erf",
erfc: "chlo.erfc",
erf_inv: "chlo.erf_inv",
rsqrt: "stablehlo.rsqrt",
negate: "stablehlo.negate",
count_leading_zeros: "stablehlo.count_leading_zeros",
population_count: "stablehlo.popcnt",
real: "stablehlo.real",
imag: "stablehlo.imag",
conjugate: "chlo.conj"
}
for {op, op_name} <- @unary_ops do
def unquote(op)(%Value{function: func} = operand, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, unquote(op_name), [operand], result_types, []) |> one!()
end
end
def is_infinity(%Value{function: func} = operand, out_typespec) do
%{type: type} = get_typespec(operand)
typespec = Typespec.to_type(out_typespec, {:pred, 8})
result =
cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_inf_real = is_infinity(real, typespec)
is_inf_imag = is_infinity(imag, typespec)
bitwise_or(is_inf_real, is_inf_imag, typespec)
Nx.Type.integer?(type) ->
# Integers are never infinity. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)
true ->
result_types = typespecs_to_mlir_types([typespec])
op(func, "chlo.is_inf", [operand], result_types) |> one!()
end
if out_typespec.type == typespec.type do
result
else
convert(result, out_typespec)
end
end
def is_nan(%Value{} = operand, out_typespec) do
typespec = Typespec.to_type(out_typespec, {:pred, 8})
# Only NaN is not equal to itself
result = not_equal(operand, operand, typespec)
if out_typespec.type == typespec.type do
result
else
convert(result, out_typespec)
end
end
def reshape(%Value{function: func} = operand, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.reshape", [operand], result_types) |> one!()
end
def reverse(%Value{function: func} = operand, dims, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [dimensions: attr_array_i64_elements(dims)]
op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!()
end
def transpose(%Value{function: func} = operand, axes, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [permutation: attr_array_i64_elements(axes)]
op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!()
end
def slice(%Value{function: func} = operand, starts, limits, strides, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [
start_indices: attr_array_i64_elements(starts),
limit_indices: attr_array_i64_elements(limits),
strides: attr_array_i64_elements(strides)
]
op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!()
end
def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do
result_types = typespecs_to_mlir_types([typespec])
operands = [operand] ++ starts
attributes = [slice_sizes: attr_array_i64_elements(lengths)]
op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!()
end
def convert(%Value{function: func} = operand, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.convert", [operand], result_types) |> one!()
end
def bitcast_convert(%Value{function: func} = operand, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.bitcast_convert", [operand], result_types) |> one!()
end
def top_k(%Value{function: func} = operand, k, typespecs) do
[typespec, index_typespec] = typespecs
result_types = typespecs_to_mlir_types([typespec, Typespec.to_type(index_typespec, {:s, 32})])
attributes = [k: attr_i64(k)]
[result, idx] = op(func, "chlo.top_k", [operand], result_types, attributes: attributes)
idx = convert(idx, index_typespec)
[result, idx]
end
def sort(
[%Value{function: func} | _] = operands,
%Region{ref: comparator},
axis,
stable,
typespecs
)
when is_integer(axis) and is_boolean(stable) do
result_types = typespecs_to_mlir_types(typespecs)
attributes = [
dimension: attr_i64(axis),
is_stable: attr_boolean(stable)
]
regions = [comparator]
op(func, "stablehlo.sort", operands, result_types, attributes: attributes, regions: regions)
end
def iota(%Function{} = func, dim, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [iota_dimension: attr_i64(dim)]
op(func, "stablehlo.iota", [], result_types, attributes: attributes) |> one!()
end
def constant(%Function{} = func, data, typespec) do
result_types = typespecs_to_mlir_types([typespec])
value = attr_dense_elements(data, typespec.type, typespec.shape)
attributes = [value: value]
op(func, "stablehlo.constant", [], result_types, attributes: attributes) |> one!()
end
def dot_general(
%Value{function: func} = lhs,
%Value{function: func} = rhs,
dnums,
precision_config,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
attr_precision_config = attr_precision_config(precision_config)
{contract_axes1, batch_axes1, contract_axes2, batch_axes2} = dnums
dot_dimension_numbers =
attr_struct("stablehlo.dot",
lhs_batching_dimensions: join_list(batch_axes1),
rhs_batching_dimensions: join_list(batch_axes2),
lhs_contracting_dimensions: join_list(contract_axes1),
rhs_contracting_dimensions: join_list(contract_axes2)
)
attributes = [
dot_dimension_numbers: dot_dimension_numbers,
precision_config: "[#{attr_precision_config}, #{attr_precision_config}]"
]
op(func, "stablehlo.dot_general", [lhs, rhs], result_types, attributes: attributes) |> one!()
end
def broadcast_in_dim(%Value{function: func} = operand, axes, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [
broadcast_dimensions: attr_array_i64_elements(axes)
]
op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes)
|> one!()
end
def concatenate([%Value{function: func} | _rest] = operands, dimension, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [dimension: attr_i64(dimension)]
op(func, "stablehlo.concatenate", operands, result_types, attributes: attributes) |> one!()
end
def clamp(
%Value{function: func} = operand,
%Value{function: func} = min,
%Value{function: func} = max,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.clamp", [min, operand, max], result_types) |> one!()
end
def select(
%Value{function: func} = pred,
%Value{function: func} = on_true,
%Value{function: func} = on_false,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.select", [pred, on_true, on_false], result_types) |> one!()
end
def pad(
%Value{function: func} = operand,
%Value{function: func} = pad,
padding_config,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
{padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config)
attributes = [
edge_padding_low: attr_array_i64_elements(padding_low),
edge_padding_high: attr_array_i64_elements(padding_high),
interior_padding: attr_array_i64_elements(padding_mid)
]
op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!()
end
defp unzip_padding_config(padding_config),
do: unzip_padding_config(padding_config, {[], [], []})
defp unzip_padding_config([], {low_acc, high_acc, mid_acc}) do
{Enum.reverse(low_acc), Enum.reverse(high_acc), Enum.reverse(mid_acc)}
end
defp unzip_padding_config([{low, high, mid} | rest], {low_acc, high_acc, mid_acc}) do
unzip_padding_config(rest, {[low | low_acc], [high | high_acc], [mid | mid_acc]})
end
def fft(%Value{function: func} = value, fft_kind, fft_length, typespec)
when fft_kind in [:fft, :ifft]
when is_list(fft_length) or is_integer(fft_length) do
result_types = typespecs_to_mlir_types([typespec])
fft_type = attr_fft_type(fft_kind)
attributes = [
fft_type: fft_type,
fft_length: attr_array_i64_elements(List.wrap(fft_length))
]
op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!()
end
def scatter(
%Value{function: func} = target,
%Value{function: func} = indices,
%Value{function: func} = updates,
kind,
indices_rank,
update_window_dims,
inserted_window_dims,
index_dims_to_window_dims,
typespec
)
when kind in [:add, :put] and is_integer(indices_rank) and is_list(update_window_dims) and
is_list(inserted_window_dims) and is_list(index_dims_to_window_dims) do
result_types = typespecs_to_mlir_types([typespec])
operands = [target, indices, updates]
scatter_dimension_numbers =
attr_struct("stablehlo.scatter",
update_window_dims: join_list(update_window_dims),
inserted_window_dims: join_list(inserted_window_dims),
scatter_dims_to_operand_dims: join_list(index_dims_to_window_dims),
index_vector_dim: Integer.to_string(indices_rank)
)
attributes = [scatter_dimension_numbers: scatter_dimension_numbers]
scatter_computation = scatter_computation(func, kind, typespec)
regions = [scatter_computation.ref]
op(func, "stablehlo.scatter", operands, result_types,
attributes: attributes,
regions: regions
)
|> one!()
end
defp scatter_computation(%Function{} = function, kind, typespec) do
arg_typespec = Typespec.to_shape(typespec, {})
{region, [value, update]} = Function.push_region(function, [arg_typespec, arg_typespec])
res =
case kind do
:add -> add(value, update, arg_typespec)
:put -> update
end
return(function, [res])
Function.pop_region(function)
region
end
def select_and_scatter(
%Value{function: func} = target,
%Value{function: func} = source,
%Value{function: func} = init_value,
comparison,
window_dimensions,
window_strides,
padding,
typespec
)
when comparison in [:gt, :lt] do
operands = [target, source, init_value]
result_types = typespecs_to_mlir_types([typespec])
attributes = [
window_dimensions: attr_array_i64_elements(window_dimensions),
window_strides: attr_array_i64_elements(window_strides),
padding: attr_padding(padding)
]
select_computation = select_computation(func, comparison, typespec)
scatter_computation = scatter_computation(func, :add, typespec)
regions = [select_computation.ref, scatter_computation.ref]
op(func, "stablehlo.select_and_scatter", operands, result_types,
attributes: attributes,
regions: regions
)
|> one!()
end
defp select_computation(function, direction, typespec) do
arg_typespec = Typespec.to_shape(typespec, {})
{region, [arg0, arg1]} = Function.push_region(function, [arg_typespec, arg_typespec])
res = compare_and_return_bool(function, arg0, arg1, arg_typespec, direction)
return(function, [res])
Function.pop_region(function)
region
end
def gather(
%Value{function: func} = source,
%Value{function: func} = indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
dimension_numbers =
attr_struct("stablehlo.gather",
offset_dims: join_list(offset_dims),
collapsed_slice_dims: join_list(collapsed_slice_dims),
start_index_map: join_list(start_index_map),
index_vector_dim: Integer.to_string(index_vector_dim)
)
attributes = [
dimension_numbers: dimension_numbers,
slice_sizes: attr_array_i64_elements(slice_sizes),
indices_are_sorted: attr_boolean(false)
]
op(func, "stablehlo.gather", [source, indices], result_types, attributes: attributes)
|> one!()
end
defp attr_precision_config(precision_config) do
case precision_config do
:default ->
attr_precision(:default)
:high ->
attr_precision(:high)
:highest ->
attr_precision(:highest)
_ ->
raise ArgumentError,
"expected precision configuration to be one of" <>
" :default, :high, or :highest," <>
" got: #{inspect(precision_config)}"
end
end
def convolution(
%Value{function: func} = tensor,
%Value{function: func} = kernel,
strides,
padding,
input_dilation,
kernel_dilation,
dimension_numbers,
feature_group_count,
batch_group_count,
precision_config,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
attr_precision_config = attr_precision_config(precision_config)
attributes = [
window_strides: attr_array_i64_elements(strides),
padding: attr_padding(padding),
lhs_dilation: attr_array_i64_elements(input_dilation),
rhs_dilation: attr_array_i64_elements(kernel_dilation),
dimension_numbers: attr_conv_dimension_numbers(dimension_numbers),
feature_group_count: attr_i64(feature_group_count),
batch_group_count: attr_i64(batch_group_count),
precision_config: "[#{attr_precision_config}, #{attr_precision_config}]"
]
op(func, "stablehlo.convolution", [tensor, kernel], result_types, attributes: attributes)
|> one!()
end
defp attr_conv_dimension_numbers(dimension_numbers) do
{input_permutation, kernel_permutation, output_permutation} = dimension_numbers
input_string = convolution_dims_permutation(input_permutation, "b", "f")
kernel_string = convolution_dims_permutation(kernel_permutation, "o", "i")
output_string = convolution_dims_permutation(output_permutation, "b", "f")
"#stablehlo.conv<[#{input_string}]x[#{kernel_string}]->[#{output_string}]>"
end
defp convolution_dims_permutation(permutation, dim1_mark, dim2_mark) do
[dim1, dim2 | spatial_dims] = permutation
dims_with_marks =
[{dim1, dim1_mark}, {dim2, dim2_mark}] ++
Enum.with_index(spatial_dims, fn dim, idx -> {dim, Integer.to_string(idx)} end)
dims_with_marks
|> Enum.sort()
|> Enum.map_join(",", fn {_dim, mark} -> mark end)
end
def triangular_solve(
%Value{function: func} = a,
%Value{function: func} = b,
left_side,
lower,
transform,
typespec
) do
result_types = typespecs_to_mlir_types([typespec])
complex? = Nx.Type.complex?(typespec.type)
transpose_a =
case transform do
:transpose when complex? -> attr_transpose(:adjoint)
:transpose -> attr_transpose(:transpose)
:none -> attr_transpose(:no_transpose)
end
attributes = [
left_side: attr_boolean(left_side),
lower: attr_boolean(lower),
unit_diagonal: attr_boolean(false),
transpose_a: transpose_a
]
op(func, "stablehlo.triangular_solve", [a, b], result_types, attributes: attributes) |> one!()
end
def dynamic_update_slice(%Value{function: func} = operand, updates, starts, typespec) do
result_types = typespecs_to_mlir_types([typespec])
op(func, "stablehlo.dynamic_update_slice", [operand, updates] ++ starts, result_types)
|> one!()
end
def reduce(
%Region{ref: reducer},
[%Value{function: func} | _] = init_values,
[%Value{function: func} | _] = inputs,
dimensions,
typespecs
) do
operands = inputs ++ init_values
result_types = typespecs_to_mlir_types(typespecs)
attributes = [dimensions: attr_array_i64_elements(dimensions)]
regions = [reducer]
op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions)
end
def window_reduce(
%Region{ref: reducer},
[%Value{function: func} | _] = init_values,
[%Value{function: func} | _] = inputs,
window_dimensions,
window_strides,
input_dilations,
window_dilations,
padding,
typespecs
) do
operands = inputs ++ init_values
result_types = typespecs_to_mlir_types(typespecs)
attributes = [
window_dimensions: attr_array_i64_elements(window_dimensions),
window_strides: attr_array_i64_elements(window_strides),
base_dilations: attr_array_i64_elements(input_dilations),
window_dilations: attr_array_i64_elements(window_dilations),
padding: attr_padding(padding)
]
regions = [reducer]
op(func, "stablehlo.reduce_window", operands, result_types,
attributes: attributes,
regions: regions
)
end
def if_op(
%Value{function: func} = pred,
%Region{ref: on_true},
%Region{ref: on_false},
typespecs
) do
result_types = typespecs_to_mlir_types(typespecs)
regions = [on_true, on_false]
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
op(func, "stablehlo.if", [pred], result_types, regions: regions)
end
def infeed(%Value{function: func} = token, typespecs) do
result_types = typespecs_to_mlir_types(typespecs ++ [Typespec.token()])
results = op(func, "stablehlo.infeed", [token], result_types)
{results, [token]} = Enum.split(results, -1)
{token, results}
end
def outfeed(%Value{} = input, token), do: outfeed([input], token)
def outfeed(inputs, %Value{function: func} = token) do
result_types = [type_token()]
op(func, "stablehlo.outfeed", inputs ++ [token], result_types) |> one!()
end
def create_token(%Function{} = func) do
result_types = [type_token()]
op(func, "stablehlo.create_token", [], result_types) |> one!()
end
def call(%Function{} = func, args, %Function{} = computation, typespecs) do
result_types = typespecs_to_mlir_types(typespecs)
attributes = [callee: attr_symbol_reference(computation.name)]
op(func, "func.call", args, result_types, attributes: attributes)
end
def while(%Function{} = func, %Region{ref: pred}, %Region{ref: body}, initial) do
typespecs = Enum.map(initial, &get_typespec/1)
result_types = typespecs_to_mlir_types(typespecs)
regions = [pred, body]
op(func, "stablehlo.while", initial, result_types, regions: regions)
end
def func_return(func, values) when is_list(values) do
op(func, "func.return", values, [])
end
def return(func, values) when is_list(values) do
op(func, "stablehlo.return", values, [])
end
def eigh(%Value{function: func} = value, eigenvecs_typespec, eigenvals_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: eigenvecs_type, shape: eigenvecs_shape} = eigenvecs_typespec
%{type: eigenvals_type, shape: eigenvals_shape} = eigenvals_typespec
dim_sizes = [tuple_size(op_shape), tuple_size(eigenvecs_shape), tuple_size(eigenvals_shape)]
operand_dims = Tuple.to_list(op_shape)
eigenvecs_dims = Tuple.to_list(eigenvecs_shape)
eigenvals_dims = Tuple.to_list(eigenvals_shape)
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
eigenvecs_dims =
constant(func, eigenvecs_dims, Typespec.tensor({:u, 64}, {length(eigenvecs_dims)}))
eigenvals_dims =
constant(func, eigenvals_dims, Typespec.tensor({:u, 64}, {length(eigenvals_dims)}))
operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims]
eigenvecs_result_type = type_tensor(eigenvecs_type, eigenvecs_shape)
eigenvals_result_type = type_tensor(eigenvals_type, eigenvals_shape)
result_types = [type_tuple([eigenvecs_result_type, eigenvals_result_type])]
call_target_name =
case op_type do
{:f, 32} ->
"eigh_cpu_custom_call_f32"
{:f, 64} ->
"eigh_cpu_custom_call_f64"
type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "Eigh decomposition not supported on :host device for type #{inspect(type)}"
end
attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]
result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()
eigenvecs = get_tuple_element(result, 0, eigenvecs_typespec)
eigenvals = get_tuple_element(result, 1, eigenvals_typespec)
{eigenvecs, eigenvals}
end
def qr(%Value{function: func} = value, q_typespec, r_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: q_type, shape: q_shape} = q_typespec
%{type: r_type, shape: r_shape} = r_typespec
dim_sizes = [tuple_size(op_shape), tuple_size(q_shape), tuple_size(r_shape)]
operand_dims = Tuple.to_list(op_shape)
q_dims = Tuple.to_list(q_shape)
r_dims = Tuple.to_list(r_shape)
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
q_dims = constant(func, q_dims, Typespec.tensor({:u, 64}, {length(q_dims)}))
r_dims = constant(func, r_dims, Typespec.tensor({:u, 64}, {length(r_dims)}))
operands = [value, dim_sizes, operand_dims, q_dims, r_dims]
q_result_type = type_tensor(q_type, q_shape)
r_result_type = type_tensor(r_type, r_shape)
result_types = [type_tuple([q_result_type, r_result_type])]
call_target_name =
case op_type do
{:f, 32} ->
"qr_cpu_custom_call_f32"
{:f, 64} ->
"qr_cpu_custom_call_f64"
{:f, 16} ->
"qr_cpu_custom_call_f16"
{:bf, 16} ->
"qr_cpu_custom_call_bf16"
type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "QR decomposition not supported on :host device for type #{inspect(type)}"
end
attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]
result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()
q = get_tuple_element(result, 0, q_typespec)
r = get_tuple_element(result, 1, r_typespec)
{q, r}
end
def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: _p_type, shape: p_shape} = p_typespec
%{type: l_type, shape: l_shape} = l_typespec
%{type: u_type, shape: u_shape} = u_typespec
dim_sizes = [
tuple_size(op_shape),
tuple_size(p_shape),
tuple_size(l_shape),
tuple_size(u_shape)
]
operand_dims = Tuple.to_list(op_shape)
p_dims = Tuple.to_list(p_shape)
l_dims = Tuple.to_list(l_shape)
u_dims = Tuple.to_list(u_shape)
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)}))
l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)}))
u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)}))
operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims]
# Force P to always b u8 to avoid requiring too many template instances during custom_call registration
p_result_type = type_tensor({:u, 8}, p_shape)
l_result_type = type_tensor(l_type, l_shape)
u_result_type = type_tensor(u_type, u_shape)
result_types = [type_tuple([p_result_type, l_result_type, u_result_type])]
call_target_name =
case op_type do
{:f, 32} ->
"lu_cpu_custom_call_f32"
{:f, 64} ->
"lu_cpu_custom_call_f64"
{:f, 16} ->
"lu_cpu_custom_call_f16"
{:bf, 16} ->
"lu_cpu_custom_call_bf16"
type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "LU decomposition not supported on :host device for type #{inspect(type)}"
end
attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]
result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()
# This is not the best approach, but the alternative would require many more template instances
u8_typespec = Typespec.to_type(p_typespec, {:u, 8})
p = get_tuple_element(result, 0, u8_typespec)
p =
if u8_typespec != p_typespec do
convert(p, p_typespec)
else
p
end
l = get_tuple_element(result, 1, l_typespec)
u = get_tuple_element(result, 2, u_typespec)
{p, l, u}
end
def get_tuple_element(%Value{function: func} = operand, index, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [index: attr_i32(index)]
op(func, "stablehlo.get_tuple_element", [operand], result_types, attributes: attributes)
|> one!()
end
def get_typespec(value) do
EXLA.NIF.mlir_get_typespec(value.ref)
|> unwrap!()
|> Typespec.nif_decode()
end
def typespecs_to_mlir_types(shapes) do
Enum.map(shapes, &typespec_to_mlir_type/1)
end
defp typespec_to_mlir_type(%{type: :token}), do: type_token()
defp typespec_to_mlir_type(%{type: type, shape: shape}), do: type_tensor(type, shape)
defp unwrap!(:ok), do: :ok
defp unwrap!({:ok, value}), do: value
defp unwrap!(other), do: raise("#{inspect(other)}")
defp one!([value]), do: value
defp one!(other) do
raise "expected a list with single element, got: #{inspect(other)}"
end
defp complex_part_type({:c, size}), do: {:f, div(size, 2)}
defp op(%Function{} = function, op_name, operands, result_types, opts \\ []) do
opts = Keyword.validate!(opts, attributes: [], regions: [])
%{ref: function_ref} = function
refs =
Enum.map(operands, fn
%Value{ref: ref, function: %Function{ref: ^function_ref}} -> ref
end)
refs =
EXLA.NIF.mlir_op(
function.ref,
op_name,
refs,
result_types,
opts[:attributes],
opts[:regions]
)
|> unwrap!()
Enum.map(refs, &%Value{function: function, ref: &1})
end
defp type_tensor(type, shape) do
shape_sequence = shape |> Tuple.to_list() |> Enum.map_join("", &"#{&1}x")
"tensor<#{shape_sequence}#{type_number(type)}>"
end
defp type_number({:pred, 8}), do: "i1"
defp type_number({:s, width}), do: "i#{width}"
defp type_number({:u, width}), do: "ui#{width}"
defp type_number({:f, 8}), do: "f8E5M2"
defp type_number({:f, width}), do: "f#{width}"
defp type_number({:bf, width}), do: "bf#{width}"
defp type_number({:c, 64}), do: "complex<f32>"
defp type_number({:c, 128}), do: "complex<f64>"
defp type_token(), do: "!stablehlo.token"
defp type_tuple(children) do
"tuple<#{Enum.join(children, ", ")}>"
end
defp number_literal(value, type) do
cond do
Nx.Type.complex?(type) ->
{re, im} =
case value do
%Complex{re: re, im: im} -> {re, im}
true -> {1, 0}
false -> {0, 0}
n -> {n, 0}
end
subtype = complex_part_type(type)
"(#{number_literal(re, subtype)}, #{number_literal(im, subtype)})"
Nx.Type.float?(type) ->
# We pass floats using binary representation, because that is
# likely more robust and not a subject to formatting limits and
# rounding. Based on the examples in the docs, the hexadecimal
# representation is always big-endian.
#
# See https://mlir.llvm.org/docs/Dialects/Builtin/#floatattr
hex_data = float_hex(value, type)
"0x#{hex_data}"
true ->
"#{value}"
end
end
defp float_hex(value, {mod, size} = type) do
data =
case value do
:nan -> type |> Nx.Type.nan_binary() |> native_to_big()
:infinity -> type |> Nx.Type.infinity_binary() |> native_to_big()
:neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big()
value when size == 8 -> f8E5M2_to_big(value)
value when mod == :bf and size == 16 -> bf16_to_big(value)
value -> <<value::float-size(size)-big>>
end
Base.encode16(data)
end
defp f8E5M2_to_big(x) do
binary_part(<<x::float-big-16>>, 0, 1)
end
defp bf16_to_big(x) do
binary_part(<<x::float-big-32>>, 0, 2)
end
defp native_to_big(binary) do
size = byte_size(binary) * 8
<<value::size(size)-native>> = binary
<<value::size(size)-big>>
end
defp attr_array_i64_elements([]) do
"array<i64>"
end
defp attr_array_i64_elements(list) do
"array<i64: #{Enum.join(list, ", ")}>"
end
defp attr_dense_elements([], type, {0} = shape) do
"dense<[]> : #{type_tensor(type, shape)}"
end
defp attr_dense_elements(list, type, shape) do
literals = Enum.map(list, &number_literal(&1, type))
list_literal =
shape
|> Tuple.to_list()
|> List.foldr(literals, fn size, acc ->
acc
|> Enum.chunk_every(size)
|> Enum.map(fn chunk ->
["[", Enum.intersperse(chunk, ", "), "]"]
end)
end)
|> IO.iodata_to_binary()
"dense<#{list_literal}> : #{type_tensor(type, shape)}"
end
defp attr_string(string), do: ~s["#{string}"]
defp attr_symbol_reference(id), do: "@#{id}"
defp attr_boolean(true), do: "true"
defp attr_boolean(false), do: "false"
defp attr_i32(number), do: "#{number} : i32"
defp attr_i64(number), do: "#{number} : i64"
defp attr_padding(padding) do
list = Enum.flat_map(padding, &Tuple.to_list/1)
attr_dense_elements(list, {:s, 64}, {length(padding), 2})
end
defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne],
do: attr_enum("stablehlo", "comparison_direction", value)
defp attr_comparison_type(value) when value in [:float, :totalorder],
do: attr_enum("stablehlo", "comparison_type", value)
defp attr_precision(value) when value in [:default, :high, :highest],
do: attr_enum("stablehlo", "precision", value)
defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose],
do: attr_enum("stablehlo", "transpose", value)
defp attr_fft_type(value) when value in [:fft, :ifft],
do: attr_enum("stablehlo", "fft_type", value)
defp attr_enum(dialect, enum_name, value) do
value = value |> Atom.to_string() |> String.upcase()
"##{dialect}<#{enum_name} #{value}>"
end
defp attr_struct(name, keyword_list) do
content = Enum.map_join(keyword_list, ", ", fn {key, value} -> "#{key} = #{value}" end)
"##{name}<#{content}>"
end
defp join_list(list) do
"[" <> Enum.join(list, ", ") <> "]"
end
end