from typing import List, Any, Tuple, Union, Dict
import lightning as pl
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from .adopt import ADOPT
[docs]
class BaseKGELightning(pl.LightningModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.training_step_outputs = []
[docs]
def mem_of_model(self) -> Dict:
""" Size of model in MB and number of params"""
# https://discuss.pytorch.org/t/finding-model-size/130275/2
# (2) Store NumParam and EstimatedSizeMB
num_params = sum(p.numel() for p in self.parameters())
# Not quite sure about EstimatedSizeMB ?
buffer_size = 0
for buffer in self.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return {'EstimatedSizeMB': (num_params + buffer_size) / 1024 ** 2, 'NumParam': num_params}
[docs]
def training_step(self, batch, batch_idx=None):
if len(batch)==2:
# Default
x_batch, y_batch = batch
yhat_batch = self.forward(x_batch)
elif len(batch)==3:
# KvsSample or 1vsSample
x_batch, y_select, y_batch = batch
yhat_batch = self.forward((x_batch,y_select))
else:
raise RuntimeError("Invalid batch received.")
loss_batch = self.loss_function(yhat_batch, y_batch)
self.training_step_outputs.append(loss_batch.item())
self.log("loss",
value=loss_batch,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
logger=False)
return loss_batch
[docs]
def loss_function(self, yhat_batch: torch.FloatTensor, y_batch: torch.FloatTensor):
"""
Parameters
----------
yhat_batch
y_batch
Returns
-------
"""
return self.loss(yhat_batch, y_batch)
[docs]
def on_train_epoch_end(self, *args, **kwargs):
if len(args) >= 1:
raise RuntimeError(f"Arguments must not be empty:{args}")
if len(kwargs) >= 1:
raise RuntimeError(f"Keyword Arguments must not be empty:{kwargs}")
self.loss_history.append(sum(self.training_step_outputs) / len(self.training_step_outputs))
self.training_step_outputs.clear()
[docs]
def test_epoch_end(self, outputs: List[Any]):
""" """
[docs]
def test_dataloader(self) -> None:
pass
[docs]
def val_dataloader(self) -> None:
pass
[docs]
def predict_dataloader(self) -> None:
pass
[docs]
def train_dataloader(self) -> None:
pass
[docs]
class BaseKGE(BaseKGELightning):
def __init__(self, args: dict):
super().__init__()
self.args = args
self.embedding_dim = None
self.num_entities = None
self.num_relations = None
self.num_tokens = None
self.learning_rate = None
self.apply_unit_norm = None
self.input_dropout_rate = None
self.hidden_dropout_rate = None
self.optimizer_name = None
self.feature_map_dropout_rate = None
self.kernel_size = None
self.num_of_output_channels = None
self.weight_decay = None
self.loss = torch.nn.BCEWithLogitsLoss()
self.selected_optimizer = None
self.normalizer_class = None
self.normalize_head_entity_embeddings = IdentityClass()
self.normalize_relation_embeddings = IdentityClass()
self.normalize_tail_entity_embeddings = IdentityClass()
self.hidden_normalizer = IdentityClass()
self.param_init = IdentityClass
self.init_params_with_sanity_checking()
# Dropouts
self.input_dp_ent_real = torch.nn.Dropout(self.input_dropout_rate)
self.input_dp_rel_real = torch.nn.Dropout(self.input_dropout_rate)
self.hidden_dropout = torch.nn.Dropout(self.input_dropout_rate)
# average minibatch loss per epoch
self.loss_history = []
self.byte_pair_encoding = self.args.get("byte_pair_encoding", False)
self.max_length_subword_tokens = self.args.get("max_length_subword_tokens", None)
self.block_size=self.args.get("block_size", None)
if self.byte_pair_encoding and self.args['model'] != "BytE":
self.token_embeddings = torch.nn.Embedding(self.num_tokens, self.embedding_dim)
self.param_init(self.token_embeddings.weight.data)
# Reduces subword units embedding matrix from T x D into D.
self.lf = nn.Sequential(nn.Linear(self.embedding_dim * self.max_length_subword_tokens,
self.embedding_dim, bias=False))
if self.args["scoring_technique"] in ["AllvsAll", "KvsAll"]:
self.str_to_bpe_entity_to_idx = {str_ent: idx for idx, (str_ent, bpe_ent, shaped_bpe_ent) in
enumerate(self.args["ordered_bpe_entities"])}
self.bpe_entity_to_idx = {shaped_bpe_ent: idx for idx, (str_ent, bpe_ent, shaped_bpe_ent) in
enumerate(self.args["ordered_bpe_entities"])}
self.ordered_bpe_entities = torch.tensor(list(self.bpe_entity_to_idx.keys()), dtype=torch.long)
elif self.byte_pair_encoding and self.args['model'] == "BytE":
""" Transformer implements token embeddings"""
else:
self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)
self.param_init(self.entity_embeddings.weight.data), self.param_init(self.relation_embeddings.weight.data)
[docs]
def forward_byte_pair_encoded_k_vs_all(self, x: torch.LongTensor):
"""
Parameters
----------
x : B x 2 x T
Returns
-------
"""
# (1) Get unit normalized subword units embedding matrices: (B, T, D)
bpe_head_ent_emb, bpe_rel_ent_emb = self.get_bpe_head_and_relation_representation(x)
# Future work: Use attention to model similarity between subword units comprising head and relation
# attentive_head_rel_emb = self.attention_block(torch.cat((bpe_head_ent_emb, bpe_rel_ent_emb), 1))
# bpe_head_ent_emb = attentive_head_rel_emb[:, :self.max_length_subword_tokens, :]
# bpe_rel_ent_emb = attentive_head_rel_emb[:, self.max_length_subword_tokens:, :]
# (2) Reshaping (1) into row vectors.
B, T, D = bpe_head_ent_emb.shape
# Multi-node GPU setting.
device_r = bpe_head_ent_emb.get_device()
if device_r >= 0:
self.ordered_bpe_entities = self.ordered_bpe_entities.to(device_r)
else:
self.ordered_bpe_entities = self.ordered_bpe_entities.to("cpu")
# (3) Get unit normalized subword units embedding matrices of all entities : (E, T, D)
E = self.token_embeddings(self.ordered_bpe_entities)
# (4) Reshaping (3) into row vectors (E, T*D) .
E = E.reshape(len(E), T * D)
# (5) Reshape and Reduce from (_, T*D) into row vectors.
bpe_head_ent_emb = self.input_dp_ent_real(bpe_head_ent_emb.reshape(B, T * D))
bpe_rel_ent_emb = self.input_dp_rel_real(bpe_rel_ent_emb.reshape(B, T * D))
bpe_head_ent_emb = self.lf(bpe_head_ent_emb)
bpe_rel_ent_emb = self.lf(bpe_rel_ent_emb)
E = self.lf(E)
return self.k_vs_all_score(bpe_head_ent_emb, bpe_rel_ent_emb, E)
[docs]
def forward_byte_pair_encoded_triple(self, x: Tuple[torch.LongTensor, torch.LongTensor]):
"""
byte pair encoded neural link predictors
Parameters
----------
-------
"""
bpe_head_ent_emb, bpe_rel_ent_emb, bpe_tail_ent_emb = self.get_sentence_representation(x)
B, T, C = bpe_head_ent_emb.shape
bpe_head_ent_emb = bpe_head_ent_emb.reshape(B, T * C)
bpe_rel_ent_emb = bpe_rel_ent_emb.reshape(B, T * C)
bpe_tail_ent_emb = bpe_tail_ent_emb.reshape(B, T * C)
bpe_triple_score = self.score(self.lf(bpe_head_ent_emb), self.lf(bpe_rel_ent_emb), self.lf(bpe_tail_ent_emb))
return bpe_triple_score
[docs]
def init_params_with_sanity_checking(self):
if self.args.get('weight_decay'):
self.weight_decay = self.args['weight_decay']
else:
self.weight_decay = 0.0
if self.args.get('embedding_dim'):
self.embedding_dim = self.args['embedding_dim']
else:
self.embedding_dim = 1
self.num_entities = self.args.get('num_entities', None)
self.num_relations = self.args.get('num_relations', None)
self.num_tokens = self.args.get('num_tokens', None)
if self.args.get('learning_rate'):
self.learning_rate = self.args['learning_rate']
else:
self.learning_rate = .1
if self.args.get("input_dropout_rate"):
self.input_dropout_rate = self.args['input_dropout_rate']
else:
self.input_dropout_rate = 0.0
if self.args.get("hidden_dropout_rate"):
self.hidden_dropout_rate = self.args['hidden_dropout_rate']
else:
self.hidden_dropout_rate = 0.0
if self.args.get("model") in ['ConvQ', 'ConvO', 'ConEx', 'AConEx', 'AConvQ', 'AConvO']:
if self.args.get("kernel_size"):
self.kernel_size = self.args['kernel_size']
else:
self.kernel_size = 3
if self.args.get("num_of_output_channels"):
self.num_of_output_channels = self.args['num_of_output_channels']
else:
self.num_of_output_channels = 3
if self.args.get("feature_map_dropout_rate"):
self.feature_map_dropout_rate = self.args['feature_map_dropout_rate']
else:
self.feature_map_dropout_rate = 0.0
if self.args.get("normalization") == 'LayerNorm':
self.normalizer_class = torch.nn.LayerNorm
self.normalize_head_entity_embeddings = self.normalizer_class(self.embedding_dim)
self.normalize_relation_embeddings = self.normalizer_class(self.embedding_dim)
if self.args['scoring_technique'] in ['NegSample', 'KvsSample']:
self.normalize_tail_entity_embeddings = self.normalizer_class(self.embedding_dim)
elif self.args.get("normalization") == 'BatchNorm1d':
self.normalizer_class = torch.nn.BatchNorm1d
self.normalize_head_entity_embeddings = self.normalizer_class(self.embedding_dim, affine=False)
self.normalize_relation_embeddings = self.normalizer_class(self.embedding_dim, affine=False)
if self.args['scoring_technique'] in ['NegSample', 'KvsSample']:
self.normalize_tail_entity_embeddings = self.normalizer_class(self.embedding_dim, affine=False)
elif self.args.get("normalization") is None:
self.normalizer_class = IdentityClass
else:
raise NotImplementedError()
self.optimizer_name = self.args.get('optim',None)
if self.args.get("init_param") is None:
self.param_init = IdentityClass
elif self.args['init_param'] == 'xavier_normal':
self.param_init = torch.nn.init.xavier_normal_
else:
print(f'--init_param (***{self.args.get("init_param")}***) not found')
self.optimizer_name = IdentityClass
[docs]
def forward(self, x: Union[torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]],
y_idx: torch.LongTensor = None):
"""
Parameters
----------
x
y_idx
ordered_bpe_entities
Returns
-------
"""
if isinstance(x, tuple):
x, y_idx = x
return self.forward_k_vs_sample(x=x, target_entity_idx=y_idx)
else:
shape_info = x.shape
if len(shape_info) == 2:
batch_size, dim = x.shape
if dim == 3:
return self.forward_triples(x)
elif dim == 2:
# h, y = x[0], x[1]
# Note that y can be relation or tail entity.
return self.forward_k_vs_all(x=x)
else:
size_of_input_data = shape_info[1]
if size_of_input_data == 3:
# NegSample with BPE
return self.forward_byte_pair_encoded_triple(x=x)
elif size_of_input_data == 2:
# KvsAll with BPE
return self.forward_byte_pair_encoded_k_vs_all(x)
[docs]
def forward_triples(self, x: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x
Returns
-------
"""
# (1) Retrieve embeddings & Apply Dropout & Normalization.
h_emb, r_emb, t_emb = self.get_triple_representation(x)
return self.score(h_emb, r_emb, t_emb)
[docs]
def forward_k_vs_all(self, *args, **kwargs):
raise ValueError(f'MODEL:{self.name} does not have forward_k_vs_all function')
[docs]
def forward_k_vs_sample(self, *args, **kwargs):
raise ValueError(f'MODEL:{self.name} does not have forward_k_vs_sample function')
[docs]
def get_triple_representation(self, idx_hrt):
# (1) Split input into indexes.
idx_head_entity, idx_relation, idx_tail_entity = idx_hrt[:, 0], idx_hrt[:, 1], idx_hrt[:, 2]
# (2) Retrieve embeddings & Apply Dropout & Normalization
head_ent_emb = self.normalize_head_entity_embeddings(
self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
rel_ent_emb = self.normalize_relation_embeddings(self.input_dp_rel_real(self.relation_embeddings(idx_relation)))
tail_ent_emb = self.normalize_tail_entity_embeddings(self.entity_embeddings(idx_tail_entity))
return head_ent_emb, rel_ent_emb, tail_ent_emb
[docs]
def get_head_relation_representation(self, indexed_triple):
# (1) Split input into indexes.
idx_head_entity, idx_relation = indexed_triple[:, 0], indexed_triple[:, 1]
# (2) Retrieve embeddings & Apply Dropout & Normalization
head_ent_emb = self.normalize_head_entity_embeddings(
self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
rel_ent_emb = self.normalize_relation_embeddings(self.input_dp_rel_real(self.relation_embeddings(idx_relation)))
return head_ent_emb, rel_ent_emb
[docs]
def get_sentence_representation(self, x: torch.LongTensor):
"""
Parameters
----------
x shape (b,3,t)
Returns
-------
"""
h, r, t = x[:, 0, :], x[:, 1, :], x[:, 2, :]
head_ent_emb = self.token_embeddings(h)
rel_emb = self.token_embeddings(r)
tail_emb = self.token_embeddings(t)
return head_ent_emb, rel_emb, tail_emb
[docs]
def get_bpe_head_and_relation_representation(self, x: torch.LongTensor) -> Tuple[
torch.FloatTensor, torch.FloatTensor]:
"""
Parameters
----------
x : B x 2 x T
Returns
-------
"""
# h: batchsize, T where T represents the maximum shaped token size
# h: B x T, r: B x T
h, r = x[:, 0, :], x[:, 1, :]
# B, T, D
head_ent_emb = self.token_embeddings(h)
# B, T, D
rel_emb = self.token_embeddings(r)
# A sequence of sub-list embeddings representing an embedding of a head entity should be normalized to 0.
# Therefore, the norm of a row vector obtained from T by D matrix must be 1.
# B, T, D
head_ent_emb = F.normalize(head_ent_emb, p=2, dim=(1, 2))
# B, T, D
rel_emb = F.normalize(rel_emb, p=2, dim=(1, 2))
return head_ent_emb, rel_emb
[docs]
def get_embeddings(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Returns
-------
"""
return self.entity_embeddings.weight.data.data.detach(), self.relation_embeddings.weight.data.detach()
[docs]
class IdentityClass(torch.nn.Module):
def __init__(self, args=None):
super().__init__()
self.args = args
[docs]
def __call__(self, x):
return x
[docs]
@staticmethod
def forward(x):
return x