"""Factory functions for constructing training datasets.
The two public functions — ``reload_dataset`` and ``construct_dataset`` —
select the appropriate ``torch.utils.data.Dataset`` sub-class based on the
requested scoring technique and labelling strategy.
"""
from typing import Union
import numpy as np
import torch
from ..static_funcs import timeit, load_term_mapping
from ._bpe import (
BPE_NegativeSamplingDataset,
MultiClassClassificationDataset,
MultiLabelDataset,
)
from ._label_based import AllvsAll, KvsAll, KvsSampleDataset, OnevsAllDataset
from ._negative_sampling import FixedNegSampleDataset, OnevsSample, TriplePredictionDataset
[docs]
@timeit
def reload_dataset(
path: str,
form_of_labelling,
scoring_technique,
neg_ratio,
label_smoothing_rate,
):
"""Reload training data from disk and construct a Pytorch dataset."""
return construct_dataset(
train_set=np.load(path + "/train_set.npy"),
valid_set=None,
test_set=None,
entity_to_idx=load_term_mapping(file_path=path + "/entity_to_idx"),
relation_to_idx=load_term_mapping(file_path=path + "/relation_to_idx"),
form_of_labelling=form_of_labelling,
scoring_technique=scoring_technique,
neg_ratio=neg_ratio,
label_smoothing_rate=label_smoothing_rate,
)
[docs]
@timeit
def construct_dataset(
*,
train_set: Union[np.ndarray, list],
valid_set=None,
test_set=None,
ordered_bpe_entities=None,
train_target_indices=None,
target_dim: int = None,
entity_to_idx: dict,
relation_to_idx: dict,
form_of_labelling: str,
scoring_technique: str,
neg_ratio: int,
label_smoothing_rate: float,
byte_pair_encoding=None,
block_size: int = None,
seed: int = None,
) -> torch.utils.data.Dataset:
"""Build the appropriate dataset for the given training configuration.
Parameters
----------
train_set : numpy.ndarray or list
Raw integer-indexed triples.
entity_to_idx, relation_to_idx : dict
Name → index mappings.
form_of_labelling : str
``'EntityPrediction'`` or ``'RelationPrediction'``.
scoring_technique : str
One of ``'NegSample'``, ``'FixedNegSample'``, ``'1vsAll'``, ``'1vsSample'``, ``'KvsAll'``,
``'AllvsAll'``, ``'KvsSample'``.
neg_ratio : int
Negative sample ratio.
label_smoothing_rate : float
Label smoothing coefficient.
Returns
-------
torch.utils.data.Dataset
"""
if (
ordered_bpe_entities
and byte_pair_encoding
and scoring_technique == "NegSample"
):
train_set = BPE_NegativeSamplingDataset(
train_set=torch.tensor(train_set, dtype=torch.long),
ordered_shaped_bpe_entities=torch.tensor(
[
shaped_bpe_ent
for (str_ent, bpe_ent, shaped_bpe_ent) in ordered_bpe_entities
]
),
neg_ratio=neg_ratio,
)
elif (
ordered_bpe_entities
and byte_pair_encoding
and scoring_technique in ["KvsAll", "AllvsAll"]
):
train_set = MultiLabelDataset(
train_set=torch.tensor(train_set, dtype=torch.long),
train_indices_target=train_target_indices,
target_dim=target_dim,
torch_ordered_shaped_bpe_entities=torch.tensor(
[
shaped_bpe_ent
for (str_ent, bpe_ent, shaped_bpe_ent) in ordered_bpe_entities
]
),
)
elif byte_pair_encoding:
train_set = MultiClassClassificationDataset(
train_set, block_size=block_size
)
elif scoring_technique == "NegSample":
train_set = TriplePredictionDataset(
train_set=train_set,
num_entities=len(entity_to_idx),
num_relations=len(relation_to_idx),
neg_sample_ratio=neg_ratio,
label_smoothing_rate=label_smoothing_rate,
seed=seed,
)
elif scoring_technique == "FixedNegSample":
train_set = FixedNegSampleDataset(
train_set=train_set,
num_entities=len(entity_to_idx),
num_relations=len(relation_to_idx),
neg_sample_ratio=neg_ratio,
label_smoothing_rate=label_smoothing_rate,
seed=seed,
)
elif form_of_labelling == "EntityPrediction":
if scoring_technique == "1vsAll":
train_set = OnevsAllDataset(train_set, entity_idxs=entity_to_idx)
elif scoring_technique == "1vsSample":
train_set = OnevsSample(
train_set=train_set,
num_entities=len(entity_to_idx),
num_relations=len(relation_to_idx),
neg_sample_ratio=neg_ratio,
label_smoothing_rate=label_smoothing_rate,
)
elif scoring_technique == "KvsAll":
train_set = KvsAll(
train_set,
entity_idxs=entity_to_idx,
relation_idxs=relation_to_idx,
form=form_of_labelling,
label_smoothing_rate=label_smoothing_rate,
)
elif scoring_technique == "AllvsAll":
train_set = AllvsAll(
train_set,
entity_idxs=entity_to_idx,
relation_idxs=relation_to_idx,
label_smoothing_rate=label_smoothing_rate,
)
elif scoring_technique == "KvsSample":
train_set = KvsSampleDataset(
train_set,
entity_idxs=entity_to_idx,
relation_idxs=relation_to_idx,
form=form_of_labelling,
neg_ratio=neg_ratio,
label_smoothing_rate=label_smoothing_rate,
)
else:
raise ValueError(f"Invalid scoring technique : {scoring_technique}")
elif form_of_labelling == "RelationPrediction":
train_set = KvsAll(
train_set,
entity_idxs=entity_to_idx,
relation_idxs=relation_to_idx,
form=form_of_labelling,
label_smoothing_rate=label_smoothing_rate,
)
else:
raise KeyError("Illegal input.")
print(f"Number of datapoints: {len(train_set)}")
return train_set