Source code for dicee.dataset_classes._label_based

"""Label-based (multi-label / multi-class) dataset classes.

Provides ``KvsAll``, ``AllvsAll``, ``KvsSampleDataset``, and
``OnevsAllDataset`` — datasets where each sample is a ``(head, relation)``
pair and the target is a label vector over all entities (or relations).
"""

import numpy as np
import torch

from ..static_preprocess_funcs import mapping_from_first_two_cols_to_third


[docs] class OnevsAllDataset(torch.utils.data.Dataset): """Dataset for the 1-vs-All training strategy (multi-class). Each sample is a ``(head, relation)`` pair with a one-hot target vector whose single active position corresponds to the true tail entity. Parameters ---------- train_set_idx : numpy.ndarray ``(N, 3)`` integer-indexed triples. entity_idxs : dict Entity-name → index mapping (used to determine the target dimension). """ def __init__(self, train_set_idx: np.ndarray, entity_idxs): super().__init__() assert isinstance(train_set_idx, (np.memmap, np.ndarray)) assert len(train_set_idx) > 0 # Sort by (head, relation, tail) to ensure order-independent training # This prevents different input orderings from affecting optimization sorted_indices = np.lexsort( (train_set_idx[:, 2], train_set_idx[:, 1], train_set_idx[:, 0]) ) self.train_data = train_set_idx[sorted_indices] self.target_dim = len(entity_idxs) self.collate_fn = None
[docs] def __len__(self): return len(self.train_data)
[docs] def __getitem__(self, idx): y_vec = torch.zeros(self.target_dim) triple = torch.from_numpy(self.train_data[idx].copy()).long() y_vec[triple[2]] = 1 return triple[:2], y_vec
[docs] class KvsAll(torch.utils.data.Dataset): """Dataset for KvsAll training (multi-label). D := {(x, y)_i}_{i=1}^{N} where * x = (h, r) is a unique (entity, relation) pair observed in the KG, * y ∈ [0, 1]^{|E|} is a multi-label vector with y_j = 1 iff (h, r, e_j) ∈ KG. Parameters ---------- train_set_idx : numpy.ndarray ``(N, 3)`` integer-indexed triples. entity_idxs : dict Entity-name → index mapping. relation_idxs : dict Relation-name → index mapping. form : str ``'EntityPrediction'`` or ``'RelationPrediction'``. label_smoothing_rate : float, optional Label smoothing coefficient (default ``0.0``). """ def __init__( self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, form, store=None, label_smoothing_rate: float = 0.0, ): super().__init__() assert len(train_set_idx) > 0 assert isinstance(train_set_idx, (np.memmap, np.ndarray)) self.train_data = None self.train_target = None self.label_smoothing_rate = torch.tensor(label_smoothing_rate) self.collate_fn = None if store is None: store = dict() if form == "RelationPrediction": self.target_dim = len(relation_idxs) for s_idx, p_idx, o_idx in train_set_idx: store.setdefault((s_idx, o_idx), list()).append(p_idx) # Sort keys to ensure order-independent training store = dict(sorted(store.items())) elif form == "EntityPrediction": self.target_dim = len(entity_idxs) # mapping_from_first_two_cols_to_third already returns sorted dict store = mapping_from_first_two_cols_to_third(train_set_idx) else: raise NotImplementedError else: raise ValueError() assert len(store) > 0 self.train_data = torch.LongTensor(list(store.keys())) if sum(len(i) for i in store.values()) == len(store): self.train_target = np.array(list(store.values())) try: assert isinstance(self.train_target[0], np.ndarray) except (IndexError, AssertionError): print(self.train_target) exit(1) else: self.train_target = list(store.values()) assert isinstance(self.train_target[0], list) del store
[docs] def __len__(self): assert len(self.train_data) == len(self.train_target) return len(self.train_data)
[docs] def __getitem__(self, idx): y_vec = torch.zeros(self.target_dim) y_vec[self.train_target[idx]] = 1.0 if self.label_smoothing_rate: y_vec = y_vec * (1 - self.label_smoothing_rate) + (1 / y_vec.size(0)) return self.train_data[idx], y_vec
[docs] class AllvsAll(torch.utils.data.Dataset): """Dataset for AllvsAll training (multi-label, exhaustive). Extends the ``KvsAll`` idea: every *possible* ``(entity, relation)`` combination is included — not just those observed in the KG. Pairs without any known tail entities receive an all-zeros label vector. Parameters ---------- train_set_idx : numpy.ndarray ``(N, 3)`` integer-indexed triples. entity_idxs : dict Entity-name → index mapping. relation_idxs : dict Relation-name → index mapping. label_smoothing_rate : float, optional Label smoothing coefficient (default ``0.0``). """ def __init__( self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, label_smoothing_rate=0.0, ): super().__init__() assert len(train_set_idx) > 0 assert isinstance(train_set_idx, (np.memmap, np.ndarray)) self.train_data = None self.train_target = None self.label_smoothing_rate = torch.tensor(label_smoothing_rate) self.collate_fn = None self.target_dim = len(entity_idxs) # mapping_from_first_two_cols_to_third already returns sorted dict store = mapping_from_first_two_cols_to_third(train_set_idx) print("Number of unique pairs:", len(store)) for i in range(len(entity_idxs)): for j in range(len(relation_idxs)): if store.get((i, j), None) is None: store[(i, j)] = list() print("Number of unique augmented pairs:", len(store)) # Re-sort after adding new keys to maintain consistent ordering store = dict(sorted(store.items())) assert len(store) > 0 self.train_data = torch.LongTensor(list(store.keys())) if sum(len(i) for i in store.values()) == len(store): self.train_target = np.array(list(store.values())) assert isinstance(self.train_target[0], np.ndarray) else: self.train_target = list(store.values()) assert isinstance(self.train_target[0], list) del store
[docs] def __len__(self): assert len(self.train_data) == len(self.train_target) return len(self.train_data)
[docs] def __getitem__(self, idx): y_vec = torch.zeros(self.target_dim) existing_indices = self.train_target[idx] if len(existing_indices) > 0: y_vec[self.train_target[idx]] = 1.0 if self.label_smoothing_rate: y_vec = y_vec * (1 - self.label_smoothing_rate) + (1 / y_vec.size(0)) return self.train_data[idx], y_vec
[docs] class KvsSampleDataset(torch.utils.data.Dataset): """Dataset for KvsSample training (dynamic multi-label). Like ``KvsAll`` but sub-samples the target vector at each access to keep mini-batch sizes manageable when the entity set is large. Parameters ---------- train_set_idx : numpy.ndarray ``(N, 3)`` integer-indexed triples. entity_idxs : dict Entity-name → index mapping. relation_idxs : dict Relation-name → index mapping. form : str ``'EntityPrediction'``. neg_ratio : int Number of negative samples per positive target. label_smoothing_rate : float, optional Label smoothing coefficient (default ``0.0``). """ def __init__( self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, form, store=None, neg_ratio=None, label_smoothing_rate: float = 0.0, ): super().__init__() assert len(train_set_idx) > 0 assert isinstance(train_set_idx, np.ndarray) assert neg_ratio is not None self.train_data = None self.train_target = None self.neg_ratio = neg_ratio self.num_entities = len(entity_idxs) self.label_smoothing_rate = torch.tensor(label_smoothing_rate) self.collate_fn = None store = mapping_from_first_two_cols_to_third(train_set_idx) assert len(store) > 0 self.train_data = torch.LongTensor(list(store.keys())) self.train_target = list(store.values()) self.max_num_of_classes = ( max(len(i) for i in self.train_target) + self.neg_ratio ) del store
[docs] def __len__(self): return len(self.train_data)
[docs] def __getitem__(self, idx): # (1) Get i-th unique (head, relation) pair. x = self.train_data[idx] # (2) Get tail entities given (1). y = self.train_target[idx] num_positive_class = len(y) num_negative_class = self.max_num_of_classes - num_positive_class # Sample negatives weights = torch.ones(self.num_entities) weights[y] = 0.0 negative_idx = torch.multinomial( weights, num_samples=num_negative_class, replacement=True ) y_idx = torch.cat((torch.LongTensor(y), negative_idx), 0) y_vec = torch.cat( (torch.ones(num_positive_class), torch.zeros(num_negative_class)), 0 ) return x, y_idx, y_vec