Source code for dicee.models.dualE

import torch
from .base_model import BaseKGE


[docs] class DualE(BaseKGE): """Dual Quaternion Knowledge Graph Embeddings (https://ojs.aaai.org/index.php/AAAI/article/download/16850/16657)""" def __init__(self, args): super().__init__(args) self.name = 'DualE' self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) self.num_ent = self.num_entities def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3): """Calculate the Dual Hamiltonian product""" h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3 h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3 h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2 h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2 h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1 h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1 h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0 h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0 return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3) def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8): """Normalization of relationship embedding Inputs -------- Real and Imaginary parts of the Relation embeddings .. math:: W_r = (c,d) c = (r_1, r_2, r_3, r_4) d = (r_5, r_6, r_7, r_8) .. math:: \bar{d} = d - \frac{<d,c>}{<c,c>} c c' = \frac{c}{\|c\|} = \frac{c_0 + c_1i + c_2j + c_3k}{c_0^2 + c_1^2 + c_2^2 + c_3^2} Outputs -------- Normalized Real and Imaginary parts of the Relation embeddings .. math:: W_r' = (c', \bar{d}) """ denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2 denominator_1 = torch.sqrt(denominator_0) #denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2) deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4 r_5 = r_5 - deno_cross / denominator_0 * r_1 r_6 = r_6 - deno_cross / denominator_0 * r_2 r_7 = r_7 - deno_cross / denominator_0 * r_3 r_8 = r_8 - deno_cross / denominator_0 * r_4 r_1 = r_1 / denominator_1 r_2 = r_2 / denominator_1 r_3 = r_3 / denominator_1 r_4 = r_4 / denominator_1 #r_5 = r_5 / denominator_2 #r_6 = r_6 / denominator_2 #r_7 = r_7 / denominator_2 #r_8 = r_8 / denominator_2 return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )->torch.tensor: """Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity ref(Eq.8) \phi(h,r,t) = <a'_h, a_t> + <b'_h, b_t> + <c'_h, c_t> + <d'_h, d_t> Inputs: ---------- (Tensors) Real and imaginary parts of the head, relation and tail embeddings Output: inner product of the head entity and the relationship Hamiltonian product and the tail entity""" r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t + o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t) return -torch.sum(score_r, -1)
[docs] def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )->torch.tensor: """KvsAll scoring function Input --------- x: torch.LongTensor with (n, ) shape Output ------- torch.FloatTensor with (n) shape """ r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) score_r = torch.mm(o_1, e_1_t) + torch.mm(o_2 ,e_2_t) + torch.mm(o_3, e_3_t) + torch.mm(o_4, e_4_t)\ + torch.mm(o_5, e_5_t) + torch.mm(o_6, e_6_t) + torch.mm(o_7, e_7_t) +torch.mm( o_8 , e_8_t) return -score_r
[docs] def forward_triples(self, idx_triple:torch.tensor)-> torch.tensor: """Negative Sampling forward pass: Input --------- x: torch.LongTensor with (n, ) shape Output ------- torch.FloatTensor with (n) shape """ head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple) e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8) r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8) score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) return score
[docs] def forward_k_vs_all(self,x): """KvsAll forward pass Input --------- x: torch.LongTensor with (n, ) shape Output ------- torch.FloatTensor with (n) shape """ # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_ent_emb, 8) e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(self.entity_embeddings.weight, 8) e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = self.T(e_1_t), self.T(e_2_t), self.T(e_3_t),\ self.T(e_4_t), self.T(e_5_t), self.T(e_6_t), self.T(e_7_t), self.T(e_8_t) score = self.kvsall_score(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) return score
[docs] def T(self, x:torch.tensor)->torch.tensor: """ Transpose function Input: Tensor with shape (nxm) Output: Tensor with shape (mxn)""" return x.transpose(1, 0)