Source code for dicee.evaluation.link_prediction

"""Link prediction evaluation functions.

This module provides various functions for evaluating link prediction
performance of knowledge graph embedding models.
"""

from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm

from .utils import (
    compute_metrics_from_ranks,
    compute_metrics_from_ranks_simple,
    update_hits,
    create_hits_dict,
    ALL_HITS_RANGE,
)


















[docs] @torch.no_grad() def evaluate_lp( model, triple_idx, num_entities: int, er_vocab: Dict[Tuple, List], re_vocab: Dict[Tuple, List], info: str = 'Eval Starts', batch_size: int = 128, chunk_size: int = 1000 ) -> Dict[str, float]: """Evaluate link prediction with batched processing. Memory-efficient evaluation using chunked entity scoring. Args: model: The KGE model to evaluate. triple_idx: Integer-indexed triples as numpy array. num_entities: Total number of entities. er_vocab: Mapping (head_idx, rel_idx) -> list of tail indices. re_vocab: Mapping (rel_idx, tail_idx) -> list of head indices. info: Description to print. batch_size: Batch size for triple processing. chunk_size: Chunk size for entity scoring. Returns: Dictionary with H@1, H@3, H@10, and MRR metrics. """ assert model is not None, "Model must be provided" assert triple_idx is not None, "triple_idx must be provided" assert num_entities is not None, "num_entities must be provided" assert er_vocab is not None, "er_vocab must be provided" assert re_vocab is not None, "re_vocab must be provided" model.eval() print(info) print(f'Num of triples {len(triple_idx)}') hits = {} reciprocal_ranks = [] all_entities = torch.arange(0, num_entities).long() for batch_start in tqdm(range(0, len(triple_idx), batch_size), desc="Evaluating Batches"): batch_end = min(batch_start + batch_size, len(triple_idx)) batch_triples = triple_idx[batch_start:batch_end] batch_size_current = len(batch_triples) h_batch = torch.tensor([dp[0] for dp in batch_triples]) r_batch = torch.tensor([dp[1] for dp in batch_triples]) t_batch = torch.tensor([dp[2] for dp in batch_triples]) predictions_tails = torch.zeros(batch_size_current, num_entities) predictions_heads = torch.zeros(batch_size_current, num_entities) # Process entities in chunks for chunk_start in range(0, num_entities, chunk_size): chunk_end = min(chunk_start + chunk_size, num_entities) entities_chunk = all_entities[chunk_start:chunk_end] chunk_size_current = entities_chunk.size(0) # Tail prediction x_tails = torch.stack(( h_batch.repeat_interleave(chunk_size_current), r_batch.repeat_interleave(chunk_size_current), entities_chunk.repeat(batch_size_current) ), dim=1) preds_tails = model(x_tails).view(batch_size_current, chunk_size_current) predictions_tails[:, chunk_start:chunk_end] = preds_tails del x_tails # Head prediction x_heads = torch.stack(( entities_chunk.repeat(batch_size_current), r_batch.repeat_interleave(chunk_size_current), t_batch.repeat_interleave(chunk_size_current) ), dim=1) preds_heads = model(x_heads).view(batch_size_current, chunk_size_current) predictions_heads[:, chunk_start:chunk_end] = preds_heads del x_heads # Compute filtered ranks for i in range(batch_size_current): h = h_batch[i].item() r = r_batch[i].item() t = t_batch[i].item() # Tail filtering filt_tails = set(er_vocab[(h, r)]) - {t} target_value = predictions_tails[i, t].item() predictions_tails[i, list(filt_tails)] = -np.Inf predictions_tails[i, t] = target_value _, sort_idxs = torch.sort(predictions_tails[i], descending=True) filt_tail_entity_rank = np.where(sort_idxs.detach() == t)[0][0] # Head filtering filt_heads = set(re_vocab[(r, t)]) - {h} target_value = predictions_heads[i, h].item() predictions_heads[i, list(filt_heads)] = -np.Inf predictions_heads[i, h] = target_value _, sort_idxs = torch.sort(predictions_heads[i], descending=True) filt_head_entity_rank = np.where(sort_idxs.detach() == h)[0][0] filt_head_entity_rank += 1 filt_tail_entity_rank += 1 rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank) reciprocal_ranks.append(rr) for hits_level in range(1, 11): res = 1 if filt_head_entity_rank <= hits_level else 0 res += 1 if filt_tail_entity_rank <= hits_level else 0 if res > 0: hits.setdefault(hits_level, []).append(res) results = compute_metrics_from_ranks( ranks=[], num_triples=len(triple_idx), hits_dict=hits, scale_factor=2 ) | {'MRR': sum(reciprocal_ranks) / (float(len(triple_idx) * 2))} print(results) return results
[docs] @torch.no_grad() def evaluate_bpe_lp( model, triple_idx: List[Tuple], all_bpe_shaped_entities, er_vocab: Dict[Tuple, List], re_vocab: Dict[Tuple, List], info: str = 'Eval Starts' ) -> Dict[str, float]: """Evaluate link prediction with BPE-encoded entities. Args: model: The KGE model to evaluate. triple_idx: List of BPE-encoded triple tuples. all_bpe_shaped_entities: All entities with BPE representations. er_vocab: Mapping for tail filtering. re_vocab: Mapping for head filtering. info: Description to print. Returns: Dictionary with H@1, H@3, H@10, and MRR metrics. """ assert isinstance(triple_idx, list) assert isinstance(triple_idx[0], tuple) assert len(triple_idx[0]) == 3 model.eval() print(info) print(f'Num of triples {len(triple_idx)}') hits = {} reciprocal_ranks = [] num_entities = len(all_bpe_shaped_entities) bpe_entity_to_idx = {} all_bpe_entities = [] for idx, (str_entity, bpe_entity, shaped_bpe_entity) in tqdm(enumerate(all_bpe_shaped_entities)): bpe_entity_to_idx[shaped_bpe_entity] = idx all_bpe_entities.append(shaped_bpe_entity) all_bpe_entities = torch.LongTensor(all_bpe_entities) for (bpe_h, bpe_r, bpe_t) in tqdm(triple_idx): idx_bpe_h = bpe_entity_to_idx[bpe_h] idx_bpe_t = bpe_entity_to_idx[bpe_t] torch_bpe_h = torch.LongTensor(bpe_h).unsqueeze(0) torch_bpe_r = torch.LongTensor(bpe_r).unsqueeze(0) torch_bpe_t = torch.LongTensor(bpe_t).unsqueeze(0) # Tail predictions x = torch.stack(( torch.repeat_interleave(input=torch_bpe_h, repeats=num_entities, dim=0), torch.repeat_interleave(input=torch_bpe_r, repeats=num_entities, dim=0), all_bpe_entities ), dim=1) predictions_tails = model(x) # Head predictions x = torch.stack(( all_bpe_entities, torch.repeat_interleave(input=torch_bpe_r, repeats=num_entities, dim=0), torch.repeat_interleave(input=torch_bpe_t, repeats=num_entities, dim=0) ), dim=1) predictions_heads = model(x) # Filter tails filt_tails = [bpe_entity_to_idx[i] for i in er_vocab[(bpe_h, bpe_r)]] target_value = predictions_tails[idx_bpe_t].item() predictions_tails[filt_tails] = -np.Inf predictions_tails[idx_bpe_t] = target_value _, sort_idxs = torch.sort(predictions_tails, descending=True) filt_tail_entity_rank = np.where(sort_idxs.detach() == idx_bpe_t)[0][0] # Filter heads filt_heads = [bpe_entity_to_idx[i] for i in re_vocab[(bpe_r, bpe_t)]] target_value = predictions_heads[idx_bpe_h].item() predictions_heads[filt_heads] = -np.Inf predictions_heads[idx_bpe_h] = target_value _, sort_idxs = torch.sort(predictions_heads, descending=True) filt_head_entity_rank = np.where(sort_idxs.detach() == idx_bpe_h)[0][0] filt_head_entity_rank += 1 filt_tail_entity_rank += 1 rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank) reciprocal_ranks.append(rr) for hits_level in range(1, 11): res = 1 if filt_head_entity_rank <= hits_level else 0 res += 1 if filt_tail_entity_rank <= hits_level else 0 if res > 0: hits.setdefault(hits_level, []).append(res) results = compute_metrics_from_ranks( ranks=[], num_triples=len(triple_idx), hits_dict=hits, scale_factor=2 ) | {'MRR': sum(reciprocal_ranks) / (float(len(triple_idx) * 2))} print(results) return results
[docs] @torch.no_grad() def evaluate_lp_bpe_k_vs_all( model, triples: List[List[str]], er_vocab: Optional[Dict] = None, batch_size: Optional[int] = None, func_triple_to_bpe_representation: Optional[Callable] = None, str_to_bpe_entity_to_idx: Optional[Dict] = None ) -> Dict[str, float]: """Evaluate BPE link prediction with KvsAll scoring. Args: model: The KGE model wrapper. triples: List of string triples. er_vocab: Entity-relation vocabulary for filtering. batch_size: Batch size for processing. func_triple_to_bpe_representation: Function to convert triples to BPE. str_to_bpe_entity_to_idx: Mapping from string entities to BPE indices. Returns: Dictionary with H@1, H@3, H@10, and MRR metrics. Raises: ValueError: If batch_size is not provided. """ if batch_size is None: raise ValueError("batch_size must be provided") model.model.eval() num_triples = len(triples) ranks: List[int] = [] hits_range = ALL_HITS_RANGE hits = create_hits_dict(hits_range) for i in range(0, num_triples, batch_size): str_data_batch = triples[i:i + batch_size] torch_batch_bpe_triple = torch.LongTensor([ func_triple_to_bpe_representation(t) for t in str_data_batch ]) bpe_hr = torch_batch_bpe_triple[:, [0, 1], :] predictions = model(bpe_hr) for j in range(len(predictions)): h, r, t = str_data_batch[j] id_e_target = str_to_bpe_entity_to_idx[t] filt_idx_entities = [str_to_bpe_entity_to_idx[_] for _ in er_vocab[(h, r)]] target_value = predictions[j, id_e_target].item() predictions[j, filt_idx_entities] = -np.Inf predictions[j, id_e_target] = target_value _, sort_idxs = torch.sort(predictions, dim=1, descending=True) for j in range(len(predictions)): t = str_data_batch[j][2] rank = torch.where(sort_idxs[j] == str_to_bpe_entity_to_idx[t])[0].item() + 1 ranks.append(rank) update_hits(hits, rank, hits_range) assert len(triples) == len(ranks) == num_triples return compute_metrics_from_ranks_simple(ranks, num_triples, hits)