4203 lines
119 KiB
Elixir
4203 lines
119 KiB
Elixir
defmodule Axon do
|
|
@moduledoc """
|
|
A high-level interface for creating neural network models.
|
|
|
|
Axon is built entirely on top of Nx numerical definitions,
|
|
so every neural network can be JIT or AOT compiled using
|
|
any Nx compiler, or even transformed into high-level neural
|
|
network formats like TensorFlow Lite and
|
|
[ONNX](https://github.com/elixir-nx/axon_onnx).
|
|
|
|
For a more in-depth overview of Axon, refer to the [Guides](guides.html).
|
|
|
|
## Model Creation
|
|
|
|
All Axon models start with an input layer, optionally specifying
|
|
the expected shape of the input data:
|
|
|
|
input = Axon.input("input", shape: {nil, 784})
|
|
|
|
Notice you can specify some dimensions as `nil`, indicating
|
|
that the dimension size will be filled in at model runtime.
|
|
You can then compose inputs with other layers:
|
|
|
|
model =
|
|
input
|
|
|> Axon.dense(128, activation: :relu)
|
|
|> Axon.batch_norm()
|
|
|> Axon.dropout(rate: 0.8)
|
|
|> Axon.dense(64)
|
|
|> Axon.tanh()
|
|
|> Axon.dense(10)
|
|
|> Axon.activation(:softmax)
|
|
|
|
You can inspect the model for a nice summary:
|
|
|
|
IO.inspect(model)
|
|
|
|
#Axon<
|
|
inputs: %{"input" => {nil, 784}}
|
|
outputs: "softmax_0"
|
|
nodes: 9
|
|
>
|
|
|
|
Or use the `Axon.Display` module to see more in-depth summaries:
|
|
|
|
Axon.Display.as_table(model, Nx.template({1, 784}, :f32)) |> IO.puts
|
|
|
|
+----------------------------------------------------------------------------------------------------------------+
|
|
| Model |
|
|
+=======================================+=============+==============+===================+=======================+
|
|
| Layer | Input Shape | Output Shape | Options | Parameters |
|
|
+=======================================+=============+==============+===================+=======================+
|
|
| input ( input ) | [] | {1, 784} | shape: {nil, 784} | |
|
|
| | | | optional: false | |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| dense_0 ( dense["input"] ) | [{1, 784}] | {1, 128} | | kernel: f32[784][128] |
|
|
| | | | | bias: f32[128] |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| relu_0 ( relu["dense_0"] ) | [{1, 128}] | {1, 128} | | |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| batch_norm_0 ( batch_norm["relu_0"] ) | [{1, 128}] | {1, 128} | epsilon: 1.0e-5 | gamma: f32[128] |
|
|
| | | | channel_index: 1 | beta: f32[128] |
|
|
| | | | momentum: 0.1 | mean: f32[128] |
|
|
| | | | | var: f32[128] |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| dropout_0 ( dropout["batch_norm_0"] ) | [{1, 128}] | {1, 128} | rate: 0.8 | |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| dense_1 ( dense["dropout_0"] ) | [{1, 128}] | {1, 64} | | kernel: f32[128][64] |
|
|
| | | | | bias: f32[64] |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| tanh_0 ( tanh["dense_1"] ) | [{1, 64}] | {1, 64} | | |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| dense_2 ( dense["tanh_0"] ) | [{1, 64}] | {1, 10} | | kernel: f32[64][10] |
|
|
| | | | | bias: f32[10] |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
| softmax_0 ( softmax["dense_2"] ) | [{1, 10}] | {1, 10} | | |
|
|
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
|
|
|
|
### Multiple Inputs
|
|
|
|
Creating a model with multiple inputs is as easy as declaring an
|
|
additional input in your Axon graph. Every input layer present in
|
|
the final Axon graph will be required to be passed as input at the
|
|
time of model execution.
|
|
|
|
inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
inp2 = Axon.input("input_1", shape: {nil, 1})
|
|
|
|
# Both inputs will be used
|
|
model1 = Axon.add(inp1, inp2)
|
|
|
|
# Only inp2 will be used
|
|
model2 = Axon.add(inp2, inp2)
|
|
|
|
Axon graphs are immutable, which means composing and manipulating
|
|
an Axon graph creates an entirely new graph. Additionally, layer
|
|
names are lazily generated at model execution time. To avoid
|
|
non-deterministic input orderings and names, Axon requires each
|
|
input to have a unique binary identifier. You can then reference
|
|
inputs by name when passing to models at execution time:
|
|
|
|
inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
inp2 = Axon.input("input_1", shape: {nil, 1})
|
|
|
|
model1 = Axon.add(inp1, inp2)
|
|
|
|
{init_fn, predict_fn} = Axon.build(model1)
|
|
|
|
params1 = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
|
|
# Inputs are referenced by name
|
|
predict_fn.(params1, %{"input_0" => x, "input_1" => y})
|
|
|
|
### Multiple Outputs
|
|
|
|
Nx offers robust [container](https://hexdocs.pm/nx/Nx.Container.html) support
|
|
which is extended to Axon. Axon allows you to wrap any valid Nx container
|
|
in a layer. Containers are most commonly used to structure outputs:
|
|
|
|
inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
inp2 = Axon.input("input_1", shape: {nil, 1})
|
|
model = Axon.container(%{foo: inp1, bar: inp2})
|
|
|
|
Containers can be arbitrarily nested:
|
|
|
|
inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
inp2 = Axon.input("input_1", shape: {nil, 1})
|
|
model = Axon.container({%{foo: {inp1, %{bar: inp2}}}})
|
|
|
|
You can even use custom structs which implement the container protocol:
|
|
|
|
inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
inp2 = Axon.input("input_1", shape: {nil, 1})
|
|
model = Axon.container(%MyStruct{foo: inp1, bar: inp2})
|
|
|
|
### Custom Layers
|
|
|
|
If you find that Axon's built-in layers are insufficient for your needs,
|
|
you can create your own using the custom layer API. All of Axon's built-in
|
|
layers (aside from special ones such as `input`, `constant`, and `container`)
|
|
make use of this same API.
|
|
|
|
Axon layers are really just placeholders for Nx computations with trainable
|
|
parameters and possibly state. To define a custom layer, you just need to
|
|
define a `defn` implementation:
|
|
|
|
defn my_layer(x, weight, _opts \\\\ []) do
|
|
Nx.atan2(x, weight)
|
|
end
|
|
|
|
Notice the only stipulation is that your custom layer implementation must
|
|
accept at least 1 input and a list of options. At execution time, every
|
|
layer will be passed a `:mode` option which can be used to control behavior
|
|
at training and inference time.
|
|
|
|
Inputs to your custom layer can be either Axon graph inputs or trainable
|
|
parameters. You can pass Axon graph inputs as-is to a custom layer. To
|
|
declare trainable parameters, use `Axon.param/3`:
|
|
|
|
weight = Axon.param("weight", param_shape)
|
|
|
|
To create a custom layer, you "wrap" your implementation and inputs into
|
|
a layer using `Axon.layer`. You'll notice the API mirrors Elixir's `apply`:
|
|
|
|
def atan2_layer(%Axon{} = input) do
|
|
weight = Axon.param("weight", param_shape)
|
|
Axon.layer(&my_layer/3, [input, weight])
|
|
end
|
|
|
|
## Model Execution
|
|
|
|
Under the hood, Axon models are represented as Elixir structs. You
|
|
can initialize and apply models by building or compiling them with
|
|
`Axon.build/2` or `Axon.compile/4` and then calling the produced
|
|
initialization and predict functions:
|
|
|
|
{init_fn, predict_fn} = Axon.build(model)
|
|
|
|
params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
|
|
predict_fn.(params, inputs)
|
|
|
|
You may either set the default JIT compiler or backend globally, or
|
|
pass a specific compiler to `Axon.build/2`:
|
|
|
|
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
|
|
|
|
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train)
|
|
|
|
params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
|
|
predict_fn.(params, inputs)
|
|
|
|
`predict_fn` by default runs in inference mode, which performs certain
|
|
optimizations and removes layers such as dropout layers. If constructing
|
|
a training step using `Axon.predict/4` or `Axon.build/2`, be sure to specify
|
|
`mode: :train`.
|
|
|
|
## Model Training
|
|
|
|
Combining the Axon model creation API with the optimization and training
|
|
APIs, you can create and train neural networks with ease:
|
|
|
|
model =
|
|
Axon.input("input_0", shape: {nil, 784})
|
|
|> Axon.dense(128, activation: :relu)
|
|
|> Axon.layer_norm()
|
|
|> Axon.dropout()
|
|
|> Axon.dense(10, activation: :softmax)
|
|
|
|
IO.inspect model
|
|
|
|
model_state =
|
|
model
|
|
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.005))
|
|
|> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)
|
|
|
|
See `Polaris.Updates` and `Axon.Loop` for a more in-depth treatment of
|
|
model optimization and model training.
|
|
|
|
## Using with `Nx.Serving`
|
|
|
|
When deploying an `Axon` model to production, you usually want to batch
|
|
multiple prediction requests and run the inference for all of them at
|
|
once. Conveniently, `Nx` already has an abstraction for this task in the
|
|
form of `Nx.Serving`. Here's how you could define a serving for an `Axon`
|
|
model:
|
|
|
|
def build_serving() do
|
|
# Configuration
|
|
batch_size = 4
|
|
defn_options = [compiler: EXLA]
|
|
|
|
Nx.Serving.new(
|
|
# This function runs on the serving startup
|
|
fn ->
|
|
# Build the Axon model and load params (usually from file)
|
|
model = build_model()
|
|
params = load_params()
|
|
|
|
# Build the prediction defn function
|
|
{_init_fun, predict_fun} = Axon.build(model)
|
|
|
|
inputs_template = %{"pixel_values" => Nx.template({batch_size, 224, 224, 3}, :f32)}
|
|
template_args = [Nx.to_template(params), inputs_template]
|
|
|
|
# Compile the prediction function upfront for the configured batch_size
|
|
predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)
|
|
|
|
# The returned function is called for every accumulated batch
|
|
fn inputs ->
|
|
inputs = Nx.Batch.pad(inputs, batch_size - inputs.size)
|
|
predict_fun.(params, inputs)
|
|
end
|
|
end,
|
|
batch_size: batch_size
|
|
)
|
|
end
|
|
|
|
Then you would start the serving server as part of your application's
|
|
supervision tree:
|
|
|
|
children = [
|
|
...,
|
|
{Nx.Serving, serving: build_serving(), name: MyApp.Serving, batch_timeout: 100}
|
|
]
|
|
|
|
With that in place, you can now ask serving for predictions all across
|
|
your application (controllers, live views, async jobs, etc.). Having a
|
|
tensor input you would do:
|
|
|
|
inputs = %{"pixel_values" => ...}
|
|
batch = Nx.Batch.concatenate([inputs])
|
|
result = Nx.Serving.batched_run(MyApp.Serving, batch)
|
|
|
|
Usually you also want to do pre/post-processing of the model input/output.
|
|
You could make those preparations directly before/after `Nx.Serving.batched_run/2`,
|
|
however you can also make use of `Nx.Serving.client_preprocessing/2` and
|
|
`Nx.Serving.client_postprocessing/2` to encapsulate that logic as part of
|
|
the serving.
|
|
"""
|
|
alias __MODULE__, as: Axon
|
|
alias Axon.Parameter
|
|
|
|
import Axon.Shared
|
|
|
|
require Logger
|
|
|
|
@type t :: %__MODULE__{}
|
|
|
|
defstruct [
|
|
:nodes,
|
|
:output
|
|
]
|
|
|
|
@doc """
|
|
Custom Axon layer with given inputs.
|
|
|
|
Inputs may be other Axon layers or trainable parameters created
|
|
with `Axon.param`. At inference time, `op` will be applied with
|
|
inputs in specified order and an additional `opts` parameter which
|
|
specifies inference options. All options passed to layer are forwarded
|
|
to inference function except:
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:op_name` - layer operation for inspection and building parameter map.
|
|
|
|
* `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both`
|
|
|
|
* `:global_options` - a list of global option names that this layer
|
|
supports. Global options passed to `build/2` will be forwarded to
|
|
the layer, as long as they are declared
|
|
|
|
Note this means your layer should not use these as input options,
|
|
as they will always be dropped during inference compilation.
|
|
|
|
Axon's compiler will additionally forward the following options to
|
|
every layer at inference time:
|
|
|
|
* `:mode` - `:inference` or `:train`. To control layer behavior
|
|
based on inference or train time.
|
|
|
|
`op` is a function of the form:
|
|
|
|
fun = fn input, weight, bias, _opts ->
|
|
input * weight + bias
|
|
end
|
|
|
|
"""
|
|
@doc type: :special
|
|
def layer(op, inputs, opts \\ []) when (is_atom(op) or is_function(op)) and is_list(inputs) do
|
|
{inputs, params, args, updated_nodes} = split_inputs(op, inputs)
|
|
|
|
inputs = Enum.reverse(inputs)
|
|
params = Enum.reverse(params)
|
|
args = Enum.reverse(args)
|
|
|
|
{mode, opts} = Keyword.pop(opts, :mode, :both)
|
|
{name, opts} = Keyword.pop(opts, :name)
|
|
{op_name, opts} = Keyword.pop(opts, :op_name, :custom)
|
|
{global_options, opts} = Keyword.pop(opts, :global_options, [])
|
|
{meta, opts} = Keyword.pop(opts, :meta, %{})
|
|
name = name(op_name, name)
|
|
|
|
id = System.unique_integer([:positive, :monotonic])
|
|
|
|
axon_node =
|
|
make_node(id, op, name, op_name, mode, inputs, params, args, meta, opts, global_options)
|
|
|
|
%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
|
|
end
|
|
|
|
defp make_node(
|
|
id,
|
|
op,
|
|
name,
|
|
op_name,
|
|
mode,
|
|
inputs,
|
|
params,
|
|
args,
|
|
meta,
|
|
layer_opts,
|
|
global_options
|
|
) do
|
|
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
|
|
Process.info(self(), :current_stacktrace)
|
|
|
|
%Axon.Node{
|
|
id: id,
|
|
mode: mode,
|
|
name: name,
|
|
parent: inputs,
|
|
parameters: params,
|
|
args: args,
|
|
op: op,
|
|
policy: Axon.MixedPrecision.create_policy(),
|
|
hooks: [],
|
|
opts: layer_opts,
|
|
global_options: global_options,
|
|
op_name: op_name,
|
|
meta: meta,
|
|
stacktrace: stacktrace
|
|
}
|
|
end
|
|
|
|
defp split_inputs(_op, inputs) do
|
|
Enum.reduce(inputs, {[], [], [], %{}}, fn
|
|
%Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} ->
|
|
{[layer_input | layers], params, [:layer | args], Map.merge(nodes, cache)}
|
|
|
|
%Parameter{} = param, {layers, params, args, cache} ->
|
|
{layers, [param | params], [:parameter | args], cache}
|
|
|
|
invalid, _ ->
|
|
raise ArgumentError, "invalid input given to layer: #{inspect(invalid)}"
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Trainable Axon parameter used to create custom layers.
|
|
|
|
Parameters are specified in usages of `Axon.layer` and will be
|
|
automatically initialized and used in subsequent applications of
|
|
Axon models.
|
|
|
|
You must specify a parameter "template" which can be a static template
|
|
tensor or a function which takes model input templates and returns a
|
|
template. It's most common to use functions because most parameters'
|
|
shapes rely on input shape information.
|
|
"""
|
|
@doc type: :special
|
|
def parameter(name, template, opts \\ [])
|
|
|
|
def parameter(name, %Nx.Tensor{} = template, opts) do
|
|
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
|
|
initializer = validate_initializer!(opts[:initializer])
|
|
kind = opts[:kind] || :parameter
|
|
|
|
template = Nx.to_template(template)
|
|
|
|
%Axon.Parameter{
|
|
name: name,
|
|
template: template,
|
|
initializer: initializer,
|
|
kind: kind,
|
|
# Legacy
|
|
type: Nx.type(template),
|
|
shape: Nx.shape(template)
|
|
}
|
|
end
|
|
|
|
def parameter(name, function, opts) when is_function(function) do
|
|
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
|
|
initializer = validate_initializer!(opts[:initializer])
|
|
kind = opts[:kind] || :parameter
|
|
|
|
%Axon.Parameter{
|
|
name: name,
|
|
template: function,
|
|
initializer: initializer,
|
|
kind: kind
|
|
}
|
|
end
|
|
|
|
@doc """
|
|
Trainable Axon parameter used to create custom layers.
|
|
|
|
Parameters are specified in usages of `Axon.layer` and will
|
|
be automatically initialized and used in subsequent applications
|
|
of Axon models.
|
|
|
|
You may specify the parameter shape as either a static shape or
|
|
as function of the inputs to the given layer. If you specify the
|
|
parameter shape as a function, it will be given the
|
|
|
|
## Options
|
|
|
|
* `:initializer` - parameter initializer. Defaults to `:glorot_uniform`.
|
|
|
|
"""
|
|
@doc type: :special
|
|
def param(name, shape, opts \\ [])
|
|
|
|
def param(name, shape, opts) when is_binary(name) and is_tuple(shape) do
|
|
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
|
|
{type, opts} = Keyword.pop(opts, :type, {:f, 32})
|
|
|
|
template = Nx.template(shape, type)
|
|
parameter(name, template, opts)
|
|
end
|
|
|
|
def param(name, shape, opts) when is_binary(name) and is_function(shape) do
|
|
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
|
|
{type, opts} = Keyword.pop(opts, :type, {:f, 32})
|
|
|
|
{:arity, arity} = Function.info(shape, :arity)
|
|
|
|
template =
|
|
shape_fun(arity, fn templates ->
|
|
shapes = Enum.map(List.wrap(templates), &Nx.shape/1)
|
|
out_shape = apply(shape, shapes)
|
|
Nx.template(out_shape, type)
|
|
end)
|
|
|
|
parameter(name, template, opts)
|
|
end
|
|
|
|
for i <- 0..128 do
|
|
args = Macro.generate_arguments(i, __MODULE__)
|
|
|
|
defp shape_fun(unquote(i), callback) do
|
|
fn unquote_splicing(args) -> callback.(unquote(args)) end
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds an input layer to the network.
|
|
|
|
Input layers specify a model's inputs. Input layers are
|
|
always the root layers of the neural network.
|
|
|
|
You must specify the input layers name, which will be used
|
|
to uniquely identify it in the case of multiple inputs.
|
|
|
|
## Options
|
|
|
|
* `:shape` - the expected input shape, use `nil` for dimensions
|
|
of a dynamic size.
|
|
|
|
* `:optional` - if `true`, the input may be omitted when using
|
|
the model. This needs to be handled in one of the subsequent
|
|
layers. See `optional/2` for more details.
|
|
|
|
"""
|
|
@doc type: :special
|
|
def input(name, opts \\ [])
|
|
|
|
def input(name, opts) when is_binary(name) and is_list(opts) do
|
|
opts = Keyword.validate!(opts, [:shape, :meta, optional: false])
|
|
optional = opts[:optional]
|
|
meta = opts[:meta]
|
|
|
|
input_shape = opts[:shape]
|
|
|
|
output_shape = input_shape && Axon.Shape.input(input_shape)
|
|
|
|
layer(:input, [],
|
|
name: name,
|
|
shape: output_shape,
|
|
meta: meta,
|
|
op_name: :input,
|
|
optional: optional
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Wraps an Axon model in an optional node.
|
|
|
|
By default, when an optional input is missing, all subsequent layers
|
|
are nullified. For example, consider this model:
|
|
|
|
values = Axon.input("values")
|
|
mask = Axon.input("mask", optional: true)
|
|
|
|
model =
|
|
values
|
|
|> Axon.dense(10)
|
|
|> Axon.multiply(mask)
|
|
|> Axon.dense(1)
|
|
|> Axon.sigmoid()
|
|
|
|
In case the mask is not provided, the input node will resolve to
|
|
`%Axon.None{}` and so will all the layers that depend on it. By
|
|
using `optional/2` a layer may opt-in to receive `%Axon.None{}`.
|
|
To fix our example, we could define a custom layer to apply the
|
|
mask only when present
|
|
|
|
def apply_optional_mask(%Axon{} = x, %Axon{} = mask) do
|
|
Axon.layer(
|
|
fn x, mask, _opts ->
|
|
case mask do
|
|
%Axon.None{} -> x
|
|
mask -> Nx.multiply(x, mask)
|
|
end
|
|
end,
|
|
[x, Axon.optional(mask)]
|
|
)
|
|
end
|
|
|
|
# ...
|
|
|
|
model =
|
|
values
|
|
|> Axon.dense(10)
|
|
|> apply_optional_mask(mask)
|
|
|> Axon.dense(1)
|
|
|> Axon.sigmoid()
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :special
|
|
def optional(%Axon{} = x, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
layer(:optional, [x], name: opts[:name], meta: opts[:meta], op_name: :optional)
|
|
end
|
|
|
|
@doc """
|
|
Implements an or else (e.g. an Elixir ||)
|
|
"""
|
|
@doc type: :special
|
|
def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
Axon.layer(
|
|
fn x, y, _ ->
|
|
case x do
|
|
%Axon.None{} -> y
|
|
_ -> x
|
|
end
|
|
end,
|
|
[a, b],
|
|
op_name: :or_else,
|
|
name: opts[:name],
|
|
meta: opts[:meta]
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a constant layer to the network.
|
|
|
|
Constant layers encapsulate Nx tensors in an Axon layer for ease
|
|
of use with other Axon layers. They can be used interchangeably
|
|
with other Axon layers:
|
|
|
|
inp = Axon.input("input", shape: {nil, 32})
|
|
my_constant = Axon.constant(Nx.iota({1, 32}))
|
|
model = Axon.add(inp, my_constant)
|
|
|
|
Constant layers will be cast according to the mixed precision policy.
|
|
If it's important for your constant to retain it's type during
|
|
the computation, you will need to set the mixed precision policy to
|
|
ignore constant layers.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
def constant(tensor, opts \\ [])
|
|
|
|
@doc type: :special
|
|
def constant(%Nx.Tensor{} = tensor, opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:constant, [], name: opts[:name], meta: opts[:meta], value: tensor, op_name: :constant)
|
|
end
|
|
|
|
def constant(number, opts) when is_number(number) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:constant, [],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
value: Nx.tensor(number),
|
|
op_name: :constant
|
|
)
|
|
end
|
|
|
|
def constant(value, _) do
|
|
raise ArgumentError,
|
|
"value passed to constant must be an Nx tensor" <>
|
|
" but got #{inspect(value)}, if you are passing" <>
|
|
" a number, wrap it with a call to Nx.tensor/2"
|
|
end
|
|
|
|
@doc """
|
|
Adds a container layer to the network.
|
|
|
|
In certain cases you may want your model to have multiple
|
|
outputs. In order to make this work, you must "join" the
|
|
outputs into an Axon layer using this function for use in
|
|
initialization and inference later on.
|
|
|
|
The given container can be any valid Axon Nx container.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
## Examples
|
|
|
|
iex> inp1 = Axon.input("input_0", shape: {nil, 1})
|
|
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
|
|
iex> model = Axon.container(%{a: inp1, b: inp2})
|
|
iex> %{a: a, b: b} = Axon.predict(model, Axon.ModelState.empty(), %{
|
|
...> "input_0" => Nx.tensor([[1.0]]),
|
|
...> "input_1" => Nx.tensor([[1.0, 2.0]])
|
|
...> })
|
|
iex> a
|
|
#Nx.Tensor<
|
|
f32[1][1]
|
|
[
|
|
[1.0]
|
|
]
|
|
>
|
|
iex> b
|
|
#Nx.Tensor<
|
|
f32[1][2]
|
|
[
|
|
[1.0, 2.0]
|
|
]
|
|
>
|
|
|
|
"""
|
|
@doc type: :special
|
|
def container(container, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
{structure_fn, nodes} = destructure(container)
|
|
layer(structure_fn, nodes, name: opts[:name], meta: opts[:meta], op_name: :container)
|
|
end
|
|
|
|
defp destructure(container) do
|
|
{structure, {nodes, _}} = recur_destructure(container, {[], 0})
|
|
fun = restructure(length(nodes) + 1, structure)
|
|
{fun, Enum.reverse(nodes)}
|
|
end
|
|
|
|
defp recur_destructure(container, acc) do
|
|
Nx.Container.traverse(container, acc, fn value, {leaves, idx} ->
|
|
case value do
|
|
%Axon{} = leaf ->
|
|
{idx, {[leaf | leaves], idx + 1}}
|
|
|
|
container ->
|
|
recur_destructure(container, {leaves, idx})
|
|
end
|
|
end)
|
|
end
|
|
|
|
for i <- 0..128 do
|
|
args = Macro.generate_arguments(i, __MODULE__)
|
|
|
|
defp restructure(unquote(i), structure) do
|
|
fn unquote_splicing(args) ->
|
|
args_tuple = {unquote_splicing(args)}
|
|
{container, :ok} = recur_restructure(structure, args_tuple)
|
|
container
|
|
end
|
|
end
|
|
end
|
|
|
|
defp recur_restructure(structure, args_tuple) do
|
|
Nx.Container.traverse(structure, :ok, fn value, :ok ->
|
|
case value do
|
|
idx when is_integer(idx) -> {elem(args_tuple, idx), :ok}
|
|
container -> recur_restructure(container, args_tuple)
|
|
end
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Returns a function which represents a self-contained re-usable block
|
|
of operations in a neural network. All parameters in the block are
|
|
shared between every usage of the block.
|
|
|
|
This returns an arity-1 function which accepts a list of inputs which
|
|
are forwarded to `fun`. This is most often used in situations where
|
|
you wish to re-use parameters in a block:
|
|
|
|
reused_dense = Axon.block(&Axon.dense(&1, 32))
|
|
|
|
Everytime `reused_dense` is invoked, it re-uses the same parameters:
|
|
|
|
input = Axon.input("features")
|
|
# unique parameters
|
|
x1 = Axon.dense(input, 32)
|
|
# unique parameters
|
|
x2 = reused_dense.(x1)
|
|
# parameters shared
|
|
x3 = reused_dense.(x2)
|
|
|
|
Subgraphs in blocks can be arbitrarily complex:
|
|
|
|
reused_block = Axon.block(fn x ->
|
|
x
|
|
|> Axon.dense(32)
|
|
|> Axon.dense(64)
|
|
|> Axon.dense(32)
|
|
end)
|
|
|
|
Blocks can also have multiple inputs, you can invoke a block with multiple
|
|
inputs by passing a list of arguments:
|
|
|
|
reused_block = Axon.block(fn x, y, z ->
|
|
x = Axon.dense(x, 32)
|
|
y = Axon.dense(y, 32)
|
|
z = Axon.dense(z, 32)
|
|
|
|
Axon.add([x, y, z])
|
|
end)
|
|
|
|
# invoke with a list
|
|
reused_block.([x, y, z])
|
|
|
|
Blocks prefix subgraph parameters with their name and a dot. As with other
|
|
Axon layers, if a name is not explicitly provided, one will be dynamically
|
|
generated.
|
|
"""
|
|
@doc type: :special
|
|
def block(fun, opts \\ []) when is_function(fun) do
|
|
{:arity, arity} = Function.info(fun, :arity)
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
block_id = System.unique_integer([:positive, :monotonic])
|
|
|
|
block_fun(arity, fn inputs ->
|
|
layer(:block, List.wrap(inputs),
|
|
op_name: :block,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
block_fun: fun,
|
|
block_id: block_id
|
|
)
|
|
end)
|
|
end
|
|
|
|
for i <- 0..128 do
|
|
args = Macro.generate_arguments(i, __MODULE__)
|
|
|
|
defp block_fun(unquote(i), callback) do
|
|
fn unquote_splicing(args) -> callback.(unquote(args)) end
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a dense layer to the network.
|
|
|
|
The dense layer implements:
|
|
|
|
output = activation(dot(input, kernel) + bias)
|
|
|
|
where `activation` is given by the `:activation` option and both
|
|
`kernel` and `bias` are layer parameters. `units` specifies the
|
|
number of output units.
|
|
|
|
Compiles to `Axon.Layers.dense/4`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`.
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`.
|
|
|
|
"""
|
|
@doc type: :linear
|
|
def dense(%Axon{} = x, units, opts \\ [])
|
|
when is_integer(units) and units > 0 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true
|
|
])
|
|
|
|
meta =
|
|
opts[:meta] ||
|
|
%{}
|
|
|> Map.put(:units, units)
|
|
|> Map.put(:use_bias, opts[:use_bias])
|
|
|
|
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
|
|
bias_shape = &Axon.Shape.dense_bias(&1, units)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
{[x, kernel, bias], :dense}
|
|
else
|
|
{[x, kernel], :dense}
|
|
end
|
|
|
|
node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a bilinear layer to the network.
|
|
|
|
The bilinear layer implements:
|
|
|
|
output = activation(dot(dot(input1, kernel), input2) + bias)
|
|
|
|
where `activation` is given by the `:activation` option and both
|
|
`kernel` and `bias` are layer parameters. `units` specifies the
|
|
number of output units.
|
|
|
|
All dimensions but the last of `input1` and `input2` must match. The
|
|
batch sizes of both inputs must also match or at least one must be `nil`.
|
|
Inferred output batch size coerces to the strictest input batch size.
|
|
|
|
Compiles to `Axon.Layers.bilinear/5`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`.
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`.
|
|
|
|
"""
|
|
@doc type: :linear
|
|
def bilinear(
|
|
%Axon{} = input1,
|
|
%Axon{} = input2,
|
|
units,
|
|
opts \\ []
|
|
)
|
|
when is_integer(units) and units > 0 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true
|
|
])
|
|
|
|
kernel_shape = &Axon.Shape.bilinear_kernel(&1, &2, units)
|
|
bias_shape = &Axon.Shape.bilinear_bias(&1, &2, units)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
{[input1, input2, kernel, bias], :bilinear}
|
|
else
|
|
{[input1, input2, kernel], :bilinear}
|
|
end
|
|
|
|
node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :bilinear)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a convolution layer to the network.
|
|
|
|
The convolution layer implements a general dimensional
|
|
convolutional layer - which convolves a kernel over the input
|
|
to produce an output.
|
|
|
|
Compiles to `Axon.Layers.conv/4`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
|
|
|
|
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
|
|
|
|
* `:feature_group_size` - feature group size for convolution. Defaults
|
|
to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :convolution
|
|
def conv(%Axon{} = x, units, opts \\ [])
|
|
when is_integer(units) and units > 0 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
padding: :valid,
|
|
input_dilation: 1,
|
|
kernel_dilation: 1,
|
|
channels: :last,
|
|
feature_group_size: 1
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
input_dilation = opts[:input_dilation]
|
|
kernel_dilation = opts[:kernel_dilation]
|
|
channels = opts[:channels]
|
|
feature_group_size = opts[:feature_group_size]
|
|
|
|
kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, feature_group_size)
|
|
bias_shape = &Axon.Shape.conv_bias(&1, units, kernel_size, channels, feature_group_size)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
{[x, kernel, bias], :conv}
|
|
else
|
|
{[x, kernel], :conv}
|
|
end
|
|
|
|
node =
|
|
layer(op, inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
strides: strides,
|
|
padding: padding,
|
|
input_dilation: input_dilation,
|
|
kernel_dilation: kernel_dilation,
|
|
feature_group_size: feature_group_size,
|
|
channels: channels,
|
|
op_name: :conv
|
|
)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a transposed convolution layer to the network.
|
|
|
|
The transposed convolution layer is sometimes referred to as a
|
|
fractionally strided convolution or (incorrectly) as a deconvolution.
|
|
|
|
Compiles to `Axon.Layers.conv_transpose/4`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :convolution
|
|
def conv_transpose(%Axon{} = x, units, opts \\ []) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
padding: :valid,
|
|
kernel_dilation: 1,
|
|
channels: :last
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
kernel_dilation = opts[:kernel_dilation]
|
|
channels = opts[:channels]
|
|
|
|
kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, 1)
|
|
bias_shape = &Axon.Shape.conv_bias(&1, units, kernel_size, channels, 1)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
{[x, kernel, bias], :conv_transpose}
|
|
else
|
|
{[x, kernel], :conv_transpose}
|
|
end
|
|
|
|
node =
|
|
layer(op, inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
strides: strides,
|
|
padding: padding,
|
|
kernel_dilation: kernel_dilation,
|
|
channels: channels,
|
|
op_name: :conv_transpose
|
|
)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a depthwise convolution layer to the network.
|
|
|
|
The depthwise convolution layer implements a general
|
|
dimensional depthwise convolution - which is a convolution
|
|
where the feature group size is equal to the number of
|
|
input channels.
|
|
|
|
Channel multiplier grows the input channels by the given
|
|
factor. An input factor of 1 means the output channels
|
|
are the same as the input channels.
|
|
|
|
Compiles to `Axon.Layers.depthwise_conv/4`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
|
|
|
|
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :convolution
|
|
def depthwise_conv(%Axon{} = x, channel_multiplier, opts \\ [])
|
|
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
padding: :valid,
|
|
input_dilation: 1,
|
|
kernel_dilation: 1,
|
|
channels: :last
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
input_dilation = opts[:input_dilation]
|
|
kernel_dilation = opts[:kernel_dilation]
|
|
channels = opts[:channels]
|
|
|
|
kernel_shape =
|
|
&Axon.Shape.depthwise_conv_kernel(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
bias_shape = &Axon.Shape.depthwise_conv_bias(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
|
|
{[x, kernel, bias], :depthwise_conv}
|
|
else
|
|
{[x, kernel], :depthwise_conv}
|
|
end
|
|
|
|
node =
|
|
layer(op, inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
strides: strides,
|
|
padding: padding,
|
|
input_dilation: input_dilation,
|
|
kernel_dilation: kernel_dilation,
|
|
channels: channels,
|
|
op_name: :depthwise_conv
|
|
)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a depthwise separable 2-dimensional convolution to the
|
|
network.
|
|
|
|
Depthwise separable convolutions break the kernel into kernels
|
|
for each dimension of the input and perform a depthwise conv
|
|
over the input with each kernel.
|
|
|
|
Compiles to `Axon.Layers.separable_conv2d/6`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
|
|
|
|
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :convolution
|
|
def separable_conv2d(%Axon{} = x, channel_multiplier, opts \\ [])
|
|
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
padding: :valid,
|
|
input_dilation: 1,
|
|
kernel_dilation: 1,
|
|
channels: :last
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
input_dilation = opts[:input_dilation]
|
|
kernel_dilation = opts[:kernel_dilation]
|
|
channels = opts[:channels]
|
|
|
|
k1_shape =
|
|
&Axon.Shape.separable_conv2d_kernel(
|
|
&1,
|
|
channel_multiplier,
|
|
kernel_size,
|
|
1,
|
|
channels
|
|
)
|
|
|
|
k2_shape =
|
|
&Axon.Shape.separable_conv2d_kernel(
|
|
&1,
|
|
channel_multiplier,
|
|
kernel_size,
|
|
2,
|
|
channels
|
|
)
|
|
|
|
b1_shape = &Axon.Shape.separable_conv2d_bias(&1, channel_multiplier, kernel_size, channels)
|
|
b2_shape = &Axon.Shape.separable_conv2d_bias(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
kernel_initializer = opts[:kernel_initializer]
|
|
k1 = param("kernel_1", k1_shape, initializer: kernel_initializer)
|
|
k2 = param("kernel_2", k2_shape, initializer: kernel_initializer)
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias_initializer = opts[:bias_initializer]
|
|
b1 = param("bias_1", b1_shape, initializer: bias_initializer)
|
|
b2 = param("bias_2", b2_shape, initializer: bias_initializer)
|
|
{[x, k1, b1, k2, b2], :separable_conv2d}
|
|
else
|
|
{[x, k1, k2], :separable_conv2d}
|
|
end
|
|
|
|
node =
|
|
layer(
|
|
op,
|
|
inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
strides: strides,
|
|
padding: padding,
|
|
input_dilation: input_dilation,
|
|
kernel_dilation: kernel_dilation,
|
|
channels: channels,
|
|
op_name: :separable_conv2d
|
|
)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Adds a depthwise separable 3-dimensional convolution to the
|
|
network.
|
|
|
|
Depthwise separable convolutions break the kernel into kernels
|
|
for each dimension of the input and perform a depthwise conv
|
|
over the input with each kernel.
|
|
|
|
Compiles to `Axon.Layers.separable_conv3d/8`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights.
|
|
Defaults to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`
|
|
|
|
* `:activation` - element-wise activation function.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
|
|
|
|
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :convolution
|
|
def separable_conv3d(%Axon{} = x, channel_multiplier, opts \\ [])
|
|
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:activation,
|
|
:meta,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
padding: :valid,
|
|
input_dilation: 1,
|
|
kernel_dilation: 1,
|
|
channels: :last
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
input_dilation = opts[:input_dilation]
|
|
kernel_dilation = opts[:kernel_dilation]
|
|
channels = opts[:channels]
|
|
|
|
k1_shape =
|
|
&Axon.Shape.separable_conv3d_kernel(
|
|
&1,
|
|
channel_multiplier,
|
|
kernel_size,
|
|
1,
|
|
channels
|
|
)
|
|
|
|
k2_shape =
|
|
&Axon.Shape.separable_conv3d_kernel(
|
|
&1,
|
|
channel_multiplier,
|
|
kernel_size,
|
|
2,
|
|
channels
|
|
)
|
|
|
|
k3_shape =
|
|
&Axon.Shape.separable_conv3d_kernel(
|
|
&1,
|
|
channel_multiplier,
|
|
kernel_size,
|
|
3,
|
|
channels
|
|
)
|
|
|
|
b1_shape = &Axon.Shape.separable_conv3d_bias(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
b2_shape = &Axon.Shape.separable_conv3d_bias(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
b3_shape = &Axon.Shape.separable_conv3d_bias(&1, channel_multiplier, kernel_size, channels)
|
|
|
|
kernel_initializer = opts[:kernel_initializer]
|
|
k1 = param("kernel_1", k1_shape, initializer: kernel_initializer)
|
|
k2 = param("kernel_2", k2_shape, initializer: kernel_initializer)
|
|
k3 = param("kernel_3", k3_shape, initializer: kernel_initializer)
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias_initializer = opts[:bias_initializer]
|
|
b1 = param("bias_1", b1_shape, initializer: bias_initializer)
|
|
b2 = param("bias_2", b2_shape, initializer: bias_initializer)
|
|
b3 = param("bias_3", b3_shape, initializer: bias_initializer)
|
|
{[x, k1, b1, k2, b2, k3, b3], :separable_conv3d}
|
|
else
|
|
{[x, k1, k2, k3], :separable_conv3d}
|
|
end
|
|
|
|
node =
|
|
layer(
|
|
op,
|
|
inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
strides: strides,
|
|
padding: padding,
|
|
input_dilation: input_dilation,
|
|
kernel_dilation: kernel_dilation,
|
|
channels: channels,
|
|
op_name: :separable_conv3d
|
|
)
|
|
|
|
if activation = opts[:activation] do
|
|
activation(node, activation)
|
|
else
|
|
node
|
|
end
|
|
end
|
|
|
|
@activation_layers [
|
|
{:celu, "Continuously-differentiable exponential linear unit", "a"},
|
|
{:elu, "Exponential linear unit", "an"},
|
|
{:exp, "Exponential", "an"},
|
|
{:gelu, "Gaussian error linear unit", "a"},
|
|
{:hard_sigmoid, "Hard sigmoid", "a"},
|
|
{:hard_silu, "Hard sigmoid weighted linear unit", "a"},
|
|
{:hard_tanh, "Hard hyperbolic tangent", "a"},
|
|
{:leaky_relu, "Leaky rectified linear unit", "a"},
|
|
{:linear, "Linear", "a"},
|
|
{:log_sumexp, "Log-sumexp", "a"},
|
|
{:log_sigmoid, "Log-sigmoid", "a"},
|
|
{:log_softmax, "Log-softmax", "a"},
|
|
{:mish, "Mish", "a"},
|
|
{:relu, "Rectified linear unit", "a"},
|
|
{:relu6, "Rectified linear unit 6", "a"},
|
|
{:sigmoid, "Sigmoid", "a"},
|
|
{:silu, "Sigmoid weighted linear unit", "a"},
|
|
{:selu, "Scaled exponential linear unit", "a"},
|
|
{:softmax, "Softmax", "a"},
|
|
{:softplus, "Softplus", "a"},
|
|
{:softsign, "Softsign", "a"},
|
|
{:tanh, "Hyperbolic tangent", "a"}
|
|
]
|
|
|
|
@doc """
|
|
Adds an activation layer to the network.
|
|
|
|
Activation layers are element-wise functions typically called
|
|
after the output of another layer.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :activation
|
|
def activation(x, activation, opts \\ [])
|
|
|
|
def activation(%Axon{} = x, activation, opts) when is_atom(activation) do
|
|
opts = opts ++ [op_name: activation]
|
|
layer(activation, [x], opts)
|
|
end
|
|
|
|
def activation(%Axon{} = x, activation, opts)
|
|
when is_function(activation) do
|
|
layer(activation, [x], opts)
|
|
end
|
|
|
|
## Activation
|
|
|
|
for {activation, name, a_or_an} <- @activation_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} activation layer to the network.
|
|
|
|
See `Axon.Activations.#{Atom.to_string(activation)}/1` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :activation
|
|
def unquote(activation)(%Axon{} = x, opts \\ []) do
|
|
activation(x, unquote(activation), opts)
|
|
end
|
|
end
|
|
|
|
## Dropout
|
|
|
|
@dropout_layers [
|
|
{:dropout, "Dropout", "a"},
|
|
{:feature_alpha_dropout, "Feature alpha dropout", "a"},
|
|
{:spatial_dropout, "Spatial dropout", "a"},
|
|
{:alpha_dropout, "Alpha dropout", "an"}
|
|
]
|
|
|
|
for {dropout, name, a_or_an} <- @dropout_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(dropout)}/2` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:rate` - dropout rate. Defaults to `0.5`.
|
|
Needs to be equal or greater than zero and less than one.
|
|
|
|
"""
|
|
@doc type: :dropout
|
|
def unquote(dropout)(%Axon{} = x, opts \\ []) do
|
|
dropout(x, unquote(dropout), opts)
|
|
end
|
|
end
|
|
|
|
defp dropout(%Axon{} = x, dropout, opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, :seed, rate: 0.5])
|
|
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
|
|
|
|
if opts[:rate] < 0 or opts[:rate] >= 1 do
|
|
raise ArgumentError,
|
|
"The dropout rate needs to be >= 0 and < 1, got #{inspect(opts[:rate])}"
|
|
end
|
|
|
|
key_state =
|
|
param("key", fn _ -> {2} end,
|
|
type: {:u, 32},
|
|
initializer: fn _, _ -> Nx.Random.key(seed) end,
|
|
kind: :state
|
|
)
|
|
|
|
layer(dropout, [x, key_state],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
rate: opts[:rate],
|
|
op_name: dropout,
|
|
mode: :train
|
|
)
|
|
end
|
|
|
|
## Pooling
|
|
|
|
@pooling_layers [
|
|
{:max_pool, "Max pool", "a"},
|
|
{:avg_pool, "Average pool", "an"},
|
|
{:lp_pool, "Power average pool", "a"}
|
|
]
|
|
|
|
for {pool, name, a_or_an} <- @pooling_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
|
|
to `1`.
|
|
|
|
* `:strides` - stride during convolution. Defaults to size of kernel.
|
|
|
|
* `:padding` - padding to the spatial dimensions of the input.
|
|
Defaults to `:valid`.
|
|
|
|
* `:dilations` - window dilations. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :pooling
|
|
def unquote(pool)(%Axon{} = x, opts \\ []) do
|
|
pool(x, unquote(pool), opts)
|
|
end
|
|
end
|
|
|
|
defp pool(%Axon{} = x, pool, opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:strides,
|
|
:meta,
|
|
kernel_size: 1,
|
|
padding: :valid,
|
|
channels: :last,
|
|
dilations: 1,
|
|
norm: 2
|
|
])
|
|
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
padding = opts[:padding]
|
|
channels = opts[:channels]
|
|
dilations = opts[:dilations]
|
|
name = opts[:name]
|
|
|
|
opts =
|
|
if pool == :lp_pool do
|
|
norm = opts[:norm]
|
|
|
|
[
|
|
name: name,
|
|
meta: opts[:meta],
|
|
kernel_size: kernel_size,
|
|
strides: strides,
|
|
padding: padding,
|
|
channels: channels,
|
|
window_dilations: dilations,
|
|
norm: norm,
|
|
op_name: pool
|
|
]
|
|
else
|
|
[
|
|
name: name,
|
|
meta: opts[:meta],
|
|
kernel_size: kernel_size,
|
|
strides: strides,
|
|
padding: padding,
|
|
channels: channels,
|
|
window_dilations: dilations,
|
|
op_name: pool
|
|
]
|
|
end
|
|
|
|
layer(pool, [x], opts)
|
|
end
|
|
|
|
@doc """
|
|
Adds a blur pooling layer to the network.
|
|
|
|
See `Axon.Layers.blur_pool/2` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:strides` - stride during convolution. Defaults to `1`.
|
|
|
|
* `:channels` - channels location. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
"""
|
|
def blur_pool(%Axon{} = x, opts \\ []) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
channels: :last
|
|
])
|
|
|
|
channels = opts[:channels]
|
|
name = opts[:name]
|
|
|
|
opts = [
|
|
name: name,
|
|
meta: opts[:meta],
|
|
channels: channels,
|
|
op_name: :blur_pool
|
|
]
|
|
|
|
layer(:blur_pool, [x], opts)
|
|
end
|
|
|
|
## Adaptive Pooling
|
|
|
|
@adaptive_pooling_layers [
|
|
{:adaptive_avg_pool, "Adaptive average pool", "an"},
|
|
{:adaptive_max_pool, "Adaptive max pool", "an"},
|
|
{:adaptive_lp_pool, "Adaptive power average pool", "an"}
|
|
]
|
|
|
|
for {pool, name, a_or_an} <- @adaptive_pooling_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:output_size` - layer output size.
|
|
|
|
* `:channels` - channel configuration. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :pooling
|
|
def unquote(pool)(%Axon{} = x, opts \\ []) do
|
|
adaptive_pool(x, unquote(pool), opts)
|
|
end
|
|
end
|
|
|
|
defp adaptive_pool(%Axon{} = x, pool, opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, :output_size, channels: :last, norm: 2])
|
|
|
|
channels = opts[:channels]
|
|
name = opts[:name]
|
|
output_size = opts[:output_size]
|
|
|
|
opts =
|
|
if pool == :adaptive_lp_pool do
|
|
norm = opts[:norm]
|
|
|
|
[
|
|
name: name,
|
|
meta: opts[:meta],
|
|
output_size: output_size,
|
|
norm: norm,
|
|
channels: channels,
|
|
op_name: pool
|
|
]
|
|
else
|
|
[
|
|
name: name,
|
|
meta: opts[:meta],
|
|
output_size: output_size,
|
|
channels: channels,
|
|
op_name: pool
|
|
]
|
|
end
|
|
|
|
layer(pool, [x], opts)
|
|
end
|
|
|
|
## Global Pooling
|
|
|
|
@global_pooling_layers [
|
|
{:global_avg_pool, "Global average pool"},
|
|
{:global_max_pool, "Global max pool"},
|
|
{:global_lp_pool, "Global LP pool"}
|
|
]
|
|
|
|
for {pool, name} <- @global_pooling_layers do
|
|
@doc """
|
|
Adds a #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
|
|
|
|
Typically used to connect feature extractors such as those in convolutional
|
|
neural networks to fully-connected models by reducing inputs along spatial
|
|
dimensions to only feature and batch dimensions.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:keep_axes` - option to keep reduced axes. If `true`, keeps reduced axes
|
|
with a dimension size of 1.
|
|
|
|
* `:channels` - channel configuration. One of `:first` or `:last`.
|
|
Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :pooling
|
|
def unquote(pool)(%Axon{} = x, opts \\ []) do
|
|
global_pool(x, unquote(pool), opts)
|
|
end
|
|
end
|
|
|
|
defp global_pool(%Axon{} = x, pool, opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, keep_axes: false, channels: :last, norm: 2])
|
|
|
|
keep_axes = opts[:keep_axes]
|
|
name = opts[:name]
|
|
channels = opts[:channels]
|
|
|
|
opts =
|
|
if pool == :global_lp_pool do
|
|
norm = opts[:norm]
|
|
|
|
[
|
|
name: name,
|
|
meta: opts[:meta],
|
|
channels: channels,
|
|
keep_axes: keep_axes,
|
|
norm: norm,
|
|
op_name: pool
|
|
]
|
|
else
|
|
[name: name, meta: opts[:meta], channels: channels, keep_axes: keep_axes, op_name: pool]
|
|
end
|
|
|
|
layer(pool, [x], opts)
|
|
end
|
|
|
|
## Normalization
|
|
|
|
@normalization_with_stats_layers [
|
|
{:batch_norm, "Batch normalization", "a"},
|
|
{:instance_norm, "Instance normalization", "an"}
|
|
]
|
|
|
|
for {norm, name, a_or_an} <- @normalization_with_stats_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(norm)}/6` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:gamma_initializer` - gamma parameter initializer. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:beta_initializer` - beta parameter initializer. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:channel_index` - input feature index used for calculating
|
|
mean and variance. Defaults to `-1`.
|
|
|
|
* `:epsilon` - numerical stability term. Defaults to `1.0e-5`.
|
|
|
|
"""
|
|
@doc type: :normalization
|
|
def unquote(norm)(%Axon{} = x, opts \\ []) do
|
|
norm_with_stats(x, unquote(norm), opts)
|
|
end
|
|
end
|
|
|
|
defp norm_with_stats(%Axon{} = x, norm, opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
gamma_initializer: :glorot_uniform,
|
|
beta_initializer: :zeros,
|
|
channel_index: -1,
|
|
epsilon: 1.0e-5,
|
|
momentum: 0.1
|
|
])
|
|
|
|
channel_index = opts[:channel_index]
|
|
|
|
gamma_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
beta_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
mean_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
var_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
|
|
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
|
|
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
|
|
|
|
mean = param("mean", mean_shape, initializer: :zeros, kind: :state)
|
|
var = param("var", var_shape, initializer: :ones, kind: :state)
|
|
|
|
layer(
|
|
norm,
|
|
[x, gamma, beta, mean, var],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
epsilon: opts[:epsilon],
|
|
channel_index: channel_index,
|
|
momentum: opts[:momentum],
|
|
op_name: norm
|
|
)
|
|
end
|
|
|
|
@normalization_layers [
|
|
{:layer_norm, "Layer normalization", "a"}
|
|
]
|
|
|
|
for {norm, name, a_or_an} <- @normalization_layers do
|
|
@doc """
|
|
Adds #{a_or_an} #{name} layer to the network.
|
|
|
|
See `Axon.Layers.#{Atom.to_string(norm)}/4` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:gamma_initializer` - gamma parameter initializer. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:beta_initializer` - beta parameter initializer. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:channel_index` - input feature index used for calculating
|
|
mean and variance. Defaults to `-1`.
|
|
|
|
* `:epsilon` - numerical stability term.
|
|
|
|
"""
|
|
@doc type: :normalization
|
|
def unquote(norm)(%Axon{} = x, opts \\ []) do
|
|
norm(x, unquote(norm), opts)
|
|
end
|
|
end
|
|
|
|
defp norm(%Axon{} = x, norm, opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
gamma_initializer: :glorot_uniform,
|
|
beta_initializer: :zeros,
|
|
channel_index: -1,
|
|
epsilon: 1.0e-5
|
|
])
|
|
|
|
channel_index = opts[:channel_index]
|
|
|
|
gamma_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
beta_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
|
|
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
|
|
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
|
|
|
|
layer(norm, [x, gamma, beta],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
epsilon: opts[:epsilon],
|
|
channel_index: channel_index,
|
|
op_name: norm
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a group normalization layer to the network.
|
|
|
|
See `Axon.Layers.group_norm/4` for more details.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:gamma_initializer` - gamma parameter initializer. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:beta_initializer` - beta parameter initializer. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:channel_index` - input feature index used for calculating
|
|
mean and variance. Defaults to `-1`.
|
|
|
|
* `:epsilon` - numerical stability term.
|
|
|
|
"""
|
|
@doc type: :normalization
|
|
def group_norm(%Axon{} = x, num_groups, opts \\ [])
|
|
when is_integer(num_groups) and num_groups >= 1 do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
gamma_initializer: :ones,
|
|
beta_initializer: :zeros,
|
|
channel_index: -1,
|
|
epsilon: 1.0e-5
|
|
])
|
|
|
|
channel_index = opts[:channel_index]
|
|
|
|
gamma_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
beta_shape = &Axon.Shape.norm_param(&1, channel_index)
|
|
|
|
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
|
|
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
|
|
|
|
layer(:group_norm, [x, gamma, beta],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
epsilon: opts[:epsilon],
|
|
channel_index: channel_index,
|
|
num_groups: num_groups,
|
|
op_name: :group_norm
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Applies the given `Nx` expression to the input.
|
|
|
|
Nx layers are meant for quick applications of functions without
|
|
trainable parameters. For example, they are useful for applying
|
|
functions which apply accessors to containers:
|
|
|
|
model = Axon.container({foo, bar})
|
|
Axon.nx(model, &elem(&1, 0))
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
def nx(input, fun, opts \\ [])
|
|
|
|
@doc type: :special
|
|
def nx(%Axon{} = x, fun, opts) when is_function(fun, 1) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, :op_name])
|
|
op_name = opts[:op_name] || :nx
|
|
fun_with_params = fn x, _opts -> fun.(x) end
|
|
layer(fun_with_params, [x], name: opts[:name], meta: opts[:meta], op_name: op_name)
|
|
end
|
|
|
|
@doc """
|
|
Adds a flatten layer to the network.
|
|
|
|
This layer will flatten all but the batch dimensions
|
|
of the input into a single layer. Typically called to flatten
|
|
the output of a convolution for use with a dense layer.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :shape
|
|
def flatten(%Axon{} = x, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:flatten, [x],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
op_name: :flatten
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a reshape layer to the network.
|
|
|
|
This layer implements a special case of `Nx.reshape` which accounts
|
|
for possible batch dimensions in the input tensor. You may pass the
|
|
magic dimension `:batch` as a placeholder for dynamic batch sizes.
|
|
You can use `:batch` seamlessly with `:auto` dimension sizes.
|
|
|
|
If the input is an Axon constant, the reshape behavior matches that of
|
|
`Nx.reshape/2`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
"""
|
|
@doc type: :shape
|
|
def reshape(%Axon{} = x, new_shape, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:reshape, [x],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
shape: new_shape,
|
|
op_name: :reshape
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a transpose layer to the network.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :shape
|
|
def transpose(%Axon{} = x, permutation \\ nil, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:transpose, [x],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
axes: permutation,
|
|
op_name: :transpose
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a pad layer to the network.
|
|
|
|
This layer will pad the spatial dimensions of the input.
|
|
Padding configuration is a list of tuples for each spatial
|
|
dimension.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:channels` - channel configuration. One of `:first` or
|
|
`:last`. Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :shape
|
|
def pad(%Axon{} = x, config, value \\ 0.0, opts \\ [])
|
|
when is_list(config) and is_number(value) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, channels: :last])
|
|
channels = opts[:channels]
|
|
|
|
layer(:pad, [x],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
padding_config: config,
|
|
value: value,
|
|
channels: channels,
|
|
op_name: :pad
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a resize layer to the network.
|
|
|
|
Resizing can be used for interpolation or upsampling input
|
|
values in a neural network. For example, you can use this
|
|
layer as an upsampling layer within a GAN.
|
|
|
|
Resize shape must be a tuple representing the resized spatial
|
|
dimensions of the input tensor.
|
|
|
|
Compiles to `Axon.Layers.resize/2`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:method` - resize method. Defaults to `:nearest`.
|
|
|
|
* `:antialias` - whether an anti-aliasing filter should be used
|
|
when downsampling. Defaults to `true`.
|
|
|
|
* `:channels` - channel configuration. One of `:first` or
|
|
`:last`. Defaults to `:last`.
|
|
|
|
"""
|
|
@doc type: :shape
|
|
def resize(%Axon{} = x, resize_shape, opts \\ []) do
|
|
opts =
|
|
Keyword.validate!(opts, [:name, :meta, method: :nearest, antialias: true, channels: :last])
|
|
|
|
channels = opts[:channels]
|
|
|
|
layer(:resize, [x],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
method: opts[:method],
|
|
antialias: opts[:antialias],
|
|
channels: channels,
|
|
size: resize_shape,
|
|
op_name: :resize
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a concatenate layer to the network.
|
|
|
|
This layer will concatenate inputs along the last
|
|
dimension unless specified otherwise.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:axis` - concatenate axis. Defaults to `-1`.
|
|
|
|
"""
|
|
@doc type: :combinator
|
|
def concatenate(%Axon{} = x, %Axon{} = y, opts)
|
|
when is_list(opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, axis: -1])
|
|
axis = opts[:axis]
|
|
|
|
layer(:concatenate, [container({x, y})],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
axis: axis,
|
|
op_name: :concatenate
|
|
)
|
|
end
|
|
|
|
@doc type: :combinator
|
|
def concatenate([%Axon{} | _] = inputs, opts)
|
|
when is_list(inputs) and is_list(opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, axis: -1])
|
|
axis = opts[:axis]
|
|
|
|
layer(:concatenate, [container(List.to_tuple(inputs))],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
axis: axis,
|
|
op_name: :concatenate
|
|
)
|
|
end
|
|
|
|
@doc false
|
|
def concatenate(%Axon{} = x, %Axon{} = y), do: concatenate(x, y, [])
|
|
|
|
@doc false
|
|
def concatenate(inputs) when is_list(inputs), do: concatenate(inputs, [])
|
|
|
|
@element_wise_layers [:add, :subtract, :multiply]
|
|
|
|
for op <- @element_wise_layers do
|
|
@doc """
|
|
Adds a #{op} layer to the network.
|
|
|
|
This layer performs an element-wise #{Atom.to_string(op)} operation
|
|
on input layers. All input layers must be capable of being
|
|
broadcast together.
|
|
|
|
If one shape has a static batch size, all other shapes must have a
|
|
static batch size as well.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :combinator
|
|
def unquote(op)(%Axon{} = x, %Axon{} = y, opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(unquote(op), [container({x, y})],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
op_name: unquote(op)
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Adds a #{op} layer to the network.
|
|
|
|
This layer performs an element-wise #{Atom.to_string(op)} operation
|
|
on all input layers. All input layers must be capable of being
|
|
broadcast together.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
"""
|
|
@doc type: :combinator
|
|
def unquote(op)(inputs, opts) when is_list(inputs) and is_list(opts) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(unquote(op), [container(List.to_tuple(inputs))],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
op_name: unquote(op)
|
|
)
|
|
end
|
|
|
|
@doc false
|
|
def unquote(op)(%Axon{} = x, %Axon{} = y) do
|
|
unquote(op)(x, y, [])
|
|
end
|
|
|
|
@doc false
|
|
def unquote(op)([%Axon{} | _] = inputs), do: unquote(op)(inputs, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a conditional layer which conditionally executes
|
|
`true_graph` or `false_graph` based on the condition `cond_fn`
|
|
at runtime.
|
|
|
|
`cond_fn` is an arity-1 function executed on the output of the
|
|
parent graph. It must return a boolean scalar tensor (e.g. 1 or 0).
|
|
|
|
The shapes of `true_graph` and `false_graph` must be equal.
|
|
"""
|
|
@doc type: :combinator
|
|
def cond(
|
|
%Axon{} = parent,
|
|
cond_fn,
|
|
%Axon{} = true_graph,
|
|
%Axon{} = false_graph,
|
|
opts \\ []
|
|
)
|
|
when is_function(cond_fn, 1) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
layer(:cond, [parent, true_graph, false_graph],
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
cond: cond_fn,
|
|
op_name: :cond
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Splits input graph into a container of `n` input graphs
|
|
along the given axis.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:axis` - concatenate axis. Defaults to `-1`.
|
|
|
|
"""
|
|
@doc type: :combinator
|
|
def split(parent, splits, opts \\ [])
|
|
|
|
def split(%Axon{} = parent, splits, opts) when is_list(splits) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, axis: -1])
|
|
axis = opts[:axis]
|
|
|
|
{_, split_layers} =
|
|
for {split, i} <- Enum.with_index(splits), reduce: {0, []} do
|
|
{num_split, split_layers} ->
|
|
name =
|
|
case opts[:name] do
|
|
names when is_list(names) ->
|
|
Enum.at(names, i)
|
|
|
|
name ->
|
|
name
|
|
end
|
|
|
|
layer =
|
|
layer(
|
|
fn x, _ -> Nx.slice_along_axis(x, num_split, split, axis: axis) end,
|
|
[parent],
|
|
name: name,
|
|
meta: opts[:meta],
|
|
op_name: :split
|
|
)
|
|
|
|
{num_split + split, [layer | split_layers]}
|
|
end
|
|
|
|
split_layers |> Enum.reverse() |> List.to_tuple()
|
|
end
|
|
|
|
def split(%Axon{} = parent, n, opts) when is_integer(n) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, axis: -1])
|
|
axis = opts[:axis]
|
|
|
|
splits =
|
|
for i <- 0..(n - 1) do
|
|
name =
|
|
case opts[:name] do
|
|
names when is_list(names) ->
|
|
Enum.at(names, i)
|
|
|
|
name ->
|
|
name
|
|
end
|
|
|
|
layer(
|
|
&Axon.Layers.split/2,
|
|
[parent],
|
|
name: name,
|
|
meta: opts[:meta],
|
|
index: i,
|
|
splits: n,
|
|
axis: axis,
|
|
op_name: :split
|
|
)
|
|
end
|
|
|
|
List.to_tuple(splits)
|
|
end
|
|
|
|
@doc """
|
|
Computes a sequence mask according to the given EOS token.
|
|
|
|
Masks can be propagated to recurrent layers or custom layers to
|
|
indicate that a given token should be ignored in processing. This
|
|
is useful when you have sequences of variable length.
|
|
|
|
Most commonly, `eos_token` is `0`.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
"""
|
|
@doc type: :recurrent
|
|
def mask(%Axon{} = input, eos_token, opts \\ []) when is_integer(eos_token) do
|
|
opts = Keyword.validate!(opts, [:name, :meta])
|
|
|
|
fun = fn x, opts ->
|
|
Nx.equal(Nx.as_type(x, :s64), opts[:eos_token])
|
|
end
|
|
|
|
layer(fun, [input],
|
|
eos_token: eos_token,
|
|
op_name: :mask,
|
|
meta: opts[:meta],
|
|
name: opts[:name]
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Applies the given forward function bidirectionally and merges
|
|
the results with the given merge function.
|
|
|
|
This is most commonly used with RNNs to capture the dependencies
|
|
of a sequence in both directions.
|
|
|
|
## Options
|
|
|
|
* `axis` - Axis to reverse.
|
|
"""
|
|
def bidirectional(%Axon{} = input, forward_fun, merge_fun, opts \\ [])
|
|
when is_function(forward_fun, 1) and is_function(merge_fun, 2) do
|
|
opts = Keyword.validate!(opts, [:name, axis: 1])
|
|
|
|
fun =
|
|
Axon.block(
|
|
fn x ->
|
|
Axon.container(forward_fun.(x))
|
|
end,
|
|
name: opts[:name]
|
|
)
|
|
|
|
forward_out = fun.(input)
|
|
|
|
backward_out =
|
|
input
|
|
|> Axon.nx(&Nx.reverse(&1, axes: [opts[:axis]]))
|
|
|> fun.()
|
|
|> Axon.nx(fn x ->
|
|
deep_new(x, &Nx.reverse(&1, axes: [opts[:axis]]))
|
|
end)
|
|
|
|
{forward_out, backward_out}
|
|
|> Axon.container()
|
|
|> Axon.nx(fn {forward, backward} ->
|
|
deep_merge(forward, backward, merge_fun)
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
See `lstm/3`.
|
|
"""
|
|
@doc type: :recurrent
|
|
def lstm(%Axon{} = x, units) when is_integer(units) and units > 0 do
|
|
lstm(x, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a long short-term memory (LSTM) layer to the network
|
|
with a random initial hidden state.
|
|
|
|
See `lstm/4` for more details.
|
|
|
|
## Additional options
|
|
|
|
* `:recurrent_initializer` - initializer for hidden state.
|
|
Defaults to `:orthogonal`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def lstm(%Axon{} = x, units, opts)
|
|
when is_integer(units) and units > 0 and is_list(opts) do
|
|
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
|
|
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
|
|
c = rnn_state(x, units, :lstm, opts[:name], "c", recurrent_initializer, seed)
|
|
h = rnn_state(x, units, :lstm, opts[:name], "h", recurrent_initializer, seed)
|
|
lstm(x, {c, h}, units, opts)
|
|
end
|
|
|
|
def lstm(%Axon{} = x, {%Axon{}, %Axon{}} = hidden_state, units)
|
|
when is_integer(units) and units > 0 do
|
|
lstm(x, hidden_state, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a long short-term memory (LSTM) layer to the network
|
|
with the given initial hidden state.
|
|
|
|
LSTMs apply `Axon.Layers.lstm_cell/7` over an entire input
|
|
sequence and return:
|
|
|
|
{output_sequence, {new_cell, new_hidden}}
|
|
|
|
You can use the output state as the hidden state of another
|
|
LSTM layer.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:activation` - recurrent activation. Defaults to `:tanh`.
|
|
|
|
* `:gate` - recurrent gate function. Defaults to `:sigmoid`.
|
|
|
|
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
|
|
unrolling of RNN.
|
|
|
|
* `:kernel_initializer` - initializer for kernel weights. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for bias weights. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def lstm(
|
|
%Axon{} = x,
|
|
{%Axon{}, %Axon{}} = hidden_state,
|
|
units,
|
|
opts \\ []
|
|
)
|
|
when is_integer(units) and units > 0 and is_list(opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
activation: :tanh,
|
|
gate: :sigmoid,
|
|
unroll: :dynamic,
|
|
use_bias: true,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
mask: Axon.constant(0)
|
|
])
|
|
|
|
activation = opts[:activation]
|
|
gate = opts[:gate]
|
|
unroll = opts[:unroll]
|
|
|
|
kernel_initializer = opts[:kernel_initializer]
|
|
|
|
input_kernel_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :lstm)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
hidden_kernel_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :lstm)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
bias_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :lstm)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
initializer = fn prefix, init ->
|
|
fn shape, type, key ->
|
|
split_key = Nx.Random.split(key, parts: 4)
|
|
|
|
init =
|
|
if is_atom(init) do
|
|
apply(Axon.Initializers, init, [])
|
|
else
|
|
init
|
|
end
|
|
|
|
fun =
|
|
case init do
|
|
init when is_function(init, 2) ->
|
|
fn _ -> init.(shape, type) end
|
|
|
|
init when is_function(init, 3) ->
|
|
fn key -> init.(shape, type, key) end
|
|
end
|
|
|
|
%{
|
|
"#{prefix}i" => fun.(split_key[0]),
|
|
"#{prefix}f" => fun.(split_key[1]),
|
|
"#{prefix}g" => fun.(split_key[2]),
|
|
"#{prefix}o" => fun.(split_key[3])
|
|
}
|
|
end
|
|
end
|
|
|
|
# Parameters
|
|
input_kernel =
|
|
parameter("input_kernel", input_kernel_template,
|
|
initializer: initializer.("wi", kernel_initializer)
|
|
)
|
|
|
|
hidden_kernel =
|
|
parameter("hidden_kernel", hidden_kernel_template,
|
|
initializer: initializer.("wh", kernel_initializer)
|
|
)
|
|
|
|
hidden_state_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"lstm_#{op_counts[:lstm]}_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_hidden_state"
|
|
end
|
|
|
|
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias_initializer = opts[:bias_initializer]
|
|
|
|
bias = parameter("bias", bias_template, initializer: initializer.("b", bias_initializer))
|
|
|
|
{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm}
|
|
else
|
|
{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel], &Axon.Layers.lstm/6}
|
|
end
|
|
|
|
output =
|
|
layer(
|
|
op,
|
|
inputs,
|
|
name: opts[:name],
|
|
meta: opts[:meta],
|
|
activation: activation,
|
|
gate: gate,
|
|
unroll: unroll,
|
|
op_name: :lstm
|
|
)
|
|
|
|
new_c_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"lstm_#{op_counts[:lstm]}_c_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_c_hidden_state"
|
|
end
|
|
|
|
new_h_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"lstm_#{op_counts[:lstm]}_h_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_h_hidden_state"
|
|
end
|
|
|
|
output_sequence_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"lstm_#{op_counts[:lstm]}_output_sequence"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_output_sequence"
|
|
end
|
|
|
|
output_sequence =
|
|
layer(fn x, _ -> elem(x, 0) end, [output],
|
|
name: output_sequence_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
new_c =
|
|
layer(fn x, _ -> elem(elem(x, 1), 0) end, [output],
|
|
name: new_c_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
new_h =
|
|
layer(fn x, _ -> elem(elem(x, 1), 1) end, [output],
|
|
name: new_h_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
{output_sequence, {new_c, new_h}}
|
|
end
|
|
|
|
@doc """
|
|
See `gru/3`.
|
|
"""
|
|
@doc type: :recurrent
|
|
def gru(%Axon{} = x, units) do
|
|
gru(x, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a gated recurrent unit (GRU) layer to the network with
|
|
a random initial hidden state.
|
|
|
|
See `gru/4` for more details.
|
|
|
|
## Additional options
|
|
|
|
* `:recurrent_initializer` - initializer for hidden state.
|
|
Defaults to `:orthogonal`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def gru(%Axon{} = x, units, opts)
|
|
when is_integer(units) and units > 0
|
|
when is_list(opts) do
|
|
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
|
|
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
|
|
h = rnn_state(x, units, :gru, opts[:name], "h", recurrent_initializer, seed)
|
|
gru(x, {h}, units, opts)
|
|
end
|
|
|
|
def gru(%Axon{} = x, {%Axon{}} = hidden_state, units) when is_integer(units) and units > 0 do
|
|
gru(x, hidden_state, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a gated recurrent unit (GRU) layer to the network with
|
|
the given initial hidden state.
|
|
|
|
GRUs apply `Axon.Layers.gru_cell/7` over an entire input
|
|
sequence and return:
|
|
|
|
{{new_hidden}, output_sequence}
|
|
|
|
You can use the output state as the hidden state of another
|
|
GRU layer.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:activation` - recurrent activation. Defaults to `:tanh`.
|
|
|
|
* `:gate` - recurrent gate function. Defaults to `:sigmoid`.
|
|
|
|
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
|
|
unrolling of RNN.
|
|
|
|
* `:kernel_initializer` - initializer for kernel weights. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for bias weights. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def gru(
|
|
%Axon{} = x,
|
|
{%Axon{}} = hidden_state,
|
|
units,
|
|
opts
|
|
)
|
|
when is_integer(units) and units > 0 and is_list(opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
mask: Axon.constant(0),
|
|
activation: :tanh,
|
|
gate: :sigmoid,
|
|
unroll: :dynamic,
|
|
use_bias: true,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros
|
|
])
|
|
|
|
activation = opts[:activation]
|
|
gate = opts[:gate]
|
|
unroll = opts[:unroll]
|
|
|
|
input_kernel_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :gru)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
hidden_kernel_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :gru)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
bias_template = fn inp, _, _ ->
|
|
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :gru)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
initializer = fn prefix, init ->
|
|
fn shape, type, key ->
|
|
split_key = Nx.Random.split(key, parts: 3)
|
|
|
|
init =
|
|
if is_atom(init) do
|
|
apply(Axon.Initializers, init, [])
|
|
else
|
|
init
|
|
end
|
|
|
|
fun =
|
|
case init do
|
|
init when is_function(init, 2) ->
|
|
fn _ -> init.(shape, type) end
|
|
|
|
init when is_function(init, 3) ->
|
|
fn key -> init.(shape, type, key) end
|
|
end
|
|
|
|
%{
|
|
"#{prefix}r" => fun.(split_key[0]),
|
|
"#{prefix}z" => fun.(split_key[1]),
|
|
"#{prefix}n" => fun.(split_key[2])
|
|
}
|
|
end
|
|
end
|
|
|
|
input_kernel =
|
|
parameter("input_kernel", input_kernel_template,
|
|
initializer: initializer.("wi", opts[:kernel_initializer])
|
|
)
|
|
|
|
hidden_kernel =
|
|
parameter("hidden_kernel", hidden_kernel_template,
|
|
initializer: initializer.("wh", opts[:kernel_initializer])
|
|
)
|
|
|
|
hidden_state_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"gru_#{op_counts[:gru]}_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_hidden_state"
|
|
end
|
|
|
|
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
|
|
|
|
inputs =
|
|
if opts[:use_bias] do
|
|
bias_initializer = fn shape, type, key ->
|
|
split_key = Nx.Random.split(key, parts: 4)
|
|
|
|
init =
|
|
if is_atom(opts[:bias_initializer]) do
|
|
apply(Axon.Initializers, opts[:bias_initializer], [])
|
|
else
|
|
opts[:bias_initializer]
|
|
end
|
|
|
|
fun =
|
|
case init do
|
|
init when is_function(init, 2) ->
|
|
fn _ -> init.(shape, type) end
|
|
|
|
init when is_function(init, 3) ->
|
|
fn key -> init.(shape, type, key) end
|
|
end
|
|
|
|
%{
|
|
"br" => fun.(split_key[0]),
|
|
"bz" => fun.(split_key[1]),
|
|
"bin" => fun.(split_key[2]),
|
|
"bhn" => fun.(split_key[3])
|
|
}
|
|
end
|
|
|
|
bias = parameter("bias", bias_template, initializer: bias_initializer)
|
|
|
|
[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias]
|
|
else
|
|
[x, hidden_state, opts[:mask], input_kernel, hidden_kernel]
|
|
end
|
|
|
|
output =
|
|
layer(
|
|
:gru,
|
|
inputs,
|
|
meta: opts[:meta],
|
|
name: opts[:name],
|
|
activation: activation,
|
|
gate: gate,
|
|
unroll: unroll,
|
|
op_name: :gru
|
|
)
|
|
|
|
new_h_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"gru_#{op_counts[:gru]}_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_hidden_state"
|
|
end
|
|
|
|
output_sequence_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"gru_#{op_counts[:gru]}_output_sequence"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_output_sequence"
|
|
end
|
|
|
|
output_sequence =
|
|
layer(fn x, _ -> elem(x, 0) end, [output],
|
|
name: output_sequence_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
new_h =
|
|
layer(fn x, _ -> elem(elem(x, 1), 0) end, [output],
|
|
name: new_h_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
{output_sequence, {new_h}}
|
|
end
|
|
|
|
@doc """
|
|
See `conv_lstm/3`.
|
|
"""
|
|
@doc type: :recurrent
|
|
def conv_lstm(%Axon{} = x, units) when is_integer(units) and units > 0 do
|
|
conv_lstm(x, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a convolutional long short-term memory (LSTM) layer to the network
|
|
with a random initial hidden state.
|
|
|
|
See `conv_lstm/4` for more details.
|
|
|
|
## Additional options
|
|
|
|
* `:recurrent_initializer` - initializer for hidden state. Defaults
|
|
to `:orthogonal`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def conv_lstm(%Axon{} = x, units, opts)
|
|
when is_integer(units) and units > 0 and is_list(opts) do
|
|
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
|
|
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
|
|
c = rnn_state(x, units, :conv_lstm, opts[:name], "c", recurrent_initializer, seed)
|
|
h = rnn_state(x, units, :conv_lstm, opts[:name], "h", recurrent_initializer, seed)
|
|
conv_lstm(x, {c, h}, units, opts)
|
|
end
|
|
|
|
def conv_lstm(%Axon{} = x, {%Axon{}, %Axon{}} = hidden_state, units)
|
|
when is_integer(units) and units > 0 do
|
|
conv_lstm(x, hidden_state, units, [])
|
|
end
|
|
|
|
@doc """
|
|
Adds a convolutional long short-term memory (LSTM) layer to the network
|
|
with the given initial hidden state..
|
|
|
|
ConvLSTMs apply `Axon.Layers.conv_lstm_cell/5` over an entire input
|
|
sequence and return:
|
|
|
|
{{new_cell, new_hidden}, output_sequence}
|
|
|
|
You can use the output state as the hidden state of another
|
|
ConvLSTM layer.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:padding` - convolutional padding. Defaults to `:same`.
|
|
|
|
* `:kernel_size` - convolutional kernel size. Defaults to `1`.
|
|
|
|
* `:strides` - convolutional strides. Defaults to `1`.
|
|
|
|
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
|
|
unrolling of RNN.
|
|
|
|
* `:kernel_initializer` - initializer for kernel weights. Defaults
|
|
to `:glorot_uniform`.
|
|
|
|
* `:bias_initializer` - initializer for bias weights. Defaults to
|
|
`:zeros`.
|
|
|
|
* `:use_bias` - whether the layer should add bias to the output.
|
|
Defaults to `true`.
|
|
|
|
"""
|
|
@doc type: :recurrent
|
|
def conv_lstm(
|
|
%Axon{} = x,
|
|
{%Axon{}, %Axon{}} = hidden_state,
|
|
units,
|
|
opts
|
|
)
|
|
when is_integer(units) and units > 0 and is_list(opts) do
|
|
opts =
|
|
Keyword.validate!(opts, [
|
|
:name,
|
|
:meta,
|
|
mask: Axon.constant(0),
|
|
padding: :same,
|
|
kernel_size: 1,
|
|
strides: 1,
|
|
unroll: :dynamic,
|
|
kernel_initializer: :glorot_uniform,
|
|
bias_initializer: :zeros,
|
|
use_bias: true
|
|
])
|
|
|
|
padding = opts[:padding]
|
|
kernel_size = opts[:kernel_size]
|
|
strides = opts[:strides]
|
|
unroll = opts[:unroll]
|
|
kernel_initializer = opts[:kernel_initializer]
|
|
|
|
hidden_kernel_template = fn _, {inp, _}, _ ->
|
|
shape = Tuple.delete_at(Nx.shape(inp), 1)
|
|
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
input_kernel_template = fn inp, _, _ ->
|
|
shape = Tuple.delete_at(Nx.shape(inp), 1)
|
|
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
bias_template = fn inp, _, _ ->
|
|
shape = Tuple.delete_at(Nx.shape(inp), 1)
|
|
shape = Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
|
|
Nx.template(shape, :f32)
|
|
end
|
|
|
|
wi = parameter("input_kernel", input_kernel_template, initializer: kernel_initializer)
|
|
wh = parameter("hidden_kernel", hidden_kernel_template, initializer: kernel_initializer)
|
|
|
|
hidden_state_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"conv_lstm_#{op_counts[:conv_lstm]}_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_hidden_state"
|
|
end
|
|
|
|
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
|
|
|
|
{inputs, op} =
|
|
if opts[:use_bias] do
|
|
bias_initializer = opts[:bias_initializer]
|
|
b = parameter("bias", bias_template, initializer: bias_initializer)
|
|
{[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm}
|
|
else
|
|
{[x, hidden_state, opts[:mask], wi, wh], :conv_lstm}
|
|
end
|
|
|
|
output =
|
|
layer(
|
|
op,
|
|
inputs,
|
|
meta: opts[:meta],
|
|
name: opts[:name],
|
|
conv_opts: [
|
|
strides: strides,
|
|
padding: padding
|
|
],
|
|
unroll: unroll,
|
|
op_name: :conv_lstm
|
|
)
|
|
|
|
new_c_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"conv_lstm_#{op_counts[:lstm]}_c_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_c_hidden_state"
|
|
end
|
|
|
|
new_h_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"conv_lstm_#{op_counts[:lstm]}_h_hidden_state"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_h_hidden_state"
|
|
end
|
|
|
|
output_sequence_name =
|
|
case opts[:name] do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
"conv_lstm_#{op_counts[:lstm]}_output_sequence"
|
|
end
|
|
|
|
name when is_binary(name) ->
|
|
"#{name}_output_sequence"
|
|
end
|
|
|
|
output_sequence =
|
|
layer(fn x, _ -> elem(x, 0) end, [output],
|
|
name: output_sequence_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
new_c =
|
|
layer(fn x, _ -> elem(elem(x, 1), 0) end, [output],
|
|
name: new_c_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
new_h =
|
|
layer(fn x, _ -> elem(elem(x, 1), 1) end, [output],
|
|
name: new_h_name,
|
|
op_name: :elem
|
|
)
|
|
|
|
{output_sequence, {new_c, new_h}}
|
|
end
|
|
|
|
defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer, seed) do
|
|
initializer = initializer || :glorot_uniform
|
|
|
|
key_state =
|
|
param("key", fn _ -> {2} end,
|
|
type: {:u, 32},
|
|
initializer: fn _, _ -> Nx.Random.key(seed) end,
|
|
kind: :state
|
|
)
|
|
|
|
name =
|
|
case parent_name do
|
|
nil ->
|
|
fn _, op_counts ->
|
|
count = op_counts[rnn_type] || 0
|
|
"#{Atom.to_string(rnn_type)}_#{count}_#{state_name}_hidden_state"
|
|
end
|
|
|
|
parent_name when is_binary(parent_name) ->
|
|
"#{parent_name}_#{state_name}_hidden_state"
|
|
end
|
|
|
|
initializer =
|
|
if is_function(initializer) do
|
|
initializer
|
|
else
|
|
apply(Axon.Initializers, initializer, [])
|
|
end
|
|
|
|
{:arity, arity} = Function.info(initializer, :arity)
|
|
|
|
{fun, inputs} =
|
|
cond do
|
|
arity == 2 ->
|
|
fun =
|
|
fn inputs, _opts ->
|
|
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
|
|
initializer.(shape, {:f, 32})
|
|
end
|
|
|
|
{fun, [x]}
|
|
|
|
arity == 3 ->
|
|
fun =
|
|
fn inputs, key, opts ->
|
|
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
|
|
keys = Nx.Random.split(key)
|
|
out = initializer.(shape, {:f, 32}, keys[1])
|
|
|
|
if opts[:mode] == :train do
|
|
%Axon.StatefulOutput{output: out, state: %{"key" => keys[0]}}
|
|
else
|
|
out
|
|
end
|
|
end
|
|
|
|
{fun, [x, key_state]}
|
|
|
|
true ->
|
|
raise ArgumentError, "bad arity for initializer"
|
|
end
|
|
|
|
layer(fun, inputs, name: name, op_name: :recurrent_state)
|
|
end
|
|
|
|
@doc """
|
|
Adds an embedding layer to the network.
|
|
|
|
An embedding layer initializes a kernel of shape `{vocab_size, embedding_size}`
|
|
which acts as a lookup table for sequences of discrete tokens (e.g. sentences).
|
|
Embeddings are typically used to obtain a dense representation of a sparse input
|
|
space.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:kernel_initializer` - initializer for `kernel` weights. Defaults
|
|
to `:uniform`.
|
|
|
|
"""
|
|
@doc type: :linear
|
|
def embedding(%Axon{} = x, vocab_size, embedding_size, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, kernel_initializer: :uniform])
|
|
|
|
kernel_shape = &Axon.Shape.embedding_kernel(&1, vocab_size, embedding_size)
|
|
|
|
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
|
|
|
|
layer(:embedding, [x, kernel], name: opts[:name], meta: opts[:meta], op_name: :embedding)
|
|
end
|
|
|
|
@doc """
|
|
Adds a bias layer to the network.
|
|
|
|
A bias layer simply adds a trainable bias to an input.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:bias_initializer` - initializer for `bias` weights. Defaults
|
|
to `:zeros`.
|
|
|
|
"""
|
|
@doc type: :linear
|
|
def bias(%Axon{} = x, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, :meta, bias_initializer: :zeros])
|
|
|
|
bias_shape = fn shape -> {elem(shape, tuple_size(shape) - 1)} end
|
|
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
|
|
|
|
layer(:bias, [x, bias], name: opts[:name], meta: opts[:meta], op_name: :bias)
|
|
end
|
|
|
|
@doc """
|
|
Adds a stack columns layer to the network.
|
|
|
|
A stack columns layer is designed to be used with `Nx.LazyContainer`
|
|
data structures like Explorer DataFrames. Given an input which is a
|
|
DataFrame, `stack_columns/2` will stack the columns in each row to
|
|
create a single vector.
|
|
|
|
You may optionally specify `:ignore` to ignore certain columns in
|
|
the container.
|
|
|
|
## Options
|
|
|
|
* `:name` - layer name.
|
|
|
|
* `:ignore` - keys to ignore when stacking.
|
|
"""
|
|
@doc type: :special
|
|
def stack_columns(%Axon{} = x, opts \\ []) do
|
|
opts = Keyword.validate!(opts, [:name, ignore: []])
|
|
|
|
layer(:stack_columns, [x],
|
|
meta: opts[:meta],
|
|
name: opts[:name],
|
|
ignore: opts[:ignore],
|
|
op_name: :stack_columns
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Freezes parameters returned from the given function or predicate.
|
|
|
|
`fun` can be a predicate `:all`, `up: n`, or `down: n`. `:all`
|
|
freezes all parameters in the model, `up: n` freezes the first `n`
|
|
layers up (starting from output), and `down: n` freezes the first `n`
|
|
layers down (starting from input).
|
|
|
|
`fun` may also be a predicate function which takes a parameter and
|
|
returns `true` if a parameter should be frozen or `false` otherwise.
|
|
|
|
Freezing parameters is useful when performing transfer learning
|
|
to leverage features learned from another problem in a new problem.
|
|
For example, it's common to combine the convolutional base from
|
|
larger models trained on ImageNet with fresh fully-connected classifiers.
|
|
The combined model is then trained on fresh data, with the convolutional
|
|
base frozen so as not to lose information. You can see this example
|
|
in code here:
|
|
|
|
cnn_base = get_pretrained_cnn_base()
|
|
model =
|
|
cnn_base
|
|
|> Axon.freeze()
|
|
|> Axon.flatten()
|
|
|> Axon.dense(1024, activation: :relu)
|
|
|> Axon.dropout()
|
|
|> Axon.dense(1000, activation: :softmax)
|
|
|
|
model
|
|
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.005))
|
|
|> Axon.Loop.run(data, epochs: 10)
|
|
|
|
When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`,
|
|
which zeros out the gradient with respect to the frozen parameter. Gradients
|
|
of frozen parameters will return `0.0`, meaning they won't be changed during
|
|
the update process.
|
|
"""
|
|
@doc type: :model
|
|
@deprecated "Use Axon.ModelState.freeze/2 instead"
|
|
def freeze(model, fun_or_predicate \\ :all) do
|
|
freeze(model, fun_or_predicate, true)
|
|
end
|
|
|
|
defp freeze(%Axon{output: id, nodes: nodes} = axon, fun_or_predicate, flag) do
|
|
{nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())
|
|
|
|
nodes =
|
|
case fun_or_predicate do
|
|
:all ->
|
|
freeze_nodes(nodes, flag)
|
|
|
|
[{:up, n}] ->
|
|
{pre, post} = Enum.split(nodes, n)
|
|
freeze_nodes(pre, flag) ++ post
|
|
|
|
[{:down, n}] ->
|
|
{pre, post} = Enum.split(nodes, -n)
|
|
pre ++ freeze_nodes(post, flag)
|
|
|
|
fun ->
|
|
Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node ->
|
|
%{
|
|
axon_node
|
|
| parameters:
|
|
Enum.map(params, fn p ->
|
|
if fun.(p), do: %{p | frozen: flag}, else: p
|
|
end)
|
|
}
|
|
end)
|
|
end
|
|
|
|
%{axon | nodes: Map.new(nodes, fn %{id: id} = node -> {id, node} end)}
|
|
end
|
|
|
|
defp freeze_nodes(nodes, flag) do
|
|
Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node ->
|
|
%{axon_node | parameters: Enum.map(params, fn p -> %{p | frozen: flag} end)}
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Unfreezes parameters returned from the given function or predicate.
|
|
|
|
`fun` can be a predicate `:all`, `up: n`, or `down: n`. `:all`
|
|
freezes all parameters in the model, `up: n` unfreezes the first `n`
|
|
layers up (starting from output), and `down: n` freezes the first `n`
|
|
layers down (starting from input).
|
|
|
|
`fun` may also be a predicate function which takes a parameter and
|
|
returns `true` if a parameter should be unfrozen or `false` otherwise.
|
|
|
|
Unfreezing parameters is useful when fine tuning a model which you
|
|
have previously frozen and performed transfer learning on. You may
|
|
want to unfreeze some of the later frozen layers in a model and
|
|
fine tune them specifically for your application:
|
|
|
|
cnn_base = get_pretrained_cnn_base()
|
|
model =
|
|
frozen_model
|
|
|> Axon.unfreeze(up: 25)
|
|
|
|
model
|
|
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.0005))
|
|
|> Axon.Loop.run(data, epochs: 10)
|
|
|
|
When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`,
|
|
which zeros out the gradient with respect to the frozen parameter. Gradients
|
|
of frozen parameters will return `0.0`, meaning they won't be changed during
|
|
the update process.
|
|
"""
|
|
@doc type: :model
|
|
@deprecated "Use Axon.ModelState.freeze/2 instead"
|
|
def unfreeze(model, fun_or_predicate \\ :all) do
|
|
freeze(model, fun_or_predicate, false)
|
|
end
|
|
|
|
@doc """
|
|
Attaches a hook to the given Axon model.
|
|
|
|
Hooks compile down to `Nx.Defn.Kernel.hook/3` and provide the same
|
|
functionality for adding side-effecting operations to a compiled
|
|
model. For example, you can use hooks to inspect intermediate activations,
|
|
send data to an external service, and more.
|
|
|
|
Hooks can be configured to be invoked on the following events:
|
|
|
|
* `:initialize` - on model initialization.
|
|
* `:pre_forward` - before layer forward pass is invoked.
|
|
* `:forward` - after layer forward pass is invoked.
|
|
* `:backward` - after layer backward pass is invoked.
|
|
|
|
To invoke a hook on every single event, you may pass `:all` to `on:`.
|
|
|
|
Axon.input("input", shape: {nil, 1}) |> Axon.attach_hook(&IO.inspect/1, on: :all)
|
|
|
|
The default event is `:forward`, assuming you want a hook invoked
|
|
on the layers forward pass.
|
|
|
|
You may configure hooks to run in one of only training or inference
|
|
mode using the `:mode` option. The default mode is `:both` to be invoked
|
|
during both train and inference mode.
|
|
|
|
Axon.input("input", shape: {nil, 1}) |> Axon.attach_hook(&IO.inspect/1, on: :forward, mode: :train)
|
|
|
|
You can also attach multiple hooks to a single layer. Hooks are invoked in
|
|
the order in which they are declared. If order is important, you should attach
|
|
hooks in the order you want them to be executed:
|
|
|
|
Axon.input("input", shape: {nil, 1})
|
|
# I will be executed first
|
|
|> Axon.attach_hook(&IO.inspect/1)
|
|
# I will be executed second
|
|
|> Axon.attach_hook(fn _ -> IO.write("HERE") end)
|
|
|
|
Hooks are executed at their point of attachment. You must insert hooks at each point
|
|
you want a hook to execute during model execution.
|
|
|
|
Axon.input("input", shape: {nil, 1})
|
|
|> Axon.attach_hook(&IO.inspect/1)
|
|
|> Axon.relu()
|
|
|> Axon.attach_hook(&IO.inspect/1)
|
|
|
|
"""
|
|
@doc type: :debug
|
|
def attach_hook(x, fun, opts \\ [])
|
|
|
|
def attach_hook(%Axon{output: id, nodes: nodes} = axon, fun, opts) do
|
|
updated_nodes =
|
|
Map.update!(nodes, id, fn axon_node ->
|
|
attach_hook(axon_node, fun, opts)
|
|
end)
|
|
|
|
%{axon | nodes: updated_nodes}
|
|
end
|
|
|
|
def attach_hook(%Axon.Node{hooks: hooks} = axon_node, fun, opts) do
|
|
opts = Keyword.validate!(opts, on: :forward, mode: :both)
|
|
on_event = opts[:on]
|
|
mode = opts[:mode]
|
|
|
|
%{axon_node | hooks: [{on_event, mode, fun} | hooks]}
|
|
end
|
|
|
|
## Graph Manipulation and Utilities
|
|
|
|
# TODO: Revisit later with new decoupled structs
|
|
# e.g. there should be a node API and graph API
|
|
|
|
@doc """
|
|
Returns a node's immediate parameters.
|
|
|
|
Note this does not take into account parameters of
|
|
parent layers - only the parameters which belong to
|
|
the immediate layer.
|
|
"""
|
|
@doc type: :graph
|
|
def get_parameters(%Axon{output: id, nodes: nodes}) do
|
|
Access.get(nodes, [id, :parameters])
|
|
end
|
|
|
|
@doc """
|
|
Sets a node's immediate parameters to the given
|
|
parameters.
|
|
|
|
Note this does not take into account parameters of
|
|
parent layers - only the parameters which belong to
|
|
the immediate layer.
|
|
|
|
The new parameters must be compatible with the layer's
|
|
old parameters.
|
|
"""
|
|
@doc type: :graph
|
|
def set_parameters(%Axon{output: id, nodes: nodes} = axon, new_params) do
|
|
# TODO: Check compatibility
|
|
updated_nodes =
|
|
Map.update!(nodes, id, fn axon_node ->
|
|
%{axon_node | parameters: new_params}
|
|
end)
|
|
|
|
%{axon | nodes: updated_nodes}
|
|
end
|
|
|
|
@doc """
|
|
Returns a node's immediate input options.
|
|
|
|
Note that this does not take into account options of
|
|
parent layers, only the option which belong to the
|
|
immediate layer.
|
|
"""
|
|
@doc type: :graph
|
|
def get_options(%Axon{output: id, nodes: nodes}) do
|
|
Access.get(nodes, [id, :opts])
|
|
end
|
|
|
|
@doc """
|
|
Sets a node's immediate options to the given input
|
|
options.
|
|
|
|
Note that this does not take into account options of
|
|
parent layers, only the option which belong to the
|
|
immediate layer.
|
|
|
|
New options must be compatible with the given layer
|
|
op. Adding unsupported options to an Axon layer will
|
|
result in an error at graph execution time.
|
|
"""
|
|
@doc type: :graph
|
|
def set_options(%Axon{output: id, nodes: nodes} = axon, new_opts) do
|
|
updated_nodes =
|
|
Map.update!(nodes, id, fn axon_node ->
|
|
%{axon_node | opts: new_opts}
|
|
end)
|
|
|
|
%{axon | nodes: updated_nodes}
|
|
end
|
|
|
|
@doc """
|
|
Returns information about a model's inputs.
|
|
"""
|
|
@doc type: :graph
|
|
def get_inputs(%Axon{} = axon) do
|
|
reduce_nodes(axon, %{}, fn
|
|
%Axon.Node{op: :input, name: name, opts: opts}, inputs ->
|
|
name = name.(:input, %{})
|
|
Map.put(inputs, name, opts[:shape])
|
|
|
|
_, inputs ->
|
|
inputs
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Returns a model's output template from the given input
|
|
template.
|
|
|
|
The output template gives you access to the output shape
|
|
and type of the given input graph.
|
|
"""
|
|
@doc type: :graph
|
|
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
|
|
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])
|
|
|
|
inputs =
|
|
case inputs do
|
|
%Nx.Tensor{} = input -> Nx.to_template(input)
|
|
inputs when is_map(inputs) -> Map.new(inputs, fn {k, v} -> {k, Nx.to_template(v)} end)
|
|
end
|
|
|
|
fun =
|
|
Nx.Defn.jit(
|
|
fn inputs ->
|
|
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
|
|
end,
|
|
compiler: Axon.Defn
|
|
)
|
|
|
|
deep_new(apply(fun, [inputs]), &Nx.to_template/1)
|
|
end
|
|
|
|
@doc """
|
|
Returns a map of model op counts for each unique operation
|
|
in a model by their given `:op_name`.
|
|
|
|
## Examples
|
|
|
|
iex> model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
|
|
iex> Axon.get_op_counts(model)
|
|
%{input: 1, dense: 1}
|
|
|
|
iex> model = Axon.input("input", shape: {nil, 1}) |> Axon.tanh() |> Axon.tanh()
|
|
iex> Axon.get_op_counts(model)
|
|
%{input: 1, tanh: 2}
|
|
|
|
"""
|
|
@doc type: :graph
|
|
def get_op_counts(%Axon{} = axon) do
|
|
reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts ->
|
|
Map.update(op_counts, op, 1, fn x -> x + 1 end)
|
|
end)
|
|
end
|
|
|
|
@doc """
|
|
Traverses graph nodes in order, applying `fun` to each
|
|
node exactly once to return a transformed node in its
|
|
place(s) in the graph.
|
|
|
|
This function maintains an internal cache which ensures
|
|
each node is only visited and transformed exactly once.
|
|
|
|
`fun` must accept an Axon node and return an Axon node.
|
|
|
|
Please note that modifying node lineage (e.g. altering
|
|
a node's parent) will result in disconnected graphs.
|
|
|
|
## Examples
|
|
|
|
One common use of this function is to implement common
|
|
instrumentation between layers without needing to build
|
|
a new explicitly instrumented version of a model. For example,
|
|
you can use this function to visualize intermediate activations
|
|
of all convolutional layers in a model:
|
|
|
|
instrumented_model = Axon.map_nodes(model, fn
|
|
%Axon.Node{op: :conv} = axon_node ->
|
|
Axon.attach_hook(axon_node, &visualize_activations/1)
|
|
|
|
axon_node ->
|
|
axon_node
|
|
end)
|
|
|
|
Another use case is to replace entire classes of layers
|
|
with another. For example, you may want to replace all
|
|
relu layers with tanh layers:
|
|
|
|
new_model = Axon.map_nodes(model, fn
|
|
%Axon.Node{op: :relu} = axon_node ->
|
|
%{axon_node | op: :tanh}
|
|
|
|
graph ->
|
|
graph
|
|
end)
|
|
|
|
For more complex graph rewriting and manipulation cases, see
|
|
`Axon.rewrite_nodes/2`.
|
|
"""
|
|
@doc type: :graph
|
|
def map_nodes(%Axon{output: id, nodes: nodes} = axon, fun) when is_function(fun, 1) do
|
|
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())
|
|
updated_nodes = Map.new(inorder_nodes, fn %{id: id} = axon_node -> {id, fun.(axon_node)} end)
|
|
%{axon | nodes: updated_nodes}
|
|
end
|
|
|
|
@doc """
|
|
Traverses graph nodes in order, applying `fun` to each
|
|
node exactly once to return a transformed node in its
|
|
place(s) in the graph.
|
|
|
|
This function maintains an internal cache which ensures
|
|
each node is only visited and transformed exactly once.
|
|
|
|
`fun` must accept an Axon node and accumulator and return
|
|
an updated accumulator.
|
|
|
|
## Examples
|
|
|
|
Internally this function is used in several places to accumulate
|
|
graph metadata. For example, you can use it to count the number
|
|
of a certain type of operation in the graph:
|
|
|
|
Axon.reduce_nodes(model, 0, fn
|
|
%Axon.Nodes{op: :relu}, acc -> acc + 1
|
|
_, acc -> acc
|
|
end)
|
|
|
|
"""
|
|
@doc type: :graph
|
|
def reduce_nodes(%Axon{output: id, nodes: nodes}, acc, fun) when is_function(fun, 2) do
|
|
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())
|
|
|
|
Enum.reduce(inorder_nodes, acc, fun)
|
|
end
|
|
|
|
@doc """
|
|
Rewrite and manipulate nodes in the Axon execution graph.
|
|
|
|
Axon models are represented as a graph of nodes. Working on these nodes
|
|
directly can be difficult and lead to disconnected and invalid graphs.
|
|
In some cases, you simply want to rewrite patterns. This function takes
|
|
an Axon model and traverses the nodes, applying the rewrite `fun` on each
|
|
node to rewrite some or all of the nodes in the Axon model.
|
|
|
|
The rewrite function is an arity-1 function which takes the current Axon node
|
|
as input and returns a function that replaces or rewrites the given node.
|
|
For example, you can define a simple rewriter which replaces the `:relu`
|
|
layers with `:tanh` layers:
|
|
|
|
tanh_rewriter = fn [%Axon{} = x], _output ->
|
|
Axon.relu(x)
|
|
end
|
|
|
|
Axon.rewrite_nodes(model, fn
|
|
%Axon.Node{op: :relu} -> tanh_rewriter
|
|
_ -> :skip
|
|
end)
|
|
|
|
Notice that the rewriter receives all of the original graph inputs *as well as*
|
|
the original graph outputs. This makes certain transformations which may rely
|
|
on both the input and output, such as LoRA, much easier to perform.
|
|
"""
|
|
@doc type: :graph
|
|
def rewrite_nodes(%Axon{output: id, nodes: nodes}, fun) when is_function(fun, 1) do
|
|
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())
|
|
|
|
updated_nodes =
|
|
Enum.reduce(inorder_nodes, nodes, fn
|
|
%{id: original_id, parent: parents} = current_node, nodes ->
|
|
rewriter = fun.(current_node)
|
|
|
|
case rewriter do
|
|
:skip ->
|
|
nodes
|
|
|
|
rewriter when is_function(rewriter, 2) ->
|
|
input_axons = Enum.map(parents, &%Axon{output: &1, nodes: nodes})
|
|
%Axon{output: swapped_id} = placeholder_output = Axon.input("placeholder_output")
|
|
|
|
%Axon{output: new_node_id, nodes: updated_nodes} =
|
|
rewriter.(input_axons, placeholder_output)
|
|
|
|
# now we have to swap the IDs for the rewritten model so that
|
|
# anything that references this node takes the new, rewritten form
|
|
# as an input properly
|
|
original_node = %{updated_nodes[original_id] | id: swapped_id}
|
|
updated_node = %{updated_nodes[new_node_id] | id: original_id}
|
|
|
|
updated_nodes
|
|
|> Map.replace(swapped_id, original_node)
|
|
|> Map.replace(original_id, updated_node)
|
|
end
|
|
end)
|
|
|
|
# if we removed any nodes (like by just using the input instead)
|
|
# then technically we will have extra nodes in the graph, so we
|
|
# can prune them by traversing once again
|
|
{pruned_nodes, _} = traverse_nodes(id, updated_nodes, [], MapSet.new())
|
|
pruned_nodes = Map.new(pruned_nodes, fn %{id: id} = axon_node -> {id, axon_node} end)
|
|
|
|
%Axon{output: id, nodes: pruned_nodes}
|
|
end
|
|
|
|
defp traverse_nodes(id, nodes, acc, visited) do
|
|
if MapSet.member?(visited, id) do
|
|
{acc, visited}
|
|
else
|
|
%{parent: parents} = parent = nodes[id]
|
|
|
|
{acc, visited} =
|
|
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
|
|
traverse_nodes(pid, nodes, acc, visited)
|
|
end)
|
|
|
|
{[parent | acc], MapSet.put(visited, id)}
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Pops the top node off of the graph.
|
|
|
|
This returns the popped node and the updated graph:
|
|
|
|
{_node, model} = Axon.pop_node(model)
|
|
"""
|
|
@doc type: :graph
|
|
def pop_node(%Axon{nodes: nodes, output: id}) do
|
|
{popped, nodes} = Map.pop!(nodes, id)
|
|
|
|
case popped do
|
|
%{op_name: :container, parent: parents, op: fun} = popped ->
|
|
{popped, apply(fun, Enum.map(parents, &%Axon{nodes: nodes, output: &1}) ++ [[]])}
|
|
|
|
%{parent: [_ | _] = parents} = popped ->
|
|
{popped, Enum.map(parents, &%Axon{nodes: nodes, output: &1})}
|
|
|
|
%{parent: [parent_id]} = popped ->
|
|
{popped, %Axon{nodes: nodes, output: parent_id}}
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Builds the given model to `{init_fn, predict_fn}`.
|
|
|
|
The given functions can be either given as arguments to `Nx.Defn`
|
|
functions or be invoked directly, to perform just-in-time compilation
|
|
and execution. If you want to compile the model (instead of just-in-time)
|
|
based on a predefined initialization shape, see `compile/4`.
|
|
|
|
## `init_fn`
|
|
|
|
The `init_fn` receives two arguments, the input template and
|
|
an optional map with initial parameters for layers or namespaces:
|
|
|
|
{init_fn, predict_fn} = Axon.build(model)
|
|
init_fn.(Nx.template({1, 1}, {:f, 32}), %{"dense_0" => dense_params})
|
|
|
|
## `predict_fn`
|
|
|
|
The `predict_fn` receives two arguments, the trained parameters
|
|
and the actual inputs:
|
|
|
|
{_init_fn, predict_fn} = Axon.build(model, opts)
|
|
predict_fn.(params, input)
|
|
|
|
## Options
|
|
|
|
* `:compiler` - the underlying `Nx.Defn` compiler to perform
|
|
JIT compilation when the functions are invoked. If none is
|
|
passed, it uses the default compiler configured in `Nx.Defn`;
|
|
|
|
* `:debug` - if `true`, will log graph traversal and generation
|
|
metrics. Also forwarded to JIT if debug mode is available
|
|
for your chosen compiler or backend. Defaults to `false`
|
|
|
|
* `:print_values` - if `true`, will print intermediate layer
|
|
values to the screen for inspection. This is useful if you need
|
|
to debug intermediate values of a model
|
|
|
|
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
|
|
to control differences in compilation at training or inference time.
|
|
Defaults to `:inference`
|
|
|
|
* `:global_layer_options` - a keyword list of options passed to
|
|
layers that accept said options
|
|
|
|
All other options are forwarded to the underlying JIT compiler.
|
|
"""
|
|
@doc type: :model
|
|
def build(model, opts \\ []) when is_list(opts) do
|
|
if opts[:backend] do
|
|
IO.warn(
|
|
"the :backend option has no effect on Axon.build/2. " <>
|
|
"Use Nx.default_backend/1 to set a backend instead"
|
|
)
|
|
end
|
|
|
|
{init_fn, predict_fn} = Axon.Compiler.build(model, opts)
|
|
opts = [on_conflict: :reuse] ++ opts
|
|
{Nx.Defn.jit(init_fn, opts), Nx.Defn.jit(predict_fn, opts)}
|
|
end
|
|
|
|
@doc """
|
|
Compiles the given model to `{init_params, predict_fn}`.
|
|
|
|
This function will compile a model specialized to the given
|
|
input shapes and types. This is useful for avoiding the overhead
|
|
of long compilations at program runtime. You must provide template
|
|
inputs which match the expected shapes and types of inputs at
|
|
execution time. Depending on the Nx compiler, such as EXLA v0.9.1+,
|
|
both `init_params` the `predict_fn` can be sent across nodes, as
|
|
long the node that owns them keeps a reference to the underlying
|
|
resources.
|
|
|
|
This function makes use of the built-in `Nx.Defn.compile/3`. Note
|
|
that passing inputs which differ in shape or type from the templates
|
|
provided to this function will result in a crash.
|
|
|
|
## Options
|
|
|
|
It accepts the same options as `build/2`.
|
|
"""
|
|
@doc type: :model
|
|
def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ [])
|
|
when is_list(opts) do
|
|
{init_fn, predict_fn} = build(model, opts)
|
|
model_state = Axon.ModelState.new(init_params)
|
|
|
|
# If there is a disk cache, we only want it to apply to the predict function
|
|
init_opts = if is_binary(opts[:cache]), do: Keyword.delete(opts, :cache), else: opts
|
|
init_params = Nx.Defn.jit_apply(init_fn, [template, model_state], init_opts)
|
|
|
|
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
|
|
{init_params, predict_compiled_fn}
|
|
end
|
|
|
|
@doc """
|
|
Compiles and returns the given model's init function
|
|
expression with the given options.
|
|
|
|
The returned expression is an Nx expression which can be
|
|
traversed and lowered to an IR or inspected for debugging
|
|
purposes.
|
|
|
|
You may optionally specify initial parameters for some layers or
|
|
namespaces by passing a partial parameter map:
|
|
|
|
Axon.trace_init(model, %{"dense_0" => dense_params})
|
|
|
|
The parameter map will be merged with the initialized model
|
|
parameters.
|
|
|
|
## Options
|
|
|
|
* `:debug` - if `true`, will log graph traversal and generation
|
|
metrics. Also forwarded to JIT if debug mode is available
|
|
for your chosen compiler or backend. Defaults to `false`
|
|
|
|
"""
|
|
@doc type: :debug
|
|
def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do
|
|
{init_fn, _} = build(model, opts)
|
|
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, Axon.ModelState.new(params))
|
|
end
|
|
|
|
@doc """
|
|
Compiles and returns the given model's forward function
|
|
expression with the given options.
|
|
|
|
The returned expression is an Nx expression which can be
|
|
traversed and lowered to an IR or inspected for debugging
|
|
purposes.
|
|
|
|
## Options
|
|
|
|
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
|
|
to control differences in compilation at training or inference time.
|
|
Defaults to `:inference`
|
|
|
|
* `:debug` - if `true`, will log graph traversal and generation
|
|
metrics. Also forwarded to JIT if debug mode is available
|
|
for your chosen compiler or backend. Defaults to `false`
|
|
|
|
"""
|
|
@doc type: :debug
|
|
def trace_forward(model, inputs, params, opts \\ []) when is_list(opts) do
|
|
{_, forward_fun} = build(model, opts)
|
|
Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)
|
|
end
|
|
|
|
@doc """
|
|
Compiles and returns the given model's backward function
|
|
expression with respect to the given loss function.
|
|
|
|
The returned expression is an Nx expression which can be
|
|
traversed and lowered to an IR or inspected for debugging
|
|
purposes.
|
|
|
|
The given loss function must be a scalar loss function which
|
|
expects inputs and targets with the same shapes as the model's
|
|
output shapes as determined by the model's signature.
|
|
|
|
## Options
|
|
|
|
* `:debug` - if `true`, will log graph traversal and generation
|
|
metrics. Also forwarded to JIT if debug mode is available
|
|
for your chosen compiler or backend. Defaults to `false`
|
|
|
|
"""
|
|
@doc type: :debug
|
|
def trace_backward(model, inputs, params, loss, opts \\ []) do
|
|
{_, forward_fn} = build(model, opts ++ [mode: :train])
|
|
|
|
backward_fn = fn params, inputs, targets ->
|
|
Nx.Defn.grad(params, fn params ->
|
|
%{prediction: preds} = forward_fn.(params, inputs)
|
|
loss.(targets, preds)
|
|
end)
|
|
end
|
|
|
|
%{prediction: outputs} =
|
|
Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)
|
|
|
|
inputs = [params, inputs, outputs]
|
|
|
|
apply(Nx.Defn.jit(backward_fn, compiler: Axon.Defn), inputs)
|
|
end
|
|
|
|
@doc false
|
|
@deprecated "Use Axon.build/2 instead"
|
|
def init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do
|
|
{init_fn, _predict_fn} = build(model, opts)
|
|
init_fn.(template, Axon.ModelState.new(params))
|
|
end
|
|
|
|
@doc """
|
|
Builds and runs the given Axon `model` with `params` and `input`.
|
|
|
|
This is equivalent to calling `build/2` and then invoking the
|
|
predict function.
|
|
|
|
## Options
|
|
|
|
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
|
|
to control differences in compilation at training or inference time.
|
|
Defaults to `:inference`
|
|
|
|
* `:debug` - if `true`, will log graph traversal and generation
|
|
metrics. Also forwarded to JIT if debug mode is available
|
|
for your chosen compiler or backend. Defaults to `false`
|
|
|
|
All other options are forwarded to the default JIT compiler
|
|
or backend.
|
|
"""
|
|
@doc type: :model
|
|
def predict(%Axon{} = model, params, input, opts \\ []) when is_list(opts) do
|
|
{_init_fn, predict_fn} = build(model, opts)
|
|
predict_fn.(Axon.ModelState.new(params), input)
|
|
end
|
|
|
|
## Inspection
|
|
|
|
defimpl Inspect do
|
|
import Inspect.Algebra
|
|
|
|
def inspect(%Axon{output: id, nodes: nodes} = axon, opts) do
|
|
inputs =
|
|
axon
|
|
|> Axon.get_inputs()
|
|
|> Enum.sort()
|
|
|> Map.new()
|
|
|
|
op_counts = Axon.get_op_counts(axon)
|
|
%Axon.Node{op_name: op_name, name: name_fn} = nodes[id]
|
|
op_counts = Map.update(op_counts, op_name, 0, fn x -> x - 1 end)
|
|
output_name = name_fn.(op_name, op_counts)
|
|
|
|
node_count = Enum.count(axon.nodes)
|
|
|
|
inner =
|
|
concat([
|
|
line(),
|
|
"inputs: #{inspect(inputs)}",
|
|
line(),
|
|
"outputs: #{inspect(output_name)}",
|
|
line(),
|
|
"nodes: #{inspect(node_count)}"
|
|
])
|
|
|
|
force_unfit(
|
|
concat([
|
|
color("#Axon<", :map, opts),
|
|
nest(inner, 2),
|
|
line(),
|
|
color(">", :map, opts)
|
|
])
|
|
)
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Returns a mapping of layer names to layer properties.
|
|
"""
|
|
def properties(%Axon{output: id, nodes: nodes}) do
|
|
{_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}})
|
|
properties
|
|
end
|
|
|
|
defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do
|
|
case cache do
|
|
%{^id => _} ->
|
|
{cache, op_counts, properties}
|
|
|
|
%{} ->
|
|
%Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id]
|
|
|
|
{cache, op_counts, properties} =
|
|
Enum.reduce(parents, acc, &node_properties(&1, nodes, &2))
|
|
|
|
name = name_fn.(op_name, op_counts)
|
|
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
|
|
properties = Map.put(properties, name, op_name)
|
|
|
|
{Map.put(cache, id, name), op_counts, properties}
|
|
end
|
|
end
|
|
|
|
## Helpers
|
|
|
|
@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++
|
|
[:lecun_uniform, :lecun_normal, :he_uniform, :he_normal] ++
|
|
[:glorot_uniform, :glorot_normal, :variance_scaling]
|
|
|
|
defp validate_initializer!(initializer)
|
|
when is_atom(initializer) and initializer in @valid_initializers do
|
|
apply(Axon.Initializers, initializer, [])
|
|
end
|
|
|
|
defp validate_initializer!(initializer) when is_function(initializer, 2) do
|
|
initializer
|
|
end
|
|
|
|
defp validate_initializer!(initializer) when is_function(initializer, 3) do
|
|
initializer
|
|
end
|
|
|
|
defp validate_initializer!(initializer) do
|
|
raise ArgumentError,
|
|
"initializer must be one of #{inspect(@valid_initializers)}," <>
|
|
" or an arity-3 function accepting initializer shape, type, and key" <>
|
|
" got #{inspect(initializer)}"
|
|
end
|
|
|
|
# Names are generated lazily at inspect, initialization, and compile
|
|
# time, so for name we return a function which takes `op` and `op_count`
|
|
# and returns a unique name for the given model.
|
|
defp name(type, nil) do
|
|
fn op, op_counts ->
|
|
count = op_counts[op] || 0
|
|
Atom.to_string(type) <> "_#{count}"
|
|
end
|
|
end
|
|
|
|
defp name(_type, name_fn) when is_function(name_fn, 2) do
|
|
name_fn
|
|
end
|
|
|
|
defp name(_type, name) when is_binary(name) do
|
|
fn _, _ -> name end
|
|
end
|
|
|
|
defp name(_type, name) do
|
|
raise ArgumentError,
|
|
"expected layer name to be a binary, a function or nil, " <>
|
|
"got: #{inspect(name)}"
|
|
end
|
|
end
|