dicee.callbacks =============== .. py:module:: dicee.callbacks Classes ------- .. autoapisummary:: dicee.callbacks.AccumulateEpochLossCallback dicee.callbacks.PrintCallback dicee.callbacks.KGESaveCallback dicee.callbacks.PseudoLabellingCallback dicee.callbacks.ASWA dicee.callbacks.Eval dicee.callbacks.KronE dicee.callbacks.Perturb Functions --------- .. autoapisummary:: dicee.callbacks.estimate_q dicee.callbacks.compute_convergence Module Contents --------------- .. py:class:: AccumulateEpochLossCallback(path: str) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: path .. py:method:: on_fit_end(trainer, model) -> None Store epoch loss Parameter --------- trainer: model: :rtype: None .. py:class:: PrintCallback Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: start_time .. py:method:: on_fit_start(trainer, pl_module) Call at the beginning of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_fit_end(trainer, pl_module) Call at the end of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_train_batch_end(*args, **kwargs) Call at the end of each mini-batch during the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_train_epoch_end(*args, **kwargs) Call at the end of each epoch during training. Parameter --------- trainer: model: :rtype: None .. py:class:: KGESaveCallback(every_x_epoch: int, max_epochs: int, path: str) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: every_x_epoch .. py:attribute:: max_epochs .. py:attribute:: epoch_counter :value: 0 .. py:attribute:: path .. py:method:: on_train_batch_end(*args, **kwargs) Call at the end of each mini-batch during the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_fit_start(trainer, pl_module) Call at the beginning of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_train_epoch_end(*args, **kwargs) Call at the end of each epoch during training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_fit_end(*args, **kwargs) Call at the end of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_epoch_end(model, trainer, **kwargs) .. py:class:: PseudoLabellingCallback(data_module, kg, batch_size) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: data_module .. py:attribute:: kg .. py:attribute:: num_of_epochs :value: 0 .. py:attribute:: unlabelled_size .. py:attribute:: batch_size .. py:method:: create_random_data() .. py:method:: on_epoch_end(trainer, model) .. py:function:: estimate_q(eps) estimate rate of convergence q from sequence esp .. py:function:: compute_convergence(seq, i) .. py:class:: ASWA(num_epochs, path) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. .. py:attribute:: path .. py:attribute:: num_epochs .. py:attribute:: initial_eval_setting :value: None .. py:attribute:: epoch_count :value: 0 .. py:attribute:: alphas :value: [] .. py:attribute:: val_aswa :value: -1 .. py:method:: on_fit_end(trainer, model) Call at the end of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: compute_mrr(trainer, model) -> float :staticmethod: .. py:method:: get_aswa_state_dict(model) .. py:method:: decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model) Perform Hard Update, software or rejection :param running_model_state_dict: :param ensemble_state_dict: :param val_running_model: :param mrr_updated_ensemble_model: .. py:method:: on_train_epoch_end(trainer, model) Call at the end of each epoch during training. Parameter --------- trainer: model: :rtype: None .. py:class:: Eval(path, epoch_ratio: int = None) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: path .. py:attribute:: reports :value: [] .. py:attribute:: epoch_ratio :value: None .. py:attribute:: epoch_counter :value: 0 .. py:method:: on_fit_start(trainer, model) Call at the beginning of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_fit_end(trainer, model) Call at the end of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_train_epoch_end(trainer, model) Call at the end of each epoch during training. Parameter --------- trainer: model: :rtype: None .. py:method:: on_train_batch_end(*args, **kwargs) Call at the end of each mini-batch during the training. Parameter --------- trainer: model: :rtype: None .. py:class:: KronE Bases: :py:obj:`dicee.abstracts.AbstractCallback` Abstract class for Callback class for knowledge graph embedding models Parameter --------- .. py:attribute:: f :value: None .. py:method:: batch_kronecker_product(a, b) :staticmethod: Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor .. py:method:: get_kronecker_triple_representation(indexed_triple: torch.LongTensor) Get kronecker embeddings .. py:method:: on_fit_start(trainer, model) Call at the beginning of the training. Parameter --------- trainer: model: :rtype: None .. py:class:: Perturb(level: str = 'input', ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None) Bases: :py:obj:`dicee.abstracts.AbstractCallback` A callback for a three-Level Perturbation Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation. Parameter Perturbation: Output Perturbation: .. py:attribute:: level :value: 'input' .. py:attribute:: ratio :value: 0.0 .. py:attribute:: method :value: None .. py:attribute:: scaler :value: None .. py:attribute:: frequency :value: None .. py:method:: on_train_batch_start(trainer, model, batch, batch_idx) Called when the train batch begins.