Source code for dicee.eval_static_funcs

import os
import torch
from typing import Dict, Tuple, List, Callable
from .knowledge_graph_embeddings import KGE
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, root_mean_squared_error














# @torch.no_grad()




[docs] @torch.no_grad() def evaluate_lp_bpe_k_vs_all(model, triples: List[List[str]], er_vocab=None, batch_size=None, func_triple_to_bpe_representation: Callable = None, str_to_bpe_entity_to_idx=None): # (1) set model to eval model model.model.eval() num_triples = len(triples) ranks = [] # Hit range hits_range = [i for i in range(1, 11)] hits = {i: [] for i in hits_range} # Iterate over integer indexed triples in mini batch fashion for i in range(0, num_triples, batch_size): str_data_batch = triples[i:i + batch_size] # (1) Get a batch of data. torch_batch_bpe_triple = torch.LongTensor( [func_triple_to_bpe_representation(i) for i in str_data_batch]) # (2) Extract entities and relations. bpe_hr = torch_batch_bpe_triple[:, [0, 1], :] # (3) Predict missing entities, i.e., assign probs to all entities. predictions = model(bpe_hr) # (4) Filter entities except the target entity for j in range(len(predictions)): # (4.2) Get all ids of all entities occurring with the head entity and relation extracted in 4.1. h, r, t = str_data_batch[j] id_e_target = str_to_bpe_entity_to_idx[t] filt_idx_entities = [str_to_bpe_entity_to_idx[_] for _ in er_vocab[(h, r)]] # (4.3) Store the assigned score of the target tail entity extracted in 4.1. target_value = predictions[j, id_e_target].item() # (4.4.1) Filter all assigned scores for entities. predictions[j, filt_idx_entities] = -np.Inf # (4.4.2) Filter entities based on the range of a relation as well. # (4.5) Insert 4.3. after filtering. predictions[j, id_e_target] = target_value # (5) Sort predictions. sort_values, sort_idxs = torch.sort(predictions, dim=1, descending=True) # (6) Compute the filtered ranks. for j in range(len(predictions)): t = str_data_batch[j][2] # index between 0 and \inf rank = torch.where(sort_idxs[j] == str_to_bpe_entity_to_idx[t])[0].item() + 1 ranks.append(rank) for hits_level in hits_range: if rank <= hits_level: hits[hits_level].append(1.0) # (7) Sanity checking: a rank for a triple assert len(triples) == len(ranks) == num_triples hit_1 = sum(hits[1]) / num_triples hit_3 = sum(hits[3]) / num_triples hit_10 = sum(hits[10]) / num_triples mean_reciprocal_rank = np.mean(1. / np.array(ranks)) results = {'H@1': hit_1, 'H@3': hit_3, 'H@10': hit_10, 'MRR': mean_reciprocal_rank} return results
[docs] def evaluate_literal_prediction( kge_model: KGE, eval_file_path: str = None, store_lit_preds: bool = True, eval_literals: bool = True, loader_backend: str = "pandas", return_attr_error_metrics: bool = False, ): """ Evaluates the trained literal prediction model on a test file. Args: eval_file_path (str): Path to the evaluation file. store_lit_preds (bool): If True, stores the predictions in a CSV file. eval_literals (bool): If True, evaluates the literal predictions and prints error metrics. loader_backend (str): Backend for loading the dataset ('pandas' or 'rdflib'). Returns: pd.DataFrame: DataFrame containing error metrics for each attribute if return_attr_error_metrics is True. Raises: RuntimeError: If the kGE model does not have a trained literal model. AssertionError: If the kGE model is not an instance of KGE or if the test set has no valid entities or attributes. """ # kGE Literal model sanity checking assert isinstance(kge_model, KGE), "kge_model must be an instance of KGE." if not hasattr(kge_model, "literal_model") or kge_model.literal_model is None: raise RuntimeError("Literal model is not trained or loaded.") # sanity checking done in load_and_validate_literal_data test_df_unfiltered = kge_model.literal_dataset.load_and_validate_literal_data( file_path=eval_file_path,loader_backend=loader_backend ) test_df = test_df_unfiltered[ test_df_unfiltered["head"].isin(kge_model.entity_to_idx.keys()) & test_df_unfiltered["attribute"].isin(kge_model.data_property_to_idx.keys()) ] entities = test_df["head"].to_list() attributes = test_df["attribute"].to_list() assert len(entities) > 0, "No valid entities in test set — check entity_to_idx." assert len(attributes) > 0, "No valid attributes in test set — check data_property_to_idx." test_df["predictions"] = kge_model.predict_literals( entity=entities, attribute=attributes ) # If store_lit_preds is True, save the predictions to a CSV file if store_lit_preds: prediction_df = test_df[["head", "attribute", "predictions"]] prediction_path = os.path.join(kge_model.path, "lit_predictions.csv") prediction_df.to_csv(prediction_path, index=False) print(f"Literal predictions saved to {prediction_path}") # Calculate,print and store error metrics if eval_literals: attr_error_metrics = test_df.groupby("attribute").agg( MAE=("value", lambda x: mean_absolute_error(x, test_df.loc[x.index, "predictions"])), RMSE=("value", lambda x: root_mean_squared_error(x, test_df.loc[x.index, "predictions"])) ).reset_index() pd.options.display.float_format = "{:.6f}".format print("Literal-Prediction evaluation results on Test Set") print(attr_error_metrics) results_path = os.path.join(kge_model.path, "lit_eval_results.csv") attr_error_metrics.to_csv(results_path, index=False) print(f"Literal-Prediction evaluation results saved to {results_path}") if return_attr_error_metrics: return attr_error_metrics