dicee.models.base_model
Classes
Thin PyTorch Lightning wrapper shared by all KGE models. |
|
Base class for all Knowledge Graph Embedding models. |
|
No-op normalisation / dropout placeholder. |
Module Contents
- class dicee.models.base_model.BaseKGELightning(*args, **kwargs)[source]
Bases:
lightning.LightningModuleThin PyTorch Lightning wrapper shared by all KGE models.
Provides the standard Lightning training loop hooks (
training_step,on_train_epoch_end,configure_optimizers) as well as a helper for reporting model size. All concrete KGE models should extendBaseKGErather than this class directly.- training_step_outputs = []
- training_step(batch, batch_idx=None)[source]
Execute one optimisation step for the given mini-batch.
Handles two- and three-element batches produced by the different dataset classes (
KvsAll/NegSamplevs.KvsSample).- Parameters:
batch (tuple) –
(x, y)for standard scoring, or(x, y_select, y)for sample-based labelling.batch_idx (int, optional) – Index of the current batch (unused, kept for Lightning API compat).
- Returns:
Scalar loss value for this batch.
- Return type:
torch.FloatTensor
- loss_function(yhat_batch: torch.FloatTensor, y_batch: torch.FloatTensor) torch.FloatTensor[source]
Compute the loss between model predictions and targets.
Delegates to
self.losswhich is configured inBaseKGE.__init__based on the scoring technique (BCEWithLogitsLossfor entity/relation prediction,CrossEntropyLossfor classification).- Parameters:
yhat_batch (torch.FloatTensor) – Model output scores, shape
(batch_size, *).y_batch (torch.FloatTensor) – Ground-truth labels of the same shape as yhat_batch.
- Returns:
Scalar loss value.
- Return type:
torch.FloatTensor
- on_train_epoch_end(*args, **kwargs)[source]
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- test_dataloader() None[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- val_dataloader() None[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- predict_dataloader() None[source]
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- train_dataloader() None[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- configure_optimizers(parameters=None)[source]
Instantiate and return the optimiser for training.
The optimiser type is taken from
self.optimizer_namewhich is set inBaseKGE.init_params_with_sanity_checking()from the--optimargument. Supported values:'SGD','Adam','Adopt','AdamW','NAdam','Adagrad','ASGD','Muon'.- Parameters:
parameters (iterable, optional) – Model parameters to optimise. Defaults to
self.parameters()whenNone.- Returns:
The configured optimiser instance.
- Return type:
torch.optim.Optimizer
- class dicee.models.base_model.BaseKGE(args: dict)[source]
Bases:
BaseKGELightningBase class for all Knowledge Graph Embedding models.
Inherits the Lightning training loop from
BaseKGELightningand adds the embedding tables, normalisation / dropout layers, and the routing logic that dispatchesforward()calls to the appropriate scoring method.Sub-classes must implement at minimum:
forward_triples()— score a batch of(h, r, t)triples.forward_k_vs_all()— score a(h, r)batch against every entity.
- Parameters:
args (dict) – Flat configuration dictionary produced by
vars(argparse.Namespace). Required keys:embedding_dim,num_entities,num_relations,learning_rate(orlr),optim,scoring_technique.
- args
- embedding_dim = None
- num_entities = None
- num_relations = None
- num_tokens = None
- learning_rate = None
- apply_unit_norm = None
- input_dropout_rate = None
- optimizer_name = None
- feature_map_dropout_rate = None
- kernel_size = None
- num_of_output_channels = None
- weight_decay = None
- loss
- selected_optimizer = None
- normalizer_class = None
- normalize_head_entity_embeddings
- normalize_relation_embeddings
- normalize_tail_entity_embeddings
- param_init
- input_dp_ent_real
- input_dp_rel_real
- loss_history = []
- byte_pair_encoding
- max_length_subword_tokens
- block_size
- forward_byte_pair_encoded_k_vs_all(x: torch.LongTensor) torch.FloatTensor[source]
KvsAll scoring for BPE-encoded head entities and relations.
Retrieves subword-unit embeddings for the head entity and relation, reduces them to fixed-size vectors via a linear projection, then scores against all BPE entity embeddings.
- Parameters:
x (torch.LongTensor) – Shape
(batch_size, 2, T)BPE token indices where dim 1 indexes[head, relation]and T ismax_length_subword_tokens.- Returns:
Shape
(batch_size, num_bpe_entities)score matrix.- Return type:
torch.FloatTensor
- forward_byte_pair_encoded_triple(x: Tuple[torch.LongTensor, torch.LongTensor]) torch.FloatTensor[source]
NegSample scoring for BPE-encoded
(head, relation, tail)triples.Retrieves subword-unit embeddings for all three elements and reduces them to fixed-size vectors via a linear projection before computing the triple score.
- Parameters:
x (torch.LongTensor) – Shape
(batch_size, 3, T)BPE token indices.- Returns:
Shape
(batch_size,)triple scores.- Return type:
torch.FloatTensor
- init_params_with_sanity_checking() None[source]
Populate model hyper-parameters from
self.argswith safe defaults.Reads embedding dimension, learning rate, dropout rates, normalisation strategy, optimizer name, and parameter initialisation scheme from the
argsdict. Falls back to sensible defaults for any missing key so that minimalargsdicts (e.g. for unit tests) are still valid.
- forward(x: torch.LongTensor | Tuple[torch.LongTensor, torch.LongTensor], y_idx: torch.LongTensor = None) torch.FloatTensor[source]
Route the forward pass to the appropriate scoring method.
Inspects the shape and type of x to decide which low-level scorer to call:
Tuple
(x, y_idx)→forward_k_vs_sample()(batch, 3)tensor →forward_triples()(batch, 2)tensor →forward_k_vs_all()BPE triple tensor →
forward_byte_pair_encoded_triple()BPE pair tensor →
forward_byte_pair_encoded_k_vs_all()
- Parameters:
x (torch.LongTensor or Tuple[torch.LongTensor, torch.LongTensor]) – Either a plain index tensor or a
(triple_idx, target_idx)tuple for sample-based labelling.y_idx (torch.LongTensor, optional) – Target entity indices used by
forward_k_vs_sample(). Ignored when x is a plain tensor.
- Returns:
Score tensor whose shape depends on the selected scorer.
- Return type:
torch.FloatTensor
- forward_triples(x: torch.LongTensor) torch.Tensor[source]
Score a batch of
(head, relation, tail)index triples.- Parameters:
x (torch.LongTensor) – Shape
(batch_size, 3)integer tensor where each row is[head_idx, relation_idx, tail_idx].- Returns:
Shape
(batch_size,)triple scores.- Return type:
torch.FloatTensor
- forward_k_vs_all(*args, **kwargs)[source]
Score a
(head, relation)batch against every entity.Sub-classes must override this method. The default implementation raises
ValueErrorto make missing overrides obvious at runtime.- Returns:
Shape
(batch_size, num_entities)score matrix.- Return type:
torch.FloatTensor
- forward_k_vs_sample(*args, **kwargs)[source]
Score a
(head, relation)batch against a sampled subset of entities.Used by
KvsSampleand1vsSampledatasets. Sub-classes that support sample-based labelling must override this method.- Returns:
Shape
(batch_size, k)score matrix where k is the number of sampled target entities.- Return type:
torch.FloatTensor
- get_triple_representation(idx_hrt) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]
Retrieve and normalise embedding vectors for a triple index batch.
- Parameters:
idx_hrt (torch.LongTensor) – Shape
(batch_size, 3)integer tensor with columns[head_idx, relation_idx, tail_idx].- Returns:
head_ent_emb, rel_ent_emb, tail_ent_emb – Each has shape
(batch_size, embedding_dim)after applying the configured dropout and normalisation.- Return type:
torch.FloatTensor
- get_head_relation_representation(indexed_triple) Tuple[torch.FloatTensor, torch.FloatTensor][source]
Retrieve and normalise embedding vectors for head entities and relations.
- Parameters:
indexed_triple (torch.LongTensor) – Shape
(batch_size, 2)integer tensor with columns[head_idx, relation_idx].- Returns:
head_ent_emb, rel_ent_emb – Each has shape
(batch_size, embedding_dim)after applying the configured dropout and normalisation.- Return type:
torch.FloatTensor
- get_sentence_representation(x: torch.LongTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]
Retrieve BPE subword-unit embeddings for a batch of triples.
- Parameters:
x (torch.LongTensor) – Shape
(batch_size, 3, T)where T ismax_length_subword_tokens.- Returns:
head_ent_emb, rel_emb, tail_emb – Each has shape
(batch_size, T, embedding_dim).- Return type:
torch.FloatTensor
- get_bpe_head_and_relation_representation(x: torch.LongTensor) Tuple[torch.FloatTensor, torch.FloatTensor][source]
Retrieve unit-normalised BPE embeddings for head entities and relations.
Each entity/relation is represented as a sequence of T subword tokens. Their token embeddings are L2-normalised across the sequence dimension so that the resulting matrix has unit Frobenius norm.
- Parameters:
x (torch.LongTensor) – Shape
(batch_size, 2, T)where dim 1 indexes[head, relation]and T ismax_length_subword_tokens.- Returns:
head_ent_emb, rel_emb – Each has shape
(batch_size, T, embedding_dim), L2-normalised over the(T, D)dimensions.- Return type:
torch.FloatTensor
- class dicee.models.base_model.IdentityClass(args=None)[source]
Bases:
torch.nn.ModuleNo-op normalisation / dropout placeholder.
Used whenever no normalisation layer is requested (
--normalization None). All inputs are returned unchanged so that the rest of the model code does not need conditional checks around normalisation calls.- args = None