Source code for dicee.eval_static_funcs

import torch
from typing import Dict, Tuple, List, Callable
from .knowledge_graph_embeddings import KGE
from tqdm import tqdm
import numpy as np

# @torch.no_grad()

[docs] @torch.no_grad() def evaluate_lp_bpe_k_vs_all(model, triples: List[List[str]], er_vocab=None, batch_size=None, func_triple_to_bpe_representation: Callable = None, str_to_bpe_entity_to_idx=None): # (1) set model to eval model model.model.eval() num_triples = len(triples) ranks = [] # Hit range hits_range = [i for i in range(1, 11)] hits = {i: [] for i in hits_range} # Iterate over integer indexed triples in mini batch fashion for i in range(0, num_triples, batch_size): str_data_batch = triples[i:i + batch_size] # (1) Get a batch of data. torch_batch_bpe_triple = torch.LongTensor( [func_triple_to_bpe_representation(i) for i in str_data_batch]) # (2) Extract entities and relations. bpe_hr = torch_batch_bpe_triple[:, [0, 1], :] # (3) Predict missing entities, i.e., assign probs to all entities. predictions = model(bpe_hr) # (4) Filter entities except the target entity for j in range(len(predictions)): # (4.2) Get all ids of all entities occurring with the head entity and relation extracted in 4.1. h, r, t = str_data_batch[j] id_e_target = str_to_bpe_entity_to_idx[t] filt_idx_entities = [str_to_bpe_entity_to_idx[_] for _ in er_vocab[(h, r)]] # (4.3) Store the assigned score of the target tail entity extracted in 4.1. target_value = predictions[j, id_e_target].item() # (4.4.1) Filter all assigned scores for entities. predictions[j, filt_idx_entities] = -np.Inf # (4.4.2) Filter entities based on the range of a relation as well. # (4.5) Insert 4.3. after filtering. predictions[j, id_e_target] = target_value # (5) Sort predictions. sort_values, sort_idxs = torch.sort(predictions, dim=1, descending=True) # (6) Compute the filtered ranks. for j in range(len(predictions)): t = str_data_batch[j][2] # index between 0 and \inf rank = torch.where(sort_idxs[j] == str_to_bpe_entity_to_idx[t])[0].item() + 1 ranks.append(rank) for hits_level in hits_range: if rank <= hits_level: hits[hits_level].append(1.0) # (7) Sanity checking: a rank for a triple assert len(triples) == len(ranks) == num_triples hit_1 = sum(hits[1]) / num_triples hit_3 = sum(hits[3]) / num_triples hit_10 = sum(hits[10]) / num_triples mean_reciprocal_rank = np.mean(1. / np.array(ranks)) results = {'H@1': hit_1, 'H@3': hit_3, 'H@10': hit_10, 'MRR': mean_reciprocal_rank} return results