Source code for dicee.models.real

from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
from torch import nn

from dicee.models.transformers import Block

from .base_model import BaseKGE


[docs] class DistMult(BaseKGE): """DistMult: bilinear diagonal knowledge graph embedding. Scores a triple ``(h, r, t)`` as the element-wise product of the head, relation, and tail embeddings summed over the embedding dimension:: f(h, r, t) = \\sum_i h_i \\cdot r_i \\cdot t_i Simple yet effective baseline; incapable of modelling asymmetric relations. References ---------- Yang et al., *Embedding Entities and Relations for Learning and Inference in Knowledge Bases*, ICLR 2015. https://arxiv.org/abs/1412.6575 """ def __init__(self, args): super().__init__(args) self.name = 'DistMult'
[docs] def k_vs_all_score(self, emb_h: torch.FloatTensor, emb_r: torch.FloatTensor, emb_E: torch.FloatTensor) -> torch.FloatTensor: """Score a head/relation batch against all entity embeddings. Computes ``(h * r) @ E^T`` after applying hidden dropout and normalisation to the element-wise product. Parameters ---------- emb_h : torch.FloatTensor Head entity embeddings, shape ``(batch_size, embedding_dim)``. emb_r : torch.FloatTensor Relation embeddings, shape ``(batch_size, embedding_dim)``. emb_E : torch.FloatTensor All entity embeddings, shape ``(num_entities, embedding_dim)``. Returns ------- torch.FloatTensor Shape ``(batch_size, num_entities)`` score matrix. """ return torch.mm(self.hidden_dropout(self.hidden_normalizer(emb_h * emb_r)), emb_E.transpose(1, 0))
[docs] def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor: """KvsAll forward pass: score head/relation against all entities. Parameters ---------- x : torch.LongTensor Shape ``(batch_size, 2)`` integer tensor ``[head_idx, relation_idx]``. Returns ------- torch.FloatTensor Shape ``(batch_size, num_entities)`` score matrix. """ emb_head, emb_rel = self.get_head_relation_representation(x) return self.k_vs_all_score(emb_h=emb_head, emb_r=emb_rel, emb_E=self.entity_embeddings.weight)
[docs] def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx: torch.LongTensor) -> torch.FloatTensor: """KvsSample forward pass: score head/relation against a sampled entity subset. Parameters ---------- x : torch.LongTensor Shape ``(batch_size, 2)`` integer tensor ``[head_idx, relation_idx]``. target_entity_idx : torch.LongTensor Shape ``(batch_size, k)`` indices of the *k* target entities per sample. Returns ------- torch.FloatTensor Shape ``(batch_size, k)`` score matrix. """ # (b,d), (b,d) emb_head_real, emb_rel_real = self.get_head_relation_representation(x) # (b, d) hr = torch.einsum('bd, bd -> bd', emb_head_real, emb_rel_real) # (b, k, d) t = self.entity_embeddings(target_entity_idx) return torch.einsum('bd, bkd -> bk', hr, t)
[docs] def score(self, h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor) -> torch.FloatTensor: """Score a batch of ``(head, relation, tail)`` embedding triples. Parameters ---------- h, r, t : torch.FloatTensor Each has shape ``(batch_size, embedding_dim)``. Returns ------- torch.FloatTensor Shape ``(batch_size,)`` triple scores. """ return (self.hidden_dropout(self.hidden_normalizer(h * r)) * t).sum(dim=1)
[docs] class TransE(BaseKGE): """TransE: translation-based knowledge graph embedding. Models a relation *r* as a translation in embedding space such that ``h + r \u2248 t`` for a true triple ``(h, r, t)``. The score function is defined as:: f(h, r, t) = margin - ||h + r - t||_2 TransE is effective for 1-to-1 relations but struggles with reflexive, one-to-many, and many-to-one patterns. References ---------- Bordes et al., *Translating Embeddings for Modeling Multi-relational Data*, NeurIPS 2013. https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf """ def __init__(self, args): super().__init__(args) self.name = 'TransE' self._norm = 2 self.margin = 4
[docs] def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor) -> torch.FloatTensor: """Score a batch of triples using the TransE margin-distance formula. Parameters ---------- head_ent_emb, rel_ent_emb, tail_ent_emb : torch.FloatTensor Each has shape ``(batch_size, embedding_dim)``. Returns ------- torch.FloatTensor Shape ``(batch_size,)`` scores equal to ``margin - ||h + r - t||_2``. """ # Original d:=|| s+p - t||_2 \approx 0 distance, if true # if d =0 sigma(5-0) => 1 # if d =5 sigma(5-5) => 0.5 # Update: sigmoid( \gamma - d) return self.margin - torch.nn.functional.pairwise_distance(head_ent_emb + rel_ent_emb, tail_ent_emb, p=self._norm)
[docs] def forward_k_vs_all(self, x: torch.Tensor) -> torch.FloatTensor: """KvsAll forward pass: score head/relation against all entities. Computes ``margin - ||h + r - e||_2`` for every entity embedding *e*. Parameters ---------- x : torch.Tensor Shape ``(batch_size, 2)`` integer tensor ``[head_idx, relation_idx]``. Returns ------- torch.FloatTensor Shape ``(batch_size, num_entities)`` score matrix. """ emb_head_real, emb_rel_real = self.get_head_relation_representation(x) distance = torch.nn.functional.pairwise_distance(torch.unsqueeze(emb_head_real + emb_rel_real, 1), self.entity_embeddings.weight, p=self._norm) return self.margin - distance
[docs] class Shallom(BaseKGE): """Shallom: shallow neural model for relation prediction. Represents each triple as the concatenation of head and tail entity embeddings and feeds it through a two-layer MLP to predict the relation. Designed for the ``RelationPrediction`` labelling form. References ---------- Demir et al., *A Shallow Neural Model for Relation Prediction*, ISWC 2021. https://arxiv.org/abs/2101.09090 """ def __init__(self, args): super().__init__(args) self.name = 'Shallom' shallom_width = int(2 * self.embedding_dim) self.shallom = torch.nn.Sequential(torch.nn.Dropout(self.input_dropout_rate), torch.nn.Linear(self.embedding_dim * 2, shallom_width), self.normalizer_class(shallom_width), torch.nn.ReLU(), torch.nn.Dropout(self.hidden_dropout_rate), torch.nn.Linear(shallom_width, self.num_relations))
[docs] def get_embeddings(self) -> Tuple[np.ndarray, None]: return self.entity_embeddings.weight.data.detach(), None
[docs] def forward_k_vs_all(self, x) -> torch.FloatTensor: e1_idx: torch.Tensor e2_idx: torch.Tensor e1_idx, e2_idx = x[:, 0], x[:, 1] emb_s, emb_o = self.entity_embeddings(e1_idx), self.entity_embeddings(e2_idx) return self.shallom(torch.cat((emb_s, emb_o), 1))
[docs] def forward_triples(self, x) -> torch.FloatTensor: """Score a batch of triples by looking up relation scores from ``forward_k_vs_all``. Parameters ---------- x : torch.LongTensor Shape ``(batch_size, 3)`` integer tensor ``[head_idx, relation_idx, tail_idx]``. Returns ------- torch.FloatTensor Shape ``(batch_size,)`` triple scores. """ n, d = x.shape assert d == 3 scores_for_all_relations = self.forward_k_vs_all(x[:, [0, 2]]) return scores_for_all_relations[:, x[:, 1]].flatten()
[docs] class Pyke(BaseKGE): """Pyke: Physical Embedding Model for Knowledge Graphs. Scores a triple ``(h, r, t)`` based on the average pairwise distance between head-to-relation and relation-to-tail in embedding space:: f(h, r, t) = margin - (||h - r||_2 + ||r - t||_2) / 2 The model encodes geometric proximity between entities and the relations that connect them. """ def __init__(self, args): super().__init__(args) self.name = 'Pyke' self.dist_func = torch.nn.PairwiseDistance(p=2) self.margin = 1.0
[docs] def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor: """Score a batch of triples using the Pyke distance formula. Parameters ---------- x : torch.LongTensor Shape ``(batch_size, 3)`` integer tensor ``[head_idx, relation_idx, tail_idx]``. Returns ------- torch.FloatTensor Shape ``(batch_size,)`` triple scores. """ # (1) get embeddings for a batch of entities and relations head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x) # (2) Compute the Euclidean distance from head to relation dist_head_rel = self.dist_func(head_ent_emb, rel_ent_emb) dist_rel_tail = self.dist_func(rel_ent_emb, tail_ent_emb) avg_dist = (dist_head_rel + dist_rel_tail) / 2 return self.margin - avg_dist
[docs] @dataclass class CoKEConfig: """ Configuration for the CoKE (Contextualized Knowledge Graph Embedding) model. Attributes: block_size: Sequence length for transformer (3 for triples: head, relation, tail) vocab_size: Total vocabulary size (num_entities + num_relations) n_layer: Number of transformer layers n_head: Number of attention heads per layer n_embd: Embedding dimension (set to match model embedding_dim) dropout: Dropout rate applied throughout the model bias: Whether to use bias in linear layers causal: Whether to use causal masking (False for bidirectional attention) """ block_size: int = 3 # triples -> TODO: LF: for multi-hop this needs to be bigger vocab_size: int = None # Must be set to num_entities + num_relations before initializing CoKE n_layer: int = 6 n_head: int = 8 n_embd: int = None dropout: float = 0.3 # according to paper in [0.1 - 0.5] bias: bool = True # idk if better with false? causal: bool = False # non-causal so that we gather information in mask token
[docs] class CoKE(BaseKGE): """ Contextualized Knowledge Graph Embedding (CoKE) model. Based on: https://arxiv.org/pdf/1911.02168. CoKE uses a transformer encoder to learn contextualized representations of entities and relations. For link prediction, it predicts masked elements in (head, relation, tail) triples using bidirectional attention, similar to BERT's masked language modeling approach. The model creates a sequence [head_emb, relation_emb, mask_emb], adds positional embeddings, and processes it through transformer layers to predict the tail entity. """ def __init__(self, args, config: CoKEConfig = CoKEConfig()): super().__init__(args) self.name = 'CoKE' # Configure model dimensions self.config = config self.config.vocab_size = self.num_entities + self.num_relations self.config.n_embd = self.embedding_dim # Positional and mask embeddings self.pos_emb = torch.nn.Embedding(config.block_size, self.embedding_dim) self.mask_emb = torch.nn.Parameter(torch.zeros(self.embedding_dim)) # Transformer layers self.blocks = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embedding_dim) self.coke_dropout = nn.Dropout(config.dropout)
[docs] def forward_k_vs_all(self, x: torch.Tensor): device = x.device b = x.size(dim=0) # Get embeddings for head and relation head_emb, rel_emb = self.get_head_relation_representation(x) # (b, dim), (b, dim) mask_emb = self.mask_emb.unsqueeze(0).expand(b, -1) # (b, dim) # Create sequence: [head, relation, mask] seq = torch.stack([head_emb, rel_emb, mask_emb], dim=1) # (b, 3, dim) # Add positional embeddings pos_ids = torch.arange(0, 3, device=device) # (3,) -> TODO: LF: here 3 has to change according to voacb size (in case we want multi-hop) pos_ids = pos_ids.unsqueeze(0).expand(b, 3) # (b, 3) TODO: LF: same as above pos_emb = self.pos_emb(pos_ids) # (b, 3, dim) x_tok = seq + pos_emb # (b, 3, dim) # Pass through transformer layers for block in self.blocks: x_tok = block(x_tok) x_tok = self.ln_f(x_tok) # Extract the mask token's hidden state (position 2) h_mask = x_tok[:, 2, :] h_mask = self.coke_dropout(h_mask) # Score against all entity embeddings E = self.entity_embeddings.weight E = self.normalize_tail_entity_embeddings(E) scores = h_mask.mm(E.t()) return scores
[docs] def score(self, emb_h, emb_r, emb_t): b = emb_h.size(0) device = emb_h.device # Create sequence with mask token mask_emb = self.mask_emb.unsqueeze(0).expand(b, -1) seq = torch.stack([emb_h, emb_r, mask_emb], dim=1) # Add positional embeddings pos_ids = torch.arange(0, 3, device=device).unsqueeze(0).expand(b, 3) pos_emb = self.pos_emb(pos_ids) x_tok = seq + pos_emb # Pass through transformer for block in self.blocks: x_tok = block(x_tok) x_tok = self.ln_f(x_tok) # Extract mask token hidden state h_mask = x_tok[:, 2, :] h_mask = self.coke_dropout(h_mask) # Compute similarity between mask representation and tail embedding score = torch.einsum('bd,bd -> b', h_mask, emb_t) return score
[docs] def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx: torch.LongTensor): emb_head, emb_rel = self.get_head_relation_representation(x) b = emb_head.size(0) emb_tail = self.entity_embeddings(target_entity_idx) # (b, k, dim) device = emb_head.device # Create sequence with mask token mask_emb = self.mask_emb.unsqueeze(0).expand(b, -1) seq = torch.stack([emb_head, emb_rel, mask_emb], dim=1) # Add positional embeddings pos_ids = torch.arange(0, 3, device=device).unsqueeze(0).expand(b, 3) pos_emb = self.pos_emb(pos_ids) x_tok = seq + pos_emb # Pass through transformer for block in self.blocks: x_tok = block(x_tok) x_tok = self.ln_f(x_tok) # Extract mask token hidden state h_mask = x_tok[:, 2, :] h_mask = self.coke_dropout(h_mask) scores = torch.einsum('bd, bkd -> bk', h_mask, emb_tail) # dot product between each batch (how simlar is mask to all k tails in batch x) #output: (b,k) -> k scores per batch return scores