import logging
import typing
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast
import more_itertools
import torch
from datasets import Dataset
from datasets import DatasetDict as HfDatasetDict
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
CTRLLMHeadModel,
CTRLTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
)
from tango import Format, JsonFormat, SqliteDictFormat, Step
from tango.common import DatasetDict
from tango.common.sequences import MappedSequence, SqliteSparseSequence
from tango.common.tqdm import Tqdm
from tango.integrations.torch import Model
from tango.integrations.torch.util import resolve_device, set_seed_all
logger = logging.getLogger(__name__)
#
# A lot of the code in this step is stolen from the run_generation.py script in transformers. Unfortunately their
# examples don't ship when you `pip install transformers`, so we have to duplicate it here.
#
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
MODEL_CLASSES = {
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore
CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore
def adjust_length_to_model(length, model):
max_sequence_length = (
model.config.max_position_embeddings
if hasattr(model.config, "max_position_embeddings")
else MAX_LENGTH
)
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
@typing.no_type_check # mypy has somehow lost the ability to tell what PreTrainedTokenizer and Model are.
def _generate(
model: Model,
# TODO: Change type to `Tokenizer` once HF includes `convert_tokens_to_ids` in `PretrainedTokenizerBase` class.
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompts: Iterable[str],
*,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> Iterable[List[str]]:
if not isinstance(model.config, tuple(SEQ2SEQ + CAUSAL)):
raise NotImplementedError(
"This function is only defined for huggingface models seq2seq/causal models."
)
device = resolve_device()
set_seed_all(seed)
tokenizer_kwargs: Dict[str, Any] = {}
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
# Seq2Seq models don't return their own prefix.
seq2seq_model = model.config_class in SEQ2SEQ
# HF does not do this? WTF?
model.eval()
model.to(device)
if fp16:
model.half()
def prepare_batch_without_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:
result = tokenizer.batch_encode_plus(
prompts,
add_special_tokens=False,
return_tensors="pt",
padding=True,
**tokenizer_kwargs,
)
result = {key: tensor.to(device) for key, tensor in result.items()}
return result
def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:
if len(prefix) > 0:
prompts = [f"{prefix} {t}" for t in prompts]
return prepare_batch_without_prefix(prompts)
prepare_batch_fn = prepare_batch_with_prefix
num_prefix_tokens: Optional[int] = None
# transformer model-specific exceptions
if isinstance(model, PreTrainedModel) and model.config_class:
if model.config_class.model_type == "xlm":
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
model.config.lang_id = xlm_language
# Original HF code ignores the prefix, but it looks like a bug?
prepare_batch_fn = prepare_batch_without_prefix
num_prefix_tokens = 0
elif model.config_class.model_type in {"xlnet", "transfo-xl"}:
prefix = prefix if prefix else PREFIX
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
# This actually doesn't work in the current version of transformers, which is probably a bug in the
# transformers library.
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
if num_prefix_tokens is None:
num_prefix_tokens = len(tokenizer.tokenize(prefix))
batches = more_itertools.chunked(Tqdm.tqdm(prompts, desc="Pre-processing prompts"), batch_size)
encoded_batches = map(prepare_batch_fn, batches)
for encoded_batch in Tqdm.tqdm(encoded_batches, desc="Processing batches"):
if seq2seq_model:
length = max_length
else:
length = adjust_length_to_model(max_length + encoded_batch["input_ids"].size(1), model)
with torch.inference_mode():
generated_sequences: torch.Tensor = model.generate( # type: ignore
**encoded_batch,
max_length=length,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
synced_gpus=False, # Needs to be True if we have more than one GPU running.
)
generated_sequences = generated_sequences.view(
-1, num_return_sequences, *generated_sequences.shape[1:]
).to("cpu")
def strip_special_tokens(t: torch.Tensor) -> torch.Tensor:
# amazing that torch has no capability for this
start = 0
while start < len(t) and int(t[start]) in {0, eos_token_id, pad_token_id}:
start += 1
end = len(t)
while int(t[end - 1]) in {0, eos_token_id, pad_token_id} and end > start:
end -= 1
return t[start:end]
# strip padding
generated_sequences_list = [
[strip_special_tokens(sequence) for sequence in per_prompt_sequences]
for per_prompt_sequences in generated_sequences
]
# strip prefix
if not seq2seq_model:
generated_sequences_list = [
[sequence[num_prefix_tokens:] for sequence in per_prompt_sequences]
for per_prompt_sequences in generated_sequences_list
]
texts = [
tokenizer.batch_decode(per_prompt_sequences, clean_up_tokenization_spaces=True)
for per_prompt_sequences in generated_sequences_list
]
yield from texts
def _generate_with_model_name(model_name: str, *args, **kwargs) -> Iterable[List[str]]:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
except ValueError:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return _generate(model, tokenizer, *args, **kwargs)
[docs]@Step.register("transformers::run_generation")
class RunGeneration(Step[Iterable[List[str]]]):
"""
A step that runs seq2seq Huggingface models in inference mode.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation".
"""
FORMAT: Format = JsonFormat("gz")
VERSION = "001"
SKIP_ID_ARGUMENTS = {"batch_size"}
# TODO: multiple GPUs
[docs] def run( # type: ignore
self,
model: Union[str, Model],
prompts: Iterable[str],
*,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> Iterable[List[str]]:
"""
Run a Huggingface seq2seq model in inference mode.
:param model:
The name of the model to run. Any name that works in the transformers library works here.
Or, you can directly provide the model to run.
:param prompts:
The prompts to run through the model. You can specify prompts directly in the config, but
more commonly the prompts are produced by another step that reads a dataset, for example.
:param tokenizer:
The tokenizer to run.
:param batch_size:
The number of sequences to process at one time. This has no bearing on the output, so
you can change this number without invalidating cached results.
:param max_length:
The maximum number of tokens/word pieces that the model will generate. For models that
extend the prompt, the prefix does not count towards this limit.
:param temperature:
Passed directly to transformer's ``generate()`` method.
The value used to model the next token probabilities.
:param repetition_penalty:
Passed directly to transformer's ``generate()`` method.
The parameter for repetition penalty. 1.0 means no penalty.
:param k:
Passed directly to transformer's ``generate()`` method.
The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param p:
Passed directly to transformer's ``generate()`` method.
If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
:param prefix:
A prefix that gets pre-pended to all prompts.
:param xlm_language:
For the XLM model, this is a way to specify the language you want to use.
:param seed:
Random seed
:param num_return_sequences:
The number of generations to return for each prompt.
:param fp16:
Whether to use 16-bit floats.
:returns:
Returns an iterator of lists of string. Each list contains the predictions for one prompt.
"""
if isinstance(model, str):
try:
model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))
except ValueError:
model = cast(Model, AutoModelForCausalLM.from_pretrained(model))
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)
return _generate(
model,
tokenizer,
prompts,
batch_size=batch_size,
max_length=max_length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
xlm_language=xlm_language,
seed=seed,
num_return_sequences=num_return_sequences,
fp16=fp16,
)
[docs]@Step.register("transformers::run_generation_dataset")
class RunGenerationDataset(Step[DatasetDict]):
"""
A step that runs seq2seq Huggingface models in inference mode.
This is similar to :class:`RunGeneration`, but it takes a dataset as input and produces
a new dataset as output, which contains the predictions in a new field.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation_dataset".
"""
FORMAT: Format = SqliteDictFormat()
VERSION = "002"
SKIP_ID_ARGUMENTS = {"batch_size"}
[docs] def run( # type: ignore
self,
model: Union[str, Model],
input: Union[DatasetDict, HfDatasetDict],
prompt_field: str,
*,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
output_field: Optional[str] = None,
splits: Optional[Union[str, Set[str]]] = None,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> DatasetDict:
"""
Augment an input dataset with generations from a Huggingface seq2seq model.
:param model:
The name of the model to run. Any name that works in the transformers library works here.
Or, you can directly provide the model to run.
:param input:
The input dataset.
:param prompt_field:
The field in the dataset that contains the text of the prompts.
:param tokenizer:
The tokenizer to run.
:param output_field:
The field in the dataset that we will write the predictions into. In the result, this field
will contain ``List[str]``.
:param splits:
A split, or set of splits, to process. If this is not specified, we will process all splits.
:param batch_size:
The number of sequences to process at one time. This has no bearing on the output, so
you can change this number without invalidating cached results.
:param max_length:
The maximum number of tokens/word pieces that the model will generate. For models that
extend the prompt, the prefix does not count towards this limit.
:param temperature:
Passed directly to transformer's `generate()` method.
The value used to model the next token probabilities.
:param repetition_penalty:
Passed directly to transformer's `generate()` method.
The parameter for repetition penalty. 1.0 means no penalty.
:param k:
Passed directly to transformer's `generate()` method.
The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param p:
Passed directly to transformer's `generate()` method.
If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
:param prefix:
A prefix that gets pre-pended to all prompts.
:param xlm_language:
For the XLM model, this is a way to specify the language you want to use.
:param seed:
Random seed
:param num_return_sequences:
The number of generations to return for each prompt.
:param fp16:
Whether to use 16-bit floats.
:returns:
Returns a dataset with an extra field containing the predictions.
"""
if isinstance(model, str):
try:
model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))
except ValueError:
model = cast(Model, AutoModelForCausalLM.from_pretrained(model))
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)
if isinstance(input, HfDatasetDict):
input = DatasetDict(input, {})
if splits is None:
splits = input.keys()
elif isinstance(splits, str):
splits = {splits}
result: Dict[str, Sequence] = {}
for split_name, input_split in input.items():
if split_name in splits:
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
if len(output_split) > 0:
logger.info(
"Found %d items already generated. Will generate %d more.",
len(output_split),
len(input_split) - len(output_split),
)
if len(output_split) > 0:
if isinstance(input_split, Dataset):
input_split = input_split.select(range(len(output_split), len(input_split)))
else:
input_split = input_split[len(output_split) :]
prompts = MappedSequence(lambda i: i[prompt_field], input_split)
generations = _generate(
model,
tokenizer,
prompts,
batch_size=batch_size,
max_length=max_length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
xlm_language=xlm_language,
seed=seed,
num_return_sequences=num_return_sequences,
fp16=fp16,
)
for instance, generation in zip(input_split, generations):
output_split.append(
{**instance, **{output_field or prompt_field + "_generated": generation}}
)
result[split_name] = output_split
else:
result[split_name] = input_split
return DatasetDict(result, input.metadata)