dicee.evaluation.utils
Utility functions for evaluation module.
This module contains shared helper functions used across different evaluation components.
Functions
|
Wrap an iterable with tqdm progress bar if verbose is True. |
|
Compute standard link prediction metrics from ranks. |
|
Compute link prediction metrics without scaling factor. |
|
Update hits dictionary based on rank. |
|
Efficiently zero gradients using parameter.grad = None. |
Module Contents
- dicee.evaluation.utils.make_iterable_verbose(iterable_object: Iterable, verbose: bool, desc: str = 'Default', position: int = None, leave: bool = True) Iterable
Wrap an iterable with tqdm progress bar if verbose is True.
- Parameters:
iterable_object – The iterable to potentially wrap.
verbose – Whether to show progress bar.
desc – Description for the progress bar.
position – Position of the progress bar.
leave – Whether to leave the progress bar after completion.
- Returns:
The original iterable or a tqdm-wrapped version.
- dicee.evaluation.utils.compute_metrics_from_ranks(ranks: List[int], num_triples: int, hits_dict: Dict[int, List[float]], scale_factor: int = 1) Dict[str, float]
Compute standard link prediction metrics from ranks.
- Parameters:
ranks – List of ranks for each prediction.
num_triples – Total number of triples evaluated.
hits_dict – Dictionary mapping hit levels to lists of hits.
scale_factor – Factor to scale the denominator (e.g., 2 for head+tail).
- Returns:
Dictionary containing H@1, H@3, H@10, and MRR metrics.
- dicee.evaluation.utils.compute_metrics_from_ranks_simple(ranks: List[int], num_triples: int, hits_dict: Dict[int, List[float]]) Dict[str, float]
Compute link prediction metrics without scaling factor.
- Parameters:
ranks – List of ranks for each prediction.
num_triples – Total number of triples evaluated.
hits_dict – Dictionary mapping hit levels to lists of hits.
- Returns:
Dictionary containing H@1, H@3, H@10, and MRR metrics.
- dicee.evaluation.utils.update_hits(hits: Dict[int, List[float]], rank: int, hits_range: List[int] = None) None
Update hits dictionary based on rank.
- Parameters:
hits – Dictionary to update in-place.
rank – The rank to check against hit levels.
hits_range – List of hit levels to check (default: 1-10).
- dicee.evaluation.utils.efficient_zero_grad(model) None
Efficiently zero gradients using parameter.grad = None.
This is more efficient than optimizer.zero_grad() as it avoids memory operations.
See: https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html
- Parameters:
model – PyTorch model to zero gradients for.