dicee.abstracts
Classes
Abstract base class for KGE model trainers. |
|
Base class for interactive, post-training use of KGE models. |
|
Mixin that provides fuzzy-logic operators for multi-hop EPFO query answering. |
|
Abstract base class for KGE training lifecycle callbacks. |
|
Abstract base class for Post-training Parameter Ensembling (PPE) callbacks. |
|
Abstract/base class for training knowledge graph embedding models interactively. |
Module Contents
- class dicee.abstracts.AbstractTrainer(args, callbacks)[source]
Abstract base class for KGE model trainers.
Provides the callback dispatch mechanism shared by all concrete trainer implementations (TorchTrainer, TorchDDPTrainer, etc.). Sub-classes call the
on_*hooks at the appropriate points in the training loop so that any registeredAbstractCallbackcan react.- Parameters:
args (argparse.Namespace or similar) – Processed configuration object. Must expose at least
random_seed(int).callbacks (list of AbstractCallback) – Ordered list of callback instances to invoke at each lifecycle hook.
- attributes
- callbacks
- is_global_zero = True
- global_rank = 0
- local_rank = 0
- strategy = None
- on_fit_start(*args, **kwargs)[source]
Dispatch
on_fit_startto all registered callbacks.Called once before the first training epoch begins.
- on_fit_end(*args, **kwargs)[source]
Dispatch
on_fit_endto all registered callbacks.Called once after the last training epoch completes.
- on_train_epoch_start(*args, **kwargs)[source]
Dispatch
on_train_epoch_startto all registered callbacks.Called at the beginning of every epoch.
- on_train_epoch_end(*args, **kwargs)[source]
Dispatch
on_train_epoch_endto all registered callbacks.Called at the end of every epoch after the loss has been accumulated.
- on_train_batch_end(*args, **kwargs)[source]
Dispatch
on_train_batch_endto all registered callbacks.Called after each mini-batch gradient update.
- static save_checkpoint(full_path: str, model) None[source]
Persist model weights to disk.
- Parameters:
full_path (str) – Absolute or relative file path (including filename) where the
state_dictwill be written, e.g.'Experiments/run1/model.pt'.model (torch.nn.Module) – The model whose
state_dictis to be saved.
- class dicee.abstracts.BaseInteractiveKGE(path: str = None, url: str = None, construct_ensemble: bool = False, model_name: str = None, apply_semantic_constraint: bool = False)[source]
Base class for interactive, post-training use of KGE models.
Loads a pre-trained model from disk (or a remote URL) together with its entity/relation index mappings and exposes the prediction API used by
KGE.- Parameters:
path (str, optional) – Path to the experiment directory produced by
Execute. Must containmodel.pt,configuration.json,entity_to_idx.csvandrelation_to_idx.csv.url (str, optional) – Remote URL of a pre-trained model to download. Mutually exclusive with path.
construct_ensemble (bool, optional) – When
True, load all checkpoint files in path and average their weights to form an ensemble model. Defaults toFalse.model_name (str, optional) – Filename (without extension) of the checkpoint to load when multiple
.ptfiles exist in path.apply_semantic_constraint (bool, optional) – Reserved for future use. Defaults to
False.
- construct_ensemble = False
- apply_semantic_constraint = False
- configs
- get_bpe_token_representation(str_entity_or_relation: List[str] | str) List[List[int]] | List[int][source]
- Parameters:
str_entity_or_relation (corresponds to a str or a list of strings to be tokenized via BPE and shaped.)
- Return type:
A list integer(s) or a list of lists containing integer(s)
- get_padded_bpe_triple_representation(triples: List[List[str]]) Tuple[List, List, List][source]
- Parameters:
triples
- set_model_train_mode() None[source]
Switch the underlying model to training mode.
Calls
model.train()and re-enables gradient computation for all parameters so that subsequent calls to optimisation steps work correctly after a period of inference.
- set_model_eval_mode() None[source]
Switch the underlying model to evaluation mode.
Calls
model.eval()and freezes all parameters (requires_grad = False) so that dropout and batch-norm layers behave deterministically during inference.
- property name
- sample_entity(n: int) List[str][source]
Return n random entity strings without replacement.
- Parameters:
n (int) – Number of entities to sample. Must be non-negative and at most
num_entities.- Returns:
Randomly selected entity string labels.
- Return type:
List[str]
- sample_relation(n: int) List[str][source]
Return n random relation strings without replacement.
- Parameters:
n (int) – Number of relations to sample. Must be non-negative and at most
num_relations.- Returns:
Randomly selected relation string labels.
- Return type:
List[str]
- is_seen(entity: str = None, relation: str = None) bool[source]
Check whether an entity or relation was present in the training set.
Exactly one of entity or relation should be provided.
- Parameters:
entity (str, optional) – Entity string label to look up.
relation (str, optional) – Relation string label to look up.
- Returns:
Trueif the given string is in the respective index mapping,Falseotherwise.- Return type:
bool
- save() None[source]
Persist the current model weights to the experiment directory.
The checkpoint filename encodes the current timestamp so successive calls do not overwrite each other. Ensemble models are saved with an
_ensemble_infix in the filename.
- get_entity_index(x: str) int[source]
Return the integer index for a given entity string.
- Parameters:
x (str) – Entity string label (must have been seen during training).
- Returns:
Corresponding row index in the entity embedding matrix.
- Return type:
int
- get_relation_index(x: str) int[source]
Return the integer index for a given relation string.
- Parameters:
x (str) – Relation string label (must have been seen during training).
- Returns:
Corresponding row index in the relation embedding matrix.
- Return type:
int
- index_triple(head_entity: List[str], relation: List[str], tail_entity: List[str]) Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor][source]
Convert string triple lists to integer index tensors.
- Parameters:
head_entity (List[str]) – Head entity string labels.
relation (List[str]) – Relation string labels.
tail_entity (List[str]) – Tail entity string labels.
- Returns:
idx_head_entity, idx_relation, idx_tail_entity – Each has shape
(n, 1)containing the integer indices for the corresponding strings.- Return type:
torch.LongTensor
- add_new_entity_embeddings(entity_name: str = None, embeddings: torch.FloatTensor = None) None[source]
Extend the entity embedding table with a new entity at inference time.
The new entity is appended to both
entity_to_idx/idx_to_entitymappings and theentity_embeddingsweight tensor so that subsequent calls to prediction methods can reference it by name.- Parameters:
entity_name (str) – String label for the new entity. If the entity already exists in the index no modification is made.
embeddings (torch.FloatTensor) – 1-D float tensor of length
embedding_dimcontaining the pre-computed embedding for the new entity.
- get_entity_embeddings(items: List[str]) torch.FloatTensor[source]
Return the embedding vectors for the given entity strings.
For standard (non-BPE) models the method looks up each string in
entity_to_idxand returns the corresponding rows of the entity embedding matrix. For BPE models subword token embeddings are fetched and flattened into a single vector per entity.- Parameters:
items (List[str]) – Entity string labels to retrieve.
- Returns:
Shape
(len(items), embedding_dim).- Return type:
torch.FloatTensor
- get_relation_embeddings(items: List[str]) torch.FloatTensor[source]
Return the embedding vectors for the given relation strings.
- Parameters:
items (List[str]) – Relation string labels to retrieve.
- Returns:
Shape
(len(items), embedding_dim).- Return type:
torch.FloatTensor
- construct_input_and_output(head_entity: List[str], relation: List[str], tail_entity: List[str], labels) Tuple[torch.LongTensor, torch.FloatTensor][source]
Build an indexed triple tensor and a label tensor from string inputs.
- Parameters:
head_entity (List[str]) – Head entity string labels.
relation (List[str]) – Relation string labels.
tail_entity (List[str]) – Tail entity string labels.
labels (list or array-like) – Binary or soft labels (one per triple) used as training targets.
- Returns:
x (torch.LongTensor) – Shape
(n, 3)integer-indexed triples.labels (torch.FloatTensor) – Shape
(n,)float label tensor.
- class dicee.abstracts.InteractiveQueryDecomposition[source]
Mixin that provides fuzzy-logic operators for multi-hop EPFO query answering.
The three families of operators — T-norm, T-conorm, and negation norm — are applied element-wise over entity score tensors to compose complex queries from atomic link-prediction results (e.g. 2p, 3p, 2i, ip, up).
- t_norm(tens_1: torch.Tensor, tens_2: torch.Tensor, tnorm: str = 'min') torch.Tensor[source]
Apply a T-norm to combine two entity score distributions.
- Parameters:
tens_1 (torch.Tensor) – Score tensors of identical shape, values in
[0, 1].tens_2 (torch.Tensor) – Score tensors of identical shape, values in
[0, 1].tnorm (str) – Operator to use.
'min'applies the Gödel (min) T-norm;'prod'applies the product T-norm.
- Returns:
Element-wise combined scores of the same shape as the inputs.
- Return type:
torch.Tensor
- tensor_t_norm(subquery_scores: torch.FloatTensor, tnorm: str = 'min') torch.FloatTensor[source]
Compute T-norm over [0,1] ^{n imes d} where n denotes the number of hops and d denotes number of entities
- t_conorm(tens_1: torch.Tensor, tens_2: torch.Tensor, tconorm: str = 'min') torch.Tensor[source]
Apply a T-conorm (S-norm) to combine two score distributions (union).
- Parameters:
tens_1 (torch.Tensor) – Score tensors of identical shape, values in
[0, 1].tens_2 (torch.Tensor) – Score tensors of identical shape, values in
[0, 1].tconorm (str) – Operator to use.
'min'applies the Gödel (max) T-conorm;'prod'applies the probabilistic sum T-conorm.
- Returns:
Element-wise combined scores of the same shape as the inputs.
- Return type:
torch.Tensor
- negnorm(tens_1: torch.Tensor, lambda_: float, neg_norm: str = 'standard') torch.Tensor[source]
Apply a negation norm (complement) to an entity score distribution.
- Parameters:
tens_1 (torch.Tensor) – Input score tensor, values in
[0, 1].lambda (float) – Shape parameter used by the Sugeno and Yager negation norms. Ignored for the standard complement.
neg_norm (str) – Which negation to apply:
'standard'(1 - x),'sugeno', or'yager'.
- Returns:
Complemented score tensor of the same shape as tens_1.
- Return type:
torch.Tensor
- class dicee.abstracts.AbstractCallback[source]
Bases:
abc.ABC,lightning.pytorch.callbacks.CallbackAbstract base class for KGE training lifecycle callbacks.
Concrete sub-classes override one or more hook methods to perform custom actions at specific points during training (e.g. weight averaging, periodic evaluation, model checkpointing). All hooks have empty default implementations so sub-classes only need to override the hooks they care about.
Callbacks are registered by passing them to the trainer’s callbacks list. They are also compatible with PyTorch Lightning trainers because this class extends
lightning.pytorch.callbacks.Callback.- on_init_start(*args, **kwargs)[source]
Called when the trainer is about to be constructed.
Override to perform setup that must happen before any trainer state is initialised.
- on_init_end(*args, **kwargs)[source]
Called immediately after the trainer has been constructed.
Override to perform setup that requires a fully initialised trainer.
- on_fit_start(trainer, model)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- on_train_epoch_end(trainer, model)[source]
Called at the end of each training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model being trained.
model.loss_historycontains the per-epoch average losses accumulated so far.
- class dicee.abstracts.AbstractPPECallback(num_epochs, path, epoch_to_start, last_percent_to_consider)[source]
Bases:
AbstractCallbackAbstract base class for Post-training Parameter Ensembling (PPE) callbacks.
Sub-classes implement weight-averaging strategies (SWA, EMA, SWAG, …) by overriding
on_train_epoch_end()andon_fit_end(). Common book-keeping (epoch counter, sample counter, alpha weights) is managed here.- Parameters:
num_epochs (int) – Total number of training epochs.
path (str) – Experiment directory where averaged checkpoints will be written.
epoch_to_start (int) – First epoch at which the averaging procedure should begin.
last_percent_to_consider (float) – Fraction of the total training epochs (counted from the end) whose checkpoints are included in the ensemble.
- num_epochs
- path
- sample_counter = 0
- epoch_count = 0
- alphas = None
- on_fit_start(trainer, model)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- class dicee.abstracts.BaseInteractiveTrainKGE[source]
Abstract/base class for training knowledge graph embedding models interactively. This class provides methods for re-training KGE models and also Literal Embedding model.
- train_triples(h: List[str], r: List[str], t: List[str], labels: List[float], iteration=2, optimizer=None)[source]
- train_k_vs_all(h, r, iteration=1, lr=0.001)[source]
Train k vs all :param head_entity: :param relation: :param iteration: :param lr: :return:
- train(kg, lr=0.1, epoch=10, batch_size=32, neg_sample_ratio=10, num_workers=1) None[source]
Retrained a pretrain model on an input KG via negative sampling.
- train_literals(train_file_path: str = None, num_epochs: int = 100, lit_lr: float = 0.001, lit_normalization_type: str = 'z-norm', batch_size: int = 1024, sampling_ratio: float = None, random_seed=1, loader_backend: str = 'pandas', freeze_entity_embeddings: bool = True, gate_residual: bool = True, device: str = None, suffle_data: bool = True)[source]
Trains the Literal Embeddings model using literal data.
- Parameters:
train_file_path (str) – Path to the training data file.
num_epochs (int) – Number of training epochs.
lit_lr (float) – Learning rate for the literal model.
norm_type (str) – Normalization type to use (‘z-norm’, ‘min-max’, or None).
batch_size (int) – Batch size for training.
sampling_ratio (float) – Ratio of training triples to use.
loader_backend (str) – Backend for loading the dataset (‘pandas’ or ‘rdflib’).
freeze_entity_embeddings (bool) – If True, freeze the entity embeddings during training.
gate_residual (bool) – If True, use gate residual connections in the model.
device (str) – Device to use for training (‘cuda’ or ‘cpu’). If None, will use available GPU or CPU.
suffle_data (bool) – If True, shuffle the dataset before training.