Source code for dicee.models.pykeen_models

import torch
import torch.utils.data
from .base_model import BaseKGE
from collections import namedtuple

import traceback

[docs] class PykeenKGE(BaseKGE): """ A class for using knowledge graph embedding models implemented in Pykeen Notes: Pykeen_DistMult: C Pykeen_ComplEx: Pykeen_QuatE: Pykeen_MuRE: Pykeen_CP: Pykeen_HolE: Pykeen_HolE: """ def __init__(self, args: dict): super().__init__(args) self.model_kwargs = {'embedding_dim': args['embedding_dim'], 'entity_initializer': None if args['init_param'] is None else torch.nn.init.xavier_normal_, "random_seed": args["random_seed"] } self.model_kwargs.update(args['pykeen_model_kwargs']) self.name = args['model'].split("_")[1] # Solving memory issue of Pykeen models caused by the regularizers # See https://github.com/pykeen/pykeen/issues/1297 if self.name == "MuRE": "No Regularizer => no Memory Leakage" # https://pykeen.readthedocs.io/en/stable/api/pykeen.models.MuRE.html elif self.name == "QuatE": self.model_kwargs["entity_regularizer"] = None self.model_kwargs["relation_regularizer"] = None elif self.name == "DistMult": self.model_kwargs["regularizer"] = None elif self.name == "BoxE": pass elif self.name == "CP": # No regularizers pass elif self.name == "HolE": # No regularizers pass elif self.name == "ProjE": # Nothing pass elif self.name == "RotatE": pass elif self.name == "TransE": self.model_kwargs["regularizer"] = None else: print("Pykeen model have a memory leak caused by their implementation of regularizers") print(f"{self.name} does not seem to have any regularizer") try: # lazy import from pykeen.models import model_resolver except ImportError: print(traceback.format_exc()) print("Pykeen does not work with pytorch>2.0.0. Current pytorch version:",torch.__version__) exit(1) self.model = model_resolver. \ make(self.name, self.model_kwargs, triples_factory= namedtuple('triples_factory', ['num_entities', 'num_relations', 'create_inverse_triples'])( self.num_entities, self.num_relations, False)) self.loss_history = [] self.args = args self.entity_embeddings = None self.relation_embeddings = None for (k, v) in self.model.named_modules(): if "entity_representations" == k: self.entity_embeddings = v[0]._embeddings elif "relation_representations" == k: self.relation_embeddings = v[0]._embeddings elif "interaction" == k: self.interaction = v else: pass
[docs] def forward_k_vs_all(self, x: torch.LongTensor): """ # => Explicit version by this we can apply bn and dropout # (1) Retrieve embeddings of heads and relations + apply Dropout & Normalization if given. h, r = self.get_head_relation_representation(x) # (2) Reshape (1). if self.last_dim > 0: h = h.reshape(len(x), self.embedding_dim, self.last_dim) r = r.reshape(len(x), self.embedding_dim, self.last_dim) # (3) Reshape all entities. if self.last_dim > 0: t = self.entity_embeddings.weight.reshape(self.num_entities, self.embedding_dim, self.last_dim) else: t = self.entity_embeddings.weight # (4) Call the score_t from interactions to generate triple scores. return self.interaction.score_t(h=h, r=r, all_entities=t, slice_size=1) """ return self.model.score_t(x)
[docs] def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor: """ # => Explicit version by this we can apply bn and dropout # (1) Retrieve embeddings of heads, relations and tails and apply Dropout & Normalization if given. h, r, t = self.get_triple_representation(x) # (2) Reshape (1). if self.last_dim > 0: h = h.reshape(len(x), self.embedding_dim, self.last_dim) r = r.reshape(len(x), self.embedding_dim, self.last_dim) t = t.reshape(len(x), self.embedding_dim, self.last_dim) # (3) Compute the triple score return self.interaction.score(h=h, r=r, t=t, slice_size=None, slice_dim=0) """ return self.model.score_hrt(hrt_batch=x, mode=None).flatten()
[docs] def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx): raise NotImplementedError(f"KvsSample has not yet implemented for {self.name}")