Source code for dicee.models.octonion

import torch
from .base_model import BaseKGE, IdentityClass


[docs] def octonion_mul(*, O_1, O_2): x0, x1, x2, x3, x4, x5, x6, x7 = O_1 y0, y1, y2, y3, y4, y5, y6, y7 = O_2 x = x0 * y0 - x1 * y1 - x2 * y2 - x3 * y3 - x4 * y4 - x5 * y5 - x6 * y6 - x7 * y7 e1 = x0 * y1 + x1 * y0 + x2 * y3 - x3 * y2 + x4 * y5 - x5 * y4 - x6 * y7 + x7 * y6 e2 = x0 * y2 - x1 * y3 + x2 * y0 + x3 * y1 + x4 * y6 + x5 * y7 - x6 * y4 - x7 * y5 e3 = x0 * y3 + x1 * y2 - x2 * y1 + x3 * y0 + x4 * y7 - x5 * y6 + x6 * y5 - x7 * y4 e4 = x0 * y4 - x1 * y5 - x2 * y6 - x3 * y7 + x4 * y0 + x5 * y1 + x6 * y2 + x7 * y3 e5 = x0 * y5 + x1 * y4 - x2 * y7 + x3 * y6 - x4 * y1 + x5 * y0 - x6 * y3 + x7 * y2 e6 = x0 * y6 + x1 * y7 + x2 * y4 - x3 * y5 - x4 * y2 + x5 * y3 + x6 * y0 - x7 * y1 e7 = x0 * y7 - x1 * y6 + x2 * y5 + x3 * y4 - x4 * y3 - x5 * y2 + x6 * y1 + x7 * y0 return x, e1, e2, e3, e4, e5, e6, e7
[docs] def octonion_mul_norm(*, O_1, O_2): x0, x1, x2, x3, x4, x5, x6, x7 = O_1 y0, y1, y2, y3, y4, y5, y6, y7 = O_2 # Normalize the relation to eliminate the scaling effect, may cause Nan due to floating point. denominator = torch.sqrt(y0 ** 2 + y1 ** 2 + y2 ** 2 + y3 ** 2 + y4 ** 2 + y5 ** 2 + y6 ** 2 + y7 ** 2) y0 = y0 / denominator y1 = y1 / denominator y2 = y2 / denominator y3 = y3 / denominator y4 = y4 / denominator y5 = y5 / denominator y6 = y6 / denominator y7 = y7 / denominator x = x0 * y0 - x1 * y1 - x2 * y2 - x3 * y3 - x4 * y4 - x5 * y5 - x6 * y6 - x7 * y7 e1 = x0 * y1 + x1 * y0 + x2 * y3 - x3 * y2 + x4 * y5 - x5 * y4 - x6 * y7 + x7 * y6 e2 = x0 * y2 - x1 * y3 + x2 * y0 + x3 * y1 + x4 * y6 + x5 * y7 - x6 * y4 - x7 * y5 e3 = x0 * y3 + x1 * y2 - x2 * y1 + x3 * y0 + x4 * y7 - x5 * y6 + x6 * y5 - x7 * y4 e4 = x0 * y4 - x1 * y5 - x2 * y6 - x3 * y7 + x4 * y0 + x5 * y1 + x6 * y2 + x7 * y3 e5 = x0 * y5 + x1 * y4 - x2 * y7 + x3 * y6 - x4 * y1 + x5 * y0 - x6 * y3 + x7 * y2 e6 = x0 * y6 + x1 * y7 + x2 * y4 - x3 * y5 - x4 * y2 + x5 * y3 + x6 * y0 - x7 * y1 e7 = x0 * y7 - x1 * y6 + x2 * y5 + x3 * y4 - x4 * y3 - x5 * y2 + x6 * y1 + x7 * y0 return x, e1, e2, e3, e4, e5, e6, e7
[docs] class OMult(BaseKGE): def __init__(self, args): super().__init__(args) self.name = 'OMult'
[docs] @staticmethod def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7): denominator = torch.sqrt( emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) y0 = emb_rel_e0 / denominator y1 = emb_rel_e1 / denominator y2 = emb_rel_e2 / denominator y3 = emb_rel_e3 / denominator y4 = emb_rel_e4 / denominator y5 = emb_rel_e5 / denominator y6 = emb_rel_e6 / denominator y7 = emb_rel_e7 / denominator return y0, y1, y2, y3, y4, y5, y6, y7
[docs] def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor): # (2) Split (1) into real and imaginary parts. emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( tail_ent_emb, 8) # (3) Octonion Multiplication e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=( emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) # (4) # (4.3) Inner product e0_score = (e0 * emb_tail_e0).sum(dim=1) e1_score = (e1 * emb_tail_e1).sum(dim=1) e2_score = (e2 * emb_tail_e2).sum(dim=1) e3_score = (e3 * emb_tail_e3).sum(dim=1) e4_score = (e4 * emb_tail_e4).sum(dim=1) e5_score = (e5 * emb_tail_e5).sum(dim=1) e6_score = (e6 * emb_tail_e6).sum(dim=1) e7_score = (e7 * emb_tail_e7).sum(dim=1) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score
[docs] def k_vs_all_score(self, bpe_head_ent_emb, bpe_rel_ent_emb, E): # (2) Split (1) into real and imaginary parts. # (2) Split (1) into real and imaginary parts. emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( bpe_head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( bpe_rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) # (3)Apply octonion multiplication e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) # Prepare all entities. emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( E, 8) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 \ = emb_tail_e0.transpose(1, 0), emb_tail_e1.transpose(1, 0), \ emb_tail_e2.transpose(1, 0), emb_tail_e3.transpose(1, 0), \ emb_tail_e4.transpose(1, 0), emb_tail_e5.transpose(1, 0), \ emb_tail_e6.transpose(1, 0), emb_tail_e7.transpose(1, 0) # (4) # (4.4) Inner product e0_score = torch.mm(e0, emb_tail_e0) e1_score = torch.mm(e1, emb_tail_e1) e2_score = torch.mm(e2, emb_tail_e2) e3_score = torch.mm(e3, emb_tail_e3) e4_score = torch.mm(e4, emb_tail_e4) e5_score = torch.mm(e5, emb_tail_e5) e6_score = torch.mm(e6, emb_tail_e6) e7_score = torch.mm(e7, emb_tail_e7) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score
[docs] def forward_k_vs_all(self, x): """ Completed. Given a head entity and a relation (h,r), we compute scores for all possible triples,i.e., [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) Given a batch of head entities and relations => shape (size of batch,| Entities|) """ # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) return self.k_vs_all_score(head_ent_emb, rel_ent_emb, self.entity_embeddings.weight)
[docs] class ConvO(BaseKGE): def __init__(self, args: dict): super().__init__(args=args) self.name = 'ConvO' # Convolution self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_of_output_channels, kernel_size=(self.kernel_size, self.kernel_size), stride=1, padding=1, bias=True) self.fc_num_input = self.embedding_dim * 2 * self.num_of_output_channels self.fc1 = torch.nn.Linear(self.fc_num_input, self.embedding_dim) # Hard compression. self.bn_conv2d = torch.nn.BatchNorm2d(self.num_of_output_channels) self.norm_fc1 = self.normalizer_class(self.embedding_dim) self.feature_map_dropout = torch.nn.Dropout2d(self.feature_map_dropout_rate)
[docs] @staticmethod def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7): denominator = torch.sqrt( emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) y0 = emb_rel_e0 / denominator y1 = emb_rel_e1 / denominator y2 = emb_rel_e2 / denominator y3 = emb_rel_e3 / denominator y4 = emb_rel_e4 / denominator y5 = emb_rel_e5 / denominator y6 = emb_rel_e6 / denominator y7 = emb_rel_e7 / denominator return y0, y1, y2, y3, y4, y5, y6, y7
[docs] def residual_convolution(self, O_1, O_2): emb_ent_e0, emb_ent_e1, emb_ent_e2, emb_ent_e3, emb_ent_e4, emb_ent_e5, emb_ent_e6, emb_ent_e7 = O_1 emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = O_2 x = torch.cat([emb_ent_e0.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e1.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e2.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e3.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e4.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e5.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e6.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e7.view(-1, 1, 1, self.embedding_dim // 8), # entities emb_rel_e0.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e1.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e2.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e3.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e4.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e5.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e6.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e7.view(-1, 1, 1, self.embedding_dim // 8), ], 2) x = torch.nn.functional.relu(self.bn_conv2d(self.conv2d(x))) x = self.feature_map_dropout(x) x = x.view(x.shape[0], -1) # reshape for NN. x = torch.nn.functional.relu(self.norm_fc1(self.fc1(x))) return torch.chunk(x, 8, dim=1)
[docs] def forward_triples(self, x: torch.Tensor) -> torch.Tensor: # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x) # (2) Split (1) into real and imaginary parts. emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer( emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) (emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7) = torch.hsplit( tail_ent_emb, 8) # (2) Apply convolution operation on (1.1) and (1.2). O_3 = self.residual_convolution(O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) conv_e0, conv_e1, conv_e2, conv_e3, conv_e4, conv_e5, conv_e6, conv_e7 = O_3 # (3) # (3.1) Apply quaternion multiplication. e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) # (4) # (4.4) Inner product e0_score = (conv_e0 * e0 * emb_tail_e0).sum(dim=1) e1_score = (conv_e1 * e1 * emb_tail_e1).sum(dim=1) e2_score = (conv_e2 * e2 * emb_tail_e2).sum(dim=1) e3_score = (conv_e3 * e3 * emb_tail_e3).sum(dim=1) e4_score = (conv_e4 * e4 * emb_tail_e4).sum(dim=1) e5_score = (conv_e5 * e5 * emb_tail_e5).sum(dim=1) e6_score = (conv_e6 * e6 * emb_tail_e6).sum(dim=1) e7_score = (conv_e7 * e7 * emb_tail_e7).sum(dim=1) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score
[docs] def forward_k_vs_all(self, x: torch.Tensor): """ Given a head entity and a relation (h,r), we compute scores for all entities. [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) Given a batch of head entities and relations => shape (size of batch,| Entities|) """ # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (2) Split (1) into real and imaginary parts. # (2) Split (1) into real and imaginary parts. emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer( emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) # (2) Apply convolution operation on (1.1) and (1.2). O_3 = self.residual_convolution(O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) conv_e0, conv_e1, conv_e2, conv_e3, conv_e4, conv_e5, conv_e6, conv_e7 = O_3 # (3) # (3.2) Apply quaternion multiplication on (1.1) and (3.1). e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( self.entity_embeddings.weight, 8) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = \ emb_tail_e0.transpose(1, 0), emb_tail_e1.transpose(1, 0), \ emb_tail_e2.transpose(1, 0), emb_tail_e3.transpose(1, 0), \ emb_tail_e4.transpose(1, 0), emb_tail_e5.transpose(1, 0), emb_tail_e6.transpose(1, 0), emb_tail_e7.transpose( 1, 0) # (4) # (4.4) Inner product e0_score = torch.mm(conv_e0 * e0, emb_tail_e0) e1_score = torch.mm(conv_e1 * e1, emb_tail_e1) e2_score = torch.mm(conv_e2 * e2, emb_tail_e2) e3_score = torch.mm(conv_e3 * e3, emb_tail_e3) e4_score = torch.mm(conv_e4 * e4, emb_tail_e4) e5_score = torch.mm(conv_e5 * e5, emb_tail_e5) e6_score = torch.mm(conv_e6 * e6, emb_tail_e6) e7_score = torch.mm(conv_e7 * e7, emb_tail_e7) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score
[docs] class AConvO(BaseKGE): """ Additive Convolutional Octonion Knowledge Graph Embeddings """ def __init__(self, args: dict): super().__init__(args=args) self.name = 'AConvO' # Convolution self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_of_output_channels, kernel_size=(self.kernel_size, self.kernel_size), stride=1, padding=1, bias=True) self.fc_num_input = self.embedding_dim * 2 * self.num_of_output_channels self.fc1 = torch.nn.Linear(self.fc_num_input, self.embedding_dim) # Hard compression. self.bn_conv2d = torch.nn.BatchNorm2d(self.num_of_output_channels) self.norm_fc1 = self.normalizer_class(self.embedding_dim) self.feature_map_dropout = torch.nn.Dropout2d(self.feature_map_dropout_rate)
[docs] @staticmethod def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7): denominator = torch.sqrt( emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) y0 = emb_rel_e0 / denominator y1 = emb_rel_e1 / denominator y2 = emb_rel_e2 / denominator y3 = emb_rel_e3 / denominator y4 = emb_rel_e4 / denominator y5 = emb_rel_e5 / denominator y6 = emb_rel_e6 / denominator y7 = emb_rel_e7 / denominator return y0, y1, y2, y3, y4, y5, y6, y7
[docs] def residual_convolution(self, O_1, O_2): emb_ent_e0, emb_ent_e1, emb_ent_e2, emb_ent_e3, emb_ent_e4, emb_ent_e5, emb_ent_e6, emb_ent_e7 = O_1 emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = O_2 x = torch.cat([emb_ent_e0.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e1.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e2.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e3.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e4.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e5.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e6.view(-1, 1, 1, self.embedding_dim // 8), emb_ent_e7.view(-1, 1, 1, self.embedding_dim // 8), # entities emb_rel_e0.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e1.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e2.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e3.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e4.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e5.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e6.view(-1, 1, 1, self.embedding_dim // 8), emb_rel_e7.view(-1, 1, 1, self.embedding_dim // 8), ], 2) x = torch.nn.functional.relu(self.bn_conv2d(self.conv2d(x))) x = self.feature_map_dropout(x) x = x.view(x.shape[0], -1) # reshape for NN. x = torch.nn.functional.relu(self.norm_fc1(self.fc1(x))) return torch.chunk(x, 8, dim=1)
[docs] def forward_triples(self, x: torch.Tensor) -> torch.Tensor: # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x) # (2) Split (1) into real and imaginary parts. emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) (emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7) = torch.hsplit( tail_ent_emb, 8) # (2) Apply convolution operation on (1.1) and (1.2). O_3 = self.residual_convolution(O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) conv_e0, conv_e1, conv_e2, conv_e3, conv_e4, conv_e5, conv_e6, conv_e7 = O_3 # (3) # (3.1) Apply quaternion multiplication. e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) # (4) # (4.4) Inner product e0_score = (conv_e0 + e0 * emb_tail_e0).sum(dim=1) e1_score = (conv_e1 + e1 * emb_tail_e1).sum(dim=1) e2_score = (conv_e2 + e2 * emb_tail_e2).sum(dim=1) e3_score = (conv_e3 + e3 * emb_tail_e3).sum(dim=1) e4_score = (conv_e4 + e4 * emb_tail_e4).sum(dim=1) e5_score = (conv_e5 + e5 * emb_tail_e5).sum(dim=1) e6_score = (conv_e6 + e6 * emb_tail_e6).sum(dim=1) e7_score = (conv_e7 + e7 * emb_tail_e7).sum(dim=1) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score
[docs] def forward_k_vs_all(self, x: torch.Tensor): """ Given a head entity and a relation (h,r), we compute scores for all entities. [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) Given a batch of head entities and relations => shape (size of batch,| Entities|) """ # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (2) Split (1) into real and imaginary parts. # (2) Split (1) into real and imaginary parts. (emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7) = torch.hsplit( head_ent_emb, 8) emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( rel_ent_emb, 8) if isinstance(self.normalize_relation_embeddings, IdentityClass): (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7) # (2) Apply convolution operation on (1.1) and (1.2). O_3 = self.residual_convolution(O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) conv_e0, conv_e1, conv_e2, conv_e3, conv_e4, conv_e5, conv_e6, conv_e7 = O_3 # (3) # (3.2) Apply quaternion multiplication on (1.1) and (3.1). e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( O_1=(emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = \ torch.hsplit(self.entity_embeddings.weight, 8) emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = \ emb_tail_e0.transpose(1, 0), emb_tail_e1.transpose(1, 0), \ emb_tail_e2.transpose(1, 0), emb_tail_e3.transpose(1, 0), emb_tail_e4.transpose( 1, 0), emb_tail_e5.transpose(1, 0), emb_tail_e6.transpose(1, 0), emb_tail_e7.transpose(1, 0) # (4) # (4.4) Inner product e0_score = torch.mm(conv_e0 + e0, emb_tail_e0) e1_score = torch.mm(conv_e1 + e1, emb_tail_e1) e2_score = torch.mm(conv_e2 + e2, emb_tail_e2) e3_score = torch.mm(conv_e3 + e3, emb_tail_e3) e4_score = torch.mm(conv_e4 + e4, emb_tail_e4) e5_score = torch.mm(conv_e5 + e5, emb_tail_e5) e6_score = torch.mm(conv_e6 + e6, emb_tail_e6) e7_score = torch.mm(conv_e7 + e7, emb_tail_e7) return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score