338 lines
8.7 KiB
Elixir
338 lines
8.7 KiB
Elixir
defmodule XLA do
|
|
@moduledoc """
|
|
API for accessing precompiled XLA archives.
|
|
"""
|
|
|
|
require Logger
|
|
|
|
@version Mix.Project.config()[:version]
|
|
|
|
@base_url "https://github.com/elixir-nx/xla/releases/download/v#{@version}"
|
|
|
|
@precompiled_targets [
|
|
"x86_64-darwin-cpu",
|
|
"aarch64-darwin-cpu",
|
|
"x86_64-linux-gnu-cpu",
|
|
"aarch64-linux-gnu-cpu",
|
|
"x86_64-linux-gnu-cuda12",
|
|
"aarch64-linux-gnu-cuda12",
|
|
"x86_64-linux-gnu-tpu"
|
|
]
|
|
|
|
@supported_xla_targets ["cpu", "cuda", "rocm", "tpu", "cuda12"]
|
|
|
|
@doc """
|
|
Returns path to the precompiled XLA archive.
|
|
|
|
Depending on the environment variables configuration,
|
|
the path will point to either built or downloaded file.
|
|
If not found locally, the file is downloaded when calling
|
|
this function.
|
|
"""
|
|
@spec archive_path!() :: Path.t()
|
|
def archive_path!() do
|
|
XLA.Utils.start_inets_profile()
|
|
|
|
cond do
|
|
build?() ->
|
|
# The archive should have already been built by this point
|
|
archive_path_for_build()
|
|
|
|
url = xla_archive_url() ->
|
|
path = archive_path_for_external_download(url)
|
|
unless File.exists?(path), do: download_external!(url, path)
|
|
path
|
|
|
|
true ->
|
|
path = archive_path_for_precompiled_download()
|
|
unless File.exists?(path), do: download_precompiled!(path)
|
|
path
|
|
end
|
|
after
|
|
XLA.Utils.stop_inets_profile()
|
|
end
|
|
|
|
defp build?() do
|
|
System.get_env("XLA_BUILD") in ~w(1 true)
|
|
end
|
|
|
|
defp xla_archive_url() do
|
|
System.get_env("XLA_ARCHIVE_URL")
|
|
end
|
|
|
|
defp xla_target() do
|
|
target = System.get_env("XLA_TARGET") || infer_xla_target() || "cpu"
|
|
|
|
supported_xla_targets = @supported_xla_targets
|
|
|
|
unless target in supported_xla_targets do
|
|
listing = supported_xla_targets |> Enum.map(&inspect/1) |> Enum.join(", ")
|
|
raise "expected XLA_TARGET to be one of #{listing}, but got: #{inspect(target)}"
|
|
end
|
|
|
|
target
|
|
end
|
|
|
|
defp infer_xla_target() do
|
|
with nvcc when nvcc != nil <- System.find_executable("nvcc"),
|
|
{output, 0} <- System.cmd(nvcc, ["--version"]) do
|
|
if output =~ "release 12.", do: "cuda12"
|
|
else
|
|
_ -> nil
|
|
end
|
|
end
|
|
|
|
defp xla_cache_dir() do
|
|
# The directory where we store all the archives
|
|
if dir = System.get_env("XLA_CACHE_DIR") do
|
|
Path.expand(dir)
|
|
else
|
|
:filename.basedir(:user_cache, "xla")
|
|
end
|
|
end
|
|
|
|
defp target() do
|
|
case target_triplet() do
|
|
{arch, os, nil} -> "#{arch}-#{os}-#{xla_target()}"
|
|
{arch, os, abi} -> "#{arch}-#{os}-#{abi}-#{xla_target()}"
|
|
end
|
|
end
|
|
|
|
defp target_triplet() do
|
|
if target = System.get_env("XLA_TARGET_PLATFORM") do
|
|
case String.split(target, "-") do
|
|
[arch, os, abi] ->
|
|
{arch, os, abi}
|
|
|
|
[arch, os] ->
|
|
{arch, os, nil}
|
|
|
|
other ->
|
|
raise "expected XLA_TARGET_PLATFORM to be either ARCHITECTURE-OS-ABI or ARCHITECTURE-OS, got: #{other}"
|
|
end
|
|
else
|
|
:erlang.system_info(:system_architecture)
|
|
|> List.to_string()
|
|
|> String.split("-")
|
|
|> case do
|
|
["arm" <> _, _vendor, "darwin" <> _ | _] -> {"aarch64", "darwin", nil}
|
|
[arch, _vendor, "darwin" <> _ | _] -> {arch, "darwin", nil}
|
|
[arch, _vendor, os, abi] -> {arch, os, abi}
|
|
[arch, _vendor, os] -> {arch, os, nil}
|
|
["win32"] -> {"x86_64", "windows", nil}
|
|
end
|
|
end
|
|
end
|
|
|
|
defp archive_path_for_build() do
|
|
filename = archive_filename(target())
|
|
cache_path(["build", filename])
|
|
end
|
|
|
|
defp archive_path_for_external_download(url) do
|
|
hash = url |> :erlang.md5() |> Base.encode32(case: :lower, padding: false)
|
|
filename = "xla_extension-#{hash}.tar.gz"
|
|
cache_path(["external", filename])
|
|
end
|
|
|
|
defp archive_path_for_precompiled_download() do
|
|
filename = archive_filename(target())
|
|
cache_path(["download", filename])
|
|
end
|
|
|
|
defp archive_filename(target) do
|
|
"xla_extension-#{@version}-#{target}.tar.gz"
|
|
end
|
|
|
|
defp cache_path(parts) do
|
|
base_dir = xla_cache_dir()
|
|
Path.join([base_dir, @version | parts])
|
|
end
|
|
|
|
defp download_external!(url, archive_path) do
|
|
Logger.info("Downloading XLA archive from #{url}")
|
|
|
|
case download_archive(url, archive_path) do
|
|
:ok ->
|
|
Logger.info("Successfully downloaded the XLA archive")
|
|
|
|
{:error, message} ->
|
|
File.rm(archive_path)
|
|
raise message
|
|
end
|
|
end
|
|
|
|
defp download_precompiled!(archive_path) do
|
|
expected_filename = Path.basename(archive_path)
|
|
|
|
target = target()
|
|
precompiled_targets = precompiled_targets()
|
|
|
|
if target not in precompiled_targets do
|
|
listing = Enum.map_join(precompiled_targets, "\n", &(" * " <> &1))
|
|
|
|
raise """
|
|
no precompiled XLA archive available for this target: #{target}.
|
|
|
|
The available targets are:
|
|
|
|
#{listing}
|
|
|
|
You can compile XLA locally by setting an environment variable: XLA_BUILD=true\
|
|
"""
|
|
end
|
|
|
|
Logger.info("Downloading a precompiled XLA archive for target #{target}")
|
|
|
|
url = release_file_url(expected_filename)
|
|
|
|
with :ok <- download_archive(url, archive_path),
|
|
:ok <- verify_integrity(archive_path) do
|
|
Logger.info("Successfully downloaded the XLA archive")
|
|
else
|
|
{:error, message} ->
|
|
File.rm(archive_path)
|
|
raise message
|
|
end
|
|
end
|
|
|
|
defp release_file_url(filename) do
|
|
@base_url <> "/" <> filename
|
|
end
|
|
|
|
defp download_archive(url, archive_path) do
|
|
File.mkdir_p!(Path.dirname(archive_path))
|
|
|
|
file = File.stream!(archive_path)
|
|
|
|
case XLA.Utils.download(url, file) do
|
|
{:ok, _file} ->
|
|
:ok
|
|
|
|
{:error, message} ->
|
|
{:error, "failed to download the XLA archive from #{url}, reason: #{message}"}
|
|
end
|
|
end
|
|
|
|
defp verify_integrity(path) do
|
|
filename = Path.basename(path)
|
|
checksum = compute_file_checksum!(path)
|
|
|
|
case read_checksums!() do
|
|
%{^filename => ^checksum} ->
|
|
:ok
|
|
|
|
%{^filename => _} ->
|
|
{:error, "the integrity check failed for file #{filename}, the checksum does not match"}
|
|
|
|
%{} ->
|
|
{:error, "no entry for file #{filename} in the checksum file"}
|
|
end
|
|
end
|
|
|
|
@doc false
|
|
def write_checksums!(%{} = checksums) do
|
|
content =
|
|
checksums
|
|
|> Enum.sort()
|
|
|> Enum.map_join("", fn {filename, checksum} ->
|
|
checksum <> " " <> filename <> "\n"
|
|
end)
|
|
|
|
File.write!(checksum_path(), content)
|
|
end
|
|
|
|
defp read_checksums!() do
|
|
content = File.read!(checksum_path())
|
|
|
|
for line <- String.split(content, "\n", trim: true), into: %{} do
|
|
[checksum, filename] = String.split(line, " ")
|
|
{filename, checksum}
|
|
end
|
|
end
|
|
|
|
defp compute_file_checksum!(path) do
|
|
path
|
|
|> File.stream!([], 64_000)
|
|
|> Enum.into(%XLA.Checksumer{})
|
|
end
|
|
|
|
defp checksum_path() do
|
|
# Note that this path points to the project source, which normally
|
|
# may not be available at runtime (in releases). However, we expect
|
|
# XLA to be called only during compilation, in which case this path
|
|
# is still available
|
|
Path.expand("../checksum.txt", __DIR__)
|
|
end
|
|
|
|
defp precompiled_targets(), do: @precompiled_targets
|
|
|
|
# Used by tasks
|
|
|
|
@doc false
|
|
def build_archive_dir() do
|
|
Path.dirname(archive_path_for_build())
|
|
end
|
|
|
|
@doc false
|
|
def version(), do: @version
|
|
|
|
@doc false
|
|
def archive_filename_with_target() do
|
|
archive_filename(target())
|
|
end
|
|
|
|
@doc false
|
|
def precompiled_files() do
|
|
for target <- @precompiled_targets do
|
|
filename = archive_filename(target)
|
|
url = release_file_url(filename)
|
|
{filename, url}
|
|
end
|
|
end
|
|
|
|
# Configuration for elixir_make
|
|
|
|
@doc false
|
|
def make_env() do
|
|
bazel_build_flags_accelerator =
|
|
case xla_target() do
|
|
"cuda" <> _ ->
|
|
[
|
|
# See https://github.com/google/jax/blob/66a92c41f6bac74960159645158e8d932ca56613/.bazelrc#L68
|
|
~s/--config=cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"/
|
|
]
|
|
|
|
"rocm" <> _ ->
|
|
[
|
|
"--config=rocm",
|
|
"--action_env=HIP_PLATFORM=hcc",
|
|
# See https://github.com/google/jax/blob/66a92c41f6bac74960159645158e8d932ca56613/.bazelrc#L128
|
|
~s/--action_env=TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030,gfx1100"/
|
|
]
|
|
|
|
"tpu" <> _ ->
|
|
["--config=tpu"]
|
|
|
|
_ ->
|
|
[]
|
|
end
|
|
|
|
bazel_build_flags_cpu =
|
|
case target_triplet() do
|
|
{"aarch64", "darwin", _} -> ["--config=macos_arm64"]
|
|
_ -> []
|
|
end
|
|
|
|
bazel_build_flags = Enum.join(bazel_build_flags_accelerator ++ bazel_build_flags_cpu, " ")
|
|
|
|
# Additional environment variables passed to make
|
|
%{
|
|
"BUILD_INTERNAL_FLAGS" => bazel_build_flags,
|
|
"ROOT_DIR" => Path.expand("..", __DIR__),
|
|
"BUILD_ARCHIVE" => archive_path_for_build(),
|
|
"BUILD_ARCHIVE_DIR" => build_archive_dir()
|
|
}
|
|
end
|
|
end
|