defmodule Bumblebee do @external_resource "README.md" [_, readme_docs, _] = "README.md" |> File.read!() |> String.split("") @moduledoc """ Pre-trained `Axon` models for easy inference and boosted training. Bumblebee provides state-of-the-art, configurable `Axon` models. On top of that, it streamlines the process of loading pre-trained models by integrating with Hugging Face Hub and [🤗 Transformers](https://github.com/huggingface/transformers). ## Usage You can load one of the supported models by specifying the model repository: {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) Then you are ready to make predictions: inputs = Bumblebee.apply_tokenizer(tokenizer, "Hello Bumblebee!") outputs = Axon.predict(model_info.model, model_info.params, inputs) ### Tasks On top of bare models, Bumblebee provides a number of **"servings"** that act as end-to-end pipelines for specific tasks. serving = Bumblebee.Text.fill_mask(model_info, tokenizer) Nx.Serving.run(serving, "The capital of [MASK] is Paris.") #=> %{ #=> predictions: [ #=> %{score: 0.9279842972755432, token: "france"}, #=> %{score: 0.008412551134824753, token: "brittany"}, #=> %{score: 0.007433671969920397, token: "algeria"}, #=> %{score: 0.004957548808306456, token: "department"}, #=> %{score: 0.004369721747934818, token: "reunion"} #=> ] #=> } As you can see the **serving** takes care of pre-processing the text input, runs the model and also post-processes its output into more structured data. In the above example we `run` serving on the fly, however for production usage you can start serving as a process and it will automatically batch requests from multiple clients. Processing inputs in batches is usually much more efficient, since it can take advantage of parallel capabilities of the target device, which is particularly relevant in case of GPU. For more details read the `Nx.Serving` docs. For more examples see the [Examples](examples.livemd) notebook. > #### Note {: .info} > > The models are generally large, so make sure to configure an efficient > `Nx` backend, such as `EXLA` or `Torchx`. #{readme_docs} """ alias Bumblebee.HuggingFace @config_filename "config.json" @featurizer_filename "preprocessor_config.json" @tokenizer_filename "tokenizer.json" @tokenizer_config_filename "tokenizer_config.json" @tokenizer_special_tokens_filename "special_tokens_map.json" @generation_filename "generation_config.json" @scheduler_filename "scheduler_config.json" @params_filenames [ "pytorch_model.bin", "diffusion_pytorch_model.bin", "model.safetensors", "diffusion_pytorch_model.safetensors" ] @transformers_class_to_model %{ "AlbertForMaskedLM" => {Bumblebee.Text.Albert, :for_masked_language_modeling}, "AlbertForMultipleChoice" => {Bumblebee.Text.Albert, :for_multiple_choice}, "AlbertForPreTraining" => {Bumblebee.Text.Albert, :for_pre_training}, "AlbertForQuestionAnswering" => {Bumblebee.Text.Albert, :for_question_answering}, "AlbertForSequenceClassification" => {Bumblebee.Text.Albert, :for_sequence_classification}, "AlbertForTokenClassification" => {Bumblebee.Text.Albert, :for_token_classification}, "AlbertModel" => {Bumblebee.Text.Albert, :base}, "BartForCausalLM" => {Bumblebee.Text.Bart, :for_causal_language_modeling}, "BartForConditionalGeneration" => {Bumblebee.Text.Bart, :for_conditional_generation}, "BartForQuestionAnswering" => {Bumblebee.Text.Bart, :for_question_answering}, "BartForSequenceClassification" => {Bumblebee.Text.Bart, :for_sequence_classification}, "BartModel" => {Bumblebee.Text.Bart, :base}, "BertForMaskedLM" => {Bumblebee.Text.Bert, :for_masked_language_modeling}, "BertForMultipleChoice" => {Bumblebee.Text.Bert, :for_multiple_choice}, "BertForNextSentencePrediction" => {Bumblebee.Text.Bert, :for_next_sentence_prediction}, "BertForPreTraining" => {Bumblebee.Text.Bert, :for_pre_training}, "BertForQuestionAnswering" => {Bumblebee.Text.Bert, :for_question_answering}, "BertForSequenceClassification" => {Bumblebee.Text.Bert, :for_sequence_classification}, "BertForTokenClassification" => {Bumblebee.Text.Bert, :for_token_classification}, "BertLMHeadModel" => {Bumblebee.Text.Bert, :for_causal_language_modeling}, "BertModel" => {Bumblebee.Text.Bert, :base}, "BlenderbotForConditionalGeneration" => {Bumblebee.Text.Blenderbot, :for_conditional_generation}, "BlenderbotModel" => {Bumblebee.Text.Blenderbot, :base}, "BlipForConditionalGeneration" => {Bumblebee.Multimodal.Blip, :for_conditional_generation}, # These models are just RoBERTa models, but the config will list them as CamemBERT "CamembertModel" => {Bumblebee.Text.Roberta, :base}, "CamembertForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, "CamembertForSequenceClassification" => {Bumblebee.Text.Roberta, :for_sequence_classification}, "CamembertForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice}, "CamembertForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "CamembertForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering}, "CLIPModel" => {Bumblebee.Multimodal.Clip, :base}, "CLIPTextModel" => {Bumblebee.Text.ClipText, :base}, "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base}, "ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base}, "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification}, "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base}, "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification}, "DeiTForImageClassificationWithTeacher" => {Bumblebee.Vision.Deit, :for_image_classification_with_teacher}, "DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling}, "DeiTModel" => {Bumblebee.Vision.Deit, :base}, "Dinov2Model" => {Bumblebee.Vision.DinoV2, :base}, "Dinov2Backbone" => {Bumblebee.Vision.DinoV2, :backbone}, "Dinov2ForImageClassification" => {Bumblebee.Vision.DinoV2, :for_image_classification}, "DistilBertModel" => {Bumblebee.Text.Distilbert, :base}, "DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling}, "DistilBertForSequenceClassification" => {Bumblebee.Text.Distilbert, :for_sequence_classification}, "DistilBertForQuestionAnswering" => {Bumblebee.Text.Distilbert, :for_question_answering}, "DistilBertForTokenClassification" => {Bumblebee.Text.Distilbert, :for_token_classification}, "DistilBertForMultipleChoice" => {Bumblebee.Text.Distilbert, :for_multiple_choice}, "GemmaModel" => {Bumblebee.Text.Gemma, :base}, "GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling}, "GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, "GPT2Model" => {Bumblebee.Text.Gpt2, :base}, "GPTBigCodeModel" => {Bumblebee.Text.GptBigCode, :base}, "GPTBigCodeForCausalLM" => {Bumblebee.Text.GptBigCode, :for_causal_language_modeling}, "GPTBigCodeForSequenceClassification" => {Bumblebee.Text.GptBigCode, :for_sequence_classification}, "GPTBigCodeForTokenClassification" => {Bumblebee.Text.GptBigCode, :for_token_classification}, "GPTNeoXModel" => {Bumblebee.Text.GptNeoX, :base}, "GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling}, "GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification}, "GPTNeoXForTokenClassification" => {Bumblebee.Text.GptNeoX, :for_token_classification}, "LayoutLMForMaskedLM" => {Bumblebee.Multimodal.LayoutLm, :for_masked_language_modeling}, "LayoutLMForQuestionAnswering" => {Bumblebee.Multimodal.LayoutLm, :for_question_answering}, "LayoutLMForSequenceClassification" => {Bumblebee.Multimodal.LayoutLm, :for_sequence_classification}, "LayoutLMForTokenClassification" => {Bumblebee.Multimodal.LayoutLm, :for_token_classification}, "LayoutLMModel" => {Bumblebee.Multimodal.LayoutLm, :base}, "LlamaModel" => {Bumblebee.Text.Llama, :base}, "LlamaForCausalLM" => {Bumblebee.Text.Llama, :for_causal_language_modeling}, "LlamaForSequenceClassification" => {Bumblebee.Text.Llama, :for_sequence_classification}, "MBartForCausalLM" => {Bumblebee.Text.Mbart, :for_causal_language_modeling}, "MBartForConditionalGeneration" => {Bumblebee.Text.Mbart, :for_conditional_generation}, "MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering}, "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification}, "MBartModel" => {Bumblebee.Text.Mbart, :base}, "M2M100Model" => {Bumblebee.Text.M2m100, :base}, "M2M100ForConditionalGeneration" => {Bumblebee.Text.M2m100, :for_conditional_generation}, "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, "PhiForTokenClassification" => {Bumblebee.Text.Phi, :for_token_classification}, "Phi3Model" => {Bumblebee.Text.Phi3, :base}, "Phi3ForCausalLM" => {Bumblebee.Text.Phi3, :for_causal_language_modeling}, "Phi3ForSequenceClassification" => {Bumblebee.Text.Phi3, :for_sequence_classification}, "Phi3ForTokenClassification" => {Bumblebee.Text.Phi3, :for_token_classification}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, "RobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice}, "RobertaForPreTraining" => {Bumblebee.Text.Roberta, :for_pre_training}, "RobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering}, "RobertaForSequenceClassification" => {Bumblebee.Text.Roberta, :for_sequence_classification}, "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, "SwinModel" => {Bumblebee.Vision.Swin, :base}, "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation}, "T5EncoderModel" => {Bumblebee.Text.T5, :encoder}, "ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification}, "ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling}, "ViTModel" => {Bumblebee.Vision.Vit, :base}, "WhisperModel" => {Bumblebee.Audio.Whisper, :base}, "WhisperForConditionalGeneration" => {Bumblebee.Audio.Whisper, :for_conditional_generation}, # These models are just RoBERTa models, but the config will list them as XLM-RoBERTa "XLMRobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "XLMRobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, "XLMRobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice}, "XLMRobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering}, "XLMRobertaForSequenceClassification" => {Bumblebee.Text.Roberta, :for_sequence_classification}, "XLMRobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "XLMRobertaModel" => {Bumblebee.Text.Roberta, :base}, # Diffusers "AutoencoderKL" => {Bumblebee.Diffusion.VaeKl, :base}, "StableDiffusionSafetyChecker" => {Bumblebee.Diffusion.StableDiffusion.SafetyChecker, :base}, "UNet2DConditionModel" => {Bumblebee.Diffusion.UNet2DConditional, :base} } @transformers_class_to_featurizer %{ "CLIPFeatureExtractor" => Bumblebee.Vision.ClipFeaturizer, "ConvNextFeatureExtractor" => Bumblebee.Vision.ConvNextFeaturizer, "DeiTFeatureExtractor" => Bumblebee.Vision.DeitFeaturizer, "ViTFeatureExtractor" => Bumblebee.Vision.VitFeaturizer, "WhisperFeatureExtractor" => Bumblebee.Audio.WhisperFeaturizer } @transformers_image_processor_type_to_featurizer %{ "BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer, "BitImageProcessor" => Bumblebee.Vision.BitFeaturizer } @model_type_to_featurizer %{ "convnext" => Bumblebee.Vision.ConvNextFeaturizer, "deit" => Bumblebee.Vision.DeitFeaturizer, "resnet" => Bumblebee.Vision.ConvNextFeaturizer, "vit" => Bumblebee.Vision.VitFeaturizer, "whisper" => Bumblebee.Audio.WhisperFeaturizer } @model_type_to_tokenizer_type %{ "albert" => :albert, "bart" => :bart, "bert" => :bert, "blenderbot" => :blenderbot, "blip" => :bert, "distilbert" => :distilbert, "camembert" => :camembert, "clip" => :clip, "gemma" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, "gpt_bigcode" => :gpt2, "layoutlm" => :layout_lm, "llama" => :llama, "mistral" => :llama, "mbart" => :mbart, "phi" => :code_gen, "phi3" => :llama, "roberta" => :roberta, "t5" => :t5, "whisper" => :whisper, "xlm-roberta" => :xlm_roberta, # Both M2M100 and NLLB model checkpoints use the M2M100 model, # but have distinct tokenizers. Consequently, model type is # "m2m_100" in both cases. Currently only NLLB has fast tokenizer # implementation, so if we load the tokenizer correctly, it is # safe to assume it's NLLB "m2m_100" => :nllb } @diffusers_class_to_scheduler %{ "DDIMScheduler" => Bumblebee.Diffusion.DdimScheduler, "LCMScheduler" => Bumblebee.Diffusion.LcmScheduler, "PNDMScheduler" => Bumblebee.Diffusion.PndmScheduler } @typedoc """ A location to fetch model files from. Can be either: * `{:hf, repository_id}` - the repository on Hugging Face. Options may be passed as the third element: * `:revision` - the specific model version to use, it can be any valid git identifier, such as branch name, tag name, or a commit hash * `:cache_dir` - the directory to store the downloaded files in. Defaults to the standard cache location for the given operating system. You can also configure it globally by setting the `BUMBLEBEE_CACHE_DIR` environment variable * `:offline` - if `true`, only cached files are accessed and missing files result in an error. You can also configure it globally by setting the `BUMBLEBEE_OFFLINE` environment variable to `true` * `:auth_token` - the token to use as HTTP bearer authorization for remote files * `:subdir` - the directory within the repository where the files are located * `{:local, directory}` - the directory containing model files """ @type repository :: {:hf, String.t()} | {:hf, String.t(), keyword()} | {:local, Path.t()} @typedoc """ A model together with its state and metadata. """ @type model_info :: %{ model: Axon.t(), params: %Axon.ModelState{}, spec: Bumblebee.ModelSpec.t() } @doc """ Builds or updates a configuration object with the given options. Expects a configuration struct or a module supporting configuration. These are usually configurable: * model specification (`Bumblebee.ModelSpec`) * featurizer (`Bumblebee.Featurizer`) * scheduler (`Bumblebee.Scheduler`) ## Examples To build a new configuration, pass a module: featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer) spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :for_image_classification) Similarly, you can update an existing configuration: featurizer = Bumblebee.configure(featurizer, resize_method: :bilinear) spec = Bumblebee.configure(spec, embedding_size: 128) """ @spec configure(module() | Bumblebee.Configurable.t(), keyword()) :: Bumblebee.Configurable.t() def configure(config, options \\ []) do %module{} = config = struct!(config) module.config(config, options) end @doc """ Builds an `Axon` model according to the given specification. ## Options * `:type` - either a type or `Axon.MixedPrecision` policy to apply to the model ## Example spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128) model = Bumblebee.build_model(spec) """ @doc type: :model @spec build_model(Bumblebee.ModelSpec.t(), keyword()) :: Axon.t() def build_model(%module{} = spec, opts \\ []) do opts = Keyword.validate!(opts, [:type]) model = module.model(spec) case opts[:type] do nil -> model %Axon.MixedPrecision.Policy{} = policy -> Axon.MixedPrecision.apply_policy(model, policy) type -> type = Nx.Type.normalize!(type) policy = Axon.MixedPrecision.create_policy(params: type, compute: type, output: type) Axon.MixedPrecision.apply_policy(model, policy) end end @doc """ Loads model specification from a model repository. ## Options * `:module` - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly * `:architecture` - the model architecture, must be supported by `:module`. By default it is inferred from the configuration file ## Examples {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}) You can explicitly specify a different architecture: {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}, architecture: :base) """ @doc type: :model @spec load_spec(repository(), keyword()) :: {:ok, Bumblebee.ModelSpec.t()} | {:error, String.t()} def load_spec(repository, opts \\ []) do repository = normalize_repository!(repository) opts = Keyword.validate!(opts, [:module, :architecture]) module = opts[:module] architecture = opts[:architecture] with {:ok, repo_files} <- get_repo_files(repository) do do_load_spec(repository, repo_files, module, architecture) end end defp do_load_spec(repository, repo_files, module, architecture) do case repo_files do %{@config_filename => etag} -> with {:ok, path} <- download(repository, @config_filename, etag), {:ok, spec_data} <- decode_config(path) do {inferred_module, inferred_architecture, inference_error} = case infer_model_type(spec_data) do {:ok, module, architecture} -> {module, architecture, nil} {:error, error} -> {nil, nil, error} end module = module || inferred_module architecture = architecture || inferred_architecture unless module do raise ArgumentError, "#{inference_error}, please specify the :module and :architecture options" end architectures = module.architectures() if architecture && architecture not in architectures do raise ArgumentError, "expected architecture to be one of: #{Enum.map_join(architectures, ", ", &inspect/1)}, but got: #{inspect(architecture)}" end spec = if architecture do configure(module, architecture: architecture) else configure(module) end spec = HuggingFace.Transformers.Config.load(spec, spec_data) {:ok, spec} end %{} -> raise ArgumentError, "no config file found in the given repository. Please refer to Bumblebee" <> " README to learn about repositories and supported models" end end defp decode_config(path) do path |> File.read!() |> Jason.decode() |> case do {:ok, data} -> {:ok, data} _ -> {:error, "failed to parse the config file, it is not a valid JSON"} end end defp infer_model_type(%{"architectures" => [class_name]}) do case @transformers_class_to_model[class_name] do nil -> {:error, "could not match the class name #{inspect(class_name)} to any of the supported models"} {module, architecture} -> {:ok, module, architecture} end end defp infer_model_type(%{"_class_name" => class_name}) do infer_model_type(%{"architectures" => [class_name]}) end defp infer_model_type(_spec_data) do {:error, "could not infer model type from the configuration"} end @doc """ Loads a pre-trained model from a model repository. The model is downloaded and cached on your disk, use `cache_dir/0` to find the location. ## Parameters precision On GPUs computations that use numeric type of lower precision can be faster and use less memory, while still providing valid results. You can configure the model to use particular type by passing the `:type` option, such as `:bf16`. Some repositories have multiple variants of the parameter files with different numeric types. The variant is usually indicated in the file extension and you can load a particular file by specifying `:params_variant`, or `:params_filename`. Note however that this does not determine the numeric type used for inference. The file type is relevant in context of download bandwidth and disk space. If you want to use a lower precision for inference, make sure to also specify `:type`. ## Options * `:spec` - the model specification to use when building the model. By default the specification is loaded using `load_spec/2` * `:spec_overrides` - additional options to configure the model specification with. This is a shorthand for using `load_spec/2`, `configure/2` and passing as `:spec` * `:module` - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly * `:architecture` - the model architecture, must be supported by `:module`. By default it is inferred from the configuration file * `:params_variant` - when specified, instead of loading parameters from ".", loads from ".." * `:params_filename` - the file with the model parameters to be loaded * `:log_params_diff` - whether to log missing, mismatched and unused parameters. By default diff is logged only if some parameters cannot be loaded * `:backend` - the backend to allocate the tensors on. It is either an atom or a tuple in the shape `{backend, options}` * `:type` - either a type or `Axon.MixedPrecision` policy to apply to the model. Passing this option automatically casts parameters to the desired type ## Examples By default the model type is inferred from configuration, so loading is as simple as: {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}) %{model: model, params: params, spec: spec} = resnet You can explicitly specify a different architecture, in which case matching parameters are still loaded: {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, architecture: :base) To further customize the model, you can also pass the specification: {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}) spec = Bumblebee.configure(spec, num_labels: 10) {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec) Or as a shorthand, you can pass just the options to override: {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec_overrides: [num_labels: 10]) """ @doc type: :model @spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()} def load_model(repository, opts \\ []) do repository = normalize_repository!(repository) opts = Keyword.validate!(opts, [ :spec, :spec_overrides, :module, :architecture, :params_variant, :params_filename, :log_params_diff, :backend, :type ]) with {:ok, repo_files} <- get_repo_files(repository), {:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files), model <- build_model(spec, Keyword.take(opts, [:type])), {:ok, params} <- load_params(spec, model, repository, repo_files, opts) do {:ok, %{model: model, params: params, spec: spec}} end end defp maybe_load_model_spec(opts, repository, repo_files) do spec_result = if spec = opts[:spec] do {:ok, spec} else do_load_spec(repository, repo_files, opts[:module], opts[:architecture]) end with {:ok, spec} <- spec_result do if options = opts[:spec_overrides] do {:ok, configure(spec, options)} else {:ok, spec} end end end defp load_params(%module{} = spec, model, repository, repo_files, opts) do input_template = module.input_template(spec) params_mapping = Bumblebee.HuggingFace.Transformers.Model.params_mapping(spec) {filename, sharded?} = infer_params_filename(repo_files, opts[:params_filename], opts[:params_variant]) loader_fun = filename |> String.replace_suffix(".index.json", "") |> Path.extname() |> params_file_loader_fun() with {:ok, paths} <- download_params_files(repository, repo_files, filename, sharded?) do opts = [ params_mapping: params_mapping, loader_fun: loader_fun ] ++ Keyword.take(opts, [:backend, :log_params_diff]) params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts) {:ok, params} end end defp infer_params_filename(repo_files, nil = _filename, variant) do validate_variant!(repo_files, variant) Enum.find_value(@params_filenames, &lookup_params_filename(repo_files, &1, variant)) || raise ArgumentError, "none of the expected parameters files found in the repository." <> " If the file exists under an unusual name, try specifying :params_filename" end defp infer_params_filename(repo_files, filename, variant) do if variant do IO.warn("ignoring :params_variant, because :params_filename was specified") end lookup_params_filename(repo_files, filename, nil) || raise ArgumentError, "could not find file #{inspect(filename)} in the repository" end defp lookup_params_filename(repo_files, filename, variant) do full_filename = add_variant(filename, variant) full_filename_sharded = add_variant(filename <> ".index.json", variant) cond do Map.has_key?(repo_files, full_filename) -> {full_filename, false} Map.has_key?(repo_files, full_filename_sharded) -> {full_filename_sharded, true} true -> nil end end defp add_variant(filename, nil), do: filename defp add_variant(filename, variant) do ext = Path.extname(filename) base = Path.basename(filename, ext) base <> "." <> variant <> ext end defp validate_variant!(_repo_files, nil), do: :ok defp validate_variant!(repo_files, variant) do variants = params_variants_in_repo(repo_files) cond do variant in variants -> :ok Enum.empty?(variants) -> raise ArgumentError, "parameters variant #{inspect(variant)} not found, the repository has no variants" true -> raise ArgumentError, "parameters variant #{inspect(variant)} not found, available variants: " <> Enum.map_join(variants, ", ", &inspect/1) end end defp params_variants_in_repo(repo_files) do params_filenames = MapSet.new(@params_filenames) Enum.reduce(repo_files, MapSet.new(), fn {name, _etag}, variants -> parts = String.split(name, ".") {variant, parts} = List.pop_at(parts, -2) name = Enum.join(parts, ".") if String.replace_suffix(name, ".index.json", "") in params_filenames and not String.contains?(variant, "-of-") do MapSet.put(variants, variant) else variants end end) end defp download_params_files(repository, repo_files, filename, false = _sharded?) do with {:ok, path} <- download(repository, filename, repo_files[filename]) do {:ok, [path]} end end defp download_params_files(repository, repo_files, index_filename, true = _sharded?) do with {:ok, path} <- download(repository, index_filename, repo_files[index_filename]), {:ok, sharded_metadata} <- decode_config(path) do filenames = for {_layer, filename} <- sharded_metadata["weight_map"], uniq: true, do: filename Enum.reduce_while(filenames, {:ok, []}, fn filename, {:ok, paths} -> case download(repository, filename, repo_files[filename]) do {:ok, path} -> {:cont, {:ok, [path | paths]}} error -> {:halt, error} end end) end end defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true) defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 @doc """ Featurizes `input` with the given featurizer. ## Options * `:defn_options` - the options for JIT compilation. Note that this is only relevant for featurizers implemented with Nx. Defaults to `[]` ## Examples featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer) {:ok, img} = StbImage.read_file(path) inputs = Bumblebee.apply_featurizer(featurizer, [img]) """ @doc type: :featurizer @spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any() def apply_featurizer(%module{} = featurizer, input, opts \\ []) do opts = Keyword.validate!(opts, defn_options: []) batch = module.process_input(featurizer, input) if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do Nx.Defn.jit_apply(&module.process_batch(featurizer, &1), [batch], opts[:defn_options]) else batch end end @doc """ Loads featurizer from a model repository. ## Options * `:module` - the featurizer module. By default it is inferred from the preprocessor configuration file, if that is not possible, it must be specified explicitly ## Examples {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"}) """ @doc type: :featurizer @spec load_featurizer(repository(), keyword()) :: {:ok, Bumblebee.Featurizer.t()} | {:error, String.t()} def load_featurizer(repository, opts \\ []) do repository = normalize_repository!(repository) opts = Keyword.validate!(opts, [:module]) module = opts[:module] case get_repo_files(repository) do {:ok, %{@featurizer_filename => etag} = repo_files} -> with {:ok, path} <- download(repository, @featurizer_filename, etag), {:ok, featurizer_data} <- decode_config(path) do module = module || case infer_featurizer_type(featurizer_data, repository, repo_files) do {:ok, module} -> module {:error, error} -> raise ArgumentError, "#{error}, please specify the :module option" end featurizer = configure(module) featurizer = HuggingFace.Transformers.Config.load(featurizer, featurizer_data) {:ok, featurizer} end {:ok, %{}} -> raise ArgumentError, "no featurizer found in the given repository" {:error, message} -> {:error, message} end end defp infer_featurizer_type(%{"feature_extractor_type" => class_name}, _repository, _repo_files) do case @transformers_class_to_featurizer[class_name] do nil -> {:error, "could not match the class name #{inspect(class_name)} to any of the supported featurizers"} module -> {:ok, module} end end defp infer_featurizer_type(%{"image_processor_type" => class_name}, _repository, _repo_files) do case @transformers_image_processor_type_to_featurizer[class_name] do nil -> {:error, "could not match the class name #{inspect(class_name)} to any of the supported featurizers"} module -> {:ok, module} end end defp infer_featurizer_type(_featurizer_data, repository, repo_files) do with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]), {:ok, featurizer_data} <- decode_config(path) do case featurizer_data do %{"model_type" => model_type} -> case @model_type_to_featurizer[model_type] do nil -> {:error, "could not match model type #{inspect(model_type)} to any of the supported featurizers"} module -> {:ok, module} end _ -> {:error, "could not infer featurizer type from the configuration"} end end end @doc """ Tokenizes and encodes `input` with the given tokenizer. ## Examples tokenizer = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) inputs = Bumblebee.apply_tokenizer(tokenizer, ["The capital of France is [MASK]."]) """ @doc type: :tokenizer @spec apply_tokenizer( Bumblebee.Tokenizer.t(), Bumblebee.Tokenizer.input() | list(Bumblebee.Tokenizer.input()), keyword() ) :: any() def apply_tokenizer(%module{} = tokenizer, input, opts \\ []) do tokenizer = if opts == [] do tokenizer else # TODO: remove options on v0.6 IO.warn( "passing options to Bumblebee.apply_tokenizer/3 is deprecated," <> " please use Bumblebee.configure/2 to set tokenizer options" ) Bumblebee.configure(tokenizer, opts) end module.apply(tokenizer, input) end @doc """ Loads tokenizer from a model repository. ## Options * `:type` - the tokenizer type. By default it is inferred from the configuration files, if that is not possible, it must be specified explicitly ## Examples {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) """ @doc type: :tokenizer @spec load_tokenizer(repository(), keyword()) :: {:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()} def load_tokenizer(repository, opts \\ []) do repository = normalize_repository!(repository) opts = Keyword.validate!(opts, [:type]) type = opts[:type] case get_repo_files(repository) do {:ok, %{@tokenizer_filename => etag} = repo_files} -> with {:ok, path} <- download(repository, @tokenizer_filename, etag) do type = type || case infer_tokenizer_type(repository, repo_files) do {:ok, type} -> type {:error, error} -> raise ArgumentError, "#{error}, please specify the :type option" end tokenizer_config_result = if Map.has_key?(repo_files, @tokenizer_config_filename) do etag = repo_files[@tokenizer_config_filename] with {:ok, path} <- download(repository, @tokenizer_config_filename, etag) do decode_config(path) end else {:ok, %{}} end special_tokens_map_result = if Map.has_key?(repo_files, @tokenizer_special_tokens_filename) do etag = repo_files[@tokenizer_special_tokens_filename] with {:ok, path} <- download(repository, @tokenizer_special_tokens_filename, etag) do decode_config(path) end else {:ok, %{}} end with {:ok, tokenizer_config} <- tokenizer_config_result, {:ok, special_tokens_map} <- special_tokens_map_result do tokenizer = struct!(Bumblebee.Text.PreTrainedTokenizer, type: type) tokenizer = HuggingFace.Transformers.Config.load(tokenizer, %{ "tokenizer_file" => path, # Note: special_tokens_map.json is a legacy file, now # tokenizer_config.json includes the same information # and takes precedence "special_tokens_map" => Map.merge(tokenizer_config, special_tokens_map) }) {:ok, tokenizer} end end {:ok, %{@tokenizer_config_filename => _}} -> raise ArgumentError, "expected a Rust-compatible tokenizer.json file, however the repository" <> " includes tokenizer in a different format. Please refer to Bumblebee" <> " README to see the possible steps you can take" {:ok, %{}} -> raise ArgumentError, "no tokenizer found in the given repository" {:error, message} -> {:error, message} end end defp infer_tokenizer_type(repository, repo_files) do with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]), {:ok, tokenizer_data} <- decode_config(path) do case tokenizer_data do %{"model_type" => model_type} -> case @model_type_to_tokenizer_type[model_type] do nil -> {:error, "could not match model type #{inspect(model_type)} to any of the supported tokenizer types"} type -> {:ok, type} end _ -> {:error, "could not infer tokenizer type from the model configuration"} end end end @doc """ Loads generation config from a model repository. Generation config includes a number of model-specific properties, so it is usually best to load the config and further configure, rather than building from scratch. See `Bumblebee.Text.GenerationConfig` for all the available options. ## Options * `:spec_module` - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly. Some models have extra options related to generations and those are loaded into a separate struct, stored under the `:extra_config` attribute ## Examples {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) generation_config = Bumblebee.configure(generation_config, max_new_tokens: 10) """ @spec load_generation_config(repository()) :: {:ok, Bumblebee.Text.GenerationConfig.t()} | {:error, String.t()} def load_generation_config(repository, opts \\ []) do opts = Keyword.validate!(opts, [:spec_module]) repository = normalize_repository!(repository) case get_repo_files(repository) do {:ok, %{@config_filename => etag} = repo_files} -> with {:ok, path} <- download(repository, @config_filename, etag), {:ok, spec_data} <- decode_config(path) do spec_module = opts[:spec_module] {inferred_module, inference_error} = case infer_model_type(spec_data) do {:ok, module, _architecture} -> {module, nil} {:error, error} -> {nil, error} end spec_module = spec_module || inferred_module unless spec_module do raise ArgumentError, "#{inference_error}, please specify the :spec_module option" end generation_data_result = if Map.has_key?(repo_files, @generation_filename) do etag = repo_files[@generation_filename] with {:ok, path} <- download(repository, @generation_filename, etag) do decode_config(path) end else # Fallback to the spec data, since it used to include # generation attributes {:ok, spec_data} end with {:ok, generation_data} <- generation_data_result do config = struct!(Bumblebee.Text.GenerationConfig) config = HuggingFace.Transformers.Config.load(config, generation_data) extra_config_module = Bumblebee.Text.Generation.extra_config_module(struct!(spec_module)) extra_config = if extra_config_module do extra_config = struct!(extra_config_module) HuggingFace.Transformers.Config.load(extra_config, generation_data) end config = %{config | extra_config: extra_config} {:ok, config} end end {:error, message} -> {:error, message} end end @doc """ Initializes state for a new scheduler loop. Returns a pair of `{state, timesteps}`, where `state` is an opaque container expected by `scheduler_step/4` and `timesteps` is a sequence of subsequent timesteps for model forward pass. Note that the number of `timesteps` may not match `num_steps` exactly. `num_steps` parameterizes sampling points, however depending on the method, sampling certain points may require multiple forward passes of the model and each element in `timesteps` corresponds to a single forward pass. """ @doc type: :scheduler @spec scheduler_init( Bumblebee.Scheduler.t(), non_neg_integer(), Nx.Tensor.t(), Nx.Tensor.t() ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()} def scheduler_init(%module{} = scheduler, num_steps, sample_template, prng_key) do module.init(scheduler, num_steps, sample_template, prng_key) end @doc """ Predicts sample at the previous timestep using the given scheduler. Takes the current `sample` and `prediction` (usually noise) returned by the model at the current timestep. Returns `{state, prev_sample}`, where `state` is the updated scheduler loop state and `prev_sample` is the predicted sample at the previous timestep. Note that some schedulers require several forward passes of the model (and a couple calls to this function) to make an actual prediction for the previous sample. """ @doc type: :scheduler @spec scheduler_step( Bumblebee.Scheduler.t(), Bumblebee.Scheduler.state(), Nx.Tensor.t(), Nx.Tensor.t() ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()} def scheduler_step(%module{} = scheduler, state, sample, prediction) do module.step(scheduler, state, sample, prediction) end @doc """ Loads scheduler from a model repository. ## Options * `:module` - the scheduler module. By default it is inferred from the scheduler configuration file, if that is not possible, it must be specified explicitly ## Examples {:ok, scheduler} = Bumblebee.load_scheduler({:hf, "CompVis/stable-diffusion-v1-4", subdir: "scheduler"}) """ @doc type: :scheduler @spec load_scheduler(repository(), keyword()) :: {:ok, Bumblebee.Scheduler.t()} | {:error, String.t()} def load_scheduler(repository, opts \\ []) do repository = normalize_repository!(repository) opts = Keyword.validate!(opts, [:module]) module = opts[:module] case get_repo_files(repository) do {:ok, %{@scheduler_filename => etag}} -> with {:ok, path} <- download(repository, @scheduler_filename, etag), {:ok, scheduler_data} <- decode_config(path) do module = module || case infer_scheduler_type(scheduler_data) do {:ok, module} -> module {:error, error} -> raise ArgumentError, "#{error}, please specify the :module option" end scheduler = configure(module) scheduler = HuggingFace.Transformers.Config.load(scheduler, scheduler_data) {:ok, scheduler} end {:ok, %{}} -> raise ArgumentError, "no scheduler found in the given repository" {:error, message} -> {:error, message} end end defp infer_scheduler_type(%{"_class_name" => class_name}) do case @diffusers_class_to_scheduler[class_name] do nil -> {:error, "could not match the class name #{inspect(class_name)} to any of the supported schedulers"} module -> {:ok, module} end end defp infer_scheduler_type(_scheduler_data) do {:error, "could not infer featurizer type from the configuration"} end defp get_repo_files({:local, dir}) do case File.ls(dir) do {:ok, filenames} -> repo_files = for filename <- filenames, path = Path.join(dir, filename), File.regular?(path), into: %{}, do: {filename, nil} {:ok, repo_files} {:error, reason} -> {:error, "could not read #{dir}, reason: #{:file.format_error(reason)}"} end end defp get_repo_files({:hf, repository_id, opts}) do subdir = opts[:subdir] url = HuggingFace.Hub.file_listing_url(repository_id, subdir, opts[:revision]) cache_scope = repository_id_to_cache_scope(repository_id) result = HuggingFace.Hub.cached_download( url, [cache_scope: cache_scope] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token]) ) with {:ok, path} <- result, {:ok, data} <- decode_config(path) do repo_files = for entry <- data, entry["type"] == "file", into: %{} do path = entry["path"] name = if subdir do String.replace_leading(path, subdir <> "/", "") else path end etag_content = entry["lfs"]["oid"] || entry["oid"] etag = <> {name, etag} end {:ok, repo_files} end end defp download({:local, dir}, filename, _etag) do path = Path.join(dir, filename) if File.exists?(path) do {:ok, path} else {:error, "local file #{inspect(path)} does not exist"} end end defp download({:hf, repository_id, opts}, filename, etag) do filename = if subdir = opts[:subdir] do subdir <> "/" <> filename else filename end url = HuggingFace.Hub.file_url(repository_id, filename, opts[:revision]) cache_scope = repository_id_to_cache_scope(repository_id) HuggingFace.Hub.cached_download( url, [etag: etag, cache_scope: cache_scope] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token]) ) end defp repository_id_to_cache_scope(repository_id) do repository_id |> String.replace("/", "--") |> String.replace(~r/[^\w-]/, "") end defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do {:hf, repository_id, []} end defp normalize_repository!({:hf, repository_id, opts}) when is_binary(repository_id) do opts = Keyword.validate!(opts, [:revision, :cache_dir, :offline, :auth_token, :subdir]) {:hf, repository_id, opts} end defp normalize_repository!({:local, dir}) when is_binary(dir) do {:local, dir} end defp normalize_repository!(other) do raise ArgumentError, "expected repository to be either {:hf, repository_id}, {:hf, repository_id, options}" <> " or {:local, directory}, got: #{inspect(other)}" end @doc """ Returns the directory where downloaded files are stored. Defaults to the standard cache location for the given operating system. Can be configured with the `BUMBLEBEE_CACHE_DIR` environment variable. """ @spec cache_dir() :: String.t() def cache_dir() do if dir = System.get_env("BUMBLEBEE_CACHE_DIR") do Path.expand(dir) else :filename.basedir(:user_cache, "bumblebee") end end end