Source code for dicee.read_preprocess_save_load_kg.save_load_disk

import numpy as np
import polars
import pandas
from .util import load_pickle, load_numpy_ndarray
import os
from dicee.static_funcs import save_pickle, save_numpy_ndarray


[docs] class LoadSaveToDisk: def __init__(self, kg): self.kg = kg
[docs] def save(self): assert self.kg.path_for_deserialization is None if self.kg.path_for_serialization is None: # No serialization return None if self.kg.byte_pair_encoding: save_numpy_ndarray(data=self.kg.train_set, file_path=self.kg.path_for_serialization + '/train_set.npy') print("NO SAVING for BPE at save_load_disk.py") save_pickle(data=self.kg.ordered_bpe_entities, file_path=self.kg.path_for_serialization + '/ordered_bpe_entities.p') save_pickle(data=self.kg.ordered_bpe_relations, file_path=self.kg.path_for_serialization + '/ordered_bpe_relations.p') else: assert isinstance(self.kg.train_set, np.ndarray) # (1) Save dictionary mappings into disk if isinstance(self.kg.entity_to_idx, dict): save_pickle(data=self.kg.entity_to_idx, file_path=self.kg.path_for_serialization + '/entity_to_idx.p') save_pickle(data=self.kg.relation_to_idx, file_path=self.kg.path_for_serialization + '/relation_to_idx.p') elif isinstance(self.kg.entity_to_idx, polars.DataFrame): self.kg.entity_to_idx.write_csv(file=self.kg.path_for_serialization + "/entity_to_idx.csv", include_header=True) self.kg.relation_to_idx.write_csv(file=self.kg.path_for_serialization + "/relation_to_idx.csv", include_header=True) elif isinstance(self.kg.entity_to_idx, pandas.DataFrame): self.kg.entity_to_idx.to_csv(path_or_buf=self.kg.path_for_serialization + "/entity_to_idx.csv", header=True) self.kg.relation_to_idx.to_csv(path_or_buf=self.kg.path_for_serialization + "/relation_to_idx.csv", header=True) else: raise RuntimeError("Unexpected type for entity_to_idx or relation_to_idx") save_numpy_ndarray(data=self.kg.train_set, file_path=self.kg.path_for_serialization + '/train_set.npy') if self.kg.valid_set is not None: save_numpy_ndarray(data=self.kg.valid_set, file_path=self.kg.path_for_serialization + '/valid_set.npy') if self.kg.test_set is not None: save_numpy_ndarray(data=self.kg.test_set, file_path=self.kg.path_for_serialization + '/test_set.npy')
[docs] def load(self): assert self.kg.path_for_deserialization is not None assert self.kg.path_for_serialization == self.kg.path_for_deserialization self.kg.entity_to_idx = load_pickle(file_path=self.kg.path_for_deserialization + '/entity_to_idx.p') self.kg.relation_to_idx = load_pickle(file_path=self.kg.path_for_deserialization + '/relation_to_idx.p') assert isinstance(self.kg.entity_to_idx, dict) assert isinstance(self.kg.relation_to_idx, dict) self.kg.num_entities = len(self.kg.entity_to_idx) self.kg.num_relations = len(self.kg.relation_to_idx) self.kg.train_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/train_set.npy') if os.path.isfile(self.kg.path_for_deserialization + '/valid_set.npy'): self.kg.valid_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/valid_set.npy') if os.path.isfile(self.kg.path_for_deserialization + '/test_set.npy'): self.kg.test_set = load_numpy_ndarray(file_path=self.kg.path_for_deserialization + '/test_set.npy') if self.kg.eval_model: self.kg.er_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/er_vocab.p') self.kg.re_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/re_vocab.p') self.kg.ee_vocab = load_pickle(file_path=self.kg.path_for_deserialization + '/ee_vocab.p') self.kg.domain_constraints_per_rel, self.kg.range_constraints_per_rel = load_pickle( file_path=self.kg.path_for_deserialization + '/constraints.p')