Abstractive API Reference
Model/Module
- class abstractive.AbstractiveSummarizer(hparams)[source]
A machine learning model that abstractively summarizes an input text using a seq2seq model. Main class that handles the data loading, initial processing, training/testing/validating setup, and contains the actual model.
- configure_optimizers()[source]
Configure the optimizers. Returns the optimizer and scheduler specified by the values in
self.hparams
.
- forward(source=None, target=None, source_mask=None, target_mask=None, labels=None, **kwargs)[source]
Model forward function. See the 60 minute bliz tutorial if you are unsure what a forward function is.
- Parameters
source (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) – Indices of input sequence tokens in the vocabulary for the encoder. What are input IDs? Defaults to None.target (
torch.LongTensor
of shape(batch_size, target_sequence_length)
, optional) – Provide for sequence to sequence training to the decoder. Defaults to None.source_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) – Mask to avoid performing attention on padding token indices for the encoder. Mask values selected in[0, 1]
:1
for tokens that are NOT MASKED,0
for MASKED tokens. Defaults to None.target_mask (
torch.BoolTensor
of shape(batch_size, tgt_seq_len)
, optional) –source_mask
but for the target sequence. Is an attention mask. Defaults to None.labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) – Labels for computing the masked language modeling loss for the decoder. Indices should be in[-100, 0, ..., config.vocab_size]
. Tokens with indices set to-100
are ignored (masked), the loss is only computed for the tokens with labels in[0, ..., config.vocab_size]
Defaults to None.
- Returns
(cross_entropy_loss, prediction_scores) The cross entropy loss and the prediction scores, which are the scores for each token in the vocabulary for each token in the output.
- Return type
tuple
- ids_to_clean_text(generated_ids, replace_sep_with_q=False)[source]
Convert IDs generated from
tokenizer.encode
to a string usingtokenizer.batch_decode
and also clean up spacing and special tokens.- Parameters
generated_ids (list) – A list examples where each example is a list of IDs generated from
tokenizer.encode
.replace_sep_with_q (bool, optional) – Replace the
self.tokenizer.sep_token
with “<q>”. Useful for determineing sentence boundaries and calculating ROUGE scores. Defaults to False.
- Returns
A list of examples where each example is a string or just one string if only one example was passed to this function.
- Return type
list or string
- on_save_checkpoint(checkpoint)[source]
Save the model in the
huggingface/transformers
format when a checkpoint is saved.
- predict(input_sequence)[source]
Summaries
input_sequence
using the model. Can summarize a list of sequences at once.- Parameters
input_sequence (str or list[str]) – The text to be summarized.
- Returns
The summary text.
- Return type
str or list[str]
- prepare_data()[source]
Create the data using the
huggingface/nlp
library. This function handles downloading, preprocessing, tokenization, and feature extraction.
- setup(stage)[source]
Load the data created by
prepare_data()
. The downloading and loading is broken into two functions since prepare_data is only called from global_rank=0, and thus is not suitable for state (self.something) assignment.
- test_epoch_end(outputs)[source]
Called at the end of a testing epoch: PyTorch Lightning Documentation Finds the mean of all the metrics logged by
test_step()
.
- test_step(batch, batch_idx)[source]
Test step: PyTorch Lightning Documentation Similar to
validation_step()
in that in runs the inputs through the model. However, this method also calculates the ROUGE scores for each example-summary pair.
- training: bool
- training_step(batch, batch_idx)[source]
Training step: PyTorch Lightning Documentation
- validation_step(batch, batch_idx)[source]
Validation step: PyTorch Lightning Documentation
- abstractive.longformer_modifier(final_dictionary, tokenizer, attention_window)[source]
Creates the global_attention_mask for the longformer. Tokens with global attention attend to all other tokens, and all other tokens attend to them. This is important for task-specific finetuning because it makes the model more flexible at representing the task. For example, for classification, the <s> token should be given global attention. For QA, all question tokens should also have global attention. For summarization, global attention is given to all of the <s> (RoBERTa ‘CLS’ equivalent) tokens. Please refer to the Longformer paper for more details. Mask values selected in
[0, 1]
:0
for local attention,1
for global attention.