557 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Elixir
		
	
	
	
	
	
			
		
		
	
	
			557 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Elixir
		
	
	
	
	
	
defmodule NxImage do
 | 
						|
  @moduledoc """
 | 
						|
  Image processing in `Nx`.
 | 
						|
 | 
						|
  All functions expect images to be tensors in either HWC or CHW order,
 | 
						|
  with an arbitrary number of leading batch axes.
 | 
						|
 | 
						|
  All transformations preserve the input type, rounding if necessary.
 | 
						|
  For higher precision, cast the input to floating-point beforehand.
 | 
						|
  """
 | 
						|
 | 
						|
  import Nx.Defn
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Crops an image at the center.
 | 
						|
 | 
						|
  If the image is too small to be cropped to the desired size, it gets
 | 
						|
  padded with zeros.
 | 
						|
 | 
						|
  ## Options
 | 
						|
 | 
						|
    * `:channels` - channels location, either `:first` or `:last`.
 | 
						|
      Defaults to `:last`
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.iota({4, 4, 1}, type: :u8)
 | 
						|
      iex> NxImage.center_crop(image, {2, 2})
 | 
						|
      #Nx.Tensor<
 | 
						|
        u8[2][2][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [5],
 | 
						|
            [6]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [9],
 | 
						|
            [10]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
      iex> image = Nx.iota({2, 2, 1}, type: :u8)
 | 
						|
      iex> NxImage.center_crop(image, {1, 4})
 | 
						|
      #Nx.Tensor<
 | 
						|
        u8[1][4][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0],
 | 
						|
            [0],
 | 
						|
            [1],
 | 
						|
            [0]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :transformation
 | 
						|
  deftransform center_crop(input, size, opts \\ []) when is_tuple(size) do
 | 
						|
    opts = Keyword.validate!(opts, channels: :last)
 | 
						|
    validate_image!(input)
 | 
						|
 | 
						|
    pad_config =
 | 
						|
      for {axis, size, out_size} <- spatial_axes_with_sizes(input, size, opts[:channels]),
 | 
						|
          reduce: List.duplicate({0, 0, 0}, Nx.rank(input)) do
 | 
						|
        pad_config ->
 | 
						|
          low = div(size - out_size, 2)
 | 
						|
          high = low + out_size
 | 
						|
          List.replace_at(pad_config, axis, {-low, high - size, 0})
 | 
						|
      end
 | 
						|
 | 
						|
    Nx.pad(input, 0, pad_config)
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp spatial_axes_with_sizes(input, size, channels) do
 | 
						|
    {height_axis, width_axis} = spatial_axes(input, channels)
 | 
						|
    {height, width} = size(input, channels)
 | 
						|
    {out_height, out_width} = size
 | 
						|
    [{height_axis, height, out_height}, {width_axis, width, out_width}]
 | 
						|
  end
 | 
						|
 | 
						|
  # Returns the image size as `{height, width}`.
 | 
						|
  deftransformp size(input, channels) do
 | 
						|
    {height_axis, width_axis} = spatial_axes(input, channels)
 | 
						|
    {Nx.axis_size(input, height_axis), Nx.axis_size(input, width_axis)}
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Resizes an image.
 | 
						|
 | 
						|
  ## Options
 | 
						|
 | 
						|
    * `:method` - the resizing method to use, either of `:nearest`,
 | 
						|
      `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
 | 
						|
      `:bilinear`
 | 
						|
 | 
						|
    * `:antialias` - whether an anti-aliasing filter should be used
 | 
						|
      when downsampling. This has no effect with upsampling. Defaults
 | 
						|
      to `true`
 | 
						|
 | 
						|
    * `:channels` - channels location, either `:first` or `:last`.
 | 
						|
      Defaults to `:last`
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.iota({2, 2, 1}, type: :u8)
 | 
						|
      iex> NxImage.resize(image, {3, 3}, method: :nearest)
 | 
						|
      #Nx.Tensor<
 | 
						|
        u8[3][3][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0],
 | 
						|
            [1],
 | 
						|
            [1]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [2],
 | 
						|
            [3],
 | 
						|
            [3]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [2],
 | 
						|
            [3],
 | 
						|
            [3]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
      iex> image = Nx.iota({2, 2, 1}, type: :f32)
 | 
						|
      iex> NxImage.resize(image, {3, 3}, method: :bilinear)
 | 
						|
      #Nx.Tensor<
 | 
						|
        f32[3][3][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0.0],
 | 
						|
            [0.5],
 | 
						|
            [1.0]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [1.0],
 | 
						|
            [1.5],
 | 
						|
            [2.0]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [2.0],
 | 
						|
            [2.5],
 | 
						|
            [3.0]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :transformation
 | 
						|
  deftransform resize(input, size, opts \\ []) when is_tuple(size) do
 | 
						|
    opts = Keyword.validate!(opts, channels: :last, method: :bilinear, antialias: true)
 | 
						|
    validate_image!(input)
 | 
						|
 | 
						|
    {spatial_axes, out_shape} =
 | 
						|
      input
 | 
						|
      |> spatial_axes_with_sizes(size, opts[:channels])
 | 
						|
      |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end)
 | 
						|
      |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape ->
 | 
						|
        {axis, put_elem(out_shape, axis, out_size)}
 | 
						|
      end)
 | 
						|
 | 
						|
    antialias = opts[:antialias]
 | 
						|
 | 
						|
    resized_input =
 | 
						|
      case opts[:method] do
 | 
						|
        :nearest ->
 | 
						|
          resize_nearest(input, out_shape, spatial_axes)
 | 
						|
 | 
						|
        :bilinear ->
 | 
						|
          resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_linear_kernel/1)
 | 
						|
 | 
						|
        :bicubic ->
 | 
						|
          resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_cubic_kernel/1)
 | 
						|
 | 
						|
        :lanczos3 ->
 | 
						|
          resize_with_kernel(
 | 
						|
            input,
 | 
						|
            out_shape,
 | 
						|
            spatial_axes,
 | 
						|
            antialias,
 | 
						|
            &fill_lanczos_kernel(3, &1)
 | 
						|
          )
 | 
						|
 | 
						|
        :lanczos5 ->
 | 
						|
          resize_with_kernel(
 | 
						|
            input,
 | 
						|
            out_shape,
 | 
						|
            spatial_axes,
 | 
						|
            antialias,
 | 
						|
            &fill_lanczos_kernel(5, &1)
 | 
						|
          )
 | 
						|
 | 
						|
        method ->
 | 
						|
          raise ArgumentError,
 | 
						|
                "expected :method to be either of :nearest, :bilinear, :bicubic, " <>
 | 
						|
                  ":lanczos3, :lanczos5, got: #{inspect(method)}"
 | 
						|
      end
 | 
						|
 | 
						|
    cast_to(resized_input, input)
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp spatial_axes(input, channels) do
 | 
						|
    axes =
 | 
						|
      case channels do
 | 
						|
        :first -> [-2, -1]
 | 
						|
        :last -> [-3, -2]
 | 
						|
      end
 | 
						|
 | 
						|
    axes
 | 
						|
    |> Enum.map(&Nx.axis_index(input, &1))
 | 
						|
    |> List.to_tuple()
 | 
						|
  end
 | 
						|
 | 
						|
  defnp cast_to(left, right) do
 | 
						|
    left_type = Nx.type(left)
 | 
						|
    right_type = Nx.type(right)
 | 
						|
 | 
						|
    left =
 | 
						|
      if Nx.Type.float?(left_type) and Nx.Type.integer?(right_type) do
 | 
						|
        Nx.round(left)
 | 
						|
      else
 | 
						|
        left
 | 
						|
      end
 | 
						|
 | 
						|
    left
 | 
						|
    |> Nx.as_type(right_type)
 | 
						|
    |> Nx.reshape(left, names: Nx.names(right))
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp resize_nearest(input, out_shape, spatial_axes) do
 | 
						|
    singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple()
 | 
						|
 | 
						|
    for axis <- spatial_axes, reduce: input do
 | 
						|
      input ->
 | 
						|
        input_shape = Nx.shape(input)
 | 
						|
        input_size = elem(input_shape, axis)
 | 
						|
        output_size = elem(out_shape, axis)
 | 
						|
        inv_scale = input_size / output_size
 | 
						|
        offset = Nx.iota({output_size}) |> Nx.add(0.5) |> Nx.multiply(inv_scale)
 | 
						|
        offset = offset |> Nx.floor() |> Nx.as_type({:s, 32})
 | 
						|
 | 
						|
        offset =
 | 
						|
          offset
 | 
						|
          |> Nx.reshape(put_elem(singular_shape, axis, output_size))
 | 
						|
          |> Nx.broadcast(put_elem(input_shape, axis, output_size))
 | 
						|
 | 
						|
        Nx.take_along_axis(input, offset, axis: axis)
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  @f32_eps :math.pow(2, -23)
 | 
						|
 | 
						|
  deftransformp resize_with_kernel(input, out_shape, spatial_axes, antialias, kernel_fun) do
 | 
						|
    for axis <- spatial_axes, reduce: input do
 | 
						|
      input ->
 | 
						|
        resize_axis_with_kernel(input,
 | 
						|
          axis: axis,
 | 
						|
          output_size: elem(out_shape, axis),
 | 
						|
          antialias: antialias,
 | 
						|
          kernel_fun: kernel_fun
 | 
						|
        )
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  defnp resize_axis_with_kernel(input, opts) do
 | 
						|
    axis = opts[:axis]
 | 
						|
    output_size = opts[:output_size]
 | 
						|
    antialias = opts[:antialias]
 | 
						|
    kernel_fun = opts[:kernel_fun]
 | 
						|
 | 
						|
    input_size = Nx.axis_size(input, axis)
 | 
						|
 | 
						|
    inv_scale = input_size / output_size
 | 
						|
 | 
						|
    kernel_scale =
 | 
						|
      if antialias do
 | 
						|
        max(1, inv_scale)
 | 
						|
      else
 | 
						|
        1
 | 
						|
      end
 | 
						|
 | 
						|
    sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
 | 
						|
    x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
 | 
						|
    weights = kernel_fun.(x)
 | 
						|
 | 
						|
    weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)
 | 
						|
 | 
						|
    weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0)
 | 
						|
 | 
						|
    input = Nx.dot(input, [axis], weights, [0])
 | 
						|
    # The transformed axis is moved to the end, so we transpose back
 | 
						|
    reorder_axis(input, -1, axis)
 | 
						|
  end
 | 
						|
 | 
						|
  defnp fill_linear_kernel(x) do
 | 
						|
    Nx.max(0, 1 - x)
 | 
						|
  end
 | 
						|
 | 
						|
  defnp fill_cubic_kernel(x) do
 | 
						|
    # See https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
 | 
						|
    out = (1.5 * x - 2.5) * x * x + 1
 | 
						|
    out = Nx.select(x >= 1, ((-0.5 * x + 2.5) * x - 4) * x + 2, out)
 | 
						|
    Nx.select(x >= 2, 0, out)
 | 
						|
  end
 | 
						|
 | 
						|
  @pi :math.pi()
 | 
						|
 | 
						|
  defnp fill_lanczos_kernel(radius, x) do
 | 
						|
    y = radius * Nx.sin(@pi * x) * Nx.sin(@pi * x / radius)
 | 
						|
    out = Nx.select(x > 1.0e-3, safe_divide(y, @pi ** 2 * x ** 2), 1)
 | 
						|
    Nx.select(x > radius, 0, out)
 | 
						|
  end
 | 
						|
 | 
						|
  defnp safe_divide(x, y) do
 | 
						|
    x / Nx.select(y != 0, y, 1)
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp reorder_axis(tensor, axis, target_axis) do
 | 
						|
    axes = Nx.axes(tensor)
 | 
						|
    {source_axis, axes} = List.pop_at(axes, axis)
 | 
						|
    axes = List.insert_at(axes, target_axis, source_axis)
 | 
						|
    Nx.transpose(tensor, axes: axes)
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Scales an image such that the short edge matches the given size.
 | 
						|
 | 
						|
  ## Options
 | 
						|
 | 
						|
    * `:method` - the resizing method to use, same as `resize/2`
 | 
						|
 | 
						|
    * `:antialias` - whether an anti-aliasing filter should be used
 | 
						|
      when downsampling. This has no effect with upsampling. Defaults
 | 
						|
      to `true`
 | 
						|
 | 
						|
    * `:channels` - channels location, either `:first` or `:last`.
 | 
						|
      Defaults to `:last`
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.iota({2, 4, 1}, type: :u8)
 | 
						|
      iex> resized_image = NxImage.resize_short(image, 3, method: :nearest)
 | 
						|
      iex> Nx.shape(resized_image)
 | 
						|
      {3, 6, 1}
 | 
						|
 | 
						|
      iex> image = Nx.iota({4, 2, 1}, type: :u8)
 | 
						|
      iex> resized_image = NxImage.resize_short(image, 3, method: :nearest)
 | 
						|
      iex> Nx.shape(resized_image)
 | 
						|
      {6, 3, 1}
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :transformation
 | 
						|
  deftransform resize_short(input, size, opts \\ []) when is_integer(size) do
 | 
						|
    opts = Keyword.validate!(opts, channels: :last, method: :bilinear, antialias: true)
 | 
						|
    validate_image!(input)
 | 
						|
    resize_short_n(input, [size: size] ++ opts)
 | 
						|
  end
 | 
						|
 | 
						|
  defnp resize_short_n(input, opts) do
 | 
						|
    size = opts[:size]
 | 
						|
    method = opts[:method]
 | 
						|
    antialias = opts[:antialias]
 | 
						|
    channels = opts[:channels]
 | 
						|
 | 
						|
    {height, width} = size(input, channels)
 | 
						|
    {out_height, out_width} = resize_short_size(height, width, size)
 | 
						|
 | 
						|
    resize(input, {out_height, out_width},
 | 
						|
      method: method,
 | 
						|
      antialias: antialias,
 | 
						|
      channels: channels
 | 
						|
    )
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp resize_short_size(height, width, size) do
 | 
						|
    {short, long} = if height < width, do: {height, width}, else: {width, height}
 | 
						|
 | 
						|
    out_short = size
 | 
						|
    out_long = floor(size * long / short)
 | 
						|
 | 
						|
    if height < width, do: {out_short, out_long}, else: {out_long, out_short}
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Normalizes an image according to the given per-channel mean and
 | 
						|
  standard deviation.
 | 
						|
 | 
						|
    * `:channels` - channels location, either `:first` or `:last`.
 | 
						|
      Defaults to `:last`
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.iota({2, 2, 3}, type: :f32)
 | 
						|
      iex> mean = Nx.tensor([0.485, 0.456, 0.406])
 | 
						|
      iex> std = Nx.tensor([0.229, 0.224, 0.225])
 | 
						|
      iex> NxImage.normalize(image, mean, std)
 | 
						|
      #Nx.Tensor<
 | 
						|
        f32[2][2][3]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [-2.1179039478302, 2.4285714626312256, 7.084444522857666],
 | 
						|
            [10.982532501220703, 15.821427345275879, 20.41777801513672]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [24.08296775817871, 29.214284896850586, 33.7511100769043],
 | 
						|
            [37.183406829833984, 42.607139587402344, 47.08444595336914]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :transformation
 | 
						|
  defn normalize(input, mean, std, opts \\ []) do
 | 
						|
    opts = keyword!(opts, channels: :last)
 | 
						|
    validate_image!(input)
 | 
						|
 | 
						|
    mean = broadcast_channel_info(mean, input, opts[:channels], "mean")
 | 
						|
    std = broadcast_channel_info(std, input, opts[:channels], "std")
 | 
						|
 | 
						|
    normalized_input = (input - mean) / std
 | 
						|
 | 
						|
    cast_to(normalized_input, input)
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp broadcast_channel_info(tensor, input, channels, name) do
 | 
						|
    rank = Nx.rank(input)
 | 
						|
 | 
						|
    channels_axis =
 | 
						|
      case channels do
 | 
						|
        :first -> rank - 3
 | 
						|
        :last -> rank - 1
 | 
						|
      end
 | 
						|
 | 
						|
    num_channels = Nx.axis_size(input, channels_axis)
 | 
						|
 | 
						|
    case Nx.shape(tensor) do
 | 
						|
      {^num_channels} ->
 | 
						|
        :ok
 | 
						|
 | 
						|
      shape ->
 | 
						|
        raise ArgumentError,
 | 
						|
              "expected #{name} to have shape {#{num_channels}}, got: #{inspect(shape)}"
 | 
						|
    end
 | 
						|
 | 
						|
    shape = 1 |> Tuple.duplicate(rank) |> put_elem(channels_axis, :auto)
 | 
						|
    Nx.reshape(tensor, shape)
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Converts pixel values (0-255) into a continuous range.
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.tensor([[[0], [128]], [[191], [255]]])
 | 
						|
      iex> NxImage.to_continuous(image, 0.0, 1.0)
 | 
						|
      #Nx.Tensor<
 | 
						|
        f32[2][2][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0.0],
 | 
						|
            [0.501960813999176]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [0.7490196228027344],
 | 
						|
            [1.0]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
      iex> image = Nx.tensor([[[0], [128]], [[191], [255]]])
 | 
						|
      iex> NxImage.to_continuous(image, -1.0, 1.0)
 | 
						|
      #Nx.Tensor<
 | 
						|
        f32[2][2][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [-1.0],
 | 
						|
            [0.003921627998352051]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [0.49803924560546875],
 | 
						|
            [1.0]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :conversion
 | 
						|
  defn to_continuous(input, min, max) do
 | 
						|
    validate_image!(input)
 | 
						|
 | 
						|
    input / 255.0 * (max - min) + min
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Converts values from continuous range into pixel values (0-255).
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> image = Nx.tensor([[[0.0], [0.5]], [[0.75], [1.0]]])
 | 
						|
      iex> NxImage.from_continuous(image, 0.0, 1.0)
 | 
						|
      #Nx.Tensor<
 | 
						|
        u8[2][2][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0],
 | 
						|
            [128]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [191],
 | 
						|
            [255]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
      iex> image = Nx.tensor([[[-1.0], [0.0]], [[0.5], [1.0]]])
 | 
						|
      iex> NxImage.from_continuous(image, -1.0, 1.0)
 | 
						|
      #Nx.Tensor<
 | 
						|
        u8[2][2][1]
 | 
						|
        [
 | 
						|
          [
 | 
						|
            [0],
 | 
						|
            [128]
 | 
						|
          ],
 | 
						|
          [
 | 
						|
            [191],
 | 
						|
            [255]
 | 
						|
          ]
 | 
						|
        ]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
  @doc type: :conversion
 | 
						|
  defn from_continuous(input, min, max) do
 | 
						|
    validate_image!(input)
 | 
						|
 | 
						|
    input = (input - min) / (max - min) * 255.0
 | 
						|
 | 
						|
    input
 | 
						|
    |> Nx.round()
 | 
						|
    |> Nx.clip(0, 255)
 | 
						|
    |> Nx.as_type(:u8)
 | 
						|
  end
 | 
						|
 | 
						|
  deftransformp validate_image!(input) do
 | 
						|
    rank = Nx.rank(input)
 | 
						|
 | 
						|
    if rank < 3 do
 | 
						|
      raise ArgumentError,
 | 
						|
            "expected the image input to have rank 3 or higher, got: #{inspect(rank)}"
 | 
						|
    end
 | 
						|
  end
 | 
						|
end
 |