Source code for dicee.read_preprocess_save_load_kg.save_load_disk

import numpy as np
import polars
import pandas
from .util import load_pickle, load_numpy_ndarray
import os
from dicee.static_funcs import save_pickle, save_numpy_ndarray


[docs] class LoadSaveToDisk: def __init__(self, kg): self.kg = kg
[docs] def save(self): assert self.kg.path_for_deserialization is None if self.kg.path_for_serialization is None: # No serialization return None if self.kg.byte_pair_encoding: save_numpy_ndarray(data=self.kg.train_set, file_path=self.kg.path_for_serialization + '/train_set.npy') print("NO SAVING for BPE at save_load_disk.py") save_pickle(data=self.kg.ordered_bpe_entities, file_path=self.kg.path_for_serialization + '/ordered_bpe_entities.p') save_pickle(data=self.kg.ordered_bpe_relations, file_path=self.kg.path_for_serialization + '/ordered_bpe_relations.p') else: assert isinstance(self.kg.train_set, np.ndarray) # (1) Save dictionary mappings into disk if isinstance(self.kg.entity_to_idx, dict): save_pickle(data=self.kg.entity_to_idx, file_path=self.kg.path_for_serialization + '/entity_to_idx.p') save_pickle(data=self.kg.relation_to_idx, file_path=self.kg.path_for_serialization + '/relation_to_idx.p') elif isinstance(self.kg.entity_to_idx, polars.DataFrame): self.kg.entity_to_idx.write_csv(file=self.kg.path_for_serialization + "/entity_to_idx.csv", include_header=True) self.kg.relation_to_idx.write_csv(file=self.kg.path_for_serialization + "/relation_to_idx.csv", include_header=True) elif isinstance(self.kg.entity_to_idx, pandas.DataFrame): self.kg.entity_to_idx.to_csv(path_or_buf=self.kg.path_for_serialization + "/entity_to_idx.csv", header=True) self.kg.relation_to_idx.to_csv(path_or_buf=self.kg.path_for_serialization + "/relation_to_idx.csv", header=True) else: raise RuntimeError("Unexpected type for entity_to_idx or relation_to_idx") save_numpy_ndarray(data=self.kg.train_set, file_path=self.kg.path_for_serialization + '/train_set.npy') if self.kg.valid_set is not None: save_numpy_ndarray(data=self.kg.valid_set, file_path=self.kg.path_for_serialization + '/valid_set.npy') if self.kg.test_set is not None: save_numpy_ndarray(data=self.kg.test_set, file_path=self.kg.path_for_serialization + '/test_set.npy')
[docs] def load(self): assert self.kg.path_for_deserialization is not None assert self.kg.path_for_serialization == self.kg.path_for_deserialization # Backward compatible loading: prefer CSV, fallback to legacy pickle format. if (os.path.isfile(self.kg.path_for_deserialization + '/entity_to_idx.csv') and os.path.isfile(self.kg.path_for_deserialization + '/relation_to_idx.csv')): self.kg.entity_to_idx = pandas.read_csv(self.kg.path_for_deserialization + '/entity_to_idx.csv', index_col=0) self.kg.relation_to_idx = pandas.read_csv(self.kg.path_for_deserialization + '/relation_to_idx.csv', index_col=0) elif (os.path.isfile(self.kg.path_for_deserialization + '/entity_to_idx.p') and os.path.isfile(self.kg.path_for_deserialization + '/relation_to_idx.p')): self.kg.entity_to_idx = load_pickle(file_path=self.kg.path_for_deserialization + '/entity_to_idx.p') self.kg.relation_to_idx = load_pickle(file_path=self.kg.path_for_deserialization + '/relation_to_idx.p') else: raise FileNotFoundError(f"Could not find mapping files in {self.kg.path_for_deserialization}. " "Expected entity_to_idx/relation_to_idx as either .csv or .p") self.kg.num_entities = len(self.kg.entity_to_idx) self.kg.num_relations = len(self.kg.relation_to_idx) self.kg.train_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/train_set.npy') if os.path.isfile(self.kg.path_for_deserialization + '/valid_set.npy'): self.kg.valid_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/valid_set.npy') if os.path.isfile(self.kg.path_for_deserialization + '/test_set.npy'): self.kg.test_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/test_set.npy') if self.kg.eval_model: self.kg.er_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/er_vocab.p') self.kg.re_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/re_vocab.p') self.kg.ee_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/ee_vocab.p') constraints_path = self.kg.path_for_deserialization + '/constraints.p' if os.path.isfile(constraints_path): self.kg.domain_constraints_per_rel, self.kg.range_constraints_per_rel = load_pickle( file_path=constraints_path )