233 lines
6.1 KiB
Elixir
233 lines
6.1 KiB
Elixir
defmodule Safetensors do
|
|
@moduledoc """
|
|
[Safetensors](https://huggingface.co/docs/safetensors/index) implementation for `Nx`.
|
|
|
|
## Examples
|
|
|
|
iex> x = Nx.tensor([1, 2, 3])
|
|
iex> y = Nx.tensor([1.0, 2.0, 3.0])
|
|
iex> tensors = %{"x" => x, "y" => y}
|
|
iex> data = Safetensors.dump(tensors)
|
|
iex> tensors = Safetensors.load!(data)
|
|
iex> tensors["x"]
|
|
#Nx.Tensor<
|
|
s64[3]
|
|
[1, 2, 3]
|
|
>
|
|
iex> tensors["y"]
|
|
#Nx.Tensor<
|
|
f32[3]
|
|
[1.0, 2.0, 3.0]
|
|
>
|
|
|
|
"""
|
|
|
|
alias Safetensors.Shared
|
|
|
|
@header_metadata_key "__metadata__"
|
|
|
|
@type_to_dtype %{
|
|
{:bf, 16} => "BF16",
|
|
{:f, 64} => "F64",
|
|
{:f, 32} => "F32",
|
|
{:f, 16} => "F16",
|
|
{:s, 64} => "I64",
|
|
{:s, 32} => "I32",
|
|
{:s, 16} => "I16",
|
|
{:s, 8} => "I8",
|
|
{:u, 64} => "U64",
|
|
{:u, 32} => "U32",
|
|
{:u, 16} => "U16",
|
|
{:u, 8} => "U8"
|
|
}
|
|
|
|
@dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}
|
|
|
|
@doc """
|
|
Writes a map of tensors to a file.
|
|
|
|
Tensors are written into the file one by one, without the need to
|
|
dump all of them into the memory at once.
|
|
"""
|
|
@spec write!(path :: Path.t(), %{String.t() => Nx.Tensor.t()}) :: :ok
|
|
def write!(path, tensors) when is_map(tensors) do
|
|
File.open!(path, [:write, :raw], fn file ->
|
|
tensors = Enum.sort(tensors)
|
|
|
|
{header_entries, _offset} =
|
|
Enum.map_reduce(tensors, 0, fn {tensor_name, tensor}, offset ->
|
|
tensor_header_entry(tensor_name, tensor, offset)
|
|
end)
|
|
|
|
:ok = :file.write(file, header_binary(header_entries))
|
|
|
|
for {_tensor_name, tensor} <- tensors do
|
|
:ok = :file.write(file, tensor_to_iodata(tensor))
|
|
end
|
|
end)
|
|
|
|
:ok
|
|
end
|
|
|
|
defp tensor_header_entry(tensor_name, tensor, offset) do
|
|
end_offset = offset + tensor_byte_size(tensor)
|
|
|
|
header_entry = {
|
|
tensor_name,
|
|
Jason.OrderedObject.new(
|
|
dtype: tensor |> Nx.type() |> type_to_dtype(),
|
|
shape: tensor |> Nx.shape() |> Tuple.to_list(),
|
|
data_offsets: [offset, end_offset]
|
|
)
|
|
}
|
|
|
|
{header_entry, end_offset}
|
|
end
|
|
|
|
defp header_binary(header_entries) do
|
|
header_json =
|
|
header_entries
|
|
|> Jason.OrderedObject.new()
|
|
|> Jason.encode!()
|
|
|
|
[<<byte_size(header_json)::unsigned-64-integer-little>>, header_json]
|
|
end
|
|
|
|
defp tensor_byte_size(tensor) do
|
|
{_, elem_size} = Nx.type(tensor)
|
|
elem_byte_size = div(elem_size, 8)
|
|
Nx.size(tensor) * elem_byte_size
|
|
end
|
|
|
|
defp tensor_to_iodata(tensor) do
|
|
{_, elem_size} = Nx.type(tensor)
|
|
|
|
tensor
|
|
|> Nx.to_binary()
|
|
|> Shared.new_byte_order(elem_size, :little)
|
|
end
|
|
|
|
@doc """
|
|
Serializes the given map of tensors to iodata.
|
|
|
|
`iodata` is a list of binaries that can be written to any IO device,
|
|
such as a file or a socket. You can ensure the result is a binary by
|
|
calling `IO.iodata_to_binary/1`.
|
|
"""
|
|
@spec dump(%{String.t() => Nx.Tensor.t()}) :: iodata()
|
|
def dump(tensors) when is_map(tensors) do
|
|
tensors = Enum.sort(tensors)
|
|
|
|
{header_entries, {buffer, _offset}} =
|
|
Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
|
|
{header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset)
|
|
binary = tensor_to_iodata(tensor)
|
|
{header_entry, {[buffer, binary], end_offset}}
|
|
end)
|
|
|
|
[header_binary(header_entries), buffer]
|
|
end
|
|
|
|
@doc """
|
|
Reads a serialized map of tensors from a file.
|
|
|
|
Tensors are loaded into Nx one by one, without the need to load the
|
|
entire file from disk into memory.
|
|
|
|
## Options
|
|
|
|
* `:lazy` - when `true`, instead of returning tensors, the function
|
|
returns lazy containers. Such a container can be converted to a
|
|
tensor using `Nx.to_tensor/1` and it is only at that point that
|
|
it is read from the file. Defaults to `false`
|
|
|
|
"""
|
|
@spec read!(path :: Path.t(), keyword()) :: %{String.t() => Nx.LazyContainer.t()}
|
|
def read!(path, opts \\ []) do
|
|
opts = Keyword.validate!(opts, lazy: false)
|
|
|
|
File.open!(path, [:read, :raw], fn file ->
|
|
{:ok, <<header_size::unsigned-64-integer-little>>} = :file.read(file, 8)
|
|
{:ok, header_json} = :file.read(file, header_size)
|
|
|
|
header = decode_header!(header_json)
|
|
|
|
for {tensor_name, tensor_info} <- header, into: %{} do
|
|
%{"data_offsets" => [offset_start, offset_end]} = tensor_info
|
|
|
|
{shape, type} = shape_and_type(tensor_info)
|
|
|
|
byte_offset = header_size + 8 + offset_start
|
|
byte_size = offset_end - offset_start
|
|
|
|
value =
|
|
if opts[:lazy] do
|
|
%Safetensors.FileTensor{
|
|
shape: shape,
|
|
type: type,
|
|
path: path,
|
|
byte_offset: byte_offset,
|
|
byte_size: byte_size
|
|
}
|
|
else
|
|
{:ok, binary} = :file.pread(file, byte_offset, byte_size)
|
|
Shared.build_tensor(binary, shape, type)
|
|
end
|
|
|
|
{tensor_name, value}
|
|
end
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Loads a serialized map of tensors.
|
|
|
|
It is the opposite of `dump/1`.
|
|
"""
|
|
@spec load!(iodata()) :: %{String.t() => Nx.Tensor.t()}
|
|
def load!(data) when is_binary(data) or is_list(data) do
|
|
data = IO.iodata_to_binary(data)
|
|
|
|
<<
|
|
header_size::unsigned-64-integer-little,
|
|
header_json::binary-size(header_size),
|
|
buffer::binary
|
|
>> = data
|
|
|
|
header = decode_header!(header_json)
|
|
|
|
for {tensor_name, tensor_info} <- header, into: %{} do
|
|
%{"data_offsets" => [offset_start, offset_end]} = tensor_info
|
|
{shape, type} = shape_and_type(tensor_info)
|
|
|
|
tensor =
|
|
buffer
|
|
|> binary_slice(offset_start, offset_end - offset_start)
|
|
|> Shared.build_tensor(shape, type)
|
|
|
|
{tensor_name, tensor}
|
|
end
|
|
end
|
|
|
|
defp decode_header!(header_json) do
|
|
{_metadata, header} =
|
|
header_json
|
|
|> Jason.decode!()
|
|
|> Map.pop(@header_metadata_key)
|
|
|
|
header
|
|
end
|
|
|
|
defp shape_and_type(%{"dtype" => dtype, "shape" => shape}) do
|
|
{List.to_tuple(shape), dtype_to_type(dtype)}
|
|
end
|
|
|
|
defp type_to_dtype(type) do
|
|
@type_to_dtype[type] || raise "unrecognized type #{inspect(type)}"
|
|
end
|
|
|
|
defp dtype_to_type(dtype) do
|
|
@dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}"
|
|
end
|
|
end
|