Source code for dicee.dataset_classes._bpe

"""Byte-pair encoding (BPE) related dataset classes.

Provides datasets for training knowledge graph embedding models that operate
on byte-pair encoded entity and relation representations.
"""

from typing import List, Tuple

import numpy as np
import torch


[docs] class BPE_NegativeSamplingDataset(torch.utils.data.Dataset): """Dataset for negative sampling with byte-pair encoded triples. Each sample is a BPE-encoded triple. The custom ``collate_fn`` constructs negatives by corrupting head or tail entities with random BPE entities. Parameters ---------- train_set : torch.LongTensor Integer-encoded triples of shape ``(N, 3)``. ordered_shaped_bpe_entities : torch.LongTensor All BPE entity representations, ordered by entity index. neg_ratio : int Number of negative samples per positive triple. """ def __init__( self, train_set: torch.LongTensor, ordered_shaped_bpe_entities: torch.LongTensor, neg_ratio: int, ): super().__init__() assert isinstance(train_set, torch.LongTensor) assert train_set.shape[1] == 3 self.train_set = train_set self.ordered_bpe_entities = ordered_shaped_bpe_entities self.num_bpe_entities = len(self.ordered_bpe_entities) self.neg_ratio = neg_ratio self.num_datapoints = len(self.train_set)
[docs] def __len__(self): return self.num_datapoints
[docs] def __getitem__(self, idx): return self.train_set[idx]
[docs] def collate_fn( self, batch_shaped_bpe_triples: List[Tuple[torch.Tensor, torch.Tensor]] ): batch_of_bpe_triples = torch.stack(batch_shaped_bpe_triples, dim=0) size_of_batch, _, token_length = batch_of_bpe_triples.shape bpe_h = batch_of_bpe_triples[:, 0, :] bpe_r = batch_of_bpe_triples[:, 1, :] bpe_t = batch_of_bpe_triples[:, 2, :] label = torch.ones((size_of_batch,)) num_of_corruption = size_of_batch * self.neg_ratio # Select bpe entities corr_bpe_entities = self.ordered_bpe_entities[ torch.randint(0, high=self.num_bpe_entities, size=(num_of_corruption,)) ] if torch.rand(1) >= 0.5: bpe_h = torch.cat((bpe_h, corr_bpe_entities), 0) bpe_r = torch.cat( ( bpe_r, torch.repeat_interleave( input=bpe_r, repeats=self.neg_ratio, dim=0 ), ), 0, ) bpe_t = torch.cat( ( bpe_t, torch.repeat_interleave( input=bpe_t, repeats=self.neg_ratio, dim=0 ), ), 0, ) else: bpe_h = torch.cat( ( bpe_h, torch.repeat_interleave( input=bpe_h, repeats=self.neg_ratio, dim=0 ), ), 0, ) bpe_r = torch.cat( ( bpe_r, torch.repeat_interleave( input=bpe_r, repeats=self.neg_ratio, dim=0 ), ), 0, ) bpe_t = torch.cat((bpe_t, corr_bpe_entities), 0) bpe_triple = torch.stack((bpe_h, bpe_r, bpe_t), dim=1) label = torch.cat((label, torch.zeros(num_of_corruption)), 0) return bpe_triple, label
[docs] class MultiLabelDataset(torch.utils.data.Dataset): """Multi-label dataset for BPE-based KvsAll / AllvsAll training. Each sample is a BPE-encoded ``(head, relation)`` pair together with a binary multi-label target vector over all entities. Parameters ---------- train_set : torch.LongTensor BPE-encoded input pairs of shape ``(N, 2, token_length)``. train_indices_target : torch.LongTensor Per-sample lists of positive target entity indices. target_dim : int Dimensionality of the target vector (number of entities). torch_ordered_shaped_bpe_entities : torch.LongTensor Ordered BPE entity representations. """ def __init__( self, train_set: torch.LongTensor, train_indices_target: torch.LongTensor, target_dim: int, torch_ordered_shaped_bpe_entities: torch.LongTensor, ): super().__init__() assert len(train_set) == len(train_indices_target) assert target_dim > 0 self.train_set = train_set self.train_indices_target = train_indices_target self.target_dim = target_dim self.num_datapoints = len(self.train_set) self.torch_ordered_shaped_bpe_entities = torch_ordered_shaped_bpe_entities self.collate_fn = None
[docs] def __len__(self): return self.num_datapoints
[docs] def __getitem__(self, idx): # (1) Initialize as all zeros. y_vec = torch.zeros(self.target_dim) # (2) Indices of labels. indices = self.train_indices_target[idx] # (3) Add 1s if holds. if len(indices) > 0: y_vec[indices] = 1.0 return self.train_set[idx], y_vec
[docs] class MultiClassClassificationDataset(torch.utils.data.Dataset): """Dataset for autoregressive multi-class classification on sub-word units. Splits a flat sequence of sub-word token ids into overlapping windows of size ``block_size`` for next-token prediction. Parameters ---------- subword_units : numpy.ndarray 1-D array of sub-word token ids. block_size : int, optional Context window length (default ``8``). """ def __init__(self, subword_units: np.ndarray, block_size: int = 8): super().__init__() assert isinstance(subword_units, np.ndarray) assert len(subword_units) > 0 self.train_data = torch.LongTensor(subword_units) self.block_size = block_size self.num_of_data_points = len(self.train_data) - block_size self.collate_fn = None
[docs] def __len__(self): return self.num_of_data_points
[docs] def __getitem__(self, idx): x = self.train_data[idx : idx + self.block_size] y = self.train_data[idx + 1 : idx + 1 + self.block_size] return x, y