defmodule NxSignal do @moduledoc """ Nx library extension for digital signal processing. """ import Nx.Defn @doc ~S""" Computes the Short-Time Fourier Transform of a tensor. Returns the complex spectrum Z, the time in seconds for each frame and the frequency bins in Hz. The STFT is parameterized through: * $k$: length of the Discrete Fourier Transform (DFT) * $N$: length of each frame * $H$: hop (in samples) between frames (calculated as $H = N - \text{overlap\\_length}$) * $M$: number of frames * $x[n]$: the input time-domain signal * $w[n]$: the window function to be applied to each frame $$ DFT(x, w) := \sum_{n=0}^{N - 1} x[n]w[n]e^\frac{-2 \pi i k n}{N} \\\\ X[m, k] = DFT(x[mH..(mH + N - 1)], w) $$ where $m$ assumes all values in the interval $[0, M - 1]$ See also: `NxSignal.Windows`, `istft/3`, `stft_to_mel/3` ## Options * `:sampling_rate` - the sampling frequency $F_s$ for the input in Hz. Defaults to `1000`. * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`. * `:overlap_length` - the number of samples for the overlap between frames. Defaults to half the window size. * `:window_padding` - `:reflect`, `:zeros` or `nil`. See `as_windowed/3` for more details. * `:scaling` - `nil`, `:spectrum` or `:psd`. * `:spectrum` - each frame is divided by $\sum_{i} window[i]$. * `nil` - No scaling is applied. * `:psd` - each frame is divided by $\sqrt{F\_s\sum_{i} window[i]^2}$. ## Examples iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400) iex> z #Nx.Tensor< c64[frames: 3][frequencies: 2] [ [1.0+0.0i, -1.0+0.0i], [3.0+0.0i, -1.0+0.0i], [5.0+0.0i, -1.0+0.0i] ] > iex> t #Nx.Tensor< f32[frames: 3] [0.0024999999441206455, 0.004999999888241291, 0.007499999832361937] > iex> f #Nx.Tensor< f32[frequencies: 2] [0.0, 200.0] > """ @doc type: :time_frequency deftransform stft(data, window, opts \\ []) do {frame_length} = Nx.shape(window) opts = Keyword.validate!(opts, [ :overlap_length, :window, :scaling, window_padding: :valid, sampling_rate: 100, fft_length: :power_of_two ]) sampling_rate = opts[:sampling_rate] || raise ArgumentError, "missing sampling_rate option" overlap_length = opts[:overlap_length] || div(frame_length, 2) stft_n(data, window, sampling_rate, Keyword.put(opts, :overlap_length, overlap_length)) end defnp stft_n(data, window, sampling_rate, opts) do {frame_length} = Nx.shape(window) padding = opts[:window_padding] fft_length = opts[:fft_length] overlap_length = opts[:overlap_length] spectrum = data |> as_windowed( padding: padding, window_length: frame_length, stride: frame_length - overlap_length ) |> Nx.multiply(window) |> Nx.fft(length: fft_length) {num_frames, fft_length} = Nx.shape(spectrum) frequencies = fft_frequencies(sampling_rate, fft_length: fft_length) # assign the middle of the equivalent time window as the time for the given frame time_step = frame_length / (2 * sampling_rate) last_frame = time_step * num_frames times = Nx.linspace(time_step, last_frame, n: num_frames, name: :frames) output = case opts[:scaling] do :spectrum -> spectrum / Nx.sum(window) :psd -> spectrum / Nx.sqrt(sampling_rate * Nx.sum(window ** 2)) nil -> spectrum scaling -> raise ArgumentError, "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}" end {Nx.reshape(output, spectrum.shape, names: [:frames, :frequencies]), times, frequencies} end @doc """ Computes the frequency bins for a FFT with given options. ## Arguments * `sampling_rate` - Sampling frequency in Hz. ## Options * `:fft_length` - Number of FFT frequency bins. * `:type` - Optional output type. Defaults to `{:f, 32}` * `:name` - Optional axis name for the tensor. Defaults to `:frequencies` ## Examples iex> NxSignal.fft_frequencies(1.6e4, fft_length: 10) #Nx.Tensor< f32[frequencies: 10] [0.0, 1.6e3, 3.2e3, 4.8e3, 6.4e3, 8.0e3, 9.6e3, 1.12e4, 1.28e4, 1.44e4] > """ @doc type: :time_frequency defn fft_frequencies(sampling_rate, opts \\ []) do opts = keyword!(opts, [:fft_length, type: {:f, 32}, name: :frequencies, endpoint: false]) fft_length = opts[:fft_length] step = sampling_rate / fft_length Nx.linspace(0, step * fft_length, n: fft_length, type: opts[:type], name: opts[:name], endpoint: opts[:endpoint] ) end @doc """ Returns a tensor of K windows of length N ## Options * `:window_length` - the number of samples in a window * `:stride` - The number of samples to skip between windows. Defaults to `1`. * `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered at the start of the signal. The padding is applied for the whole input, rather than individual windows. For `:zeros`, effectively each incomplete window will be zero-padded. ## Examples iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 4) #Nx.Tensor< s64[5][4] [ [0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 10], [3, 4, 10, 11], [4, 10, 11, 12] ] > iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 3) #Nx.Tensor< s64[6][3] [ [0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 10], [4, 10, 11], [10, 11, 12] ] > iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11]), window_length: 2, stride: 2, padding: [{0, 3}]) #Nx.Tensor< s64[5][2] [ [0, 1], [2, 3], [4, 10], [11, 0], [0, 0] ] > iex> t = Nx.iota({7}); iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1) #Nx.Tensor< s64[8][6] [ [3, 2, 1, 0, 1, 2], [2, 1, 0, 1, 2, 3], [1, 0, 1, 2, 3, 4], [0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6], [2, 3, 4, 5, 6, 5], [3, 4, 5, 6, 5, 4], [4, 5, 6, 5, 4, 3] ] > iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2) #Nx.Tensor< s64[6][6] [ [3, 2, 1, 0, 1, 2], [1, 0, 1, 2, 3, 4], [1, 2, 3, 4, 5, 6], [3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 9, 8], [7, 8, 9, 8, 7, 6] ] > """ @doc type: :windowing deftransform as_windowed(tensor, opts \\ []) do if opts[:padding] == :reflect do as_windowed_reflect_padding(tensor, opts) else as_windowed_non_reflect_padding(tensor, opts) end end deftransformp as_windowed_parse_reflect_opts(shape, opts) do window_length = opts[:window_length] as_windowed_parse_non_reflect_opts( shape, Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}]) ) end deftransformp as_windowed_parse_non_reflect_opts(shape, opts) do opts = Keyword.validate!(opts, [:window_length, padding: :valid, stride: 1]) window_length = opts[:window_length] window_dimensions = {window_length} padding = opts[:padding] [stride] = strides = case opts[:stride] do stride when is_list(stride) -> stride stride when is_integer(stride) and stride >= 1 -> [stride] stride -> raise ArgumentError, "expected an integer >= 1 or a list of integers, got: #{inspect(stride)}" end padding_config = as_windowed_to_padding_config(shape, window_dimensions, padding) # trick so that we can get Nx to calculate the pooled shape for us %{shape: pooled_shape} = Nx.window_max( Nx.iota(shape, backend: Nx.Defn.Expr), window_dimensions, padding: padding, strides: strides ) output_shape = {Tuple.product(pooled_shape), window_length} {window_length, stride, padding_config, output_shape} end defp as_windowed_to_padding_config(shape, kernel_size, mode) do case mode do :valid -> List.duplicate({0, 0, 0}, tuple_size(shape)) :same -> Enum.zip_with(Tuple.to_list(shape), Tuple.to_list(kernel_size), fn dim, k -> padding_size = max(dim - 1 + k - dim, 0) {floor(padding_size / 2), ceil(padding_size / 2), 0} end) config when is_list(config) -> Enum.map(config, fn {x, y} when is_integer(x) and is_integer(y) -> {x, y, 0} _other -> raise ArgumentError, "padding must be a list of {high, low} tuples, where each element is an integer. " <> "Got: #{inspect(config)}" end) mode -> raise ArgumentError, "invalid padding mode specified, padding must be one" <> " of :valid, :same, or a padding configuration, got:" <> " #{inspect(mode)}" end end defnp as_windowed_non_reflect_padding(tensor, opts \\ []) do # current implementation only supports windowing 1D tensors {window_length, stride, padding, output_shape} = as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts) tensor = Nx.pad(tensor, 0, padding) as_windowed_apply(tensor, stride, output_shape, window_length) end defnp as_windowed_reflect_padding(tensor, opts \\ []) do # current implementation only supports windowing 1D tensors {window_length, stride, _padding, output_shape} = as_windowed_parse_reflect_opts(Nx.shape(tensor), opts) half_window = div(window_length, 2) tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}]) as_windowed_apply(tensor, stride, output_shape, window_length) end defnp as_windowed_apply(tensor, stride, output_shape, window_length) do output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape) {num_windows, _} = Nx.shape(output) [output, tensor] = Nx.broadcast_vectors([output, tensor]) {output, _, _, _} = while {output, i = 0, current_window = 0, t = tensor}, current_window < num_windows do window = t |> Nx.slice([i], [window_length]) updated = Nx.put_slice(output, [current_window, 0], Nx.new_axis(window, 0)) {updated, i + stride, current_window + 1, t} end output end @doc """ Generates weights for converting an STFT representation into MEL-scale. See also: `stft/3`, `istft/3`, `stft_to_mel/3` ## Arguments * `fft_length` - Number of FFT bins * `mel_bins` - Number of target MEL bins * `sampling_rate` - Sampling frequency in Hz ## Options * `:max_mel` - the pitch for the last MEL bin before log scaling. Defaults to 3016 * `:mel_frequency_spacing` - the distance in Hz between two MEL bins before log scaling. Defaults to 66.6 * `:type` - Target output type. Defaults to `{:f, 32}` ## Examples iex> NxSignal.mel_filters(10, 5, 8.0e3) #Nx.Tensor< f32[mels: 5][frequencies: 10] [ [0.0, 8.129207999445498e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 9.972016559913754e-4, 2.1870288765057921e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 9.510891977697611e-4, 4.150509194005281e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 4.035891906823963e-4, 5.276656011119485e-4, 2.574124082457274e-4, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 7.329034269787371e-5, 2.342205698369071e-4, 3.8295105332508683e-4, 2.8712040511891246e-4, 1.9128978601656854e-4, 9.545915963826701e-5] ] > """ @doc type: :time_frequency deftransform mel_filters(fft_length, mel_bins, sampling_rate, opts \\ []) do opts = Keyword.validate!(opts, max_mel: 3016, mel_frequency_spacing: 200 / 3, type: {:f, 32} ) mel_filters_n(sampling_rate, opts[:max_mel], opts[:mel_frequency_spacing], type: opts[:type], fft_length: fft_length, mel_bins: mel_bins ) end defnp mel_filters_n(sampling_rate, max_mel, f_sp, opts) do fft_length = opts[:fft_length] mel_bins = opts[:mel_bins] type = opts[:type] fftfreqs = fft_frequencies(sampling_rate, type: type, fft_length: fft_length) mels = Nx.linspace(0, max_mel / f_sp, type: type, n: mel_bins + 2, name: :mels) freqs = f_sp * mels min_log_hz = 1_000 min_log_mel = min_log_hz / f_sp # numpy uses the f64 value by default logstep = Nx.log(6.4) / 27 log_t = mels >= min_log_mel # This is the same as freqs[log_t] = min_log_hz * Nx.exp(logstep * (mels[log_t] - min_log_mel)) # notice that since freqs and mels are indexed by the same conditional tensor, we don't # need to slice either of them mel_f = Nx.select(log_t, min_log_hz * Nx.exp(logstep * (mels - min_log_mel)), freqs) fdiff = Nx.new_axis(mel_f[1..-1//1] - mel_f[0..-2//1], 1) ramps = Nx.new_axis(mel_f, 1) - fftfreqs lower = -ramps[0..(mel_bins - 1)] / fdiff[0..(mel_bins - 1)] upper = ramps[2..(mel_bins + 1)//1] / fdiff[1..mel_bins] weights = Nx.max(0, Nx.min(lower, upper)) enorm = 2.0 / (mel_f[2..(mel_bins + 1)] - mel_f[0..(mel_bins - 1)]) weights * Nx.new_axis(enorm, 1) end @doc """ Converts a given STFT time-frequency spectrum into a MEL-scale time-frequency spectrum. See also: `stft/3`, `istft/3`, `mel_filters/4` ## Arguments * `z` - STFT spectrum * `sampling_rate` - Sampling frequency in Hz ## Options * `:fft_length` - Number of FFT bins * `:mel_bins` - Number of target MEL bins. Defaults to 128 * `:type` - Target output type. Defaults to `{:f, 32}` ## Examples iex> fft_length = 16 iex> sampling_rate = 8.0e3 iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect) iex> Nx.axis_size(z, :frequencies) 16 iex> Nx.axis_size(z, :frames) 6 iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4) #Nx.Tensor< f32[frames: 6][mel: 4] [ [0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825], [0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537], [0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198], [0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989], [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721], [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721] ] > """ @doc type: :time_frequency defn stft_to_mel(z, sampling_rate, opts \\ []) do opts = keyword!(opts, [:fft_length, :mel_bins, :max_mel, :mel_frequency_spacing, type: {:f, 32}]) magnitudes = Nx.abs(z) ** 2 filters = mel_filters(opts[:fft_length], opts[:mel_bins], sampling_rate, mel_filters_opts(opts)) freq_size = div(opts[:fft_length], 2) real_freqs_mag = Nx.slice_along_axis(magnitudes, 0, freq_size, axis: :frequencies) real_freqs_filters = Nx.slice_along_axis(filters, 0, freq_size, axis: :frequencies) mel_spec = Nx.dot( real_freqs_mag, [:frequencies], real_freqs_filters, [:frequencies] ) mel_spec = Nx.reshape(mel_spec, Nx.shape(mel_spec), names: [:frames, :mel]) log_spec = Nx.log(Nx.clip(mel_spec, 1.0e-10, :infinity)) / Nx.log(10) log_spec = Nx.max(log_spec, Nx.reduce_max(log_spec) - 8) (log_spec + 4) / 4 end deftransformp mel_filters_opts(opts) do Keyword.take(opts, [:max_mel, :mel_frequency_spacing, :type]) end @doc ~S""" Computes the Inverse Short-Time Fourier Transform of a tensor. Returns a tensor of M time-domain frames of length `fft_length`. See also: `NxSignal.Windows`, `stft/3` ## Options * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`. * `:overlap_length` - the number of samples for the overlap between frames. Defaults to half the window size. * `:sampling_rate` - the sampling rate $F_s$ in Hz. Defaults to `1000`. * `:scaling` - `nil`, `:spectrum` or `:psd`. * `:spectrum` - each frame is multiplied by $\sum_{i} window[i]$. * `nil` - No scaling is applied. * `:psd` - each frame is multiplied by $\sqrt{F\_s\sum_{i} window[i]^2}$. ## Examples In general, `istft/3` takes in the same parameters and window as the `stft/3` that generated the spectrum. In the first example, we can notice that the reconstruction is mostly perfect, aside from the first sample. This is because the Hann window only ensures perfect reconstruction in overlapping regions, so the edges of the signal end up being distorted. iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) iex> w = NxSignal.Windows.hann(n: 4) iex> opts = [sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) iex> Nx.as_type(result, Nx.type(t)) #Nx.Tensor< s64[8] [0, 10, 1, 0, 10, 10, 2, 20] > Different scaling options are available (see `stft/3` for a more detailed explanation). For perfect reconstruction, you want to use the same scaling as the STFT: iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) iex> w = NxSignal.Windows.hann(n: 4) iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) iex> Nx.as_type(result, Nx.type(t)) #Nx.Tensor< s64[8] [0, 10, 1, 0, 10, 10, 2, 20] > iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32) iex> w = NxSignal.Windows.hann(n: 4) iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) iex> Nx.as_type(result, Nx.type(t)) #Nx.Tensor< f32[8] [0.0, 10.0, 0.9999999403953552, -2.1900146407460852e-7, 10.0, 10.0, 2.000000238418579, 20.0] > """ @doc type: :time_frequency defn istft(data, window, opts) do opts = keyword!(opts, [:fft_length, :overlap_length, :scaling, sampling_rate: 1000]) fft_length = case opts[:fft_length] do nil -> :power_of_two fft_length -> fft_length end overlap_length = case opts[:overlap_length] do nil -> div(Nx.size(window), 2) overlap_length -> overlap_length end sampling_rate = case {opts[:scaling], opts[:sampling_rate]} do {:psd, nil} -> raise ArgumentError, ":sampling_rate is mandatory if scaling is :psd" {_, sampling_rate} -> sampling_rate end frames = Nx.ifft(data, length: fft_length) frames_rescaled = case opts[:scaling] do :spectrum -> frames * Nx.sum(window) :psd -> frames * Nx.sqrt(sampling_rate * Nx.sum(window ** 2)) nil -> frames scaling -> raise ArgumentError, "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}" end result_non_normalized = overlap_and_add(frames_rescaled * window, overlap_length: overlap_length) normalization_factor = overlap_and_add(Nx.broadcast(Nx.abs(window) ** 2, data.shape), overlap_length: overlap_length ) normalization_factor = Nx.select(normalization_factor > 1.0e-10, normalization_factor, 1.0) result_non_normalized / normalization_factor end @doc """ Performs the overlap-and-add algorithm over an {..., M, N}-shaped tensor, where M is the number of windows and N is the window size. The tensor is zero-padded on the right so the last window fully appears in the result. ## Options * `:overlap_length` - The number of overlapping samples between windows * `:type` - output type for casting the accumulated result. If not given, defaults to `Nx.Type.to_complex/1` called on the input type. ## Examples iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 0) #Nx.Tensor< s64[12] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] > iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3) #Nx.Tensor< s64[6] [0, 5, 15, 18, 17, 11] > iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1) iex> NxSignal.overlap_and_add(t, overlap_length: 3) #Nx.Tensor< vectorized[x: 2][y: 1] s64[5] [ [ [0, 5, 7, 9, 7] ], [ [10, 25, 27, 29, 17] ] ] > """ @doc type: :windowing defn overlap_and_add(tensor, opts \\ []) do opts = keyword!(opts, [:overlap_length, type: Nx.type(tensor)]) overlap_length = opts[:overlap_length] %{vectorized_axes: vectorized_axes, shape: input_shape} = tensor num_windows = Nx.axis_size(tensor, -2) window_length = Nx.axis_size(tensor, -1) if overlap_length >= window_length do raise ArgumentError, "overlap_length must be a number less than the window size #{window_length}, got: #{inspect(window_length)}" end tensor = Nx.revectorize(tensor, [condensed_vectors: :auto, windows: num_windows], target_shape: {window_length} ) stride = window_length - overlap_length output_holder_shape = {num_windows * stride + overlap_length} out = Nx.broadcast( Nx.tensor(0, type: tensor.type), output_holder_shape ) idx_template = Nx.iota({window_length, 1}, vectorized_axes: [windows: 1]) i = Nx.iota({num_windows}) |> Nx.vectorize(:windows) idx = idx_template + i * stride [%{vectorized_axes: [condensed_vectors: n, windows: _]} = tensor, idx] = Nx.broadcast_vectors([tensor, idx]) tensor = Nx.revectorize(tensor, [condensed_vectors: n], target_shape: {:auto}) idx = Nx.revectorize(idx, [condensed_vectors: n], target_shape: {:auto, 1}) out_shape = overlap_and_add_output_shape(out.shape, input_shape) out |> Nx.indexed_add(idx, tensor) |> Nx.as_type(opts[:type]) |> Nx.revectorize(vectorized_axes, target_shape: out_shape) end deftransformp overlap_and_add_output_shape({out_len}, in_shape) do idx = tuple_size(in_shape) - 2 in_shape |> Tuple.delete_at(idx) |> Tuple.delete_at(idx) |> Tuple.append(out_len) end end