Files
voice_recognition/whisper_server/deps/axon/lib/axon/quantization.ex
2025-07-15 14:39:51 +00:00

165 lines
4.9 KiB
Elixir

defmodule Axon.Quantization do
@moduledoc """
Model quantization.
Model quantization is a technique for reducing the memory footprint of
a model by converting portions of a model to use quantized representations.
Typically, these quantized representations are low-precision integers.
This is an **experimental** API which implements weight-only quantization.
The implementation in this module will convert dense layers in a large
model to quantized-variants. The only supported quantization type is
`{:s, 8}`. Axon quantization is inference-only. Training is not currently
supported.
"""
alias Axon.Quantization.Layers
alias Axon.Quantization.QTensor
@doc """
Quantizes a model and a model state.
Given a model and model state, this method will rewrite all
of the dense layers in the model to perform weight-only 8-bit
integer versions of the same operation. It will also replace values
for all dense kernels in the given model state with quantized
tensors.
"""
def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do
quantized_model = quantize_model(model)
quantized_model_state = quantize_model_state(model, model_state)
{quantized_model, quantized_model_state}
end
@doc """
Replaces standard operations with quantized variants.
The only supported conversion is to convert regular dense layers
to a weight-only 8-bit integer variant. Note that this only replaces
the properties of the model. If you have a pre-trained model state
that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/2`.
All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`.
"""
def quantize_model(%Axon{} = model) do
quantized_dense_rewriter = fn [%Axon{} = x], _output, name_fn, units, use_bias ->
weight_only_quantized_dense(x, units,
use_bias: use_bias,
name: name_fn
)
end
Axon.rewrite_nodes(model, fn
%Axon.Node{op: :dense, meta: meta, name: name_fn} ->
&quantized_dense_rewriter.(&1, &2, name_fn, meta[:units], meta[:use_bias])
_ ->
:skip
end)
end
@doc """
Returns a quantized model state.
Given a model and a model state, this function will replace
all dense layer kernels with a quantized version of the weight.
Training is not currently supported, so all quantized layers are
automatically frozen.
"""
def quantize_model_state(model, model_state) do
dense_layer_names =
model
|> Axon.properties()
|> Enum.filter(fn {_, v} -> v == :dense end)
|> Enum.map(fn {k, _} -> k end)
|> MapSet.new()
state =
Enum.reduce(dense_layer_names, model_state, fn layer_name, state ->
update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1)
end)
Axon.ModelState.freeze(state, fn [name | _] ->
MapSet.member?(dense_layer_names, name)
end)
end
## Layers
@doc """
Adds a weight-only quantized dense layer to the network.
This is equivalent to a dense layer, but works on quantized
weights for reducing model memory footprint.
Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/3`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
def weight_only_quantized_dense(x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
:meta,
use_bias: true,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros
])
meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)
kernel =
Axon.param("kernel", kernel_shape,
initializer: fn shape, type, key ->
fun =
case opts[:kernel_initializer] do
init when is_atom(init) ->
apply(Axon.Initializers, [])
fun when is_function(fun) ->
fun
end
tensor =
case fun do
fun when is_function(fun, 2) ->
fun.(shape, type)
fun when is_function(fun, 3) ->
fun.(shape, type, key)
end
QTensor.from_tensor(tensor)
end
)
{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], &Layers.weight_only_quantized_dense/4}
else
{[x, kernel], &Layers.weight_only_quantized_dense/3}
end
Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
end
end