Source code for dicee.callbacks

import datetime
import time
import numpy as np
import torch

import dicee.models.base_model
from .static_funcs import save_checkpoint_model, save_pickle
from .abstracts import AbstractCallback
import pandas as pd


[docs] class AccumulateEpochLossCallback(AbstractCallback): def __init__(self, path: str): super().__init__() self.path = path
[docs] def on_fit_end(self, trainer, model) -> None: """ Store epoch loss Parameter --------- trainer: model: Returns --------- None """ pd.DataFrame(model.loss_history, columns=['EpochLoss']).to_csv(f'{self.path}/epoch_losses.csv')
[docs] class PrintCallback(AbstractCallback): def __init__(self): super().__init__() self.start_time = time.time()
[docs] def on_fit_start(self, trainer, pl_module): # print(pl_module) # print(pl_module.summarize()) # print(pl_module.selected_optimizer) print(f"\nTraining is starting {datetime.datetime.now()}...")
[docs] def on_fit_end(self, trainer, pl_module): training_time = time.time() - self.start_time if 60 > training_time: message = f'{training_time:.3f} seconds.' elif 60 * 60 > training_time > 60: message = f'{training_time / 60:.3f} minutes.' elif training_time > 60 * 60: message = f'{training_time / (60 * 60):.3f} hours.' else: message = f'{training_time:.3f} seconds.' print(f"Training Runtime: {message}\n")
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_train_epoch_end(self, *args, **kwargs): return
[docs] class KGESaveCallback(AbstractCallback): def __init__(self, every_x_epoch: int, max_epochs: int, path: str): super().__init__() self.every_x_epoch = every_x_epoch self.max_epochs = max_epochs self.epoch_counter = 0 self.path = path if self.every_x_epoch is None: self.every_x_epoch = max(self.max_epochs // 2, 1)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_fit_start(self, trainer, pl_module): pass
[docs] def on_train_epoch_end(self, *args, **kwargs): pass
[docs] def on_fit_end(self, *args, **kwargs): pass
[docs] def on_epoch_end(self, model, trainer, **kwargs): if self.epoch_counter % self.every_x_epoch == 0 and self.epoch_counter > 1: print(f'\nStoring model {self.epoch_counter}...') save_checkpoint_model(model, path=self.path + f'/model_at_{str(self.epoch_counter)}_' f'epoch_{str(str(datetime.datetime.now()))}.pt') self.epoch_counter += 1
[docs] class PseudoLabellingCallback(AbstractCallback): def __init__(self, data_module, kg, batch_size): super().__init__() self.data_module = data_module self.kg = kg self.num_of_epochs = 0 self.unlabelled_size = len(self.kg.unlabelled_set) self.batch_size = batch_size
[docs] def create_random_data(self): entities = torch.randint(low=0, high=self.kg.num_entities, size=(self.batch_size, 2)) relations = torch.randint(low=0, high=self.kg.num_relations, size=(self.batch_size,)) # unlabelled triples return torch.stack((entities[:, 0], relations, entities[:, 1]), dim=1)
[docs] def on_epoch_end(self, trainer, model): # Create random triples # if trainer.current_epoch < 10: # return None # Increase it size, Now we increase it. model.eval() with torch.no_grad(): # (1) Create random triples # unlabelled_input_batch = self.create_random_data() # (2) or use unlabelled batch unlabelled_input_batch = self.kg.unlabelled_set[ torch.randint(low=0, high=self.unlabelled_size, size=(self.batch_size,))] # (2) Predict unlabelled batch, and use prediction as pseudo-labels pseudo_label = torch.sigmoid(model(unlabelled_input_batch)) selected_triples = unlabelled_input_batch[pseudo_label >= .90] if len(selected_triples) > 0: # Update dataset self.data_module.train_set_idx = np.concatenate( (self.data_module.train_set_idx, selected_triples.detach().numpy()), axis=0) trainer.train_dataloader = self.data_module.train_dataloader() print(f'\tEpoch:{trainer.current_epoch}: Pseudo-labelling\t |D|= {len(self.data_module.train_set_idx)}') model.train()
[docs] def estimate_q(eps): """ estimate rate of convergence q from sequence esp""" x = np.arange(len(eps) - 1) y = np.log(np.abs(np.diff(np.log(eps)))) line = np.polyfit(x, y, 1) # fit degree 1 polynomial q = np.exp(line[0]) # find q return q
[docs] def compute_convergence(seq, i): assert len(seq) >= i > 0 return estimate_q(seq[-i:] / (np.arange(i) + 1))
[docs] class ASWA(AbstractCallback): """ Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. """ def __init__(self, num_epochs, path): super().__init__() self.path=path self.num_epochs=num_epochs self.initial_eval_setting = None self.epoch_count=0 self.alphas = [] self.val_aswa = -1
[docs] def on_fit_end(self, trainer, model): # super().on_fit_end(trainer, model) if self.initial_eval_setting: # ADD this info back trainer.evaluator.args.eval_model = self.initial_eval_setting param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu")) model.load_state_dict(param_ensemble)
[docs] @staticmethod def compute_mrr(trainer, model) -> float: # (2) Enable eval mode. model.eval() # (3) MRR performance on the validation data of running model. device_name = model.device model.to("cpu") last_val_mrr_running_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=model, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] model.to(device_name) # (4) Enable train mode. model.train() return last_val_mrr_running_model
[docs] def get_aswa_state_dict(self, model): # (2) Question: Soft update or Rejection?! ensemble_state_dict = torch.load(f"{self.path}/aswa.pt", torch.device(model.device)) # Perform provision parameter update. with torch.no_grad(): for k, parameters in model.state_dict().items(): if parameters.dtype == torch.float: ensemble_state_dict[k] = (ensemble_state_dict[k] * sum(self.alphas) + parameters) / (1 + sum(self.alphas)) return ensemble_state_dict
[docs] def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model): """ Perform Hard Update, software or rejection Parameters ---------- running_model_state_dict ensemble_state_dict val_running_model mrr_updated_ensemble_model Returns ------- """ # (1) HARD UPDATE: # If the validation performance of the running model is greater than # the validation performance of updated ASWA and # the validation performance of ASWA if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa: """Hard Update """ # (1.1) Save the running model as ASWA torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt") # (2.1) Resect alphas/ensemble weights self.alphas.clear() # (2.2) Store the validation performance of ASWA self.val_aswa = val_running_model return True # (2) SOFT UPDATE: # If the validation performance of the running model is less than # the validation performance of updated ASWA if mrr_updated_ensemble_model > self.val_aswa: """Soft update""" self.val_aswa = mrr_updated_ensemble_model torch.save(ensemble_state_dict, f=f"{self.path}/aswa.pt") self.alphas.append(1.0) return True # (3) Rejection: if self.val_aswa > mrr_updated_ensemble_model: """ Ignore """ self.alphas.append(0) return True
[docs] def on_train_epoch_end(self, trainer, model): # (1) Increment epoch counter self.epoch_count += 1 # (2) Save the given eval setting if it is not saved. if self.initial_eval_setting is None: self.initial_eval_setting = trainer.evaluator.args.eval_model trainer.evaluator.args.eval_model = "val" # (3) Compute MRR of the running model. val_running_model = self.compute_mrr(trainer, model) # (4) Initialize ASWA if it is not initialized. if self.val_aswa == -1: torch.save(model.state_dict(), f=f"{self.path}/aswa.pt") self.alphas.append(1.0) self.val_aswa = val_running_model return True else: # (5) Load ASWA ensemble parameters. ensemble_state_dict = self.get_aswa_state_dict(model) # (6) Initialize ASWA ensemble with (5). ensemble = type(model)(model.args) ensemble.load_state_dict(ensemble_state_dict) # (7) Evaluate (6) on the validation data, i.e., perform the lookahead operation. mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] # print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}") # (8) Decide whether ASWA should be updated via the current running model. self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)
[docs] class Eval(AbstractCallback): def __init__(self, path, epoch_ratio: int = None): super().__init__() self.path = path self.reports = [] self.epoch_ratio = epoch_ratio if epoch_ratio is not None else 1 self.epoch_counter = 0
[docs] def on_fit_start(self, trainer, model): pass
[docs] def on_fit_end(self, trainer, model): save_pickle(data=self.reports, file_path=trainer.attributes.full_storage_path + '/evals_per_epoch') """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 7)) for (p,q), mrr in pairs_to_train_mrr.items(): ax1.plot(mrr, label=f'{p},{q}') ax1.set_ylabel('Train MRR') for (p,q), mrr in pairs_to_val_mrr.items(): ax2.plot(mrr, label=f'{p},{q}') ax2.set_ylabel('Val MRR') plt.legend() plt.xlabel('Epochs') plt.savefig('{full_storage_path}train_val_mrr.pdf') plt.show() """
[docs] def on_train_epoch_end(self, trainer, model): self.epoch_counter += 1 if self.epoch_counter % self.epoch_ratio == 0: model.eval() report = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=model, form_of_labelling=trainer.form_of_labelling, during_training=True) model.train() self.reports.append(report)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] class KronE(AbstractCallback): def __init__(self): super().__init__() self.f = None
[docs] @staticmethod def batch_kronecker_product(a, b): """ Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor """ a, b = a.unsqueeze(1), b.unsqueeze(1) siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) siz0 = res.shape[:-4] res = res.reshape(siz0 + siz1) return res.flatten(1)
[docs] def get_kronecker_triple_representation(self, indexed_triple: torch.LongTensor): """ Get kronecker embeddings """ n, d = indexed_triple.shape assert d == 3 # Get the embeddings head_ent_emb, rel_ent_emb, tail_ent_emb = self.f(indexed_triple) head_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(head_ent_emb, 2)) rel_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(rel_ent_emb, 2)) tail_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(tail_ent_emb, 2)) return torch.cat((head_ent_emb, head_ent_kron_emb), dim=1), \ torch.cat((rel_ent_emb, rel_ent_kron_emb), dim=1), \ torch.cat((tail_ent_emb, tail_ent_kron_emb), dim=1)
[docs] def on_fit_start(self, trainer, model): if isinstance(model.normalize_head_entity_embeddings, dicee.models.base_model.IdentityClass): self.f = model.get_triple_representation model.get_triple_representation = self.get_kronecker_triple_representation else: raise NotImplementedError('Normalizer should be reinitialized')
[docs] class Perturb(AbstractCallback): """ A callback for a three-Level Perturbation Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation. Parameter Perturbation: Output Perturbation: """ def __init__(self, level: str = "input", ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None): """ level in {input, param, output} ratio:float btw [0,1] a percentage of mini-batch data point to be perturbed. method = ? """ super().__init__() assert level in {"input", "param", "out"} assert ratio >= 0.0 self.level = level self.ratio = ratio self.method = method self.scaler = scaler self.frequency = frequency # per epoch, per mini-batch ?
[docs] def on_train_batch_start(self, trainer, model, batch, batch_idx): # Modifications should be in-place # (1) Extract the input and output data points in a given batch. x, y = batch n, _ = x.shape assert n > 0 # (2) Compute the number of perturbed data points. num_of_perturbed_data = int(n * self.ratio) if num_of_perturbed_data == 0: return None # (3) Detect the device on which data points reside device = x.get_device() if device == -1: device = "cpu" # (4) Sample random integers from 0 to n without replacement and take num_of_perturbed_data of tem random_indices = torch.randperm(n, device=device)[:num_of_perturbed_data] # (5) Apply perturbation depending on the level. # (5.1) Apply Input level perturbation. if self.level == "input": if torch.rand(1) > 0.5: # (5.1.1) Perturb input via heads: Sample random indices for heads. perturbation = torch.randint(low=0, high=model.num_entities, size=(num_of_perturbed_data,), device=device) # Replace the head entities with (5.1.1) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack((perturbation, x[:, 1][random_indices])) else: # (5.1.2) Perturb input via relations : Sample random indices for relations. perturbation = torch.randint(low=0, high=model.num_relations, size=(num_of_perturbed_data,), device=device) # Replace the relations with (5.1.2) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack( (x[:, 0][random_indices], perturbation)) # (5.2) Apply Parameter level perturbation. elif self.level == "param": h, r = torch.hsplit(x, 2) # (5.2.1) Apply Gaussian Perturbation if self.method == "GN": if torch.rand(1) > 0.0: # (5.2.1.1) Apply Gaussian Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.normal(mean=0, std=self.scaler, size=model.entity_embeddings.weight[ h_selected].shape, device=model.device) else: # (5.2.1.2) Apply Gaussian Perturbation on relations. r_selected = r[random_indices] with (torch.no_grad()): model.relation_embeddings.weight[r_selected] += torch.normal(mean=0, std=self.scaler, size= model.entity_embeddings.weight[ r_selected].shape, device=model.device) # (5.2.2) Apply Random Perturbation elif self.method == "RN": if torch.rand(1) > 0.0: # (5.2.2.1) Apply Random Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.rand( size=model.entity_embeddings.weight[h_selected].shape, device=model.device) * self.scaler else: # (5.2.2.2) Apply Random Perturbation on relations. r_selected = r[random_indices] with torch.no_grad(): model.relation_embeddings.weight[r_selected] += torch.rand( size=model.entity_embeddings.weight[r_selected].shape, device=model.device) * self.scaler else: raise RuntimeError(f"--method is given as {self.method}!") elif self.level == "out": # (5.3) Apply output level perturbation. if self.method == "Soft": # (5.3) Output level soft perturbation resembles label smoothing. # (5.3.1) Compute the perturbation rate. perturb = torch.rand(1, device=model.device) * self.scaler # https://pytorch.org/docs/stable/generated/torch.where.html # 1.0 => 1.0 - perturb # 0.0 => perturb # (5.3.2) Reduces 1s and increases 0s via (5.2.1) batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 1.0 - perturb, perturb) elif self.method == "Hard": # (5.3) Output level hard perturbation flips 1s to 0 and 0s to 1s. batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 0.0, 1.0) else: raise NotImplementedError(f"{self.level}") else: raise RuntimeError(f"--level is given as {self.level}!")