233 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Elixir
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Elixir
		
	
	
	
	
	
defmodule Safetensors do
 | 
						|
  @moduledoc """
 | 
						|
  [Safetensors](https://huggingface.co/docs/safetensors/index) implementation for `Nx`.
 | 
						|
 | 
						|
  ## Examples
 | 
						|
 | 
						|
      iex> x = Nx.tensor([1, 2, 3])
 | 
						|
      iex> y = Nx.tensor([1.0, 2.0, 3.0])
 | 
						|
      iex> tensors = %{"x" => x, "y" => y}
 | 
						|
      iex> data = Safetensors.dump(tensors)
 | 
						|
      iex> tensors = Safetensors.load!(data)
 | 
						|
      iex> tensors["x"]
 | 
						|
      #Nx.Tensor<
 | 
						|
        s64[3]
 | 
						|
        [1, 2, 3]
 | 
						|
      >
 | 
						|
      iex> tensors["y"]
 | 
						|
      #Nx.Tensor<
 | 
						|
        f32[3]
 | 
						|
        [1.0, 2.0, 3.0]
 | 
						|
      >
 | 
						|
 | 
						|
  """
 | 
						|
 | 
						|
  alias Safetensors.Shared
 | 
						|
 | 
						|
  @header_metadata_key "__metadata__"
 | 
						|
 | 
						|
  @type_to_dtype %{
 | 
						|
    {:bf, 16} => "BF16",
 | 
						|
    {:f, 64} => "F64",
 | 
						|
    {:f, 32} => "F32",
 | 
						|
    {:f, 16} => "F16",
 | 
						|
    {:s, 64} => "I64",
 | 
						|
    {:s, 32} => "I32",
 | 
						|
    {:s, 16} => "I16",
 | 
						|
    {:s, 8} => "I8",
 | 
						|
    {:u, 64} => "U64",
 | 
						|
    {:u, 32} => "U32",
 | 
						|
    {:u, 16} => "U16",
 | 
						|
    {:u, 8} => "U8"
 | 
						|
  }
 | 
						|
 | 
						|
  @dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Writes a map of tensors to a file.
 | 
						|
 | 
						|
  Tensors are written into the file one by one, without the need to
 | 
						|
  dump all of them into the memory at once.
 | 
						|
  """
 | 
						|
  @spec write!(path :: Path.t(), %{String.t() => Nx.Tensor.t()}) :: :ok
 | 
						|
  def write!(path, tensors) when is_map(tensors) do
 | 
						|
    File.open!(path, [:write, :raw], fn file ->
 | 
						|
      tensors = Enum.sort(tensors)
 | 
						|
 | 
						|
      {header_entries, _offset} =
 | 
						|
        Enum.map_reduce(tensors, 0, fn {tensor_name, tensor}, offset ->
 | 
						|
          tensor_header_entry(tensor_name, tensor, offset)
 | 
						|
        end)
 | 
						|
 | 
						|
      :ok = :file.write(file, header_binary(header_entries))
 | 
						|
 | 
						|
      for {_tensor_name, tensor} <- tensors do
 | 
						|
        :ok = :file.write(file, tensor_to_iodata(tensor))
 | 
						|
      end
 | 
						|
    end)
 | 
						|
 | 
						|
    :ok
 | 
						|
  end
 | 
						|
 | 
						|
  defp tensor_header_entry(tensor_name, tensor, offset) do
 | 
						|
    end_offset = offset + tensor_byte_size(tensor)
 | 
						|
 | 
						|
    header_entry = {
 | 
						|
      tensor_name,
 | 
						|
      Jason.OrderedObject.new(
 | 
						|
        dtype: tensor |> Nx.type() |> type_to_dtype(),
 | 
						|
        shape: tensor |> Nx.shape() |> Tuple.to_list(),
 | 
						|
        data_offsets: [offset, end_offset]
 | 
						|
      )
 | 
						|
    }
 | 
						|
 | 
						|
    {header_entry, end_offset}
 | 
						|
  end
 | 
						|
 | 
						|
  defp header_binary(header_entries) do
 | 
						|
    header_json =
 | 
						|
      header_entries
 | 
						|
      |> Jason.OrderedObject.new()
 | 
						|
      |> Jason.encode!()
 | 
						|
 | 
						|
    [<<byte_size(header_json)::unsigned-64-integer-little>>, header_json]
 | 
						|
  end
 | 
						|
 | 
						|
  defp tensor_byte_size(tensor) do
 | 
						|
    {_, elem_size} = Nx.type(tensor)
 | 
						|
    elem_byte_size = div(elem_size, 8)
 | 
						|
    Nx.size(tensor) * elem_byte_size
 | 
						|
  end
 | 
						|
 | 
						|
  defp tensor_to_iodata(tensor) do
 | 
						|
    {_, elem_size} = Nx.type(tensor)
 | 
						|
 | 
						|
    tensor
 | 
						|
    |> Nx.to_binary()
 | 
						|
    |> Shared.new_byte_order(elem_size, :little)
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Serializes the given map of tensors to iodata.
 | 
						|
 | 
						|
  `iodata` is a list of binaries that can be written to any IO device,
 | 
						|
  such as a file or a socket. You can ensure the result is a binary by
 | 
						|
  calling `IO.iodata_to_binary/1`.
 | 
						|
  """
 | 
						|
  @spec dump(%{String.t() => Nx.Tensor.t()}) :: iodata()
 | 
						|
  def dump(tensors) when is_map(tensors) do
 | 
						|
    tensors = Enum.sort(tensors)
 | 
						|
 | 
						|
    {header_entries, {buffer, _offset}} =
 | 
						|
      Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
 | 
						|
        {header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset)
 | 
						|
        binary = tensor_to_iodata(tensor)
 | 
						|
        {header_entry, {[buffer, binary], end_offset}}
 | 
						|
      end)
 | 
						|
 | 
						|
    [header_binary(header_entries), buffer]
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Reads a serialized map of tensors from a file.
 | 
						|
 | 
						|
  Tensors are loaded into Nx one by one, without the need to load the
 | 
						|
  entire file from disk into memory.
 | 
						|
 | 
						|
  ## Options
 | 
						|
 | 
						|
    * `:lazy` - when `true`, instead of returning tensors, the function
 | 
						|
      returns lazy containers. Such a container can be converted to a
 | 
						|
      tensor using `Nx.to_tensor/1` and it is only at that point that
 | 
						|
      it is read from the file. Defaults to `false`
 | 
						|
 | 
						|
  """
 | 
						|
  @spec read!(path :: Path.t(), keyword()) :: %{String.t() => Nx.LazyContainer.t()}
 | 
						|
  def read!(path, opts \\ []) do
 | 
						|
    opts = Keyword.validate!(opts, lazy: false)
 | 
						|
 | 
						|
    File.open!(path, [:read, :raw], fn file ->
 | 
						|
      {:ok, <<header_size::unsigned-64-integer-little>>} = :file.read(file, 8)
 | 
						|
      {:ok, header_json} = :file.read(file, header_size)
 | 
						|
 | 
						|
      header = decode_header!(header_json)
 | 
						|
 | 
						|
      for {tensor_name, tensor_info} <- header, into: %{} do
 | 
						|
        %{"data_offsets" => [offset_start, offset_end]} = tensor_info
 | 
						|
 | 
						|
        {shape, type} = shape_and_type(tensor_info)
 | 
						|
 | 
						|
        byte_offset = header_size + 8 + offset_start
 | 
						|
        byte_size = offset_end - offset_start
 | 
						|
 | 
						|
        value =
 | 
						|
          if opts[:lazy] do
 | 
						|
            %Safetensors.FileTensor{
 | 
						|
              shape: shape,
 | 
						|
              type: type,
 | 
						|
              path: path,
 | 
						|
              byte_offset: byte_offset,
 | 
						|
              byte_size: byte_size
 | 
						|
            }
 | 
						|
          else
 | 
						|
            {:ok, binary} = :file.pread(file, byte_offset, byte_size)
 | 
						|
            Shared.build_tensor(binary, shape, type)
 | 
						|
          end
 | 
						|
 | 
						|
        {tensor_name, value}
 | 
						|
      end
 | 
						|
    end)
 | 
						|
  end
 | 
						|
 | 
						|
  @doc """
 | 
						|
  Loads a serialized map of tensors.
 | 
						|
 | 
						|
  It is the opposite of `dump/1`.
 | 
						|
  """
 | 
						|
  @spec load!(iodata()) :: %{String.t() => Nx.Tensor.t()}
 | 
						|
  def load!(data) when is_binary(data) or is_list(data) do
 | 
						|
    data = IO.iodata_to_binary(data)
 | 
						|
 | 
						|
    <<
 | 
						|
      header_size::unsigned-64-integer-little,
 | 
						|
      header_json::binary-size(header_size),
 | 
						|
      buffer::binary
 | 
						|
    >> = data
 | 
						|
 | 
						|
    header = decode_header!(header_json)
 | 
						|
 | 
						|
    for {tensor_name, tensor_info} <- header, into: %{} do
 | 
						|
      %{"data_offsets" => [offset_start, offset_end]} = tensor_info
 | 
						|
      {shape, type} = shape_and_type(tensor_info)
 | 
						|
 | 
						|
      tensor =
 | 
						|
        buffer
 | 
						|
        |> binary_slice(offset_start, offset_end - offset_start)
 | 
						|
        |> Shared.build_tensor(shape, type)
 | 
						|
 | 
						|
      {tensor_name, tensor}
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  defp decode_header!(header_json) do
 | 
						|
    {_metadata, header} =
 | 
						|
      header_json
 | 
						|
      |> Jason.decode!()
 | 
						|
      |> Map.pop(@header_metadata_key)
 | 
						|
 | 
						|
    header
 | 
						|
  end
 | 
						|
 | 
						|
  defp shape_and_type(%{"dtype" => dtype, "shape" => shape}) do
 | 
						|
    {List.to_tuple(shape), dtype_to_type(dtype)}
 | 
						|
  end
 | 
						|
 | 
						|
  defp type_to_dtype(type) do
 | 
						|
    @type_to_dtype[type] || raise "unrecognized type #{inspect(type)}"
 | 
						|
  end
 | 
						|
 | 
						|
  defp dtype_to_type(dtype) do
 | 
						|
    @dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}"
 | 
						|
  end
 | 
						|
end
 |