dicee.models.base_model

Classes

BaseKGELightning

Thin PyTorch Lightning wrapper shared by all KGE models.

BaseKGE

Base class for all Knowledge Graph Embedding models.

IdentityClass

No-op normalisation / dropout placeholder.

Module Contents

class dicee.models.base_model.BaseKGELightning(*args, **kwargs)[source]

Bases: lightning.LightningModule

Thin PyTorch Lightning wrapper shared by all KGE models.

Provides the standard Lightning training loop hooks (training_step, on_train_epoch_end, configure_optimizers) as well as a helper for reporting model size. All concrete KGE models should extend BaseKGE rather than this class directly.

training_step_outputs = []
mem_of_model() Dict[source]

Size of model in MB and number of params

training_step(batch, batch_idx=None)[source]

Execute one optimisation step for the given mini-batch.

Handles two- and three-element batches produced by the different dataset classes (KvsAll / NegSample vs. KvsSample).

Parameters:
  • batch (tuple) – (x, y) for standard scoring, or (x, y_select, y) for sample-based labelling.

  • batch_idx (int, optional) – Index of the current batch (unused, kept for Lightning API compat).

Returns:

Scalar loss value for this batch.

Return type:

torch.FloatTensor

loss_function(yhat_batch: torch.FloatTensor, y_batch: torch.FloatTensor) torch.FloatTensor[source]

Compute the loss between model predictions and targets.

Delegates to self.loss which is configured in BaseKGE.__init__ based on the scoring technique (BCEWithLogitsLoss for entity/relation prediction, CrossEntropyLoss for classification).

Parameters:
  • yhat_batch (torch.FloatTensor) – Model output scores, shape (batch_size, *).

  • y_batch (torch.FloatTensor) – Ground-truth labels of the same shape as yhat_batch.

Returns:

Scalar loss value.

Return type:

torch.FloatTensor

on_train_epoch_end(*args, **kwargs)[source]

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
test_epoch_end(outputs: List[Any])[source]
test_dataloader() None[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

val_dataloader() None[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

predict_dataloader() None[source]

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • predict()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

train_dataloader() None[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

configure_optimizers(parameters=None)[source]

Instantiate and return the optimiser for training.

The optimiser type is taken from self.optimizer_name which is set in BaseKGE.init_params_with_sanity_checking() from the --optim argument. Supported values: 'SGD', 'Adam', 'Adopt', 'AdamW', 'NAdam', 'Adagrad', 'ASGD', 'Muon'.

Parameters:

parameters (iterable, optional) – Model parameters to optimise. Defaults to self.parameters() when None.

Returns:

The configured optimiser instance.

Return type:

torch.optim.Optimizer

class dicee.models.base_model.BaseKGE(args: dict)[source]

Bases: BaseKGELightning

Base class for all Knowledge Graph Embedding models.

Inherits the Lightning training loop from BaseKGELightning and adds the embedding tables, normalisation / dropout layers, and the routing logic that dispatches forward() calls to the appropriate scoring method.

Sub-classes must implement at minimum:

Parameters:

args (dict) – Flat configuration dictionary produced by vars(argparse.Namespace). Required keys: embedding_dim, num_entities, num_relations, learning_rate (or lr), optim, scoring_technique.

args
embedding_dim = None
num_entities = None
num_relations = None
num_tokens = None
learning_rate = None
apply_unit_norm = None
input_dropout_rate = None
hidden_dropout_rate = None
optimizer_name = None
feature_map_dropout_rate = None
kernel_size = None
num_of_output_channels = None
weight_decay = None
loss
selected_optimizer = None
normalizer_class = None
normalize_head_entity_embeddings
normalize_relation_embeddings
normalize_tail_entity_embeddings
hidden_normalizer
param_init
input_dp_ent_real
input_dp_rel_real
hidden_dropout
loss_history = []
byte_pair_encoding
max_length_subword_tokens
block_size
forward_byte_pair_encoded_k_vs_all(x: torch.LongTensor) torch.FloatTensor[source]

KvsAll scoring for BPE-encoded head entities and relations.

Retrieves subword-unit embeddings for the head entity and relation, reduces them to fixed-size vectors via a linear projection, then scores against all BPE entity embeddings.

Parameters:

x (torch.LongTensor) – Shape (batch_size, 2, T) BPE token indices where dim 1 indexes [head, relation] and T is max_length_subword_tokens.

Returns:

Shape (batch_size, num_bpe_entities) score matrix.

Return type:

torch.FloatTensor

forward_byte_pair_encoded_triple(x: Tuple[torch.LongTensor, torch.LongTensor]) torch.FloatTensor[source]

NegSample scoring for BPE-encoded (head, relation, tail) triples.

Retrieves subword-unit embeddings for all three elements and reduces them to fixed-size vectors via a linear projection before computing the triple score.

Parameters:

x (torch.LongTensor) – Shape (batch_size, 3, T) BPE token indices.

Returns:

Shape (batch_size,) triple scores.

Return type:

torch.FloatTensor

init_params_with_sanity_checking() None[source]

Populate model hyper-parameters from self.args with safe defaults.

Reads embedding dimension, learning rate, dropout rates, normalisation strategy, optimizer name, and parameter initialisation scheme from the args dict. Falls back to sensible defaults for any missing key so that minimal args dicts (e.g. for unit tests) are still valid.

forward(x: torch.LongTensor | Tuple[torch.LongTensor, torch.LongTensor], y_idx: torch.LongTensor = None) torch.FloatTensor[source]

Route the forward pass to the appropriate scoring method.

Inspects the shape and type of x to decide which low-level scorer to call:

Parameters:
  • x (torch.LongTensor or Tuple[torch.LongTensor, torch.LongTensor]) – Either a plain index tensor or a (triple_idx, target_idx) tuple for sample-based labelling.

  • y_idx (torch.LongTensor, optional) – Target entity indices used by forward_k_vs_sample(). Ignored when x is a plain tensor.

Returns:

Score tensor whose shape depends on the selected scorer.

Return type:

torch.FloatTensor

forward_triples(x: torch.LongTensor) torch.Tensor[source]

Score a batch of (head, relation, tail) index triples.

Parameters:

x (torch.LongTensor) – Shape (batch_size, 3) integer tensor where each row is [head_idx, relation_idx, tail_idx].

Returns:

Shape (batch_size,) triple scores.

Return type:

torch.FloatTensor

forward_k_vs_all(*args, **kwargs)[source]

Score a (head, relation) batch against every entity.

Sub-classes must override this method. The default implementation raises ValueError to make missing overrides obvious at runtime.

Returns:

Shape (batch_size, num_entities) score matrix.

Return type:

torch.FloatTensor

forward_k_vs_sample(*args, **kwargs)[source]

Score a (head, relation) batch against a sampled subset of entities.

Used by KvsSample and 1vsSample datasets. Sub-classes that support sample-based labelling must override this method.

Returns:

Shape (batch_size, k) score matrix where k is the number of sampled target entities.

Return type:

torch.FloatTensor

get_triple_representation(idx_hrt) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]

Retrieve and normalise embedding vectors for a triple index batch.

Parameters:

idx_hrt (torch.LongTensor) – Shape (batch_size, 3) integer tensor with columns [head_idx, relation_idx, tail_idx].

Returns:

head_ent_emb, rel_ent_emb, tail_ent_emb – Each has shape (batch_size, embedding_dim) after applying the configured dropout and normalisation.

Return type:

torch.FloatTensor

get_head_relation_representation(indexed_triple) Tuple[torch.FloatTensor, torch.FloatTensor][source]

Retrieve and normalise embedding vectors for head entities and relations.

Parameters:

indexed_triple (torch.LongTensor) – Shape (batch_size, 2) integer tensor with columns [head_idx, relation_idx].

Returns:

head_ent_emb, rel_ent_emb – Each has shape (batch_size, embedding_dim) after applying the configured dropout and normalisation.

Return type:

torch.FloatTensor

get_sentence_representation(x: torch.LongTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]

Retrieve BPE subword-unit embeddings for a batch of triples.

Parameters:

x (torch.LongTensor) – Shape (batch_size, 3, T) where T is max_length_subword_tokens.

Returns:

head_ent_emb, rel_emb, tail_emb – Each has shape (batch_size, T, embedding_dim).

Return type:

torch.FloatTensor

get_bpe_head_and_relation_representation(x: torch.LongTensor) Tuple[torch.FloatTensor, torch.FloatTensor][source]

Retrieve unit-normalised BPE embeddings for head entities and relations.

Each entity/relation is represented as a sequence of T subword tokens. Their token embeddings are L2-normalised across the sequence dimension so that the resulting matrix has unit Frobenius norm.

Parameters:

x (torch.LongTensor) – Shape (batch_size, 2, T) where dim 1 indexes [head, relation] and T is max_length_subword_tokens.

Returns:

head_ent_emb, rel_emb – Each has shape (batch_size, T, embedding_dim), L2-normalised over the (T, D) dimensions.

Return type:

torch.FloatTensor

get_embeddings() Tuple[numpy.ndarray, numpy.ndarray][source]

Return the entity and relation embedding matrices as numpy arrays.

Returns:

  • entity_embeddings (numpy.ndarray) – Shape (num_entities, embedding_dim).

  • relation_embeddings (numpy.ndarray) – Shape (num_relations, embedding_dim).

class dicee.models.base_model.IdentityClass(args=None)[source]

Bases: torch.nn.Module

No-op normalisation / dropout placeholder.

Used whenever no normalisation layer is requested (--normalization None). All inputs are returned unchanged so that the rest of the model code does not need conditional checks around normalisation calls.

args = None
__call__(x)[source]
static forward(x)[source]