Source code for dicee.sanity_checkers

import glob
import os

import requests
import torch


[docs] def is_sparql_endpoint_alive(sparql_endpoint: str = None): if sparql_endpoint: query = """SELECT (COUNT(*) as ?num_triples) WHERE { ?s ?p ?o .} """ response = requests.post(sparql_endpoint, data={'query': query}) assert response.ok print('SPARQL connection is successful') return response.ok else: return False
[docs] def validate_knowledge_graph(args): """ Validating the source of knowledge graph """ # (1) Validate SPARQL endpoint if is_sparql_endpoint_alive(args.sparql_endpoint): try: assert args.dataset_dir is None and args.path_single_kg is None except AssertionError: raise RuntimeWarning(f'The dataset_dir and path_single_kg arguments ' f'must be None if sparql_endpoint is given.' f'***{args.dataset_dir}***\n' f'***{args.path_single_kg}***\n' f'These two parameters are set to None.') # Set None. args.dataset_dir = None args.path_single_kg = None elif args.path_single_kg is not None: if args.sparql_endpoint is not None or args.path_single_kg is not None: #print(f'The dataset_dir and sparql_endpoint arguments ' # f'must be None if path_single_kg is given.' # f'***{args.dataset_dir}***\n' # f'***{args.sparql_endpoint}***\n' # f'These two parameters are set to None.') args.dataset_dir = None args.sparql_endpoint = None elif args.dataset_dir: try: assert isinstance(args.dataset_dir, str) except AssertionError: raise AssertionError(f'The dataset_dir must be string sparql_endpoint is not given.' f'***{args.dataset_dir}***') try: assert os.path.isdir(args.dataset_dir) or os.path.isfile(args.dataset_dir) except AssertionError: raise FileNotFoundError( f"Dataset directory not found: {args.dataset_dir}\n" f"\nSuggestions:\n" f" 1. Download datasets:\n" f" wget https://files.dice-research.org/datasets/dice-embeddings/KGs.zip --no-check-certificate\n" f" unzip KGs.zip\n" f" 2. Use absolute path: --dataset_dir /absolute/path/to/KGs/UMLS\n" f" 3. Check current directory: {os.getcwd()}\n" ) # Check whether the input parameter leads a standard data format (e.g. FOLDER/train.txt) if glob.glob(args.dataset_dir + '/train*'): """ all is good we have xxx/train.txt""" else: raise ValueError( f"Dataset directory must contain train.txt file: {args.dataset_dir}\n" f"\nExpected structure:\n" f" {args.dataset_dir}/\n" f" ├── train.txt (required)\n" f" ├── valid.txt (optional)\n" f" └── test.txt (optional)\n" f"\nFor single file datasets, use: --path_single_kg folder/dataset.owl\n" f"For SPARQL endpoints, use: --sparql_endpoint http://localhost:3030/dataset/\n" ) if args.sparql_endpoint is not None or args.path_single_kg is not None: #print(f'The sparql_endpoint and path_single_kg arguments ' # f'must be None if dataset_dir is given.' # f'***{args.sparql_endpoint}***\n' # f'***{args.path_single_kg}***\n' # f'These two parameters are set to None.') args.sparql_endpoint = None args.path_single_kg = None elif args.dataset_dir is None and args.path_single_kg is None and args.sparql_endpoint is None: raise ValueError( "No data source specified. You must provide ONE of the following:\n" "\nOption 1: Standard dataset folder\n" " --dataset_dir KGs/UMLS\n" "\nOption 2: Single RDF/OWL file\n" " --path_single_kg KGs/Family/family.owl --backend rdflib\n" "\nOption 3: SPARQL endpoint\n" " --sparql_endpoint http://localhost:3030/mydata/\n" "\nFor examples, see: tests/test_different_backends.py\n" ) else: raise RuntimeError('Invalid computation flow!')
[docs] def sanity_checking_with_arguments(args): assert args.embedding_dim > 0, f"embedding_dim must be strictly positive. Currently:{args.embedding_dim}" valid_techniques = ["AllvsAll", "1vsSample", "KvsSample", "KvsAll", "FixedNegSample", "NegSample", "1vsAll", "Pyke", "Sentence"] assert args.scoring_technique in valid_techniques, f"Invalid training strategy => {args.scoring_technique}." assert args.learning_rate > 0, f"Learning rate must be greater than 0. Currently:{args.learning_rate}" if args.num_folds_for_cv is None: args.num_folds_for_cv = 0 assert args.num_folds_for_cv >= 0, f"num_folds_for_cv can not be negative. Currently:{args.num_folds_for_cv}" validate_knowledge_graph(args)
[docs] def sanity_check_callback_args(args): """ Perform sanity checks on callback-related arguments. """ gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 # Check if any callbacks are requested if (args.trainer == "PL" and gpu_count >= 2) or args.trainer == "torchDDP": if args.path_to_store_single_run is None: raise NotImplementedError("Path to store experiments must be provided for Multi-GPU training.") if args.adaptive_lr: raise NotImplementedError("Adaptive learning rate is not supported with Multi-GPU training.") has_callbacks = any([args.swa, args.swag, args.ema, args.adaptive_swa, args.twa, args.adaptive_lr, args.eval_every_n_epochs > 0, args.eval_at_epochs is not None]) if not has_callbacks: return # No callbacks, no checks needed # SWA-related checks if any([args.swa, args.swag, args.ema, args.twa]): args.swa_start_epoch = args.swa_start_epoch or 1 assert args.swa_start_epoch > 0, "SWA Start Epoch must be greater than 0" # TWA/SWAG trainer compatibility if any([args.twa, args.swag]) and args.trainer in {"TP", "torchDDP"}: raise NotImplementedError("TWA and SWAG are not supported with TP or torchDDP trainers.")