import copy
import gc
import json
import linecache
import logging
import os
import random
from functools import partial
from multiprocessing import Pool
import numpy as np
import torch
from helpers import pad
logger = logging.getLogger(__name__)
[docs]def pad_batch_collate(batch, modifier=None):
r"""
Collate function to be passed to ``DataLoaders``. PyTorch Docs:
`https://pytorch.org/docs/stable/data.html#dataloader-collate-fn <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`__
Calculates padding (per batch for efficiency) of ``labels`` and
``token_type_ids`` if they exist within the batch from the ``Dataset``.
Also, pads ``sent_rep_token_ids`` and creates the ``sent_rep_mask`` to
indicate which numbers in the ``sent_rep_token_ids`` list are actually
the locations of sentence representation ids and which are padding.
Finally, calculates the ``attention_mask`` for each set of ``input_ids``
and pads both the ``attention_mask`` and the ``input_ids``. Converts all
inputs to tensors.
If ``sent_lengths`` are found then they will also automatically be
padded. However, the padding for sentence lengths is complicated. Each
list of sentence lengths needs to be the length of the longest list of
sentence lengths and the sum of all the lengths in each list needs to
add to the length of the input_ids width (the length of each input_id).
The second requirement exists because ``torch.split()`` (which is used
in the ``mean_tokens`` pooling algorithm to convert word vectors to
sentence embeddings in ``pooling.py``) will split a tensor into the
lengths requested but will error instead of returning any extra.
However, ``torch.split()`` will split a tensor into zero length
segments. Thus, to solve this, zeros are added to each sentence length
list for each example until one more padding value is needed to get the
maximum number of sentences. Once only one more value is needed, the
total value needded to reach the width of the ``input_ids`` is added.
``source`` and ``target``, if present, are simply passed on without any
processing. Therefore, the standard ``collate_fn`` function for
``DataLoader``\ s will not work if these are present since they cannot
be converted to tensors without padding. This ``collate_fn`` must be
used if ``source`` or ``target`` is present in the loaded dataset.
The ``modifier`` argument accepts a function that takes the
``final_dictionary`` and returns a modified ``final_dictionary``. The
``modifier`` function will be called directly before
``final_dictionary`` is returned in :meth:`~data.pad_batch_collate`. This allows
for easy extendability.
""" # noqa: E501
elem = batch[0]
final_dictionary = {}
for key in elem:
# don't process `sent_lengths`
if key == "sent_lengths":
continue
feature_list = [d[key] for d in batch]
if key == "sent_rep_token_ids":
feature_list = pad(feature_list, -1)
sent_rep_token_ids = torch.tensor(feature_list)
sent_rep_mask = ~(sent_rep_token_ids == -1)
sent_rep_token_ids[sent_rep_token_ids == -1] = 0
final_dictionary["sent_rep_token_ids"] = sent_rep_token_ids
final_dictionary["sent_rep_mask"] = sent_rep_mask
continue # go to next key
if key == "input_ids":
input_ids = feature_list
# Attention
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [[1] * len(ids) for ids in input_ids]
input_ids_width = max(len(ids) for ids in input_ids)
input_ids = pad(input_ids, 0, width=input_ids_width)
input_ids = torch.tensor(input_ids)
attention_mask = pad(attention_mask, 0)
attention_mask = torch.tensor(attention_mask)
if "sent_lengths" in elem:
sent_lengths = []
sent_lengths_mask = []
sent_lengths_width = max(len(d["sent_lengths"]) + 1 for d in batch)
for d in batch:
current_sent_lens = d["sent_lengths"]
current_sent_lengths_mask = [True] * len(current_sent_lens)
num_to_add = sent_lengths_width - len(current_sent_lens)
total_value_to_add = input_ids_width - sum(current_sent_lens)
while num_to_add > 1:
num_to_add -= 1
# total_value_to_add -= 1
current_sent_lens.append(0)
current_sent_lengths_mask.append(False)
# if a value needs to be added to make `sum(current_sent_lens)` the total input
# sequence length OR there is one more number to add (this can happen if the
# input sequence exactly ends with a sentence, making the total of the lengths
# the length of the sequence, or if there is one sentence that takes up the
# entire sequence)
if total_value_to_add > 0 or num_to_add == 1:
current_sent_lens.append(total_value_to_add)
current_sent_lengths_mask.append(False)
sent_lengths.append(current_sent_lens)
sent_lengths_mask.append(current_sent_lengths_mask)
final_dictionary["sent_lengths"] = sent_lengths
final_dictionary["sent_lengths_mask"] = torch.tensor(sent_lengths_mask)
final_dictionary["input_ids"] = input_ids
final_dictionary["attention_mask"] = attention_mask
continue
if key in ("source", "target"):
final_dictionary[key] = feature_list
continue
if key in ("labels", "token_type_ids"):
feature_list = pad(feature_list, 0)
feature_list = torch.tensor(feature_list)
final_dictionary[key] = feature_list
if modifier:
final_dictionary = modifier(final_dictionary)
return final_dictionary
[docs]class FSDataset(torch.utils.data.Dataset):
def __init__(self, files_list, shuffle=True, verbose=False):
super(FSDataset).__init__()
if shuffle:
random.shuffle(files_list) # happens in-place
self.files_list = files_list
self.shuffle = shuffle
self.verbose = verbose
self.lengths = self.get_files_lengths(files_list)
[docs] def get_files_lengths(self, files_list):
lengths = []
for data_file in files_list:
data_file_len = sum(1 for i in open(data_file, "rb"))
lengths.append(data_file_len)
lengths = np.cumsum(lengths)
self.lengths = lengths
return lengths
def __getitem__(self, index):
file_index = np.searchsorted(
self.lengths,
[
index,
],
side="right",
)[0]
if file_index > 0:
total_lines_in_other_files = self.lengths[file_index - 1]
else:
total_lines_in_other_files = 0
file_path = self.files_list[file_index]
linecache_index = index + 1 # linecache starts at 1
linecache_index -= (
total_lines_in_other_files # remove lines counted from other files
)
line_str = linecache.getline(file_path, linecache_index).rstrip("\n")
try:
line_json = json.loads(line_str)
except Exception:
print("** JSON Loading Error **")
print(file_path)
print(index)
print(linecache_index)
print(line_str)
return line_json
def __len__(self):
return self.lengths[-1]
[docs]class FSIterableDataset(torch.utils.data.IterableDataset):
"""
A dataset to yield examples from a list of files that are saved python
objects that can be iterated over. These files could be other PyTorch
datasets (tested with ``TensorDataset``) or other python objects such as
lists, for example. Each file will be loaded one at a time until all the
examples have been yielded, at which point the next file will be loaded
and used to yield examples, and so on. This means a large dataset can be
broken into smaller chunks and this class can be used to load samples as
if those files were one dataset while only utilizing the ram required
for one chunk.
Explanation about ``batch_size`` and ``__len__()``: If the ``len()``
function is needed to be accurate then the ``batch_size`` must be
specified when constructing objects of this class. PyTorch
``DataLoader`` objects will report accurate lengths by dividing the
number of examples in the dataset by the batch size only if the dataset
if not an ``IterableDataset``. If the dataset is an ``IterableDataset``
then a ``DataLoader`` will simply ask the dataset for its length,
without diving by the batch size, because in some cases the length of an
``IterableDataset`` might be difficult or impossible to determine.
However, in this case the number of examples (length of dataset) is
known. The division by batch size must happen in the dataset (for
datasets of type ``IterableDataset``) since the ``DataLoader`` will not
calculate this.
"""
# TODO: Add shuffling
def __init__(self, files_list, shuffle=True, verbose=False):
super(FSIterableDataset).__init__()
if shuffle:
random.shuffle(files_list) # happens in-place
self.files_list = files_list
self.shuffle = shuffle
self.verbose = verbose
self.total_length = None
self.file_type = os.path.splitext(files_list[0])[1]
def __iter__(self):
for data_file in self.files_list:
if self.verbose:
logger.info("Loading examples from %s", data_file)
if self.file_type == ".pt":
dataset_section = torch.load(data_file)
elif self.file_type == ".txt":
with open(data_file, "r") as file:
dataset_section = [json.loads(x) for x in file.readlines()]
for example in dataset_section:
yield example
# input(example)
# Clear memory usage before loading next file
dataset_section = None
gc.collect()
del dataset_section
gc.collect()
def __len__(self):
if self.total_length:
return self.total_length
logger.debug(
"Calculating length of `IterableDataset` by loading each file, getting the length, "
+ "and unloading, which is slow."
)
total_length = 0
for data_file in self.files_list:
if self.file_type == ".pt":
dataset_section = torch.load(data_file)
elif self.file_type == ".txt":
with open(data_file, "r") as file:
dataset_section = file.readlines()
total_length += len(dataset_section)
self.total_length = total_length
return total_length
[docs]class SentencesProcessor:
r"""Create a `SentencesProcessor`
Arguments:
name (str, optional): A label for the ``SentencesProcessor`` object, used internally for
saving if a save name is not specified in :meth:`data.SentencesProcessor.get_features`,
Default is None.
labels (list, optional): The label that goes with each sample, can be a list of lists where
the inside lists are the labels for each sentence in the coresponding
example. Default is None.
examples (list, optional): List of ``InputExample``\ s. Default is None.
verbose (bool, optional): Log extra information (such as examples of processed data
points). Default is False.
"""
def __init__(self, name=None, labels=None, examples=None, verbose=False):
self.name = name
self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples
self.verbose = verbose
def __len__(self):
return len(self.examples)
[docs] @classmethod
def create_from_examples(cls, texts, labels=None, **kwargs):
"""
Create a SentencesProcessor with ``**kwargs`` and add ``texts`` and `labels`` through
:meth:`~data.SentencesProcessor.add_examples`.
"""
processor = cls(**kwargs)
processor.add_examples(texts, labels=labels)
return processor
[docs] def add_examples(
self,
texts,
labels=None,
ids=None,
oracle_ids=None,
targets=None,
overwrite_labels=False,
overwrite_examples=False,
):
r"""Primary method of adding example sets of texts, labels, ids, and targets
to the ``SentencesProcessor``
Arguments:
texts (list): A list of documents where each document is a list of sentences where each
sentence is a list of tokens. This is the output of `convert_to_extractive.py`
and is in the 'src' field for each doc. See
:meth:`extractive.ExtractiveSummarizer.prepare_data`.
labels (list, optional): A list of the labels for each document where each label is a
list of labels where the index of the label coresponds with the index of the
sentence in the respective entry in `texts.` Similarly to `texts`, this is handled
automatically by `ExtractiveSummarizer.prepare_data`. Default is None.
ids (list, optional): A list of ids for each document. Not used by
`ExtractiveSummarizer`. Default is None.
oracle_ids (list, optional): Similar to labels but is a list of indexes of the chosen
sentences instead of a one-hot encoded vector. These will be converted to labels.
Default is None.
targets (list, optional): A list of the abstractive target for each document.
Default is None.
overwrite_labels (bool, optional): Replace any labels currently stored by the
``SentencesProcessor``. Default is False.
overwrite_examples (bool, optional): Replace any examples currently stored by the
``SentencesProcessor``. Default is False.
Returns:
list: The examples as ``InputExample``\ s that have been added.
"""
assert texts # not an empty array
assert labels is None or len(texts) == len(labels)
assert ids is None or len(texts) == len(ids)
assert not (labels and oracle_ids)
assert isinstance(texts, list)
if ids is None:
ids = [None] * len(texts)
if labels is None:
if oracle_ids: # convert `oracle_ids` to `labels`
labels = []
for text_set, oracle_id in zip(texts, oracle_ids):
text_label = [0] * len(text_set)
for l_id in oracle_id:
text_label[l_id] = 1
labels.append(text_label)
else:
labels = [None] * len(texts)
examples = []
added_labels = []
for idx, (text_set, label_set, guid) in enumerate(zip(texts, labels, ids)):
if not text_set or not label_set:
continue # input()
added_labels.append(label_set)
if targets:
example = InputExample(
guid=guid, text=text_set, labels=label_set, target=targets[idx]
)
else:
example = InputExample(guid=guid, text=text_set, labels=label_set)
examples.append(example)
# Update examples
if overwrite_examples:
self.examples = examples
else:
self.examples.extend(examples)
# Update labels
if overwrite_labels:
self.labels = added_labels
else:
self.labels += added_labels
return self.examples
[docs] def get_features_process(
self,
input_information,
num_examples=0,
tokenizer=None,
bert_compatible_cls=True,
sep_token=None,
cls_token=None,
create_sent_rep_token_ids=True,
sent_rep_token_id=None,
create_sent_lengths=True,
create_segment_ids="binary",
segment_token_id=None,
create_source=False,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
create_attention_mask=True,
pad_ids_and_attention=True,
):
"""
The process that actually creates the features.
:meth:`~data.SentencesProcessor.get_features` is the driving function, look there for a
description of how this function works. This function only exists so that processing can
easily be done in parallel using ``Pool.map``.
"""
ex_index, example, label = input_information
if ex_index % 1000 == 0:
logger.info("Generating features for example %s/%s", ex_index, num_examples)
if bert_compatible_cls:
# convert `example.text` to array of sentences
src_txt = (" ".join(sent) for sent in example.text)
else:
src_txt = example.text
input_ids = self.get_input_ids(
tokenizer, src_txt, bert_compatible_cls, sep_token, cls_token, max_length
)
# Segment (Token Type) IDs
segment_ids = None
if create_segment_ids == "binary":
current_segment_flag = True
segment_ids = []
for token in input_ids:
segment_ids += [0 if current_segment_flag else 1]
if token == segment_token_id:
current_segment_flag = not current_segment_flag
if create_segment_ids == "sequential":
current_segment = 0
segment_ids = []
for token in input_ids:
segment_ids += [current_segment]
if token == segment_token_id:
current_segment += 1
# Sentence Representation Token IDs and Sentence Lengths
sent_rep_ids = None
sent_lengths = None
if create_sent_rep_token_ids:
# create list of indexes for the `sent_rep` tokens
sent_rep_ids = [
i for i, t in enumerate(input_ids) if t == sent_rep_token_id
]
# truncate `label` to the length of the `sent_rep_ids` aka the number of sentences
label = label[: len(sent_rep_ids)]
if create_sent_lengths:
# if there are 1 or 0 sentences then the length of the entire sequence will be
# the only value in `sent_lengths`
if len(sent_rep_ids) < 2:
sent_lengths = [len(input_ids)]
else:
sent_lengths = [
sent_rep_ids[i] - sent_rep_ids[i - 1]
for i in range(1, len(sent_rep_ids))
]
# Add sentence length for the last sentence, if missing.
# If the last sentence representation token position in `input_ids` is not
# the last token in `input_ids` then add the length of the last sentence
# to `sent_lengths` by subtracting the position of the last `sent_rep_token`
# from the length of `input_ids`
if sent_rep_ids[-1] != len(input_ids) - 1:
sent_lengths.append(len(input_ids) - sent_rep_ids[-1])
# Add sentence length for the first sentence, if missing.
# If the first sentence representation token is not the first token in
# `input_ids` then add the length of the first sentence by inserting
# the first value in `sent_rep_ids` at the front of `sent_lengths`.
if sent_rep_ids[0] != 0:
sent_lengths.insert(0, sent_rep_ids[0] + 1)
# Attention
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
if create_attention_mask:
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Padding
# Zero-pad up to the sequence length.
if pad_ids_and_attention:
padding_length = max_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = (
[0 if mask_padding_with_zero else 1] * padding_length
) + attention_mask
else:
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + (
[0 if mask_padding_with_zero else 1] * padding_length
)
assert (
len(input_ids) == max_length
), "Error with input length {} vs {}".format(len(input_ids), max_length)
assert (
len(attention_mask) == max_length
), "Error with input length {} vs {}".format(
len(attention_mask), max_length
)
if ex_index < 5 and self.verbose:
logger.info("*** Example ***")
logger.info("guid: %s", example.guid)
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
if segment_ids is not None:
logger.info(
"token_type_ids: %s", " ".join([str(x) for x in segment_ids])
)
if sent_rep_ids is not None:
logger.info(
"sent_rep_token_ids: %s", " ".join([str(x) for x in sent_rep_ids])
)
if sent_lengths is not None:
logger.info(
"sent_lengths: %s", " ".join([str(x) for x in sent_lengths])
)
if create_attention_mask:
logger.info(
"attention_mask: %s", " ".join([str(x) for x in attention_mask])
)
logger.info("labels: %s (id = %s)", example.labels, label)
# Return features
# if the attention mask was created then add the mask to the returned features
outputs = {
"input_ids": input_ids,
"labels": label,
"token_type_ids": segment_ids,
"sent_rep_token_ids": sent_rep_ids,
"sent_lengths": sent_lengths,
"target": example.target,
}
if create_attention_mask:
outputs["attention_mask"] = attention_mask
if create_source:
# convert form individual tokens to only individual sentences
source = [" ".join(sentence) for sentence in example.text]
outputs["source"] = source
return InputFeatures(**outputs)
[docs] def get_features(
self,
tokenizer,
bert_compatible_cls=True,
create_sent_rep_token_ids=True,
sent_rep_token_id=None,
create_sent_lengths=True,
create_segment_ids="binary",
segment_token_id=None,
create_source=False,
n_process=2,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
create_attention_mask=True,
pad_ids_and_attention=True,
return_type=None,
save_to_path=None,
save_to_name=None,
save_as_type="txt",
):
r"""Convert the examples stored by the ``SentencesProcessor`` to features that can be used
by a model. The following processes can be performed: tokenization, token type ids (to
separate sentences), sentence representation token ids (the locations of each sentence
representation token), sentence lengths, and the attention mask. Padding can be applied to
the tokenized examples and the attention masks but it is recommended to instead use the
:meth:`data.pad_batch_collate` function so each batch is padded individually for efficiency
(less zeros passed through model).
Arguments:
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to tokenize the
examples.
bert_compatible_cls (bool, optional): Adds '[CLS]' tokens in front of each sentence.
This is useful so that the '[CLS]' token can be used to obtain sentence
embeddings. This only works if the chosen model has the '[CLS]'
token in its vocabulary. Default is True.
create_sent_rep_token_ids (bool, optional): Option to create sentence representation
token ids. This will store a list of the indexes of all the
``sent_rep_token_id``\ s in the tokenized example. Default is True.
sent_rep_token_id ([type], optional): The token id that should be captured for each
sentence (should have one per sentence and each should represent that sentence).
Default is ``'[CLS]' token if bert_compatible_cls else '[SEP]' token``.
create_sent_lengths (bool, optional): Option to create a list of sentence lengths where
each index in the list coresponds to the respective sentence in the example.
Default is True.
create_segment_ids (str, optional): Option to create segment ids (aka token type ids).
See https://huggingface.co/transformers/glossary.html#token-type-ids for more info.
Set to either "binary", "sequential", or False.
* ``binary`` alternates between 0 and 1 for each sentence.
* ``sequential`` starts at 0 and increments by 1 for each sentence.
* ``False`` does not create any segment ids.
Note: Many pretrained models that accept token type ids use them
for question answering ans related tasks where the model receives
two inputs. Therefore, most models have a token type id vocabulary
size of 2, which means they only have learned 2 token type ids. The
"binary" mode exists so that these pretrained models can easily
be used.
Default is "binary".
segment_token_id (str, optional): The token id to be used when creating segment ids.
Can be set to 'period' to treat periods as sentence separation tokens, but this
is a terrible idea for obvious reasons. Default is '[SEP]' token id.
create_source (bool, optional): Option to save the source text (non-tokenized) as a
string. Default is False.
n_process (int, optional): How many processes to use for multithreading for running
get_features_process(). Set higher to run faster and set lower is you experience
OOM issues. Default is 2.
max_length (int, optional): If ``pad_ids_and_attention`` is True then pad to this
amount. Default is ``tokenizer.max_len``.
pad_on_left (bool, optional): Optionally, pad on the left instead of right. Default
is False.
pad_token (int, optional): Which token to use for padding the ``input_ids``. Default
is 0.
mask_padding_with_zero (bool, optional): Use zeros to pad the attention. Uses ones
otherwise. Default is True.
create_attention_mask (bool, optional): Option to create the attention mask. It is
recommended to use the :meth:`data.pad_batch_collate` function, which will
automatically create attention masks and pad them on a per batch level. Default
is ``False if return_type == "lists" else True``.
pad_ids_and_attention (bool, optional): Pad the ``input_ids`` with ``pad_token`` and
attention masks with 0s or 1s deneding on ``mask_padding_with_zero``. Pad both to
``max_length``. Default is ``False if return_type == "lists" else True``
return_type (str, optional): Either "tensors", "lists", or None. See "Returns" section
below. Default is None.
save_to_path (str, optional): The folder/directory to save the data to OR None to not
save. Will save the data specified by ``return_type`` to disk. Default is None.
save_to_name (str, optional): The name of the file to save. The extension '.pt' is
automatically appended. Default is ``'dataset_' + self.name + '.pt'``.
save_as_type (str, optional): The file extension of saved file if `save_to_path` is
set. Supports "pt" (PyTorch) and "txt" (Text). Saving as "txt" requires the
``return_type`` to be ``lists``. If ``return_type`` is ``tensors`` the only
``save_as_type`` available is "pt". Defaults to "txt".
Returns:
list or torch.TensorDataset: If ``return_type is None`` return the list of calculated
features. If ``return_type == "tensors"`` return the features converted to tensors
and stacked such that features are grouped together into individual tensors. If
``return_type == "lists"``, which is the recommended option then exports each
``InputFeatures`` object in the exported ``features`` list as a dictionary and appends
each dictionary to a list. Returns that list.
"""
assert return_type in ["tensors", "lists"] or return_type is None
assert save_as_type in ["txt", "pt"] or save_to_path is None
if save_as_type == "txt":
assert return_type == "lists"
if return_type == "tensors":
assert save_as_type == "pt" or save_to_path is None
if return_type == "lists":
create_attention_mask = False
pad_ids_and_attention = False
else: # if `return_type` is None or "tensors"
create_attention_mask = True
pad_ids_and_attention = True
if max_length is None:
max_length = tokenizer.max_len
# batch_length = max(len(input_ids) for input_ids in all_input_ids)
if create_sent_rep_token_ids:
if sent_rep_token_id == "sep": # get the sep token id
sent_rep_token_id = tokenizer.sep_token_id
elif sent_rep_token_id == "cls": # get the cls token id
sent_rep_token_id = tokenizer.cls_token_id
elif not sent_rep_token_id: # if the `sent_rep_token_id` is not set
# if using `bert_compatible_cls` then default to the `cls_token_id`
if bert_compatible_cls:
sent_rep_token_id = tokenizer.cls_token_id
else: # otherwise, get the `sep_token_id`
sent_rep_token_id = tokenizer.sep_token_id
if create_segment_ids:
if segment_token_id == "period": # get the token id for a "."
segment_token_id = tokenizer.convert_tokens_to_ids(["."])[0]
elif (
not segment_token_id
): # default to trying to get the `sep_token_id` if the `segment_token_id` is not set
segment_token_id = tokenizer.sep_token_id
features = []
pool = Pool(n_process)
_get_features_process = partial(
self.get_features_process,
num_examples=len(self.labels),
tokenizer=tokenizer,
bert_compatible_cls=bert_compatible_cls,
sep_token=tokenizer.sep_token,
cls_token=tokenizer.cls_token,
create_sent_rep_token_ids=create_sent_rep_token_ids,
sent_rep_token_id=sent_rep_token_id,
create_sent_lengths=create_sent_lengths,
create_segment_ids=create_segment_ids,
segment_token_id=segment_token_id,
create_source=create_source,
max_length=max_length,
pad_on_left=pad_on_left,
pad_token=pad_token,
mask_padding_with_zero=mask_padding_with_zero,
create_attention_mask=create_attention_mask,
pad_ids_and_attention=pad_ids_and_attention,
)
for rtn_features in pool.map(
_get_features_process,
zip(range(len(self.labels)), self.examples, self.labels),
):
features.append(rtn_features)
pool.close()
pool.join()
if not return_type:
return features
elif return_type == "tensors":
final_tensors = []
all_input_ids = torch.tensor(
[f.input_ids for f in features], dtype=torch.long
)
final_tensors.append(all_input_ids)
all_attention_masks = torch.tensor(
[f.attention_mask for f in features], dtype=torch.long
)
final_tensors.append(all_attention_masks)
all_labels = torch.tensor(
pad([f.labels for f in features], 0), dtype=torch.long
)
final_tensors.append(all_labels)
if create_segment_ids:
all_token_type_ids = torch.tensor(
pad([f.token_type_ids for f in features], 0), dtype=torch.long
)
final_tensors.append(all_token_type_ids)
# Pad sentence representation token ids (`sent_rep_token_ids`)
if create_sent_rep_token_ids:
all_sent_rep_token_ids = torch.tensor(
pad([f.sent_rep_token_ids for f in features], -1), dtype=torch.long
)
all_sent_rep_token_ids_masks = ~(all_sent_rep_token_ids == -1)
all_sent_rep_token_ids[all_sent_rep_token_ids == -1] = 0
final_tensors.append(all_sent_rep_token_ids)
final_tensors.append(all_sent_rep_token_ids_masks)
if create_sent_lengths:
all_sent_lengths = torch.tensor(
pad([f.sent_lengths for f in features], 0), dtype=torch.long
)
final_tensors.append(all_sent_lengths)
dataset = torch.utils.data.TensorDataset(*final_tensors)
elif return_type == "lists":
dataset = [example.to_dict() for example in features]
if save_to_path:
final_save_name = save_to_name if save_to_name else ("dataset_" + self.name)
dataset_path = os.path.join(
save_to_path,
(final_save_name + "." + save_as_type),
)
logger.info("Saving dataset into cached file %s", dataset_path)
if save_as_type == "txt":
with open(dataset_path, "w+") as file:
# Need to replace single with double quotes so it can be loaded as JSON
file.write("\n".join([json.dumps(x) for x in dataset]) + "\n")
elif save_as_type == "pt":
torch.save(dataset, dataset_path)
else:
logger.error("'%s' is an invalid save type.", save_as_type)
return dataset
[docs] def load(self, load_from_path, dataset_name=None):
"""Attempts to load the dataset from storage. If that fails, will return None."""
final_load_name = dataset_name if dataset_name else ("dataset_" + self.name)
dataset_path = os.path.join(
load_from_path,
(final_load_name + ".pt"),
)
if os.path.exists(dataset_path):
logger.info("Loading data from file %s", dataset_path)
dataset = torch.load(dataset_path)
return dataset
return None