Files

742 lines
23 KiB
Elixir

defmodule NxSignal do
@moduledoc """
Nx library extension for digital signal processing.
"""
import Nx.Defn
@doc ~S"""
Computes the Short-Time Fourier Transform of a tensor.
Returns the complex spectrum Z, the time in seconds for
each frame and the frequency bins in Hz.
The STFT is parameterized through:
* $k$: length of the Discrete Fourier Transform (DFT)
* $N$: length of each frame
* $H$: hop (in samples) between frames (calculated as $H = N - \text{overlap\\_length}$)
* $M$: number of frames
* $x[n]$: the input time-domain signal
* $w[n]$: the window function to be applied to each frame
$$
DFT(x, w) := \sum_{n=0}^{N - 1} x[n]w[n]e^\frac{-2 \pi i k n}{N} \\\\
X[m, k] = DFT(x[mH..(mH + N - 1)], w)
$$
where $m$ assumes all values in the interval $[0, M - 1]$
See also: `NxSignal.Windows`, `istft/3`, `stft_to_mel/3`
## Options
* `:sampling_rate` - the sampling frequency $F_s$ for the input in Hz. Defaults to `1000`.
* `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
* `:overlap_length` - the number of samples for the overlap between frames.
Defaults to half the window size.
* `:window_padding` - `:reflect`, `:zeros` or `nil`. See `as_windowed/3` for more details.
* `:scaling` - `nil`, `:spectrum` or `:psd`.
* `:spectrum` - each frame is divided by $\sum_{i} window[i]$.
* `nil` - No scaling is applied.
* `:psd` - each frame is divided by $\sqrt{F\_s\sum_{i} window[i]^2}$.
## Examples
iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
iex> z
#Nx.Tensor<
c64[frames: 3][frequencies: 2]
[
[1.0+0.0i, -1.0+0.0i],
[3.0+0.0i, -1.0+0.0i],
[5.0+0.0i, -1.0+0.0i]
]
>
iex> t
#Nx.Tensor<
f32[frames: 3]
[0.0024999999441206455, 0.004999999888241291, 0.007499999832361937]
>
iex> f
#Nx.Tensor<
f32[frequencies: 2]
[0.0, 200.0]
>
"""
@doc type: :time_frequency
deftransform stft(data, window, opts \\ []) do
{frame_length} = Nx.shape(window)
opts =
Keyword.validate!(opts, [
:overlap_length,
:window,
:scaling,
window_padding: :valid,
sampling_rate: 100,
fft_length: :power_of_two
])
sampling_rate = opts[:sampling_rate] || raise ArgumentError, "missing sampling_rate option"
overlap_length = opts[:overlap_length] || div(frame_length, 2)
stft_n(data, window, sampling_rate, Keyword.put(opts, :overlap_length, overlap_length))
end
defnp stft_n(data, window, sampling_rate, opts) do
{frame_length} = Nx.shape(window)
padding = opts[:window_padding]
fft_length = opts[:fft_length]
overlap_length = opts[:overlap_length]
spectrum =
data
|> as_windowed(
padding: padding,
window_length: frame_length,
stride: frame_length - overlap_length
)
|> Nx.multiply(window)
|> Nx.fft(length: fft_length)
{num_frames, fft_length} = Nx.shape(spectrum)
frequencies = fft_frequencies(sampling_rate, fft_length: fft_length)
# assign the middle of the equivalent time window as the time for the given frame
time_step = frame_length / (2 * sampling_rate)
last_frame = time_step * num_frames
times = Nx.linspace(time_step, last_frame, n: num_frames, name: :frames)
output =
case opts[:scaling] do
:spectrum ->
spectrum / Nx.sum(window)
:psd ->
spectrum / Nx.sqrt(sampling_rate * Nx.sum(window ** 2))
nil ->
spectrum
scaling ->
raise ArgumentError,
"invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}"
end
{Nx.reshape(output, spectrum.shape, names: [:frames, :frequencies]), times, frequencies}
end
@doc """
Computes the frequency bins for a FFT with given options.
## Arguments
* `sampling_rate` - Sampling frequency in Hz.
## Options
* `:fft_length` - Number of FFT frequency bins.
* `:type` - Optional output type. Defaults to `{:f, 32}`
* `:name` - Optional axis name for the tensor. Defaults to `:frequencies`
## Examples
iex> NxSignal.fft_frequencies(1.6e4, fft_length: 10)
#Nx.Tensor<
f32[frequencies: 10]
[0.0, 1.6e3, 3.2e3, 4.8e3, 6.4e3, 8.0e3, 9.6e3, 1.12e4, 1.28e4, 1.44e4]
>
"""
@doc type: :time_frequency
defn fft_frequencies(sampling_rate, opts \\ []) do
opts = keyword!(opts, [:fft_length, type: {:f, 32}, name: :frequencies, endpoint: false])
fft_length = opts[:fft_length]
step = sampling_rate / fft_length
Nx.linspace(0, step * fft_length,
n: fft_length,
type: opts[:type],
name: opts[:name],
endpoint: opts[:endpoint]
)
end
@doc """
Returns a tensor of K windows of length N
## Options
* `:window_length` - the number of samples in a window
* `:stride` - The number of samples to skip between windows. Defaults to `1`.
* `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered
at the start of the signal. The padding is applied for the whole input, rather than individual
windows. For `:zeros`, effectively each incomplete window will be zero-padded.
## Examples
iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 4)
#Nx.Tensor<
s64[5][4]
[
[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 10],
[3, 4, 10, 11],
[4, 10, 11, 12]
]
>
iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 3)
#Nx.Tensor<
s64[6][3]
[
[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 10],
[4, 10, 11],
[10, 11, 12]
]
>
iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11]), window_length: 2, stride: 2, padding: [{0, 3}])
#Nx.Tensor<
s64[5][2]
[
[0, 1],
[2, 3],
[4, 10],
[11, 0],
[0, 0]
]
>
iex> t = Nx.iota({7});
iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
#Nx.Tensor<
s64[8][6]
[
[3, 2, 1, 0, 1, 2],
[2, 1, 0, 1, 2, 3],
[1, 0, 1, 2, 3, 4],
[0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 5],
[3, 4, 5, 6, 5, 4],
[4, 5, 6, 5, 4, 3]
]
>
iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
#Nx.Tensor<
s64[6][6]
[
[3, 2, 1, 0, 1, 2],
[1, 0, 1, 2, 3, 4],
[1, 2, 3, 4, 5, 6],
[3, 4, 5, 6, 7, 8],
[5, 6, 7, 8, 9, 8],
[7, 8, 9, 8, 7, 6]
]
>
"""
@doc type: :windowing
deftransform as_windowed(tensor, opts \\ []) do
if opts[:padding] == :reflect do
as_windowed_reflect_padding(tensor, opts)
else
as_windowed_non_reflect_padding(tensor, opts)
end
end
deftransformp as_windowed_parse_reflect_opts(shape, opts) do
window_length = opts[:window_length]
as_windowed_parse_non_reflect_opts(
shape,
Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}])
)
end
deftransformp as_windowed_parse_non_reflect_opts(shape, opts) do
opts = Keyword.validate!(opts, [:window_length, padding: :valid, stride: 1])
window_length = opts[:window_length]
window_dimensions = {window_length}
padding = opts[:padding]
[stride] =
strides =
case opts[:stride] do
stride when is_list(stride) ->
stride
stride when is_integer(stride) and stride >= 1 ->
[stride]
stride ->
raise ArgumentError,
"expected an integer >= 1 or a list of integers, got: #{inspect(stride)}"
end
padding_config = as_windowed_to_padding_config(shape, window_dimensions, padding)
# trick so that we can get Nx to calculate the pooled shape for us
%{shape: pooled_shape} =
Nx.window_max(
Nx.iota(shape, backend: Nx.Defn.Expr),
window_dimensions,
padding: padding,
strides: strides
)
output_shape = {Tuple.product(pooled_shape), window_length}
{window_length, stride, padding_config, output_shape}
end
defp as_windowed_to_padding_config(shape, kernel_size, mode) do
case mode do
:valid ->
List.duplicate({0, 0, 0}, tuple_size(shape))
:same ->
Enum.zip_with(Tuple.to_list(shape), Tuple.to_list(kernel_size), fn dim, k ->
padding_size = max(dim - 1 + k - dim, 0)
{floor(padding_size / 2), ceil(padding_size / 2), 0}
end)
config when is_list(config) ->
Enum.map(config, fn
{x, y} when is_integer(x) and is_integer(y) ->
{x, y, 0}
_other ->
raise ArgumentError,
"padding must be a list of {high, low} tuples, where each element is an integer. " <>
"Got: #{inspect(config)}"
end)
mode ->
raise ArgumentError,
"invalid padding mode specified, padding must be one" <>
" of :valid, :same, or a padding configuration, got:" <>
" #{inspect(mode)}"
end
end
defnp as_windowed_non_reflect_padding(tensor, opts \\ []) do
# current implementation only supports windowing 1D tensors
{window_length, stride, padding, output_shape} =
as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts)
tensor = Nx.pad(tensor, 0, padding)
as_windowed_apply(tensor, stride, output_shape, window_length)
end
defnp as_windowed_reflect_padding(tensor, opts \\ []) do
# current implementation only supports windowing 1D tensors
{window_length, stride, _padding, output_shape} =
as_windowed_parse_reflect_opts(Nx.shape(tensor), opts)
half_window = div(window_length, 2)
tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}])
as_windowed_apply(tensor, stride, output_shape, window_length)
end
defnp as_windowed_apply(tensor, stride, output_shape, window_length) do
output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
{num_windows, _} = Nx.shape(output)
[output, tensor] = Nx.broadcast_vectors([output, tensor])
{output, _, _, _} =
while {output, i = 0, current_window = 0, t = tensor}, current_window < num_windows do
window = t |> Nx.slice([i], [window_length])
updated = Nx.put_slice(output, [current_window, 0], Nx.new_axis(window, 0))
{updated, i + stride, current_window + 1, t}
end
output
end
@doc """
Generates weights for converting an STFT representation into MEL-scale.
See also: `stft/3`, `istft/3`, `stft_to_mel/3`
## Arguments
* `fft_length` - Number of FFT bins
* `mel_bins` - Number of target MEL bins
* `sampling_rate` - Sampling frequency in Hz
## Options
* `:max_mel` - the pitch for the last MEL bin before log scaling. Defaults to 3016
* `:mel_frequency_spacing` - the distance in Hz between two MEL bins before log scaling. Defaults to 66.6
* `:type` - Target output type. Defaults to `{:f, 32}`
## Examples
iex> NxSignal.mel_filters(10, 5, 8.0e3)
#Nx.Tensor<
f32[mels: 5][frequencies: 10]
[
[0.0, 8.129207999445498e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 9.972016559913754e-4, 2.1870288765057921e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 9.510891977697611e-4, 4.150509194005281e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 4.035891906823963e-4, 5.276656011119485e-4, 2.574124082457274e-4, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 7.329034269787371e-5, 2.342205698369071e-4, 3.8295105332508683e-4, 2.8712040511891246e-4, 1.9128978601656854e-4, 9.545915963826701e-5]
]
>
"""
@doc type: :time_frequency
deftransform mel_filters(fft_length, mel_bins, sampling_rate, opts \\ []) do
opts =
Keyword.validate!(opts,
max_mel: 3016,
mel_frequency_spacing: 200 / 3,
type: {:f, 32}
)
mel_filters_n(sampling_rate, opts[:max_mel], opts[:mel_frequency_spacing],
type: opts[:type],
fft_length: fft_length,
mel_bins: mel_bins
)
end
defnp mel_filters_n(sampling_rate, max_mel, f_sp, opts) do
fft_length = opts[:fft_length]
mel_bins = opts[:mel_bins]
type = opts[:type]
fftfreqs = fft_frequencies(sampling_rate, type: type, fft_length: fft_length)
mels = Nx.linspace(0, max_mel / f_sp, type: type, n: mel_bins + 2, name: :mels)
freqs = f_sp * mels
min_log_hz = 1_000
min_log_mel = min_log_hz / f_sp
# numpy uses the f64 value by default
logstep = Nx.log(6.4) / 27
log_t = mels >= min_log_mel
# This is the same as freqs[log_t] = min_log_hz * Nx.exp(logstep * (mels[log_t] - min_log_mel))
# notice that since freqs and mels are indexed by the same conditional tensor, we don't
# need to slice either of them
mel_f = Nx.select(log_t, min_log_hz * Nx.exp(logstep * (mels - min_log_mel)), freqs)
fdiff = Nx.new_axis(mel_f[1..-1//1] - mel_f[0..-2//1], 1)
ramps = Nx.new_axis(mel_f, 1) - fftfreqs
lower = -ramps[0..(mel_bins - 1)] / fdiff[0..(mel_bins - 1)]
upper = ramps[2..(mel_bins + 1)//1] / fdiff[1..mel_bins]
weights = Nx.max(0, Nx.min(lower, upper))
enorm = 2.0 / (mel_f[2..(mel_bins + 1)] - mel_f[0..(mel_bins - 1)])
weights * Nx.new_axis(enorm, 1)
end
@doc """
Converts a given STFT time-frequency spectrum into a MEL-scale time-frequency spectrum.
See also: `stft/3`, `istft/3`, `mel_filters/4`
## Arguments
* `z` - STFT spectrum
* `sampling_rate` - Sampling frequency in Hz
## Options
* `:fft_length` - Number of FFT bins
* `:mel_bins` - Number of target MEL bins. Defaults to 128
* `:type` - Target output type. Defaults to `{:f, 32}`
## Examples
iex> fft_length = 16
iex> sampling_rate = 8.0e3
iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
iex> Nx.axis_size(z, :frequencies)
16
iex> Nx.axis_size(z, :frames)
6
iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
#Nx.Tensor<
f32[frames: 6][mel: 4]
[
[0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
[0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
[0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
[0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
]
>
"""
@doc type: :time_frequency
defn stft_to_mel(z, sampling_rate, opts \\ []) do
opts =
keyword!(opts, [:fft_length, :mel_bins, :max_mel, :mel_frequency_spacing, type: {:f, 32}])
magnitudes = Nx.abs(z) ** 2
filters =
mel_filters(opts[:fft_length], opts[:mel_bins], sampling_rate, mel_filters_opts(opts))
freq_size = div(opts[:fft_length], 2)
real_freqs_mag = Nx.slice_along_axis(magnitudes, 0, freq_size, axis: :frequencies)
real_freqs_filters = Nx.slice_along_axis(filters, 0, freq_size, axis: :frequencies)
mel_spec =
Nx.dot(
real_freqs_mag,
[:frequencies],
real_freqs_filters,
[:frequencies]
)
mel_spec = Nx.reshape(mel_spec, Nx.shape(mel_spec), names: [:frames, :mel])
log_spec = Nx.log(Nx.clip(mel_spec, 1.0e-10, :infinity)) / Nx.log(10)
log_spec = Nx.max(log_spec, Nx.reduce_max(log_spec) - 8)
(log_spec + 4) / 4
end
deftransformp mel_filters_opts(opts) do
Keyword.take(opts, [:max_mel, :mel_frequency_spacing, :type])
end
@doc ~S"""
Computes the Inverse Short-Time Fourier Transform of a tensor.
Returns a tensor of M time-domain frames of length `fft_length`.
See also: `NxSignal.Windows`, `stft/3`
## Options
* `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
* `:overlap_length` - the number of samples for the overlap between frames.
Defaults to half the window size.
* `:sampling_rate` - the sampling rate $F_s$ in Hz. Defaults to `1000`.
* `:scaling` - `nil`, `:spectrum` or `:psd`.
* `:spectrum` - each frame is multiplied by $\sum_{i} window[i]$.
* `nil` - No scaling is applied.
* `:psd` - each frame is multiplied by $\sqrt{F\_s\sum_{i} window[i]^2}$.
## Examples
In general, `istft/3` takes in the same parameters and window as the `stft/3` that generated the spectrum.
In the first example, we can notice that the reconstruction is mostly perfect, aside from the first sample.
This is because the Hann window only ensures perfect reconstruction in overlapping regions, so the edges
of the signal end up being distorted.
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
iex> w = NxSignal.Windows.hann(n: 4)
iex> opts = [sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
iex> Nx.as_type(result, Nx.type(t))
#Nx.Tensor<
s64[8]
[0, 10, 1, 0, 10, 10, 2, 20]
>
Different scaling options are available (see `stft/3` for a more detailed explanation).
For perfect reconstruction, you want to use the same scaling as the STFT:
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
iex> w = NxSignal.Windows.hann(n: 4)
iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
iex> Nx.as_type(result, Nx.type(t))
#Nx.Tensor<
s64[8]
[0, 10, 1, 0, 10, 10, 2, 20]
>
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32)
iex> w = NxSignal.Windows.hann(n: 4)
iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
iex> Nx.as_type(result, Nx.type(t))
#Nx.Tensor<
f32[8]
[0.0, 10.0, 0.9999999403953552, -2.1900146407460852e-7, 10.0, 10.0, 2.000000238418579, 20.0]
>
"""
@doc type: :time_frequency
defn istft(data, window, opts) do
opts = keyword!(opts, [:fft_length, :overlap_length, :scaling, sampling_rate: 1000])
fft_length =
case opts[:fft_length] do
nil ->
:power_of_two
fft_length ->
fft_length
end
overlap_length =
case opts[:overlap_length] do
nil ->
div(Nx.size(window), 2)
overlap_length ->
overlap_length
end
sampling_rate =
case {opts[:scaling], opts[:sampling_rate]} do
{:psd, nil} -> raise ArgumentError, ":sampling_rate is mandatory if scaling is :psd"
{_, sampling_rate} -> sampling_rate
end
frames = Nx.ifft(data, length: fft_length)
frames_rescaled =
case opts[:scaling] do
:spectrum ->
frames * Nx.sum(window)
:psd ->
frames * Nx.sqrt(sampling_rate * Nx.sum(window ** 2))
nil ->
frames
scaling ->
raise ArgumentError,
"invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}"
end
result_non_normalized =
overlap_and_add(frames_rescaled * window, overlap_length: overlap_length)
normalization_factor =
overlap_and_add(Nx.broadcast(Nx.abs(window) ** 2, data.shape),
overlap_length: overlap_length
)
normalization_factor = Nx.select(normalization_factor > 1.0e-10, normalization_factor, 1.0)
result_non_normalized / normalization_factor
end
@doc """
Performs the overlap-and-add algorithm over
an {..., M, N}-shaped tensor, where M is the number of
windows and N is the window size.
The tensor is zero-padded on the right so
the last window fully appears in the result.
## Options
* `:overlap_length` - The number of overlapping samples between windows
* `:type` - output type for casting the accumulated result.
If not given, defaults to `Nx.Type.to_complex/1` called on the input type.
## Examples
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 0)
#Nx.Tensor<
s64[12]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
>
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3)
#Nx.Tensor<
s64[6]
[0, 5, 15, 18, 17, 11]
>
iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1)
iex> NxSignal.overlap_and_add(t, overlap_length: 3)
#Nx.Tensor<
vectorized[x: 2][y: 1]
s64[5]
[
[
[0, 5, 7, 9, 7]
],
[
[10, 25, 27, 29, 17]
]
]
>
"""
@doc type: :windowing
defn overlap_and_add(tensor, opts \\ []) do
opts = keyword!(opts, [:overlap_length, type: Nx.type(tensor)])
overlap_length = opts[:overlap_length]
%{vectorized_axes: vectorized_axes, shape: input_shape} = tensor
num_windows = Nx.axis_size(tensor, -2)
window_length = Nx.axis_size(tensor, -1)
if overlap_length >= window_length do
raise ArgumentError,
"overlap_length must be a number less than the window size #{window_length}, got: #{inspect(window_length)}"
end
tensor =
Nx.revectorize(tensor, [condensed_vectors: :auto, windows: num_windows],
target_shape: {window_length}
)
stride = window_length - overlap_length
output_holder_shape = {num_windows * stride + overlap_length}
out =
Nx.broadcast(
Nx.tensor(0, type: tensor.type),
output_holder_shape
)
idx_template = Nx.iota({window_length, 1}, vectorized_axes: [windows: 1])
i = Nx.iota({num_windows}) |> Nx.vectorize(:windows)
idx = idx_template + i * stride
[%{vectorized_axes: [condensed_vectors: n, windows: _]} = tensor, idx] =
Nx.broadcast_vectors([tensor, idx])
tensor = Nx.revectorize(tensor, [condensed_vectors: n], target_shape: {:auto})
idx = Nx.revectorize(idx, [condensed_vectors: n], target_shape: {:auto, 1})
out_shape = overlap_and_add_output_shape(out.shape, input_shape)
out
|> Nx.indexed_add(idx, tensor)
|> Nx.as_type(opts[:type])
|> Nx.revectorize(vectorized_axes, target_shape: out_shape)
end
deftransformp overlap_and_add_output_shape({out_len}, in_shape) do
idx = tuple_size(in_shape) - 2
in_shape
|> Tuple.delete_at(idx)
|> Tuple.delete_at(idx)
|> Tuple.append(out_len)
end
end