689 lines
22 KiB
Rust
689 lines
22 KiB
Rust
use std::collections::HashMap;
|
|
use std::ops::Deref;
|
|
use std::panic;
|
|
|
|
use rustler::{NifTaggedEnum, Term};
|
|
|
|
use tokenizers::models::wordpiece::WordPieceTrainerBuilder;
|
|
use tokenizers::models::TrainerWrapper;
|
|
use tokenizers::tokenizer::AddedToken;
|
|
use tokenizers::Model;
|
|
use tokenizers::{EncodeInput, TokenizerImpl};
|
|
|
|
use crate::added_token::{AddedSpecialTokenInput, AddedTokenInput};
|
|
use crate::decoders::ExTokenizersDecoder;
|
|
use crate::encoding::{apply_transformations, ExTokenizersEncoding, TransformationElement};
|
|
use crate::error::ExTokenizersError;
|
|
use crate::models::ExTokenizersModel;
|
|
use crate::normalizers::ExTokenizersNormalizer;
|
|
use crate::post_processors::ExTokenizersPostProcessor;
|
|
use crate::pre_tokenizers::ExTokenizersPreTokenizer;
|
|
use crate::trainers::ExTokenizersTrainer;
|
|
use crate::util::Direction;
|
|
|
|
type ExTokenizerImpl = TokenizerImpl<
|
|
ExTokenizersModel,
|
|
ExTokenizersNormalizer,
|
|
ExTokenizersPreTokenizer,
|
|
ExTokenizersPostProcessor,
|
|
ExTokenizersDecoder,
|
|
>;
|
|
|
|
pub struct ExTokenizersTokenizerRef(ExTokenizerImpl);
|
|
|
|
#[rustler::resource_impl]
|
|
impl rustler::Resource for ExTokenizersTokenizerRef {}
|
|
|
|
#[derive(rustler::NifStruct)]
|
|
#[module = "Tokenizers.Tokenizer"]
|
|
pub struct ExTokenizersTokenizer {
|
|
pub resource: rustler::ResourceArc<ExTokenizersTokenizerRef>,
|
|
}
|
|
|
|
impl From<ExTokenizerImpl> for ExTokenizersTokenizer {
|
|
fn from(data: ExTokenizerImpl) -> Self {
|
|
Self {
|
|
resource: rustler::ResourceArc::new(ExTokenizersTokenizerRef(data)),
|
|
}
|
|
}
|
|
}
|
|
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
// / Creators
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_init(
|
|
model: ExTokenizersModel,
|
|
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
|
|
let tokenizer = TokenizerImpl::new(model);
|
|
Ok(tokenizer.into())
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum LoadOption {
|
|
AdditionalSpecialTokens(Vec<AddedSpecialTokenInput>),
|
|
// Currently only :none is supported
|
|
Padding(rustler::Atom),
|
|
Truncation(rustler::Atom),
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyIo")]
|
|
pub fn tokenizer_from_file(
|
|
path: &str,
|
|
options: Vec<LoadOption>,
|
|
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
|
|
let mut tokenizer = TokenizerImpl::from_file(path)?;
|
|
tokenizer = apply_load_options(tokenizer, options);
|
|
Ok(tokenizer.into())
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_from_buffer(
|
|
data: String,
|
|
options: Vec<LoadOption>,
|
|
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
|
|
let mut tokenizer: ExTokenizerImpl = data.parse()?;
|
|
tokenizer = apply_load_options(tokenizer, options);
|
|
Ok(tokenizer.into())
|
|
}
|
|
|
|
fn apply_load_options(mut tokenizer: ExTokenizerImpl, options: Vec<LoadOption>) -> ExTokenizerImpl {
|
|
struct Opts {
|
|
additional_special_tokens: Vec<AddedSpecialTokenInput>,
|
|
disable_padding: bool,
|
|
disable_truncation: bool,
|
|
}
|
|
|
|
let mut opts = Opts {
|
|
additional_special_tokens: vec![],
|
|
disable_padding: false,
|
|
disable_truncation: false,
|
|
};
|
|
|
|
for opt in options {
|
|
match opt {
|
|
LoadOption::AdditionalSpecialTokens(tokens) => {
|
|
opts.additional_special_tokens = tokens;
|
|
}
|
|
LoadOption::Padding(_) => {
|
|
opts.disable_padding = true;
|
|
}
|
|
LoadOption::Truncation(_) => {
|
|
opts.disable_truncation = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
tokenizer.add_special_tokens(
|
|
opts.additional_special_tokens
|
|
.iter()
|
|
.map(|t| t.into())
|
|
.collect::<Vec<_>>()
|
|
.as_ref(),
|
|
);
|
|
|
|
if opts.disable_padding {
|
|
tokenizer.with_padding(None);
|
|
}
|
|
|
|
if opts.disable_truncation {
|
|
tokenizer.with_truncation(None).unwrap();
|
|
}
|
|
|
|
tokenizer
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum SaveOption {
|
|
Pretty(bool),
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyIo")]
|
|
pub fn tokenizer_save(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
path: &str,
|
|
options: Vec<SaveOption>,
|
|
// pretty: bool,
|
|
) -> Result<String, ExTokenizersError> {
|
|
struct Opts {
|
|
pretty: bool,
|
|
}
|
|
let mut opts = Opts { pretty: false };
|
|
for opt in options {
|
|
match opt {
|
|
SaveOption::Pretty(pretty) => opts.pretty = pretty,
|
|
}
|
|
}
|
|
|
|
tokenizer.resource.0.save(path, opts.pretty)?;
|
|
Ok(path.to_string())
|
|
}
|
|
|
|
// tokenizer_from_pretrained IS SKIPPED as implemented in elixir.
|
|
// It uses tokeniser_from_file underneeth.
|
|
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
// / Setters / Getters
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_model(tokenizer: ExTokenizersTokenizer) -> ExTokenizersModel {
|
|
let model = tokenizer.resource.0.get_model().clone();
|
|
model
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_model(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
model: ExTokenizersModel,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_model(model);
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
// Generate all setters and getters for pre_tokenizer, normalizer and so on - not as a macro:
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_normalizer(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
) -> Option<ExTokenizersNormalizer> {
|
|
let normalizer: Option<ExTokenizersNormalizer> = tokenizer.resource.0.get_normalizer().cloned();
|
|
normalizer
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_normalizer(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
normalizer: ExTokenizersNormalizer,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_normalizer(Some(normalizer));
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_pre_tokenizer(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
) -> Option<ExTokenizersPreTokenizer> {
|
|
let pre_tokenizer: Option<ExTokenizersPreTokenizer> =
|
|
tokenizer.resource.0.get_pre_tokenizer().cloned();
|
|
pre_tokenizer
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_pre_tokenizer(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
pre_tokenizer: ExTokenizersPreTokenizer,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_post_processor(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
) -> Option<ExTokenizersPostProcessor> {
|
|
let post_processor: Option<ExTokenizersPostProcessor> =
|
|
tokenizer.resource.0.get_post_processor().cloned();
|
|
post_processor
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_post_processor(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
post_processor: ExTokenizersPostProcessor,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_post_processor(Some(post_processor));
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_decoder(tokenizer: ExTokenizersTokenizer) -> Option<ExTokenizersDecoder> {
|
|
let decoder: Option<ExTokenizersDecoder> = tokenizer.resource.0.get_decoder().cloned();
|
|
decoder
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_decoder(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
decoder: ExTokenizersDecoder,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_decoder(Some(decoder));
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_vocab(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
with_added_tokens: bool,
|
|
) -> HashMap<String, u32> {
|
|
tokenizer.resource.0.get_vocab(with_added_tokens)
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_get_vocab_size(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
with_added_tokens: bool,
|
|
) -> usize {
|
|
tokenizer.resource.0.get_vocab_size(with_added_tokens)
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_add_tokens(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
tokens: Vec<AddedTokenInput>,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.add_tokens(&tokens.iter().map(|t| t.into()).collect::<Vec<AddedToken>>());
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_add_special_tokens(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
tokens: Vec<AddedSpecialTokenInput>,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.add_special_tokens(&tokens.iter().map(|t| t.into()).collect::<Vec<AddedToken>>());
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum TruncationOption {
|
|
MaxLength(usize),
|
|
Stride(usize),
|
|
Strategy(TruncateStrategy),
|
|
Direction(Direction),
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum TruncateStrategy {
|
|
LongestFirst,
|
|
OnlyFirst,
|
|
OnlySecond,
|
|
}
|
|
|
|
impl From<TruncateStrategy> for tokenizers::TruncationStrategy {
|
|
fn from(strategy: TruncateStrategy) -> Self {
|
|
match strategy {
|
|
TruncateStrategy::LongestFirst => tokenizers::TruncationStrategy::LongestFirst,
|
|
TruncateStrategy::OnlyFirst => tokenizers::TruncationStrategy::OnlyFirst,
|
|
TruncateStrategy::OnlySecond => tokenizers::TruncationStrategy::OnlySecond,
|
|
}
|
|
}
|
|
}
|
|
impl From<&TruncateStrategy> for tokenizers::TruncationStrategy {
|
|
fn from(strategy: &TruncateStrategy) -> Self {
|
|
match strategy {
|
|
TruncateStrategy::LongestFirst => tokenizers::TruncationStrategy::LongestFirst,
|
|
TruncateStrategy::OnlyFirst => tokenizers::TruncationStrategy::OnlyFirst,
|
|
TruncateStrategy::OnlySecond => tokenizers::TruncationStrategy::OnlySecond,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_truncation(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
opts: Vec<TruncationOption>,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut truncation: tokenizers::TruncationParams = Default::default();
|
|
opts.iter().for_each(|option| match option {
|
|
TruncationOption::MaxLength(max_length) => truncation.max_length = *max_length,
|
|
TruncationOption::Stride(stride) => truncation.stride = *stride,
|
|
TruncationOption::Strategy(strategy) => truncation.strategy = strategy.into(),
|
|
TruncationOption::Direction(direction) => truncation.direction = direction.into(),
|
|
});
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_truncation(Some(truncation)).unwrap();
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_disable_truncation(tokenizer: ExTokenizersTokenizer) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_truncation(None).unwrap();
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum PaddingOption {
|
|
Strategy(PadStrategy),
|
|
Direction(Direction),
|
|
PadToMultipleOf(usize),
|
|
PadId(u32),
|
|
PadTypeId(u32),
|
|
PadToken(String),
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum PadStrategy {
|
|
BatchLongest,
|
|
Fixed(usize),
|
|
}
|
|
|
|
impl From<PadStrategy> for tokenizers::PaddingStrategy {
|
|
fn from(strategy: PadStrategy) -> Self {
|
|
match strategy {
|
|
PadStrategy::BatchLongest => tokenizers::PaddingStrategy::BatchLongest,
|
|
PadStrategy::Fixed(size) => tokenizers::PaddingStrategy::Fixed(size),
|
|
}
|
|
}
|
|
}
|
|
impl From<&PadStrategy> for tokenizers::PaddingStrategy {
|
|
fn from(strategy: &PadStrategy) -> Self {
|
|
match strategy {
|
|
PadStrategy::BatchLongest => tokenizers::PaddingStrategy::BatchLongest,
|
|
PadStrategy::Fixed(size) => tokenizers::PaddingStrategy::Fixed(*size),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_set_padding(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
opts: Vec<PaddingOption>,
|
|
) -> ExTokenizersTokenizer {
|
|
let mut padding = tokenizers::PaddingParams {
|
|
..Default::default()
|
|
};
|
|
opts.iter().for_each(|option| match option {
|
|
PaddingOption::Strategy(strategy) => padding.strategy = strategy.into(),
|
|
PaddingOption::Direction(direction) => padding.direction = direction.into(),
|
|
PaddingOption::PadToMultipleOf(pad_to_multiple_of) => {
|
|
padding.pad_to_multiple_of = Some(*pad_to_multiple_of)
|
|
}
|
|
PaddingOption::PadId(pad_id) => padding.pad_id = *pad_id,
|
|
PaddingOption::PadTypeId(pad_type_id) => padding.pad_type_id = *pad_type_id,
|
|
PaddingOption::PadToken(pad_token) => padding.pad_token = pad_token.clone(),
|
|
});
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_padding(Some(padding));
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_disable_padding(tokenizer: ExTokenizersTokenizer) -> ExTokenizersTokenizer {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
new_tokenizer.with_padding(None);
|
|
new_tokenizer.into()
|
|
}
|
|
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
// / Inference
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
|
|
fn term_to_encode_input<'a, 'b>(term: &'a Term<'b>) -> Result<EncodeInput<'b>, ExTokenizersError> {
|
|
if let Ok(seq) = term.decode::<&'b str>() {
|
|
Ok(EncodeInput::Single(seq.into()))
|
|
} else if let Ok((seq1, seq2)) = term.decode::<(&'b str, &'b str)>() {
|
|
Ok(EncodeInput::Dual(seq1.into(), seq2.into()))
|
|
} else {
|
|
Err(ExTokenizersError::Other(String::from(
|
|
"input must be either a string or a tuple",
|
|
)))
|
|
}
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum EncodeOption {
|
|
AddSpecialTokens(bool),
|
|
EncodingTransformations(Vec<TransformationElement>),
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyCpu")]
|
|
pub fn tokenizer_encode(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
input: Term,
|
|
options: Vec<EncodeOption>,
|
|
) -> Result<ExTokenizersEncoding, ExTokenizersError> {
|
|
struct Opts {
|
|
add_special_tokens: bool,
|
|
encoding_transformations: Vec<TransformationElement>,
|
|
}
|
|
let mut opts = Opts {
|
|
add_special_tokens: true,
|
|
encoding_transformations: Vec::new(),
|
|
};
|
|
options.into_iter().for_each(|option| match option {
|
|
EncodeOption::AddSpecialTokens(add_special_tokens) => {
|
|
opts.add_special_tokens = add_special_tokens
|
|
}
|
|
EncodeOption::EncodingTransformations(encoding_transformations) => {
|
|
opts.encoding_transformations = encoding_transformations
|
|
}
|
|
});
|
|
|
|
let input = term_to_encode_input(&input)?;
|
|
let mut encoding = tokenizer
|
|
.resource
|
|
.0
|
|
.encode(input, opts.add_special_tokens)?;
|
|
apply_transformations(&mut encoding, &opts.encoding_transformations);
|
|
Ok(encoding.into())
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyCpu")]
|
|
pub fn tokenizer_encode_batch(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
inputs: Vec<Term>,
|
|
options: Vec<EncodeOption>,
|
|
// add_special_tokens: bool,
|
|
) -> Result<Vec<ExTokenizersEncoding>, ExTokenizersError> {
|
|
struct Opts {
|
|
add_special_tokens: bool,
|
|
encoding_transformations: Vec<TransformationElement>,
|
|
}
|
|
let mut opts = Opts {
|
|
add_special_tokens: true,
|
|
encoding_transformations: Vec::new(),
|
|
};
|
|
options.into_iter().for_each(|option| match option {
|
|
EncodeOption::AddSpecialTokens(add_special_tokens) => {
|
|
opts.add_special_tokens = add_special_tokens
|
|
}
|
|
EncodeOption::EncodingTransformations(encoding_transformations) => {
|
|
opts.encoding_transformations = encoding_transformations
|
|
}
|
|
});
|
|
let inputs = inputs
|
|
.iter()
|
|
.map(term_to_encode_input)
|
|
.collect::<Result<Vec<EncodeInput>, ExTokenizersError>>()?;
|
|
let mut encodings = tokenizer
|
|
.resource
|
|
.0
|
|
.encode_batch(inputs, opts.add_special_tokens)?;
|
|
|
|
// Applying transformations (if any)
|
|
for encoding in encodings.iter_mut() {
|
|
apply_transformations(encoding, &opts.encoding_transformations);
|
|
}
|
|
|
|
let ex_encodings = encodings
|
|
.into_iter()
|
|
.map(|encoding| encoding.into())
|
|
.collect();
|
|
Ok(ex_encodings)
|
|
}
|
|
|
|
#[derive(NifTaggedEnum)]
|
|
pub enum DecodeOption {
|
|
SkipSpecialTokens(bool),
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyCpu")]
|
|
pub fn tokenizer_decode(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
ids: Vec<u32>,
|
|
options: Vec<DecodeOption>,
|
|
) -> Result<String, ExTokenizersError> {
|
|
struct Opts {
|
|
skip_special_tokens: bool,
|
|
}
|
|
let mut opts = Opts {
|
|
skip_special_tokens: true,
|
|
};
|
|
options.into_iter().for_each(|option| match option {
|
|
DecodeOption::SkipSpecialTokens(skip_special_tokens) => {
|
|
opts.skip_special_tokens = skip_special_tokens
|
|
}
|
|
});
|
|
|
|
Ok(tokenizer
|
|
.resource
|
|
.0
|
|
.decode(&ids, opts.skip_special_tokens)?)
|
|
}
|
|
|
|
#[rustler::nif(schedule = "DirtyCpu")]
|
|
pub fn tokenizer_decode_batch(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
sentences: Vec<Vec<u32>>,
|
|
options: Vec<DecodeOption>,
|
|
) -> Result<Vec<String>, ExTokenizersError> {
|
|
struct Opts {
|
|
skip_special_tokens: bool,
|
|
}
|
|
let mut opts = Opts {
|
|
skip_special_tokens: true,
|
|
};
|
|
options.into_iter().for_each(|option| match option {
|
|
DecodeOption::SkipSpecialTokens(skip_special_tokens) => {
|
|
opts.skip_special_tokens = skip_special_tokens
|
|
}
|
|
});
|
|
|
|
Ok(tokenizer.resource.0.decode_batch(
|
|
sentences
|
|
.iter()
|
|
.map(Vec::as_slice)
|
|
.collect::<Vec<&[u32]>>()
|
|
.as_slice(),
|
|
opts.skip_special_tokens,
|
|
)?)
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_token_to_id(tokenizer: ExTokenizersTokenizer, token: &str) -> Option<u32> {
|
|
tokenizer.resource.0.token_to_id(token)
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_id_to_token(tokenizer: ExTokenizersTokenizer, id: u32) -> Option<String> {
|
|
tokenizer.resource.0.id_to_token(id)
|
|
}
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_post_processing(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
enc: ExTokenizersEncoding,
|
|
pair: Option<ExTokenizersEncoding>,
|
|
add_special_tokens: bool,
|
|
) -> Result<ExTokenizersEncoding, ExTokenizersError> {
|
|
let result: tokenizers::Encoding = tokenizer.resource.0.post_process(
|
|
enc.resource.0.clone(),
|
|
pair.map(|enc| enc.resource.0.clone()),
|
|
add_special_tokens,
|
|
)?;
|
|
Ok(result.into())
|
|
}
|
|
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
// / Training
|
|
// /////////////////////////////////////////////////////////////////////////////
|
|
|
|
#[rustler::nif]
|
|
pub fn tokenizer_train_from_files(
|
|
tokenizer: ExTokenizersTokenizer,
|
|
files: Vec<String>,
|
|
trainer: Option<ExTokenizersTrainer>,
|
|
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
|
|
// Current version of rust lib panics on retrainging with another trainer.
|
|
// This leads to unpredicted nif behaviour.
|
|
// Unwind can be removed after fixes https://github.com/huggingface/tokenizers/issues/525
|
|
|
|
let result = panic::catch_unwind(|| {
|
|
let mut new_tokenizer = tokenizer.resource.0.clone();
|
|
let new_model = ExTokenizersModel::new(
|
|
tokenizer
|
|
.resource
|
|
.0
|
|
.get_model()
|
|
.resource
|
|
.0
|
|
.read()
|
|
.unwrap()
|
|
.clone(),
|
|
);
|
|
new_tokenizer.with_model(new_model);
|
|
match trainer {
|
|
Some(trainer) => {
|
|
// TODO: call clone on trainer wrapper once available (tokenizers > 0.13.3)
|
|
// see https://github.com/huggingface/tokenizers/pull/1317
|
|
let trainer = match trainer.resource.0.read().unwrap().deref() {
|
|
TrainerWrapper::BpeTrainer(trainer) => {
|
|
TrainerWrapper::BpeTrainer(trainer.clone())
|
|
}
|
|
TrainerWrapper::WordPieceTrainer(trainer) => {
|
|
// WordPieceTrainer does not derive clone so we re-build by hand
|
|
let mut builder = WordPieceTrainerBuilder::default()
|
|
.min_frequency(trainer.min_frequency())
|
|
.vocab_size(trainer.vocab_size())
|
|
.show_progress(trainer.show_progress())
|
|
.special_tokens(trainer.special_tokens().to_vec())
|
|
.initial_alphabet(trainer.initial_alphabet().clone());
|
|
builder = match trainer.limit_alphabet() {
|
|
Some(limit_alphabet) => builder.limit_alphabet(limit_alphabet),
|
|
None => builder,
|
|
};
|
|
builder = match trainer.continuing_subword_prefix() {
|
|
Some(continuing_subword_prefix) => builder
|
|
.continuing_subword_prefix(continuing_subword_prefix.to_string()),
|
|
None => builder,
|
|
};
|
|
builder = match trainer.end_of_word_suffix() {
|
|
Some(end_of_word_suffix) => {
|
|
builder.end_of_word_suffix(end_of_word_suffix.to_string())
|
|
}
|
|
None => builder,
|
|
};
|
|
TrainerWrapper::WordPieceTrainer(builder.build())
|
|
}
|
|
TrainerWrapper::WordLevelTrainer(trainer) => {
|
|
TrainerWrapper::WordLevelTrainer(trainer.clone())
|
|
}
|
|
TrainerWrapper::UnigramTrainer(trainer) => {
|
|
TrainerWrapper::UnigramTrainer(trainer.clone())
|
|
}
|
|
};
|
|
|
|
let mut trainer = ExTokenizersTrainer::new(trainer);
|
|
|
|
new_tokenizer.train_from_files(&mut trainer, files)
|
|
}
|
|
None => {
|
|
// Trainer is not defined, using default
|
|
let mut default_trainer = new_tokenizer.get_model().get_trainer();
|
|
new_tokenizer.train_from_files(&mut default_trainer, files)
|
|
}
|
|
}?;
|
|
Ok(new_tokenizer)
|
|
});
|
|
let new_tokenizer = match result {
|
|
Ok(value) => value,
|
|
Err(panic) => {
|
|
let panic_message = match panic.downcast_ref::<String>() {
|
|
Some(s) => s.clone(),
|
|
None => "Unknown Panic".to_string(),
|
|
};
|
|
Err(ExTokenizersError::Internal(panic_message))
|
|
}
|
|
}?;
|
|
|
|
Ok(new_tokenizer.into())
|
|
}
|