diff --git a/docs/source/03_api/core/edge_generators.rst b/docs/source/03_api/core/edge_generators.rst index b156a27..bb1635c 100644 --- a/docs/source/03_api/core/edge_generators.rst +++ b/docs/source/03_api/core/edge_generators.rst @@ -1,38 +1,38 @@ Edge Generators ======================== -- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. +- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. -- :func:`element_element_bonds ` - A function that generates the bonds of an element. +- :func:`element_element_bonds ` - A function that generates the bonds of an element. -- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. +- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. -- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. +- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. -- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. +- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. -- :func:`material_element_has ` - A function that generates the elements of a material. +- :func:`material_element_has ` - A function that generates the elements of a material. -- :func:`material_lattice_has ` - A function that generates the lattice of a material. +- :func:`material_lattice_has ` - A function that generates the lattice of a material. -- :func:`material_spg_has ` - A function that generates the space group of a material. +- :func:`material_spg_has ` - A function that generates the space group of a material. -- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. +- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. -- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. +- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. .. autosummary:: :toctree: _autosummary - matgraphdb.core.edges.element_element_neighborsByGroupPeriod - matgraphdb.core.edges.element_element_bonds - matgraphdb.core.edges.element_oxiState_canOccur - matgraphdb.core.edges.material_chemenv_containsSite - matgraphdb.core.edges.material_crystalSystem_has - matgraphdb.core.edges.material_element_has - matgraphdb.core.edges.material_lattice_has - matgraphdb.core.edges.material_spg_has - matgraphdb.core.edges.element_chemenv_canOccur - matgraphdb.core.edges.spg_crystalSystem_isApart + matgraphdb.generators.edges.element_element_neighborsByGroupPeriod + matgraphdb.generators.edges.element_element_bonds + matgraphdb.generators.edges.element_oxiState_canOccur + matgraphdb.generators.edges.material_chemenv_containsSite + matgraphdb.generators.edges.material_crystalSystem_has + matgraphdb.generators.edges.material_element_has + matgraphdb.generators.edges.material_lattice_has + matgraphdb.generators.edges.material_spg_has + matgraphdb.generators.edges.element_chemenv_canOccur + matgraphdb.generators.edges.spg_crystalSystem_isApart diff --git a/docs/source/03_api/core/material_store.rst b/docs/source/03_api/core/material_store.rst index cf7be9d..90e1536 100644 --- a/docs/source/03_api/core/material_store.rst +++ b/docs/source/03_api/core/material_store.rst @@ -1,9 +1,9 @@ MaterialStore ======================== -- :class:`MaterialStore ` - A store for managing materials in a graph database. +- :class:`MaterialStore ` - A store for managing materials in a graph database. .. autosummary:: :toctree: _autosummary - matgraphdb.core.nodes.materials.MaterialStore + matgraphdb.core.material_store.MaterialStore diff --git a/docs/source/03_api/core/node_generators.rst b/docs/source/03_api/core/node_generators.rst index c699279..69a0396 100644 --- a/docs/source/03_api/core/node_generators.rst +++ b/docs/source/03_api/core/node_generators.rst @@ -2,28 +2,39 @@ Node Generators ======================== -- :func:`element ` - A function that generates the elements of a material. +- :func:`element ` - A function that generates the elements of a material. -- :func:`chemenv ` - A function that generates the chemical environments of a material. +- :func:`chemenv ` - A function that generates the chemical environments of a material. -- :func:`crystal_system ` - A function that generates the crystal systems of a material. +- :func:`crystal_system ` - A function that generates the crystal systems of a material. + +- :func:`magnetic_state ` - A function that generates the magnetic states of a material. + +- :func:`oxidation_state ` - A function that generates the oxidation states of a material. + +- :func:`space_group ` - A function that generates the space groups of a material. + +- :func:`wyckoff ` - A function that generates the wyckoffs of a material. + +- :func:`lattice ` - A function that generates the lattices of a material. + +- :func:`material_site ` - A function that generates the sites of a material. -- :func:`magnetic_state ` - A function that generates the magnetic states of a material. -- :func:`oxidation_state ` - A function that generates the oxidation states of a material. -- :func:`space_group ` - A function that generates the space groups of a material. -- :func:`wyckoff ` - A function that generates the wyckoffs of a material. .. autosummary:: :toctree: _autosummary - matgraphdb.core.nodes.generators.element - matgraphdb.core.nodes.generators.chemenv - matgraphdb.core.nodes.generators.crystal_system - matgraphdb.core.nodes.generators.magnetic_state - matgraphdb.core.nodes.generators.oxidation_state - matgraphdb.core.nodes.generators.space_group - matgraphdb.core.nodes.generators.wyckoff + matgraphdb.generators.nodes.element + matgraphdb.generators.nodes.chemenv + matgraphdb.generators.nodes.crystal_system + matgraphdb.generators.nodes.magnetic_state + matgraphdb.generators.nodes.oxidation_state + matgraphdb.generators.nodes.space_group + matgraphdb.generators.nodes.wyckoff + matgraphdb.generators.nodes.lattice + matgraphdb.generators.nodes.material_site + diff --git a/examples/notebooks/01 - Getting Started.ipynb b/examples/notebooks/01 - Getting Started.ipynb index bfae0c2..695bd4c 100644 --- a/examples/notebooks/01 - Getting Started.ipynb +++ b/examples/notebooks/01 - Getting Started.ipynb @@ -259,7 +259,6 @@ } ], "source": [ - "\n", "from matgraphdb import MatGraphDB\n", "\n", "if not os.path.exists(MATGRAPHDB_PATH):\n", @@ -284,27 +283,24 @@ "metadata": {}, "outputs": [], "source": [ - "from matgraphdb.core.nodes import (\n", - " element, chemenv, crystal_system, magnetic_state, \n", - " oxidation_state, space_group, wyckoff, material_site, material_lattice\n", - ")\n", + "from matgraphdb import generators\n", "\n", "# Here we define the generator functions and arguments if they are needed. \n", "# For instance, to get the materials sites and lattices, we need to pass the materials store to the generator function.\n", "node_generators = [\n", - " {\"generator_func\": element},\n", - " {\"generator_func\": chemenv},\n", - " {\"generator_func\": crystal_system},\n", - " {\"generator_func\": magnetic_state},\n", - " {\"generator_func\": oxidation_state},\n", - " {\"generator_func\": space_group},\n", - " {\"generator_func\": wyckoff},\n", + " {\"generator_func\": generators.element},\n", + " {\"generator_func\": generators.chemenv},\n", + " {\"generator_func\": generators.crystal_system},\n", + " {\"generator_func\": generators.magnetic_state},\n", + " {\"generator_func\": generators.oxidation_state},\n", + " {\"generator_func\": generators.space_group},\n", + " {\"generator_func\": generators.wyckoff},\n", " {\n", - " \"generator_func\": material_site,\n", + " \"generator_func\": generators.material_site,\n", " \"generator_args\": {\"material_store\": mdb.node_stores[\"material\"]},\n", " },\n", " {\n", - " \"generator_func\": material_lattice,\n", + " \"generator_func\": generators.material_lattice,\n", " \"generator_args\": {\"material_store\": mdb.node_stores[\"material\"]},\n", " },\n", "]\n" @@ -805,71 +801,57 @@ } ], "source": [ - "from matgraphdb.core.edges import (\n", - " material_element_has,\n", - " material_lattice_has,\n", - " material_spg_has,\n", - " element_element_neighborsByGroupPeriod,\n", - " element_element_bonds,\n", - " element_oxiState_canOccur,\n", - " material_chemenv_containsSite,\n", - " material_crystalSystem_has,\n", - " element_chemenv_canOccur,\n", - " spg_crystalSystem_isApart,\n", - ")\n", - "\n", - "\n", "\n", "# List of edge generator configurations.\n", "edge_generators = [\n", " {\n", - " \"generator_func\": element_element_neighborsByGroupPeriod,\n", + " \"generator_func\": generators.element_element_neighborsByGroupPeriod,\n", " \"generator_args\": {\"element_store\": mdb.node_stores[\"element\"]},\n", " },\n", " {\n", - " \"generator_func\": element_oxiState_canOccur,\n", + " \"generator_func\": generators.element_oxiState_canOccur,\n", " \"generator_args\": {\n", " \"element_store\": mdb.node_stores[\"element\"],\n", " \"oxiState_store\": mdb.node_stores[\"oxidation_state\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": material_chemenv_containsSite,\n", + " \"generator_func\": generators.material_chemenv_containsSite,\n", " \"generator_args\": {\n", " \"material_store\": mdb.node_stores[\"material\"],\n", " \"chemenv_store\": mdb.node_stores[\"chemenv\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": material_crystalSystem_has,\n", + " \"generator_func\": generators.material_crystalSystem_has,\n", " \"generator_args\": {\n", " \"material_store\": mdb.node_stores[\"material\"],\n", " \"crystal_system_store\": mdb.node_stores[\"crystal_system\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": material_element_has,\n", + " \"generator_func\": generators.material_element_has,\n", " \"generator_args\": {\n", " \"material_store\": mdb.node_stores[\"material\"],\n", " \"element_store\": mdb.node_stores[\"element\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": material_lattice_has,\n", + " \"generator_func\": generators.material_lattice_has,\n", " \"generator_args\": {\n", " \"material_store\": mdb.node_stores[\"material\"],\n", " \"lattice_store\": mdb.node_stores[\"material_lattice\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": material_spg_has,\n", + " \"generator_func\": generators.material_spg_has,\n", " \"generator_args\": {\n", " \"material_store\": mdb.node_stores[\"material\"],\n", " \"spg_store\": mdb.node_stores[\"space_group\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": element_chemenv_canOccur,\n", + " \"generator_func\": generators.element_chemenv_canOccur,\n", " \"generator_args\": {\n", " \"element_store\": mdb.node_stores[\"element\"],\n", " \"chemenv_store\": mdb.node_stores[\"chemenv\"],\n", @@ -877,14 +859,14 @@ " },\n", " },\n", " {\n", - " \"generator_func\": spg_crystalSystem_isApart,\n", + " \"generator_func\": generators.spg_crystalSystem_isApart,\n", " \"generator_args\": {\n", " \"spg_store\": mdb.node_stores[\"space_group\"],\n", " \"crystal_system_store\": mdb.node_stores[\"crystal_system\"],\n", " },\n", " },\n", " {\n", - " \"generator_func\": element_element_bonds,\n", + " \"generator_func\": generators.element_element_bonds,\n", " \"generator_args\": {\n", " \"element_store\": mdb.node_stores[\"element\"],\n", " \"material_store\": mdb.node_stores[\"material\"],\n", diff --git a/matgraphdb/__init__.py b/matgraphdb/__init__.py index efcf75e..2e9e3d9 100644 --- a/matgraphdb/__init__.py +++ b/matgraphdb/__init__.py @@ -1,3 +1,8 @@ +from parquetdb.utils.log_utils import setup_logging + from matgraphdb._version import __version__ + +setup_logging() + from matgraphdb.core import MaterialStore, MatGraphDB from matgraphdb.utils.config import PKG_DIR, config diff --git a/matgraphdb/calculations/README.md b/matgraphdb/calculations/README.md deleted file mode 100644 index 0ada8a8..0000000 --- a/matgraphdb/calculations/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Calculations Module for MatGraphDB - -Welcome to the `calculations` module of the MatGraphDB package. This module is dedicated to performing and managing various types of calculations essential for analyzing material properties and interactions. It is structured into two primary subdirectories, each specializing in a distinct type of calculation relevant to material science research. - -## Module Structure - -The `calculations` module is organized into the following subdirectories: - -- **`mat_calcs`**: Focuses on material-specific calculations that are pivotal for understanding and predicting material properties and behaviors. This includes bonding calculations, chemenv calculations, similarity calculations, and more. The goal of these calculations is to provide insights into the material's structural, chemical, and physical properties from a theoretical or empirical perspective. - -- **`dft_calcs`**: Dedicated to Density Functional Theory (DFT) calculations. DFT is a quantum mechanical modeling method used in physics and chemistry to investigate the electronic structure (principally the ground state) of many-body systems, especially atoms, molecules, and the condensed phases. This directory contains scripts and utilities for setting up, running, and analyzing DFT calculations, focusing on materials science applications. - -## Purpose - -The `calculations` module is designed to serve as a comprehensive toolkit for conducting advanced material analyses within the MatGraphDB framework. By segregating the calculation tools into material-specific (`mat_calcs`) and DFT-focused (`dft_calcs`) categories, the module ensures a structured approach to material research, enabling users to apply the most appropriate computational techniques for their specific needs. - -Whether you are looking to explore new material properties, validate theoretical models, or conduct in-depth material analyses, the `calculations` module provides the necessary computational tools and frameworks to support your research. - -## Getting Started - -To begin using the `calculations` module: - -1. Ensure you have MatGraphDB and its dependencies properly installed. -2. Familiarize yourself with the specific functions and utilities available within the `mat_calcs` and `dft_calcs` directories. Each subdirectory contains a set of scripts designed for specific types of calculations. -3. Consult the documentation and example scripts provided within each subdirectory to understand the input requirements, execution process, and output interpretation for each type of calculation. - -## Contribution - -We welcome contributions to the `calculations` module, including enhancements to existing calculation scripts, addition of new calculation types, and improvements to documentation. If you're interested in contributing, please review our contribution guidelines and submit your pull requests or issues through our repository's issue tracker. - ---- - -This README.md provides a starting point for documenting the `calculations` module within your MatGraphDB package. Adjust and expand upon it as necessary to fit the specifics of your implementation and the needs of your user base. \ No newline at end of file diff --git a/matgraphdb/calculations/__init__.py b/matgraphdb/calculations/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/calculations/calc_manager.py b/matgraphdb/calculations/calc_manager.py deleted file mode 100644 index 177f3c4..0000000 --- a/matgraphdb/calculations/calc_manager.py +++ /dev/null @@ -1,890 +0,0 @@ -import logging -import os -import json -import subprocess -from typing import Callable, Dict, List, Tuple, Union - -from matgraphdb.calculations.job_scheduler_generator import SlurmScriptGenerator -from matgraphdb.utils.mp_utils import multiprocess_task -from matgraphdb.utils.general_utils import get_function_args - -logger = logging.getLogger(__name__) - -# TODO: Add calculation name validation -# TODO: Add validation on calculation function -# TODO: Think about letting the calculation function take as arugment a dictionary vs making the user define the needed fields as the arguments -# Doing the latter would allow me to reduce the number of columns loaded into memory from the database. -# However, the former would be neater and if the user wants to speed up loading they can provide the need fields to read_args - -class CalculationManager: - def __init__(self, main_dir, matdb, n_cores=1, job_submission_script_name='run.slurm'): - """ - Initializes the `CalculationManager` with the specified main directory, database manager, number of cores, - and the name of the job submission script. - - This constructor sets up the `CalculationManager` to manage calculations, handle database interactions, - and support job submission via SLURM. - - Parameters: - ----------- - main_dir : str - Main directory path where calculations will be stored and accessed. - matdb : object - Database manager object for handling database operations. - n_cores : int, optional - Number of cores to use for multiprocessing. Defaults to `N_CORES`. - job_submission_script_name : str, optional - Name of the job submission script. Defaults to 'run.slurm'. - - Examples: - --------- - # Example usage: - # Initialize the CalculationManager with a main directory, database manager, and custom settings - .. highlight:: python - .. code-block:: python - - from matgraphdb.data.material_manager import MaterialDatabaseManager - from matgraphdb.data.calc_manager import CalculationManager - - # Initialize the database manager - matdb = MaterialDatabaseManager(db_dir="/path/to/main/directory/db") - - # Initialize the CalculationManager - calc_manager = CalculationManager(main_dir="/path/to/main/directory/calculations", - matdb=matdb, - n_cores=4, - job_submission_script_name='custom_script.slurm') - """ - - self.matdb = matdb - self.main_dir = main_dir - self.n_cores = n_cores - self.job_submission_script_name = job_submission_script_name - - self.calculation_dir = os.path.join(self.main_dir, 'MaterialsData') - self.metadata_file = os.path.join(self.main_dir, 'metadata.json') - - os.makedirs(self.calculation_dir, exist_ok=True) - - self.initialized=False - - logger.info(f"Initializing CalculationManager with main directory: {main_dir}") - logger.debug(f"Calculation directory set to: {self.calculation_dir}") - logger.debug(f"Metadata file path set to: {self.metadata_file}") - logger.debug(f"Job submission script name: {self.job_submission_script_name}") - logger.debug(f"Number of cores for multiprocessing: {self.n_cores}") - logger.info("Make sure to initialize the calculation manager before using it") - - def _setup_material_directory(self, directory): - """ - Creates the directory structure for a specific material if it doesn't exist. - - Parameters: - directory (str): Path to the material directory. - - Returns: - None - """ - logger.debug(f"Setting up material directory at: {directory}") - os.makedirs(directory, exist_ok=True) - return None - - def _setup_material_directories(self): - """ - Creates directories for all materials in the database and returns their paths. - - This method reads the list of material IDs from the database and creates corresponding directories - for each material in the calculation directory. - - Parameters: - ----------- - None - - Returns: - -------- - list - A list of material directory paths. - - Examples: - --------- - # Example usage: - # Set up material directories for all materials - .. highlight:: python - .. code-block:: python - # Set up material directories - material_dirs = calc_manager._setup_material_directories() - """ - logger.info("Setting up material directories.") - logger.debug("Reading materials from the database.") - table=self.matdb.read(columns=['id']) - id_df=table.to_pandas() - - material_dirs = [] - for i, row in id_df.iterrows(): - material_id = row['id'] - material_directory = os.path.join(self.calculation_dir, material_id) - logger.debug(f"Setting up directory for material ID {material_id} at {material_directory}") - self._setup_material_directory(material_directory) - material_dirs.append(material_directory) - return material_dirs - - def initialize(self): - """ - Initializes the `CalculationManager` by loading metadata and setting up material directories. - - This method is responsible for preparing the `CalculationManager` by loading necessary metadata - and configuring the material directories where calculations will be stored and accessed. - - Parameters: - ----------- - None - - Returns: - -------- - None - - Examples: - --------- - # Example usage: - # Initialize the CalculationManager - .. highlight:: python - .. code-block:: python - - calc_manager.initialize() - """ - logger.info("Initializing CalculationManager.") - self.metadata = self.load_metadata() - logger.debug("Metadata loaded.") - self.material_dirs = self._setup_material_directories() - logger.debug("Material directories set up.") - self.initialized=True - logger.info("CalculationManager initialization complete.") - - def run_inmemory_calculation(self, calc_func:Callable, save_results=False, verbose=False, read_args=None, **kwargs): - """ - Runs a calculation on the data retrieved from the `MaterialDatabaseManager`. - - This method processes data from the material database using a user-defined calculation function. - The `calc_func` is expected to accept a dictionary-like object (each row's data) and return a - dictionary representing the results of the calculation. Optionally, the results can be saved - back to the database. They will save as the key in the return dictionary - - Parameters: - ----------- - calc_func : Callable - A function that processes each row of data. This function should accept a dictionary-like - object representing the row's data and return a dictionary containing the calculated results - for that row. - save_results : bool, optional - A flag indicating whether to save the results back to the database after processing. - Defaults to False. - verbose : bool, optional - A flag indicating whether to print error messages. Defaults to False. - read_args : dict, optional - Additional arguments to pass to the `MaterialDatabaseManager.read` method. - **kwargs - Additional keyword arguments to pass to the calculation function. - - Returns: - -------- - list - A list of result dictionaries returned by the `calc_func` for each row in the database. - - Examples: - --------- - # Example usage: - # Define a calculation function and apply it to in-memory material data - .. highlight:: python - .. code-block:: python - - # Define a custom calculation function - def my_calc_func(row_data, **kwargs): - # Perform some calculations on row_data - return {'result': sum(row_data.values())} - - # Run the in-memory calculation function on all material data - results = calc_manager.run_inmemory_calculation(my_calc_func, save_results=True, verbose=True) - """ - - # arg_names, kwarg_names = get_function_args(calc_func) - # if read_args is None: - # read_args=dict(columns=arg_names) - # read_args['columns'].append('id') - - logger.info("Running in-memory calculation on material data.") - df=self.matdb.read(**read_args) - ids=[] - data=[] - for i,row in df.iterrows(): - ids.append(row['id']) - data.append(row.drop('id').to_dict()) - - logger.debug(f"Retrieved {len(data)} rows from the database.") - - calc_func = calculation_error_handler(calc_func,verbose=verbose) - results=multiprocess_task(calc_func, data, n_cores=self.n_cores, **kwargs) - logger.info("Calculation completed.") - - if save_results: - update_list=[(id,result) for id,result in zip(ids,results)] - self.matdb.update_many(update_list) - logger.info("Results saved back to the database.") - - return results - - def get_calculation_names(self): - """ - Retrieves a list of all calculation names in the database. - - Parameters: - ----------- - None - - Returns: - -------- - List - A list of all calculation names in the database. - - Examples: - --------- - # Example usage: - # Retrieve the list of calculation names - .. highlight:: python - .. code-block:: python - - from matgraphdb.data.material_manager import MaterialDatabaseManager - from matgraphdb.data.calc_manager import CalculationManager - - matdb = MaterialDatabaseManager(db_dir="/path/to/main/directory/db") - calc_manager = CalculationManager(main_dir="/path/to/main/directory/calculations", matdb=matdb) - - # Get the list of calculation names - calculation_names = calc_manager.get_calculation_names() - print(calculation_names) - """ - logger.info("Retrieving calculation names.") - calculation_names = os.listdir(self.material_dirs[0]) - logger.debug(f"Calculation names found: {calculation_names}") - self.update_metadata({'calculation_names': calculation_names}) - return calculation_names - - def create_disk_calculation(self, calc_func: Callable, calc_name: str = None, read_args: dict = None, **kwargs): - """ - Creates a new calculation by applying the provided function to each row in the database. - The `calc_func` expects a dictionary-like object (each row's data) and a directory path where - the calculation results should be stored. The function is responsible for saving the results in - this directory, which is specific to each row and named based on the row's unique ID and the - calculation name. - - Parameters: - ----------- - calc_func : Callable - The function to apply to each material. - The first argument should be a dictionary-like object (each row's data) - and the second argument should be the calculation directory path. - calc_name : str, optional - The name of the calculation. If not provided, it defaults to the name of the function. - read_args : dict, optional - Additional arguments to pass to the `MaterialDatabaseManager.read` method. - **kwargs - Additional arguments to pass to the calculation function. - - Returns: - -------- - List - The results of the calculation for each material. - - Examples: - --------- - # Example usage: - # Define a calculation function and apply it to all materials - .. highlight:: python - .. code-block:: python - - # Define a custom calculation function - def my_calc_func(row_data: dict, calc_dir: str, **kwargs): - # Perform some calculation - return None - - # Apply the calculation function to all materials - results = calc_manager.create_disk_calculation(my_calc_func, 'my_calc') - """ - - if read_args is None: - read_args={} - if calc_name is None: - calc_name = calc_func.__name__ - - logger.info(f"Creating calculation '{calc_name}' for all materials.") - - table = self.matdb.read(**read_args) - df=table.to_pandas() - logger.debug(f"Retrieved {len(table.shape[0])} rows from the database.") - - multi_task_list=[] - for row in df.iterrows(): - id=row['id'] - row_data_dict=row.drop('id').to_dict() - calc_dir=os.path.join(self.material_dirs[id],calc_name) - - multi_task_list.append((row_data_dict,calc_dir)) - - - logger.info(f"Prepared tasks for {len(multi_task_list)} materials.") - # Process each row using multiprocessing, passing the directory structure - logger.debug("Starting calculation tasks.") - results=multiprocess_task(calc_func, multi_task_list, n_cores=self.n_cores, **kwargs) - logger.info(f"Calculation '{calc_name}' completed for all materials.") - return results - - def generate_job_scheduler_script_for_calc(self, calc_dir: str, slurm_config: Dict = None, script_string: str = None): - """ - Generates a SLURM job scheduler submission script for a specific calculation. - - Parameters: - ----------- - calc_dir : str - The directory where the calculation is stored. - slurm_config : dict, optional - Configuration settings for the SLURM script. Defaults to None. - script_string : str, optional - A the job submission script content. Defaults to None. - - Returns: - -------- - Tuple - The path to the generated SLURM script and its content. - - Examples: - --------- - # Example usage: - # Generate a SLURM script for a specific calculation directory - .. highlight:: python - .. code-block:: python - - from matgraphdb.data.material_manager import MaterialDatabaseManager - from matgraphdb.data.calc_manager import CalculationManager - - matdb = MaterialDatabaseManager(db_dir="/path/to/main/directory/db") - calc_manager = CalculationManager(main_dir="/path/to/main/directory/calculations", matdb=matdb) - - # Generate SLURM script with default settings - calc_dir = "/path/to/material/calculation_directory" - slurm_script_path, slurm_script_content = calc_manager.generate_job_scheduler_script_for_calc(calc_dir) - - # Generate SLURM script with custom configuration - slurm_config = {"job_name": "custom_job", "partition": "batch", "time": "02:00:00", "command": "./custom_command.sh"} - slurm_script_path, slurm_script_content = calc_manager.generate_job_scheduler_script_for_calc(calc_dir, slurm_config=slurm_config) - - # Generate jobn submission scripts based on custom string - script_string="Your script string" - results = calc_manager.generate_job_scheduler_script_for_calcs("calculation_name", slurm_script=script_string) - """ - calc_name=os.path.basename(calc_dir) - material_id=os.path.basename(os.path.dirname(calc_dir)) - - logger.info(f"Generating SLURM script for calculation '{calc_name}' in material '{material_id}'.") - - if slurm_config: - logger.debug(f"SLURM config: {slurm_config}") - slurm_generator = SlurmScriptGenerator( - job_name=slurm_config.get('job_name', f"{calc_name}_calc_{material_id}"), - partition=slurm_config.get('partition', 'comm_small_day'), - time=slurm_config.get('time', '24:00:00') - ) - - slurm_generator.init_header() - slurm_generator.add_slurm_header_comp_resources( - n_nodes=slurm_config.get('n_nodes'), - n_tasks=slurm_config.get('n_tasks'), - cpus_per_task=slurm_config.get('cpus_per_task') - ) - slurm_generator.add_slurm_script_body(f"cd {calc_dir}") - slurm_generator.add_slurm_script_body(slurm_config.get('command', './run_calculation.sh')) - - slurm_script = slurm_generator.finalize() - - # Save the SLURM script in the job directory - slurm_script_path = os.path.join(calc_dir, self.job_submission_script_name) - - with open(slurm_script_path, 'w') as f: - f.write(slurm_script) - logger.info(f"SLURM script saved at {slurm_script_path}") - - return slurm_script_path, slurm_script - - def generate_job_scheduler_script_for_calcs(self, calc_name, slurm_config: Dict = None, script_string: str = None, **kwargs): - """ - Generates job scheduler submission scripts for all materials using the specified calculation name. - - Parameters: - ----------- - calc_name : str - The name of the calculation for which job scheduler scripts will be generated. - slurm_config : dict, optional - Configuration settings for the SLURM script. Defaults to None. - script_string : str, optional - A the job submission script content. Defaults to None. - - **kwargs - Additional keyword arguments to pass to the script generator. - - Returns: - -------- - List - The results of script generation for each material. - - Examples: - --------- - # Example usage: - # Generate SLURM scripts for all materials using a specific calculation name - .. highlight:: python - .. code-block:: python - - # Generate SLURM scripts for all materials with default settings - results = calc_manager.generate_job_scheduler_script_for_calcs("calculation_name") - - # Generate SLURM scripts with custom SLURM configuration - slurm_config = {"time": "02:00:00", "partition": "batch"} - results = calc_manager.generate_job_scheduler_script_for_calcs("calculation_name", slurm_config=slurm_config) - - # Generate jobn submission scripts based on custom string - script_string="Your script string" - results = calc_manager.generate_job_scheduler_script_for_calcs("calculation_name", script_string=script_string) - """ - logger.info(f"Generating SLURM scripts for calculation '{calc_name}' for all materials.") - multi_task_list=[] - for material_dir in self.material_dirs: - calc_dir=os.path.join(material_dir,calc_name) - multi_task_list.append(calc_dir) - logger.debug(f"Prepared SLURM script generation tasks for {len(multi_task_list)} materials.") - - results=multiprocess_task(self.generate_job_scheduler_script_for_calc, multi_task_list, n_cores=self.n_cores, slurm_config=slurm_config, script_string=script_string, **kwargs) - logger.info("SLURM script generation completed for all materials.") - return results - - def submit_disk_job(self, slurm_script_path: str, capture_output=True, text=True): - """ - Submits a SLURM job using a specified SLURM script path. - - Parameters: - ----------- - slurm_script_path : str - The path to the SLURM script to be submitted. - capture_output : bool, optional - Whether to capture the output of the SLURM job submission. Defaults to True. - text : bool, optional - Whether to capture output as text. Defaults to True. - - Returns: - -------- - str - The SLURM job ID if the submission is successful. - - Examples: - --------- - # Example usage: - # Submit a SLURM job using a specified script path - .. highlight:: python - .. code-block:: python - - slurm_script = "/path/to/slurm_script.sh" - job_id = calc_manager.submit_disk_jobs(slurm_script) - - # Submit a SLURM job without capturing output - job_id = calc_manager.submit_disk_jobs(slurm_script, capture_output=False) - """ - logger.info(f"Submitting SLURM job with script: {slurm_script_path}") - result = subprocess.run(['sbatch', slurm_script_path], capture_output=capture_output, text=text) - if result.returncode == 0: - # Extract the SLURM job ID from sbatch output - slurm_job_id = result.stdout.strip().split()[-1] - logger.info(f"SLURM job submitted successfully. Job ID: {slurm_job_id}") - return slurm_job_id - else: - logger.error(f"Failed to submit SLURM job with script {slurm_script_path}. Error: {result.stderr}") - raise RuntimeError(f"Failed to submit SLURM job. Error: {result.stderr}") - - def submit_disk_jobs(self, calc_name, ids=None, **kwargs): - """ - Submits SLURM jobs for all materials or a subset of materials by calculation name. - - Parameters: - ----------- - calc_name : str - The name of the calculation for which jobs will be submitted. - ids : list, optional - A list of material IDs to submit jobs for. If not provided, jobs are submitted for all materials. - **kwargs - Additional keyword arguments to pass to the job submission function. - - Returns: - -------- - List - The results of job submission for each material. - - Examples: - --------- - # Example usage: - # Submit SLURM jobs for all materials - .. highlight:: python - .. code-block:: python - - # Submit jobs for all materials for a specific calculation - results = calc_manager.submit_disk_jobs("calculation_name") - - # Submit jobs for specific material IDs - results = calc_manager.submit_disk_jobs("calculation_name", ids=[0, 1]) - """ - logger.info(f"Submitting SLURM jobs for calculation '{calc_name}'") - multi_task_list=[] - if ids is None: - logger.debug("No specific IDs provided, submitting jobs for all materials.") - for material_dir in self.material_dirs: - calc_dir=os.path.join(material_dir,calc_name) - job_submission_script_path=os.path.join(calc_dir, self.job_submission_script_name) - multi_task_list.append(job_submission_script_path) - else: - logger.debug(f"Submitting jobs for material IDs: {ids}") - for id in ids: - calc_dir=os.path.join(self.calculation_dir,id,calc_name) - job_submission_script_path=os.path.join(calc_dir, self.job_submission_script_name) - multi_task_list.append(job_submission_script_path) - logger.debug(f"Prepared job submission scripts for {len(multi_task_list)} materials.") - results=multiprocess_task(self.generate_job_scheduler_script_for_calc, multi_task_list, n_cores=self.n_cores, **kwargs) - - slurm_paths, _ = zip(*results) - - submission_results = multiprocess_task(self.submit_job, slurm_paths, - n_cores=self.n_cores, - capture_output=kwargs.get('capture_output', True), - text=kwargs.get('text', True)) - - - - logger.info("Job submissions completed.") - return results - - def run_func_on_disk_calculation(self, material_id: str, calc_func:Callable, calc_name: str, **kwargs): - """ - Runs a specified function on a specific calculation directory. - The `calc_func` expects the calculation directory path as its only argument. - - Parameters: - ----------- - material_id : str - The ID of the material for which the function will be run. - calc_func : Callable - The function to run on the calculation directory. - This functions first argument should be the calculation directory path. - calc_name : str - The name of the calculation, used to locate the directory for the material. - **kwargs - Additional keyword arguments to pass to the `calc_func`. - - Returns: - -------- - None - This method does not return a value. - - Examples: - --------- - # Example usage: - # Define a function that processes the calculation directory - .. highlight:: python - .. code-block:: python - - def process_calc_directory(calc_dir, **kwargs): - # Custom logic for processing - return {"result_key": "result_value"} - - # Run function on a specific material's calculation directory - calc_manager.run_func_on_disk_calculation("material_1", process_calc_directory, "calc_name") - """ - - logger.info(f"Running function '{calc_func.__name__}' on calculation '{calc_name}' for material ID '{material_id}'.") - calc_dir=os.path.join(self.calculation_dir,material_id,calc_name) - calc_func(calc_dir,**kwargs) - return None - - def run_func_on_disk_calculations(self, calc_func:Callable, calc_name: str, ids=None, **kwargs): - """ - Runs a specified function on all calculation directories or a subset of directories. - - Parameters: - ----------- - calc_func : Callable - The function to run on each calculation directory. - This functions first argument should be the calculation directory path. - calc_name : str - The name of the calculation, used to locate directories for each material where the calculation is stored. - ids : list, optional - A list of material IDs. If not provided, the function will run on all material directories. - **kwargs - Additional keyword arguments to pass to the `calc_func` and the multiprocessing task. - - Returns: - -------- - List - The results of running the function on each calculation directory. - - Examples: - --------- - # Example usage: - # Define a function that processes the calculation directory - .. highlight:: python - .. code-block:: python - - def process_directory(calc_dir, **kwargs): - # Custom logic for processing - return None - - # Run function on specific material IDs - results = calc_manager.run_func_on_disk_calculations(process_directory, "calculation_name", ids=[0, 1]) - - # Run function on all material directories - results = calc_manager.run_func_on_disk_calculations(process_directory, "calculation_name") - """ - - logger.info(f"Running function '{calc_func.__name__}' on calculation '{calc_name}' for multiple materials.") - multi_task_list=[] - if ids is None: - logger.debug("No specific IDs provided, running function on all materials.") - for material_dir in self.material_dirs: - calc_dir=os.path.join(material_dir,calc_name) - multi_task_list.append(calc_dir) - - else: - logger.debug(f"Running function on material IDs: {ids}") - for id in ids: - calc_dir=os.path.join(self.calculation_dir,id,calc_name) - multi_task_list.append(calc_dir) - logger.debug(f"Prepared function tasks for {len(multi_task_list)} materials.") - results=multiprocess_task(calc_func, multi_task_list, n_cores=self.n_cores, **kwargs) - logger.info(f"Function '{calc_func.__name__}' completed for all specified materials.") - return results - - def add_field_from_disk_calculation(self, func:Callable, calc_name: str, ids=None, update_args: dict = None, **kwargs): - """ - Adds calculation data from disk to the database by processing each material's calculation directory and - updating the database with the results. - - Parameters: - ----------- - func : Callable - The function that performs the calculation or processing task on each calculation directory. - The first argument should be the calculation directory path. - This function must return a dictionary with field names as keys and values as values. - calc_name : str - The name of the calculation, used to locate directories for each material where the calculation is stored. - ids : list, optional - A list of material IDs to process. If not provided, all material directories in `self.material_dirs` are processed. - update_args : dict, optional - Additional arguments to pass to the database update operation. - **kwargs - Additional keyword arguments to pass to the `func` and the multiprocessing task. - - Returns: - -------- - None - This method does not return a value. It updates the database with the results of the calculations. - - Examples: - --------- - # Example usage: - # Define a function that processes the calculation directory and returns a dictionary of results - .. highlight:: python - .. code-block:: python - - def process_directory(calc_dir, **kwargs): - # Custom logic for processing - return {"field_name": "value"} - - # Add calculation data to the database - calc_manager.add_field_from_disk_calculation(process_directory, "calculation_name", ids=[0, 1], - update_args={"table_name": "main", field_type_dict={"field_name": float}}) - """ - logger.info(f"Adding calculation data to database for calculation '{calc_name}'.") - multi_task_list=[] - if ids is None: - ids=[] - for material_dir in self.material_dirs: - calc_dir=os.path.join(material_dir,calc_name) - multi_task_list.append(calc_dir) - ids.append(os.path.dirname(material_dir)) - else: - logger.debug(f"Processing material IDs: {ids}") - for id in ids: - calc_dir=os.path.join(self.calculation_dir,id,calc_name) - multi_task_list.append(calc_dir) - logger.info("Calculation data processing completed.") - results=multiprocess_task(func, multi_task_list, n_cores=self.n_cores, **kwargs) - - update_list=[(id,result) for id,result in zip(ids,results)] - update_data=[] - - for id, result in zip(ids,results): - update_dict={} - update_dict.update(result) - update_dict['id']=id - update_data.append(update_dict) - - - logger.debug(f"Updating database with results.") - - self.matdb.update(update_list, **update_args) - logger.info("Database updated with calculation data.") - - def load_metadata(self): - """ - Loads metadata from a JSON file in the main directory. - - This method reads the metadata stored in a JSON file located in the main directory. - If the file does not exist, it returns an empty dictionary. - - Parameters: - ----------- - None - - Returns: - -------- - dict - The loaded metadata. If the file does not exist, returns an empty dictionary. - - Examples: - --------- - # Example usage: - # Load metadata from the main directory - .. highlight:: python - .. code-block:: python - - # Load metadata - metadata = calc_manager.load_metadata() - print(metadata) - """ - logger.info("Loading metadata.") - if os.path.exists(self.metadata_file): - logger.debug(f"Metadata file found at {self.metadata_file}") - with open(self.metadata_file, 'r') as f: - metadata = json.load(f) - return metadata - else: - logger.warning(f"Metadata file not found at {self.metadata_file}, returning empty metadata.") - return {} - - def update_metadata(self, metadata): - """ - Updates the metadata with new information and saves it to the metadata file. - - This method updates the current metadata with new data and saves the updated metadata to the JSON file in the main directory. - - Parameters: - ----------- - metadata : dict - The new metadata to be added or updated. - - Returns: - -------- - None - - Examples: - --------- - # Example usage: - # Update the metadata with new information - .. highlight:: python - .. code-block:: python - - # Update metadata - new_metadata = {"calculation_names": ["calc_1", "calc_2"]} - calc_manager.update_metadata(new_metadata) - """ - logger.info("Updating metadata.") - logger.debug(f"New metadata to update: {metadata}") - self.metadata.update(metadata) - self.save_metadata(self.metadata) - - def save_metadata(self, metadata): - """ - This method writes the provided metadata to a JSON file located in the main directory, ensuring the data is persisted. - - Parameters: - ----------- - metadata : dict - The metadata to save. - - Returns: - -------- - None - - Examples: - --------- - # Example usage: - # Save metadata to the JSON file - .. highlight:: python - .. code-block:: python - - # Save metadata - metadata = {"calculation_names": ["calc_1", "calc_2"]} - calc_manager.save_metadata(metadata) - """ - logger.info(f"Saving metadata to {self.metadata_file}.") - with open(self.metadata_file, 'w') as f: - json.dump(metadata, f, indent=4) - logger.debug("Metadata saved successfully.") - - - -def calculation_error_handler(calc_func: Callable): - """ - A decorator that wraps the user-defined calculation function in a try-except block. - This allows graceful handling of any errors that may occur during the calculation process. - - Args: - calc_func (Callable): The user-defined calculation function to be wrapped. - - Returns: - Callable: A wrapped function that executes the calculation and handles any exceptions. - """ - def wrapper(data): - try: - # Call the user-defined calculation function - return calc_func(data) - except Exception as e: - # Log the error (you could customize this as needed) - logger.error(f"Error in calculation function: {e}\nData: {data}") - # Optionally, you could return a default or partial result, or propagate the error - return {} - - return wrapper - -def disk_calculation(disk_calc_func: Callable, base_directory: str): - """ - A decorator that wraps the disk-based calculation function in a try-except block. - It also ensures that the specified directory for each row (ID) and calculation name - is created before the calculation function is executed. - - Args: - disk_calc_func (Callable): The user-defined disk-based calculation function to be wrapped. - base_directory (str): The base directory where files related to the calculation will be stored. - verbose (bool): If True, prints error messages. Defaults to False. - - Returns: - Callable: A wrapped function that handles errors and ensures the calculation directory exists. - """ - - def wrapper(row_data, row_id, calculation_name): - # Define the directory for this specific row (ID) and calculation - material_specific_directory = os.path.join('MaterialsData',row_id,calculation_name) - calculation_directory = os.path.join(base_directory, material_specific_directory) - os.makedirs(calculation_directory, exist_ok=True) - - try: - # Call the user-defined calculation function, passing the directory as a parameter - return disk_calc_func(row_data, directory=calculation_directory) - except Exception as e: - # Log the error - logger.error(f"Error in disk calculation function: {e}\nData: {row_data}") - # Optionally, return a default or partial result, or propagate the error - return {} - - return wrapper diff --git a/matgraphdb/calculations/dft_calcs/__init__.py b/matgraphdb/calculations/dft_calcs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/calculations/dft_calcs/calc_generators.py b/matgraphdb/calculations/dft_calcs/calc_generators.py deleted file mode 100644 index 04b6f64..0000000 --- a/matgraphdb/calculations/dft_calcs/calc_generators.py +++ /dev/null @@ -1,482 +0,0 @@ -import os -import json -from glob import glob -import subprocess -from typing import Dict, List, Tuple, Union -from multiprocessing import Pool -from functools import partial - -import numpy as np -from pymatgen.core import Structure -from pymatgen.io.vasp.inputs import Incar, Poscar, Potcar, Kpoints - - -from matgraphdb import DBManager -from matgraphdb.calculations.job_scheduler_generator import SlurmScriptGenerator -from matgraphdb.utils import get_logger - -logger=get_logger(__name__, console_out=False, log_level='info') - -class CalculationGenerator: - def __init__(self, calc_dir): - self.calc_dir=calc_dir - os.makedirs(self.calc_dir,exist_ok=True) - - def create_job_scheduler_script(self): - raise NotImplementedError("get_estimated_computaional_resources must be implemented in the child class") - - def write_job(self): - raise NotImplementedError("write_job must be implemented in the child class") - - -class VaspCalcGenerator(CalculationGenerator): - def __init__(self, species, frac_coords, lattice, calc_dir, vasp_pseudos_dir): - super().__init__(calc_dir) - self.vasp_pseudos_dir=vasp_pseudos_dir - - self.structure=None - self.incar_args={} - - self.incar_str='' - self.poscar_str='' - self.potcar_str='' - self.kpoints_str='' - self.job_scheduler_str='' - - self._add_structure(species=species, frac_coords=frac_coords, lattice=lattice) - - def _add_structure(self, species, frac_coords, lattice,): - self.structure=Structure(lattice=lattice, species=species, coords=frac_coords) - self.composition=self.structure.composition - self.elements=list(self.composition.as_dict().keys()) - self.n_atoms=len(species) - return None - - def create_incar(self, incar_args={}): - if self.job_scheduler_str is None: - raise ValueError("Call create_job_scheduler_script first") - for key,value in incar_args.items(): - self.incar_str+=f'{key} = {value}\n' - - def create_poscar(self): - self.poscar_str=self.structure.to(fmt='poscar') - - def create_potcar(self, pseudo_type='potpaw_PBE.52', potcar_str=None): - if potcar_str is not None: - self.potcar_str=potcar_str - return None - - # Create default POTCAR files - pseudos_dir=os.path.join(self.vasp_pseudos_dir,pseudo_type) - tmp_potcar='' - for symbol in self.elements: - pseudo_file=os.path.join(pseudos_dir,symbol,'POTCAR') - - if not os.path.exists(pseudos_dir): - pseudo_file=os.path.join(pseudos_dir,symbol+'_sv','POTCAR') - # open pseudo file and add it to potcar string - if symbol=='Zr_sv': - with open(pseudo_file,'r') as f: - lines=f.readlines() - new_line=lines[3].replace('r','Zr') - lines[3]=new_line - tmp_text=''.join(lines) - tmp_potcar+=tmp_text - tmp_potcar+='' - else: - with open(pseudo_file,'r') as f: - tmp_potcar+=f.read() - tmp_potcar+='' - self.potcar_str=tmp_potcar - return None - - def create_kpoints(self,kpoints:Kpoints): - """ - Create a KPOINTS file from a Kpoints object. Refer to their documentation for more information. - https://pymatgen.org/pymatgen.io.vasp.html#pymatgen.io.vasp.inputs.Kpoints - - Args: - kpoints (Kpoints): A Kpoints object. - """ - self.kpoints_str=str(kpoints) - print(self.kpoints_str) - return None - - def create_job_scheduler_script(self,slurm_script_body=None): - if self.incar_str=='': - raise ValueError("Incar must be created first. Call create_incar") - if self.poscar_str=='': - raise ValueError("Poscar must be created first. Call create_poscar") - if self.potcar_str=='': - raise ValueError("Potcar must be created first. Call create_potcar") - if self.kpoints_str=='': - raise ValueError("Kpoints must be created first. Call create_kpoints") - if self.n_atoms >= 60: - nnode=4 - ncore = 32 - kpar=4 - ntasks=160 - elif self.n_atoms >= 40: - nnode=3 - ncore = 20 - kpar=3 - ntasks=120 - elif self.n_atoms >= 20: - nnode=2 - ncore = 16 - kpar = 2 - ntasks=80 - else: - nnode=1 - ntasks=40 - ncore = 40 - kpar = 1 - - self.job_scheduler_script_generator=SlurmScriptGenerator() - self.job_scheduler_script_generator.init_header() - self.job_scheduler_script_generator.add_slurm_header_comp_resources(n_nodes=nnode, n_tasks=ntasks) - if slurm_script_body is None: - slurm_script_body="\n" - 'source ~/.bashrc\n' - 'module load atomistic/vasp/6.2.1_intel22_impi22\n' - f'cd {self.calc_dir}\n' - f'echo "CALC_DIR: {self.calc_dir}"\n' - f'echo "NCORES: $((SLURM_NTASKS))"\n' - '\n' - f'mpirun -np $SLURM_NTASKS vasp_std\n' - - - - self.job_scheduler_script_generator.add_slurm_script_body(slurm_script_body) - self.job_scheduler_str=self.job_scheduler_script_generator.finalize() - - self.incar_args['NCORE']=ncore - self.incar_args['KPAR']=kpar - - return None - - def write_job(self): - if self.structure is None: - raise ValueError("Structure must be added first") - - if self.incar_str=='': - raise ValueError("Incar must be created first. Call create_incar") - if self.poscar_str=='': - raise ValueError("Poscar must be created first. Call create_poscar") - if self.potcar_str=='': - raise ValueError("Potcar must be created first. Call create_potcar") - if self.kpoints_str=='': - raise ValueError("Kpoints must be created first. Call create_kpoints") - if self.job_scheduler_str=='': - raise ValueError("Job scheduler script must be created first. Call create_job_scheduler_script") - - - with open(os.path.join(self.calc_dir,'INCAR'),'w') as incar: - incar.write(self.incar_str) - - with open(os.path.join(self.calc_dir,'POSCAR'),'w') as poscar: - poscar.write(self.poscar_str) - - with open(os.path.join(self.calc_dir,'POTCAR'),'w') as potcar: - potcar.write(self.potcar_str) - - with open(os.path.join(self.calc_dir,'KPOINTS'),'w') as kpoints: - kpoints.write(self.kpoints_str) - - with open(os.path.join(self.calc_dir,'job_submit.run'),'w') as job_control: - job_control.write(self.job_scheduler_str) - - -class ChargemolCalcGenerator(CalculationGenerator): - def __init__(self, species, frac_coords, lattice, calc_dir, vasp_pseudos_dir): - super().__init__(calc_dir) - self.vasp_pseudos_dir=vasp_pseudos_dir - - self.structure=None - self.incar_args={} - - self.incar_str='' - self.poscar_str='' - self.potcar_str='' - self.kpoints_str='' - self.job_scheduler_str='' - self.job_control_str='' - - self._add_structure(species=species, frac_coords=frac_coords, lattice=lattice) - - def _add_structure(self, species, frac_coords, lattice,): - self.structure=Structure(lattice=lattice, species=species, coords=frac_coords) - self.composition=self.structure.composition - self.elements=list(self.composition.as_dict().keys()) - self.n_atoms=len(species) - return None - - def create_incar(self, incar_args={}): - if self.job_scheduler_str is None: - raise ValueError("Call create_job_scheduler_script first") - for key,value in incar_args.items(): - self.incar_str+=f'{key} = {value}\n' - - def create_poscar(self): - self.poscar_str=self.structure.to(fmt='poscar') - - def create_potcar(self, pseudo_type='potpaw_PBE.52', potcar_str=None): - if potcar_str is not None: - self.potcar_str=potcar_str - return None - - # Create default POTCAR files - pseudos_dir=os.path.join(self.vasp_pseudos_dir,pseudo_type) - tmp_potcar='' - for symbol in self.elements: - pseudo_file=os.path.join(pseudos_dir,symbol,'POTCAR') - - if not os.path.exists(pseudos_dir): - pseudo_file=os.path.join(pseudos_dir,symbol+'_sv','POTCAR') - # open pseudo file and add it to potcar string - if symbol=='Zr_sv': - with open(pseudo_file,'r') as f: - lines=f.readlines() - new_line=lines[3].replace('r','Zr') - lines[3]=new_line - tmp_text=''.join(lines) - tmp_potcar+=tmp_text - tmp_potcar+='' - else: - with open(pseudo_file,'r') as f: - tmp_potcar+=f.read() - tmp_potcar+='' - self.potcar_str=tmp_potcar - return None - - def create_kpoints(self,kpoints:Kpoints): - """ - Create a KPOINTS file from a Kpoints object. Refer to their documentation for more information. - https://pymatgen.org/pymatgen.io.vasp.html#pymatgen.io.vasp.inputs.Kpoints - - Args: - kpoints (Kpoints): A Kpoints object. - """ - self.kpoints_str=str(kpoints) - print(self.kpoints_str) - return None - - def create_job_scheduler_script(self,slurm_script_body=None): - if self.incar_str=='': - raise ValueError("Incar must be created first. Call create_incar") - if self.poscar_str=='': - raise ValueError("Poscar must be created first. Call create_poscar") - if self.potcar_str=='': - raise ValueError("Potcar must be created first. Call create_potcar") - if self.kpoints_str=='': - raise ValueError("Kpoints must be created first. Call create_kpoints") - if self.n_atoms >= 60: - nnode=4 - ncore = 32 - kpar=4 - ntasks=160 - elif self.n_atoms >= 40: - nnode=3 - ncore = 20 - kpar=3 - ntasks=120 - elif self.n_atoms >= 20: - nnode=2 - ncore = 16 - kpar = 2 - ntasks=80 - else: - nnode=1 - ntasks=40 - ncore = 40 - kpar = 1 - - self.job_scheduler_script_generator=SlurmScriptGenerator() - self.job_scheduler_script_generator.init_header() - self.job_scheduler_script_generator.add_slurm_header_comp_resources(n_nodes=nnode, n_tasks=ntasks) - - - if slurm_script_body is None: - slurm_script_body=("\n" - 'source ~/.bashrc\n' - 'module load atomistic/vasp/6.2.1_intel22_impi22\n' - f'cd {self.calc_dir}\n' - f'echo "CALC_DIR: {self.calc_dir}"\n' - f'echo "NCORES: $((SLURM_NTASKS))"\n' - '\n' - f'mpirun -np $SLURM_NTASKS vasp_std\n' - '\n' - f'export OMP_NUM_THREADS=$SLURM_NTASKS\n' - '~/SCRATCH/Codes/chargemol_09_26_2017/chargemol_FORTRAN_09_26_2017/compiled_binaries' - '/linux/Chargemol_09_26_2017_linux_parallel> chargemol_debug.txt 2>&1\n' - '\n' - f'echo "run complete on `hostname`: `date`" 1>&2\n') - - self.job_scheduler_script_generator.add_slurm_script_body(slurm_script_body) - self.job_scheduler_str=self.job_scheduler_script_generator.finalize() - - self.incar_args['NCORE']=ncore - self.incar_args['KPAR']=kpar - - return None - - def create_job_control_script(self,atomic_densities_dir): - job_control_script=("\n"+ - atomic_densities_dir+ '\n' - "\n\n" - - "\n" - "DDEC6\n" - "\n\n" - - "\n" - ".true.\n" - "\n") - - self.job_control_str=job_control_script - return job_control_script - - - def write_job(self): - if self.incar_str=='': - raise ValueError("Incar must be created first. Call create_incar") - if self.poscar_str=='': - raise ValueError("Poscar must be created first. Call create_poscar") - if self.potcar_str=='': - raise ValueError("Potcar must be created first. Call create_potcar") - if self.kpoints_str=='': - raise ValueError("Kpoints must be created first. Call create_kpoints") - if self.job_scheduler_str=='': - raise ValueError("Job scheduler script must be created first. Call create_job_scheduler_script") - - - with open(os.path.join(self.calc_dir,'INCAR'),'w') as incar: - incar.write(self.incar_str) - - with open(os.path.join(self.calc_dir,'POSCAR'),'w') as poscar: - poscar.write(self.poscar_str) - - with open(os.path.join(self.calc_dir,'POTCAR'),'w') as potcar: - potcar.write(self.potcar_str) - - with open(os.path.join(self.calc_dir,'KPOINTS'),'w') as kpoints: - kpoints.write(self.kpoints_str) - - with open(os.path.join(self.calc_dir,'job_submit.run'),'w') as job_control: - job_control.write(self.job_scheduler_str) - - with open(os.path.join(self.calc_dir,'job_control.txt'),'w') as job_control: - job_control.write(self.job_control_str) - - - - -if __name__=='__main__': - import pyarrow.parquet as pq - import pyarrow as pa - - # calc_manager=CalculationManager() - pseudos_dir=os.path.join('data','PP_Vasp') - - materials_parquet=os.path.join('data','production','materials_project','materials_database.parquet') - data_dir=os.path.join('data','raw','test_dir') - - - - - table = pq.read_table(materials_parquet, columns=['material_id','lattice','frac_coords','species']) - df = table.to_pandas() - index=1 - lattice=np.array(list(df.iloc[index]['lattice'])) - frac_coords=np.array(list(df.iloc[index]['frac_coords'])) - species=list(df.iloc[index]['species']) - - print(df.head()) - # generator=VaspCalcGenerator(species=species, frac_coords=frac_coords, lattice=lattice, - # calc_dir=os.path.join(data_dir,'mp-1000'), - # vasp_pseudos_dir=pseudos_dir) - - - # kpoints=Kpoints.gamma_automatic(kpts=(9,9,9), shift=(0,0,0)) - # generator.create_kpoints(kpoints=kpoints) - - # vasp_parameters = { - # "EDIFF": 1e-08, - # "ENCUT": 600, - # "IBRION": -1, - # "ISMEAR": -5, - # "ISPIN": 1, - # "ISTART": 0, - # "LASPH": True, - # "LMAXMIN": 4, - # "LORBIT": 11, - # "LREAL": False, - # "NELM": 100, - # "NELMIN": 8, - # "NSIM": 2, - # "NSW": 0, - # "LCHARG": True, - # "LAECHG": True, - # "LWAVE": False, - # "PREC": "Accurate", - # "SIGMA": 0.01, - # "NWRITE": 3 - # } - - # generator.create_incar(incar_args=vasp_parameters) - - # generator.create_poscar() - # generator.create_potcar() - # generator.create_job_scheduler_script() - - # generator.write_job() - - - - generator=ChargemolCalcGenerator(species=species, frac_coords=frac_coords, lattice=lattice, - calc_dir=os.path.join(data_dir,'mp-1000','chargemol'), - vasp_pseudos_dir=pseudos_dir) - - - kpoints=Kpoints.gamma_automatic(kpts=(9,9,9), shift=(0,0,0)) - generator.create_kpoints(kpoints=kpoints) - - vasp_parameters = { - "EDIFF": 1e-08, - "ENCUT": 600, - "IBRION": -1, - "ISMEAR": -5, - "ISPIN": 1, - "ISTART": 0, - "LASPH": True, - "LMAXMIN": 4, - "LORBIT": 11, - "LREAL": False, - "NELM": 100, - "NELMIN": 8, - "NSIM": 2, - "NSW": 0, - "LCHARG": True, - "LAECHG": True, - "LWAVE": False, - "PREC": "Accurate", - "SIGMA": 0.01, - "NWRITE": 3 - } - - generator.create_incar(incar_args=vasp_parameters) - - generator.create_poscar() - generator.create_potcar() - - atomic_densities_dir="/users/lllang/SCRATCH/Codes/chargemol_09_26_2017/atomic_densities/" - generator.create_job_control_script(atomic_densities_dir=atomic_densities_dir) - generator.create_job_scheduler_script() - - generator.write_job() - - - - - - diff --git a/matgraphdb/calculations/dft_calcs/chargemol_calc_setup.py b/matgraphdb/calculations/dft_calcs/chargemol_calc_setup.py deleted file mode 100644 index d6cfe28..0000000 --- a/matgraphdb/calculations/dft_calcs/chargemol_calc_setup.py +++ /dev/null @@ -1,142 +0,0 @@ -import os -import json -import shutil -from glob import glob - -from matgraphdb.utils import MP_DIR,DB_CALC_DIR - -def generate_batch_scripts(calc_dirs,calc_file_dir): - for calc_dir in calc_dirs[:]: - scf_dir=os.path.join(calc_dir,'static') - - if os.path.exists(os.path.join(scf_dir,'run.slurm')): - os.remove(os.path.join(scf_dir,'run.slurm')) - - try: - shutil.copy(os.path.join(calc_file_dir,'run.slurm'), os.path.join(scf_dir,'run.slurm')) - except: - pass - -def generate_potcar(calc_dirs,pseudos_dir): - - for calc_dir in calc_dirs[:]: - scf_dir=os.path.join(calc_dir,'static') - potcar_file=os.path.join(scf_dir,'POTCAR') - incomplete_dir=os.path.dirname(os.path.dirname(calc_dir)) - incomplete_dir=os.path.join(incomplete_dir,'incomplete_database') - - # # Remove pre-existing potcar file - if os.path.exists(potcar_file): - os.remove(potcar_file) - - try: - with open(os.path.join(scf_dir,'POTCAR_files.json'),'r') as f: - data = json.load(f) - functional=data['functional'] - symbols=data['symbols'] - tmp_potcar='' - - # Loop through element symbols - for symbol in symbols: - pseudo_file=os.path.join(pseudos_dir,symbol,'POTCAR') - - # open pseudo file and add it to potcar string - if symbol=='Zr_sv': - with open(pseudo_file,'r') as f: - lines=f.readlines() - new_line=lines[3].replace('r','Zr') - lines[3]=new_line - tmp_text=''.join(lines) - tmp_potcar+=tmp_text - tmp_potcar+='' - else: - with open(pseudo_file,'r') as f: - tmp_potcar+=f.read() - tmp_potcar+='' - - # Save file in new POTCAR file - with open(potcar_file,'w') as potcar: - potcar.write(tmp_potcar) - - except Exception as e: - shutil.move(calc_dir, incomplete_dir) - print(e) - pass - -def generate_incar(calc_dirs,calc_file_dir): - template_incar_file=os.path.join(calc_file_dir,'INCAR') - - for calc_dir in calc_dirs[:]: - scf_dir=os.path.join(calc_dir,'static') - incar_file=os.path.join(scf_dir,'INCAR') - - - # Rename existing INCAR file to INCAR_old - incar_old_file=os.path.join(scf_dir,'INCAR_old') - if os.path.exists(incar_file): - os.rename(incar_file, incar_old_file) - - # Copy template INCAR file to calc dir - try: - shutil.copy(template_incar_file, incar_file) - except: - pass - - -def generate_kpoints(calc_dirs,calc_file_dir): - template_incar_file=os.path.join(calc_file_dir,'KPOINTS') - - for calc_dir in calc_dirs[:]: - scf_dir=os.path.join(calc_dir,'static') - incar_file=os.path.join(scf_dir,'KPOINTS') - - # Rename existing INCAR file to INCAR_old - incar_old_file=os.path.join(scf_dir,'KPOINTS_old') - if os.path.exists(incar_file): - os.rename(incar_file, incar_old_file) - - # Copy template INCAR file to calc dir - try: - shutil.copy(template_incar_file, incar_file) - except: - pass - -def generate_job_control(calc_dirs,calc_file_dir): - template_file=os.path.join(calc_file_dir,'job_control.txt') - - for calc_dir in calc_dirs[:]: - scf_dir=os.path.join(calc_dir,'static') - copy_file=os.path.join(scf_dir,'job_control.txt') - - # # Remove pre-existing potcar file - if os.path.exists(copy_file): - os.remove(copy_file) - - try: - shutil.copy(template_file, copy_file) - except: - pass - - - -def chargemol_calc_setup(): - calc_file_dir=os.path.join(MP_DIR,'calculations','calculation_files','chargemol') - calc_dirs=glob(DB_CALC_DIR + '/mp-*') - - pseudos_dir=os.path.join("/users/lllang/SCRATCH",'PP_Vasp','potpaw_PBE.52') - - calc_dir=os.path.join(MP_DIR,'calculations','database') - print(len(os.listdir(calc_dir))) - # generate_potcar(calc_dirs,pseudos_dir) - - # generate_incar(calc_dirs,calc_file_dir) - - # generate_batch_scripts(calc_dirs,calc_file_dir) - - # generate_kpoints(calc_dirs,calc_file_dir) - - # generate_job_control(calc_dirs,calc_file_dir) - - -if __name__=='__main__': - chargemol_calc_setup() \ No newline at end of file diff --git a/matgraphdb/calculations/dft_calcs/generate_calc_dir.py b/matgraphdb/calculations/dft_calcs/generate_calc_dir.py deleted file mode 100644 index cf0262b..0000000 --- a/matgraphdb/calculations/dft_calcs/generate_calc_dir.py +++ /dev/null @@ -1,112 +0,0 @@ - - -import os -from glob import glob -import json - -from pymatgen.core import Structure -from pymatgen.io.vasp import Poscar - -from matgraphdb.utils import DB_DIR, DB_CALC_DIR -from matgraphdb.data.utils import process_database - -def generate_calc_dir_task(file): - with open(file, 'r') as f: - data = json.load(f) - - # Get mpid name - mpid = data['material_id'] - - # Create the Structure object - structure = Structure.from_dict(data['structure']) - - # Create calc directory - calc_dir=os.path.join(DB_CALC_DIR,mpid) - os.makedirs(calc_dir, exist_ok=True) - - - # Write the structure to a POSCAR file - poscar_file=os.path.join(calc_dir,'POSCAR') - poscar = Poscar(structure) - poscar.write_file(poscar_file) - - return None - -def generate_calc_dir(): - results=process_database(generate_calc_dir_task) - - - -def generate_potcars_task(file): - mpid=file.split(os.sep)[-1].split('.')[0] - - # Create calc directory - calc_dir=os.path.join(DB_CALC_DIR,mpid) - poscar_file=os.path.join(calc_dir,'POSCAR') - - # Create local potcar directory to store POTCAR. - #This is if we have to switch pseudopotentials in future - potcar_dir=os.path.join(calc_dir,'potcar') - os.makedirs(potcar_dir, exist_ok=True) - - # Get element symbols from POSCAR - with open(poscar_file) as f: - lines=f.readlines() - elements=lines[5].split() - - # Create POTCAR files - pseudos_dir=os.path.join("/users/lllang/SCRATCH",'PP_Vasp','potpaw_PBE.52') - tmp_potcar='' - for symbol in elements: - pseudo_file=os.path.join(pseudos_dir,symbol,'POTCAR') - - if not os.path.exists(pseudo_file): - pseudo_file=os.path.join(pseudos_dir,symbol+'_sv','POTCAR') - # open pseudo file and add it to potcar string - if symbol=='Zr_sv': - with open(pseudo_file,'r') as f: - lines=f.readlines() - new_line=lines[3].replace('r','Zr') - lines[3]=new_line - tmp_text=''.join(lines) - tmp_potcar+=tmp_text - tmp_potcar+='' - else: - with open(pseudo_file,'r') as f: - tmp_potcar+=f.read() - tmp_potcar+='' - - # Save file in new POTCAR file - potcar_file=os.path.join(potcar_dir,'POTCAR_PBE') - with open(potcar_file,'w') as potcar: - potcar.write(tmp_potcar) - - return None - -def generate_potcars(): - results=process_database(generate_potcars_task) - - -def main(): - - database_files=glob(DB_DIR + '/*.json') - - print('#'*100) - print('Generating Calculation dir') - print('#'*100) - - generate_calc_dir() - - - print('#'*100) - print('Generating potcar dir') - print('#'*100) - - - calc_dirs=glob(DB_CALC_DIR + '/*') - print(calc_dirs[:10]) - generate_potcars() - - -if __name__=='__main__': - main() diff --git a/matgraphdb/calculations/dft_calcs/generate_chargemol.py b/matgraphdb/calculations/dft_calcs/generate_chargemol.py deleted file mode 100644 index 747c229..0000000 --- a/matgraphdb/calculations/dft_calcs/generate_chargemol.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import json -import shutil -from glob import glob - -from matgraphdb.utils import MP_DIR,DB_CALC_DIR - - -def generate_batch_script(calc_dir,chargemol_file_dir): - if os.path.exists(os.path.join(calc_dir,'run.slurm')): - os.remove(os.path.join(calc_dir,'run.slurm')) - - shutil.copy(os.path.join(chargemol_file_dir,'run.slurm'), os.path.join(calc_dir,'run.slurm')) - -def generate_incar(calc_dir,chargemol_file_dir): - template_file=os.path.join(chargemol_file_dir,'INCAR') - file=os.path.join(calc_dir,'INCAR') - - shutil.copy(template_file, file) - -def generate_kpoints(calc_dir,chargemol_file_dir): - template_file=os.path.join(chargemol_file_dir,'KPOINTS') - file=os.path.join(calc_dir,'KPOINTS') - - shutil.copy(template_file, file) - -def generate_job_control(calc_dir,chargemol_file_dir): - template_file=os.path.join(chargemol_file_dir,'job_control.txt') - file=os.path.join(calc_dir,'job_control.txt') - - shutil.copy(template_file, file) - -def generate_potcar(calc_dir,potcar_dir): - template_file=os.path.join(potcar_dir,'POTCAR_PBE') - file=os.path.join(calc_dir,'POTCAR') - - shutil.copy(template_file, file) - -def generate_poscar(calc_dir,poscar): - template_file=poscar - file=os.path.join(calc_dir,'POSCAR') - - shutil.copy(template_file, file) - - -def chargemol_calc_setup(): - chargemol_file_dir=os.path.join(MP_DIR,'calculations','calculation_files','chargemol') - calc_dirs=glob(DB_CALC_DIR + '/*') - - # print(calc_dirs[:10]) - # chargemol - for calc_dir in calc_dirs: - # Create the calculation directory - potcar_dir=os.path.join(calc_dir,'potcar') - chargemol_dir=os.path.join(calc_dir,'chargemol') - poscar=os.path.join(calc_dir,'POSCAR') - os.makedirs(chargemol_dir,exist_ok=True) - - generate_incar(chargemol_dir,chargemol_file_dir) - generate_kpoints(chargemol_dir,chargemol_file_dir) - generate_potcar(chargemol_dir,potcar_dir) - generate_poscar(chargemol_dir,poscar) - generate_job_control(chargemol_dir,chargemol_file_dir) - generate_batch_script(chargemol_dir,chargemol_file_dir) - - - - - - -if __name__=='__main__': - chargemol_calc_setup() \ No newline at end of file diff --git a/matgraphdb/calculations/dft_calcs/remove_wavecar.py b/matgraphdb/calculations/dft_calcs/remove_wavecar.py deleted file mode 100644 index 4b6a89a..0000000 --- a/matgraphdb/calculations/dft_calcs/remove_wavecar.py +++ /dev/null @@ -1,23 +0,0 @@ - -import os -import json -import re -import shutil -from glob import glob - -scratch_dir='/users/lllang/SCRATCH' -root_dir=os.path.join(scratch_dir,'projects','crystal_graph') -database_dir=os.path.join(root_dir,'data','raw','mp_database_calcs_no_restriction') -calc_dirs=glob(database_dir + '/mp-*') - - -for calc_dir in calc_dirs[:]: - print(calc_dir) - scf_dir=os.path.join(calc_dir,'static') - wavecar_file=os.path.join(scf_dir,'WAVECAR') - - try: - - os.remove(wavecar_file) - except: - pass \ No newline at end of file diff --git a/matgraphdb/calculations/job_scheduler_generator.py b/matgraphdb/calculations/job_scheduler_generator.py deleted file mode 100644 index 660bc62..0000000 --- a/matgraphdb/calculations/job_scheduler_generator.py +++ /dev/null @@ -1,76 +0,0 @@ -class JobSchedulerGenerator: - def init_header(self): - raise NotImplementedError("init_header must be implemented in the child class") - def finalize(self): - raise NotImplementedError("finalize must be implemented in the child class") - -class SlurmScriptGenerator(): - - def __init__(self, - job_name='mp_database_job', - partition='comm_small_day', - time='24:00:00'): - - self.job_name=job_name - - self.partition=partition - self.time=time - - - self.add_comp_resources=False - - self.slurm_header='' - self.slurm_body='' - self.slurm_script='' - - def add_slurm_header_argument(self, argument:str): - if self.slurm_header=='': - raise ValueError("Slurm header is empty. Call init_slurm_header first") - self.slurm_header+= argument + '\n' - - def add_slurm_script_body(self, command:str): - self.slurm_body+=command + '\n' - return self.slurm_body - - def init_header(self): - self.slurm_header+='#!/bin/bash\n' - self.slurm_header+=f'#SBATCH -J {self.job_name}\n' - self.slurm_header+=f'#SBATCH -p {self.partition}\n' - self.slurm_header+=f'#SBATCH -t {self.time}\n' - return self.slurm_header - - def add_slurm_header_comp_resources(self, n_nodes=None, n_tasks=None, cpus_per_task=None): - self.add_comp_resources=True - self.n_tasks=n_tasks - self.cpus_per_task=cpus_per_task - self.n_nodes=n_nodes - command='' - if n_nodes is not None: - command+=f'#SBATCH --nodes={self.n_nodes}\n' - if n_tasks is not None: - command+=f'#SBATCH -n {self.n_tasks}\n' - if cpus_per_task is not None: - command+=f'#SBATCH -c {self.cpus_per_task}\n' - - self.slurm_header+=command - return self.slurm_header - - def exclude_nodes(self, node_names=[]): - if self.slurm_header=='': - raise ValueError("Slurm header is empty. Call init_slurm_header first") - node_list_string= ','.join(node_names) - command=f'#SBATCH --exclude={node_list_string}\n' - self.slurm_header+=command - return command - - def finalize(self): - if self.slurm_header=='': - raise ValueError("Slurm header is empty. Call init_slurm_header first") - if not self.add_comp_resources: - raise ValueError("Add computational resources before finalizing. add_slurm_header_comp_resources") - if self.slurm_body=='': - raise ValueError("Slurm body is empty. Call add_slurm_script first") - - self.slurm_script=self.slurm_header + '\n' + self.slurm_body - return self.slurm_script - \ No newline at end of file diff --git a/matgraphdb/calculations/mat_calcs/__init__.py b/matgraphdb/calculations/mat_calcs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/calculations/mat_calcs/bond_stats_calc.py b/matgraphdb/calculations/mat_calcs/bond_stats_calc.py deleted file mode 100644 index 53234ed..0000000 --- a/matgraphdb/calculations/mat_calcs/bond_stats_calc.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np - -from matgraphdb.utils import LOGGER -from matgraphdb.utils.chem_utils.periodic import atomic_symbols - - -def calculate_bond_orders_sum(bond_orders,bond_connections, site_element_names): - """ - Calculates the sum and count of bond orders for a given file. - - Args: - file (str): The path to the JSON file containing the database. - - Returns: - tuple: A tuple containing two numpy arrays. The first array represents the sum of bond orders - between different elements, and the second array represents the count of bond orders. - - Raises: - Exception: If there is an error processing the file. - """ - # List of element names from pymatgen's Element - ELEMENTS = atomic_symbols[1:] - # Initialize arrays for bond order calculations - n_elements = len(ELEMENTS) - n_bond_orders = np.zeros(shape=(n_elements, n_elements)) - bond_orders_sum = np.zeros(shape=(n_elements, n_elements)) - - try: - # First iteration: calculate sum and count of bond orders - for isite, site in enumerate(bond_orders): - site_element = site_element_names[isite] - neighbors = bond_connections[isite] - - for jsite in neighbors: - neighbor_site_element = site_element_names[jsite] - bond_order = bond_orders[isite][jsite] - - i_element = ELEMENTS.index(site_element) - j_element = ELEMENTS.index(neighbor_site_element) - - bond_orders_sum[i_element, j_element] += bond_order - n_bond_orders[i_element, j_element] += 1 - - except Exception as e: - LOGGER.error(f"Error processing file {e}") - - return bond_orders_sum, n_bond_orders - -def calculate_bond_orders_sum_squared_differences(bond_orders,bond_connections, site_element_names, bond_orders_avg, n_bond_orders): - """ - Calculate the standard deviation of bond orders for a given material. - - Parameters: - bond_orders (numpy.ndarray): The bond orders between different elements for a material. - bond_connections (numpy.ndarray): The bond connections between different elements for a material. - site_element_names (list): The names of the elements in the structure for a material. - bond_orders_avg (numpy.ndarray): The average bond orders between different elements in the material database - n_bond_orders (numpy.ndarray): The count of bond orders between different elements in the material database. - - Returns: - bond_orders_std (numpy.ndarray): The standard deviation of bond orders between different elements. - """ - # List of element names from pymatgen's Element - ELEMENTS = atomic_symbols[1:] - # Initialize arrays for bond order calculations - n_elements = len(n_bond_orders) - bond_orders_std = np.zeros(shape=(n_elements, n_elements)) - try: - # First iteration: calculate sum and count of bond orders - for isite, site in enumerate(bond_orders): - site_element = site_element_names[isite] - neighbors = bond_connections[isite] - - for jsite in neighbors: - neighbor_site_element = site_element_names[jsite] - bond_order = bond_orders[isite][jsite] - - i_element = ELEMENTS.index(site_element) - j_element = ELEMENTS.index(neighbor_site_element) - - bond_order_avg = bond_orders_avg[i_element, j_element] - bond_orders_std[i_element, j_element] += (bond_order - bond_order_avg) ** 2 - - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - - return bond_orders_std - diff --git a/matgraphdb/calculations/mat_calcs/bonding_calc.py b/matgraphdb/calculations/mat_calcs/bonding_calc.py deleted file mode 100644 index e7bff30..0000000 --- a/matgraphdb/calculations/mat_calcs/bonding_calc.py +++ /dev/null @@ -1,169 +0,0 @@ - -from pymatgen.analysis.local_env import CutOffDictNN - -from matgraphdb.utils.chem_utils.periodic import covalent_cutoff_map -from matgraphdb.utils import LOGGER - -def calculate_geometric_electric_consistent_bonds(geo_coord_connections,elec_coord_connections, bond_orders, threshold=0.1): - """ - Adjusts the electric bond orders and connections to be consistent with the geometric bond connections above a given threshold. - - Args: - geo_coord_connections (list): List of geometric bond connections. - elec_coord_connections (list): List of electric bond connections. - bond_orders (list): List of bond orders. - threshold (float, optional): Threshold for bond orders. Defaults to 0.1. - - Returns: - tuple: A tuple containing the adjusted electric bond connections and bond orders. - - """ - try: - final_connections=[] - final_bond_orders=[] - - for elec_site_connections,geo_site_connections, site_bond_orders in zip(elec_coord_connections,geo_coord_connections,bond_orders): - - # Determine most likely electric bonds - elec_reduced_bond_indices = [i for i,order in enumerate(site_bond_orders) if order > threshold] - n_elec_bonds=len(elec_reduced_bond_indices) - n_geo_bonds=len(geo_site_connections) - - # If there is only one geometric bond and one or less electric bonds, then we can use the electric bond orders and connections as is - if n_geo_bonds == 1 and n_elec_bonds <= 1: - reduced_bond_orders=site_bond_orders - reduced_elec_site_connections=elec_site_connections - - # Else if there is only one geometric bond and more than 1 electric bonds, then we can use the electric reduced bond orders and connections as is - elif n_geo_bonds == 1 and n_elec_bonds > 1: - reduced_elec_site_connections = [elec_site_connections[i] for i in elec_reduced_bond_indices] - reduced_bond_orders = [site_bond_orders[i] for i in elec_reduced_bond_indices] - - # If there are more than one geometric bonds, then we need to sort the bond orders and elec connections by the total number of geometric connections - # Geometric bonds and electric bonds should have a correspondence with each other - else: - geo_reduced_bond_order_indices = sorted(range(len(site_bond_orders)), key=lambda i: site_bond_orders[i], reverse=True)[:n_geo_bonds] - - geo_reduced_bond_orders = [site_bond_orders[i] for i in geo_reduced_bond_order_indices] - geo_reduced_elec_site_connections = [elec_site_connections[i] for i in geo_reduced_bond_order_indices] - - # I take only bond orders greater than 0.1 because geometric connection alone can be wrong sometimes. For example in the case of oxygen. - geo_elec_reduced_bond_indices = [i for i,order in enumerate(geo_reduced_bond_orders) if order > 0.1] - - reduced_elec_site_connections = [geo_reduced_elec_site_connections[i] for i in geo_elec_reduced_bond_indices] - reduced_bond_orders = [geo_reduced_bond_orders[i] for i in geo_elec_reduced_bond_indices] - - final_site_connections=reduced_elec_site_connections - final_site_bond_orders=reduced_bond_orders - - final_connections.append(final_site_connections) - final_bond_orders.append(final_site_bond_orders) - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - final_connections=None - final_bond_orders=None - - return final_connections, final_bond_orders - -def calculate_electric_consistent_bonds(elec_coord_connections, bond_orders, threshold=0.1): - """ - Calculates the electric consistent bonds for a given set of electric bond connections and bond orders above a given threshold. - - Args: - elec_coord_connections (list): List of electric bond connections. - bond_orders (list): List of bond orders. - threshold (float, optional): Threshold for bond orders. Defaults to 0.1. - - Returns: - tuple: A tuple containing the adjusted electric bond connections and bond orders. - - """ - try: - final_connections=[] - final_bond_orders=[] - - for elec_site_connections, site_bond_orders in zip(elec_coord_connections,bond_orders): - - # Determine most likely electric bonds - elec_reduced_bond_indices = [i for i,order in enumerate(site_bond_orders) if order > threshold] - n_elec_bonds=len(elec_reduced_bond_indices) - - reduced_elec_site_connections = [elec_site_connections[i] for i in elec_reduced_bond_indices] - reduced_bond_orders = [site_bond_orders[i] for i in elec_reduced_bond_indices] - - final_site_connections=reduced_elec_site_connections - final_site_bond_orders=reduced_bond_orders - - final_connections.append(final_site_connections) - final_bond_orders.append(final_site_bond_orders) - - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - final_connections=None - final_bond_orders=None - - return final_connections, final_bond_orders - -def calculate_geometric_consistent_bonds(geo_coord_connections,elec_coord_connections, bond_orders): - """ - Adjusts the electric bond orders and connections to be consistent with the geometric bond connections. - - Args: - geo_coord_connections (list): List of geometric bond connections. - elec_coord_connections (list): List of electric bond connections. - bond_orders (list): List of bond orders. - - Returns: - tuple: A tuple containing the adjusted electric bond connections and bond orders. - - """ - try: - final_connections=[] - final_bond_orders=[] - - for geo_site_connections,elec_site_connections, site_bond_orders in zip(geo_coord_connections,elec_coord_connections,bond_orders): - - # Orders the electric bond orders by magnitudes up to the total amount of geometric bonds - n_geo_bonds=len(geo_site_connections) - geo_reduced_bond_order_indices = sorted(range(len(site_bond_orders)), key=lambda i: site_bond_orders[i], reverse=True)[:n_geo_bonds] - - # Reduces the electric bond orders and reduces the number of electric bond connections to the number of geometric bonds - geo_reduced_bond_orders = [site_bond_orders[i] for i in geo_reduced_bond_order_indices] - reduced_elec_site_connections = [elec_site_connections[i] for i in geo_reduced_bond_order_indices] - - final_site_connections=reduced_elec_site_connections - final_site_bond_orders=geo_reduced_bond_orders - - final_connections.append(final_site_connections) - final_bond_orders.append(final_site_bond_orders) - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - final_connections=None - final_bond_orders=None - return final_connections, final_bond_orders - -def calculate_cutoff_bonds(structure): - """ - Calculates the cutoff bonds for a given crystal structure. - - Args: - structure (Structure): The crystal structure for which to calculate the cutoff bonds. - - Returns: - list: A list of lists, where each inner list contains the indices of the nearest neighbors for each site in the structure. - """ - try: - CUTOFF_DICT = covalent_cutoff_map(tol=0.1) - cutoff_nn = CutOffDictNN(cut_off_dict=CUTOFF_DICT) - all_nn = cutoff_nn.get_all_nn_info(structure=structure) - nearest_neighbors = [] - for site_nn in all_nn: - neighbor_index = [] - for nn in site_nn: - index = nn['site_index'] - neighbor_index.append(index) - nearest_neighbors.append(neighbor_index) - except Exception as e: - LOGGER.error(f"Error calculating cutoff bonds: {e}") - nearest_neighbors = None - return nearest_neighbors diff --git a/matgraphdb/calculations/mat_calcs/chemenv_calc.py b/matgraphdb/calculations/mat_calcs/chemenv_calc.py deleted file mode 100644 index f5217f6..0000000 --- a/matgraphdb/calculations/mat_calcs/chemenv_calc.py +++ /dev/null @@ -1,73 +0,0 @@ -import copy -import logging - -from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import MultiWeightsChemenvStrategy -from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder -from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments - -logger = logging.getLogger(__name__) - - -def calculate_chemenv_connections(structure): - """ - Calculate the coordination environments, nearest neighbors, and coordination numbers for a given structure. - - Parameters: - structure (Structure): The input structure for which the chemenv calculations are performed. - - Returns: - tuple: A tuple containing the coordination environments, nearest neighbors, and coordination numbers. - - coordination_environments (list): A list of dictionaries representing the coordination environments for each site. - - nearest_neighbors (list): A list of lists containing the indices of the nearest neighbors for each site. - - coordination_numbers (list): A list of coordination numbers for each site. - """ - - try: - error=None - lgf = LocalGeometryFinder() - lgf.setup_structure(structure=structure) - - # Compute the structure environments - se = lgf.compute_structure_environments(maximum_distance_factor=1.41, only_cations=False) - - # Define the strategy for environment calculation - strategy = MultiWeightsChemenvStrategy.stats_article_weights_parameters() - lse = LightStructureEnvironments.from_structure_environments(strategy=strategy, structure_environments=se) - - # Get a list of possible coordination environments per site - coordination_environments = copy.copy(lse.coordination_environments) - - # Replace empty environments with default value - for i, env in enumerate(lse.coordination_environments): - if env is None or env==[]: - coordination_environments[i] = [{'ce_symbol': 'S:1', 'ce_fraction': 1.0, 'csm': 0.0, 'permutation': [0]}] - - # Calculate coordination numbers - coordination_numbers = [] - for env in coordination_environments: - if env is None: - coordination_numbers.append('NaN') - else: - coordination_numbers.append(int(env[0]['ce_symbol'].split(':')[-1])) - - # Determine nearest neighbors - nearest_neighbors = [] - for i_site, neighbors in enumerate(lse.neighbors_sets): - neighbor_index = [] - if neighbors!=[]: - neighbors = neighbors[0] - for neighbor_site in neighbors.neighb_sites_and_indices: - index = neighbor_site['index'] - neighbor_index.append(index) - nearest_neighbors.append(neighbor_index) - else: - pass - - except Exception as error: - logger.error(f"Error processing file: {error}") - coordination_environments = None - nearest_neighbors = None - coordination_numbers = None - - - return coordination_environments, nearest_neighbors, coordination_numbers \ No newline at end of file diff --git a/matgraphdb/calculations/mat_calcs/embeddings.py b/matgraphdb/calculations/mat_calcs/embeddings.py deleted file mode 100644 index d93e6cc..0000000 --- a/matgraphdb/calculations/mat_calcs/embeddings.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import json - - -import pandas as pd -import numpy as np -from matminer.datasets import load_dataset -from matminer.featurizers.base import MultipleFeaturizer -from matminer.featurizers.structure import XRDPowderPattern,SineCoulombMatrix,CoulombMatrix -from matminer.featurizers.composition import ElementFraction, ElementProperty - -from pymatgen.core import Structure - - - - -# def num_tokens_from_string(string: str, encoding_name: str) -> int: -# """ -# Calculates the number of tokens in a given string using the specified encoding. - -# Args: -# string (str): The input string. -# encoding_name (str): The name of the encoding to use. - -# Returns: -# int: The number of tokens in the string. -# """ -# encoding = tiktoken.get_encoding(encoding_name) -# num_tokens = len(encoding.encode(string)) -# return num_tokens - - -# def get_embedding(text, client, model="text-embedding-3-small"): -# """ -# Get the embedding for a given text using OpenAI's text-embedding API. - -# Parameters: -# text (str): The input text to be embedded. -# client: The OpenAI client object used to make API requests. -# model (str): The name of the model to use for embedding. Default is "text-embedding-3-small". - -# Returns: -# list: The embedding vector for the input text. -# """ -# text = text.replace("\n", " ") -# return client.embeddings.create(input=[text], model=model).data[0].embedding - -# def extract_text_from_json(json_file): -# """ -# Extracts specific text data from a JSON file and returns it as a compact JSON string. - -# Args: -# json_file (str): The path to the JSON file. - -# Returns: -# str: A compact JSON string containing the extracted text data. -# """ -# import json -# from matgraphdb.graph.node_types import PROPERTIES -# PROPERTY_NAMES = [prop[0] for prop in PROPERTIES] -# # Extract text from json file -# with open(json_file, 'r') as f: -# data = json.load(f) - -# emd_dict = {} -# for key in data.keys(): -# if key in PROPERTY_NAMES: -# emd_dict[key] = data[key] -# elif key == 'structure': -# emd_dict['lattice'] = data[key]['lattice'] - -# compact_json_text = json.dumps(emd_dict, separators=(',', ':')) -# return compact_json_text - -# def generate_openai_embeddings( -# materials_text, -# material_ids, -# model="text-embedding-3-small", -# embedding_encoding = "cl100k_base" -# ): -# """ -# Main function for processing database and generating embeddings using OpenAI models. - -# This function performs the following steps: -# 1. Sets up the parameters for the models and cost per token. -# 2. Initializes the OpenAI client using the API key. -# 3. Processes the database and extracts raw JSON text. -# 4. Calculates the total number of tokens and the cost. -# 5. Retrieves the mp_ids from the database directory. -# 6. Creates a dataframe of the results and adds the ada_embedding column. -# 7. Creates a dataframe of the embeddings. -# 8. Saves the embeddings to a CSV file. - -# Args: -# materials_text (list): A list of materials represented as text. -# material_ids (list): A list of material IDs. -# model (str): The name of the OpenAI model to use for embedding. Default is "text-embedding-3-small". -# Possible values are "text-embedding-3-small", "text-embedding-3-large", and "ada v2". -# embedding_encoding (str): The name of the encoding to use for embedding. Default is "cl100k_base". - -# Returns: -# None -# """ - -# models_cost_per_token={ -# "text-embedding-3-small":0.00000002, -# "text-embedding-3-large":0.00000013, -# "ada v2":0.00000010 -# } - -# cost_per_token=models_cost_per_token[model] - - -# client = openai.OpenAI(api_key=OPENAI_API_KEY) - -# # Calculate the total number of tokens and the cost -# token_count=0 -# for material_text in materials_text: -# token_count+=num_tokens_from_string(material_text,encoding_name=embedding_encoding) -# LOGGER.info(f"Total number of tokens: {token_count}") -# LOGGER.info(f"Cost per token: {cost_per_token}") - -# # put reselts into a dataframe under the column 'combined' -# df = pd.DataFrame(materials_text, columns=['combined'], index=material_ids) -# df['embedding'] = df.combined.apply(lambda x: get_embedding(x,client, model=model)) - -# # Create a dataframe of the embeddings. emb dim span columns size 1535 -# tmp_dict={ -# "ada_embedding":df['embedding'].tolist(), -# } -# df_embeddings = pd.DataFrame(tmp_dict, index=material_ids) -# # df_embeddings = pd.DataFrame(np.array(df['ada_embedding'].tolist()), index=material_ids) - -# return df_embeddings - - - -def generate_composition_embeddings(compositions,material_ids): - """Generate composition embeddings using Matminer and OpenAI. - - Args: - compositions (list): A list of compositions. - material_ids (list): A list of material IDs. - - Returns: - None - """ - composition_data = pd.DataFrame({'composition': compositions}, index=material_ids) - composition_featurizer = MultipleFeaturizer([ElementFraction()]) - composition_features = composition_featurizer.featurize_dataframe(composition_data,"composition") - composition_features=composition_features.drop(columns=['composition']) - features=composition_features - return features - - - -def generate_matminer_embeddings(structures,material_ids,features=[]): - """Generate structures embeddings using Matminer and OpenAI. - - Args: - structures (list): A list of structures. - material_ids (list): A list of material IDs. - - Returns: - None - """ - allowed_features=['sine_coulomb_matrix','element_fraction','element_property','xrd_pattern'] - for feature in features: - if feature not in allowed_features: - raise ValueError(f"Feature {feature} not allowed") - - compositions=[structure.composition for structure in structures] - data = pd.DataFrame({'structure': structures, 'composition':compositions}, index=material_ids) - - composition_featurizers=[] - if 'element_fraction' in features: - composition_featurizers.append(ElementFraction()) - if 'element_property' in features: - composition_featurizers.append(ElementProperty.from_preset(preset_name="magpie")) - - if len(composition_featurizers)>0: - composition_featurizer = MultipleFeaturizer(composition_featurizers) - data = composition_featurizer.featurize_dataframe(data,"composition",pbar=False) - - structure_featurizers=[] - if 'xrd_pattern' in features: - structure_featurizers.append(XRDPowderPattern()) - if 'sine_coulomb_matrix' in features: - columb_matrix=SineCoulombMatrix() - columb_matrix.fit(structures) - structure_featurizers.append(columb_matrix) - - if len(structure_featurizers)>0: - structure_featurizer=MultipleFeaturizer(structure_featurizers) - data = structure_featurizer.featurize_dataframe(data,"structure",pbar=False) - - data=data.drop(columns=['structure','composition']) - return data - - -if __name__=='__main__': - from matgraphdb.data.manager import DBManager - from pymatgen.core import Structure - import itertools - - - data=DBManager().load_json('mp-1000.json') - features=['sine_coulomb_matrix','element_fraction','element_property','xrd_pattern'] - - feature_sets=[ - ['sine_coulomb_matrix'], - ['xrd_pattern'], - ['element_fraction'], - ['element_property'], - ['sine_coulomb_matrix','element_property'], - ['sine_coulomb_matrix','element_fraction'], - ['element_property','element_fraction'], - ['sine_coulomb_matrix','element_property','element_fraction'], - ] - structures = [Structure.from_dict(data['structure'])] - print(structures) - materials_ids=['mp-1000'] - # features=generate_matminer_embeddings(structures,materials_ids,features=['sine_coulomb_matrix','xrd_pattern']) - # print(features) - - for feature_set in feature_sets: - print(feature_set) - features=generate_matminer_embeddings(structures,materials_ids,features=feature_set) - print(features) - - - diff --git a/matgraphdb/calculations/mat_calcs/similarity_calc.py b/matgraphdb/calculations/mat_calcs/similarity_calc.py deleted file mode 100644 index 17e0a76..0000000 --- a/matgraphdb/calculations/mat_calcs/similarity_calc.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import json -from glob import glob -from multiprocessing import Pool -import itertools - -import numpy as np -import pandas as pd -import pymatgen.core as pmat -from matminer.datasets import load_dataset -from matminer.featurizers.base import MultipleFeaturizer -from matminer.featurizers.structure import XRDPowderPattern -from matminer.featurizers.composition import ElementFraction - -from matgraphdb.utils import MP_DIR, DB_DIR, SIMILARITY_DIR,N_CORES -from matgraphdb.utils.math_utils import cosine_similarity -from matgraphdb.utils.general_utils import chunk_list - - -CHUNK_DIR=os.path.join(SIMILARITY_DIR,'chunks') - -def similarity_calc(material_combs_chunk): - """ - Calculate the similarity between pairs of materials. - - Parameters: - - material_combs_chunk (tuple): A tuple containing the chunk index and a list of material combinations. - - Returns: - - None - - This function takes a chunk of material combinations and calculates the similarity between each pair of materials. - The similarity is calculated based on the structure features of the materials using cosine similarity. - - The calculated similarity values are stored in a dictionary and saved to a JSON file. - - Note: This function assumes that the necessary directories and files are already set up. - - Example usage: - similarity_calc((0, [('mat1', 'mat2'), ('mat3', 'mat4')])) - """ - - # Function code goes here - # ... -def similarity_calc(material_combs_chunk): - - i_chunk,material_combs=material_combs_chunk - - chunk_file=os.path.join(CHUNK_DIR,f'chunk_{i_chunk}.json') - - material_ids=[] - for material_comb in material_combs: - mat_1,mat_2=material_comb - if mat_1 not in material_ids: - material_ids.append(mat_1) - if mat_2 not in material_ids: - material_ids.append(mat_2) - - structures=[] - compositions=[] - for material_id in material_ids: - material_json=os.path.join(DB_DIR,material_id + '.json') - - with open(material_json) as f: - db = json.load(f) - struct = pmat.Structure.from_dict(db['structure']) - structures.append(struct) - compositions.append(struct.composition) - - structure_data = pd.DataFrame({'structure': structures}, index=material_ids) - # composition_data = pd.DataFrame({'composition': compositions}, index=material_ids) - - structure_featurizer = MultipleFeaturizer([XRDPowderPattern()]) - # composition_featurizer = MultipleFeaturizer([ElementFraction()]) - - structure_features = structure_featurizer.featurize_dataframe(structure_data,"structure") - # composition_features = composition_featurizer.featurize_dataframe(composition_data,"composition") - - structure_features=structure_features.drop(columns=['structure']) - # composition_features=composition_features.drop(columns=['composition']) - - - features=structure_features - similarity_dict={} - for material_comb in material_combs: - mat_1,mat_2=material_comb - row_1 = features.loc[mat_1].values - row_2 = features.loc[mat_2].values - - similarity=cosine_similarity(a=row_1,b=row_2) - - pair_name=f'{mat_1}_{mat_2}' - similarity_dict.update({pair_name : similarity}) - - with open(chunk_file,'w') as f: - json.dump(similarity_dict, f, indent=4) - - -if __name__=='__main__': - print('Running Similarity analysis') - print('Database Dir : ', DB_DIR) - - - CHUNK_SIZE=1000 - - database_files=glob(DB_DIR + os.sep +'*.json') - mpids=[file.split(os.sep)[-1].split('.')[0] for file in database_files] - # print(mpids) - material_combs=list(itertools.combinations_with_replacement(mpids[:100], r=2 )) - - print(len(material_combs)) - # print(material_combs[:100]) - material_combs_chunks = chunk_list(material_combs, CHUNK_SIZE) - material_combs_chunks= [(i,material_combs_chunk) for i,material_combs_chunk in enumerate(material_combs_chunks)] - - print(len(material_combs_chunks)) - # with Pool(N_CORES) as p: - # p.map(similarity_analysis, material_combs_chunks) - for material_combs_chunk in material_combs_chunks: - similarity_calc(material_combs_chunk) - - - - - # Create empty similarity file - similarity_file= os.path.join(SIMILARITY_DIR, 'similarity.json') - tmp_dict={} - with open(similarity_file,'w') as f: - try: - data=json.load(f) - except: - json.dump(tmp_dict, f, indent=4) - - chunk_files=glob(CHUNK_DIR + os.sep + '*.json') - similarity_dict={} - for chunk_file in chunk_files: - with open(chunk_file) as f: - chunk_dict = json.load(f) - similarity_dict.update(chunk_dict) - - - with open(similarity_file,'w') as f: - json.dump(similarity_dict, f) - - diff --git a/matgraphdb/calculations/mat_calcs/wyckoff_calc.py b/matgraphdb/calculations/mat_calcs/wyckoff_calc.py deleted file mode 100644 index 8035e25..0000000 --- a/matgraphdb/calculations/mat_calcs/wyckoff_calc.py +++ /dev/null @@ -1,34 +0,0 @@ -import pymatgen.core as pmat -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - -from matgraphdb.utils import LOGGER - - -def calculate_wyckoff_positions(struct): - """ - Calculates the Wyckoff positions of a given structure. - - Args: - struct (Structure): The structure to be processed. - - Returns: - list: A list of Wyckoff positions. - """ - wyckoffs=None - try: - spg_a = SpacegroupAnalyzer(struct) - sym_dataset=spg_a.get_symmetry_dataset() - wyckoffs=sym_dataset['wyckoffs'] - except Exception as e: - LOGGER.error(f"Error processing structure: {e}") - - return wyckoffs - -if __name__=='__main__': - # Testing calculate_wyckoff_positions - struct = pmat.Structure.from_dict({ - "lattice": [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5]], - "species": ["H", "H", "H"], - "coords": [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5]], - }) - calculate_wyckoff_positions(struct) diff --git a/matgraphdb/calculations/parsers.py b/matgraphdb/calculations/parsers.py deleted file mode 100644 index 14dd741..0000000 --- a/matgraphdb/calculations/parsers.py +++ /dev/null @@ -1,215 +0,0 @@ - -import re -import os - -from matgraphdb.utils import LOGGER - -def parse_chargemol_bond_orders(file, bond_order_cutoff=0.0): - """ - Parses the Chargemol bond order file and extracts the bonding orders and connections. - - Args: - file (str): The path to the Chargemol bond order file. - bond_order_cutoff (float, optional): The minimum bond order cutoff. Bonds with order below this value will be ignored. Defaults to 0.0. - - Returns: - tuple: A tuple containing two lists. The first list contains the bonding orders for each atom, and the second list contains the atom indices of the connected atoms. - """ - try: - with open(file,'r') as f: - text=f.read() - - bond_blocks=re.findall('(?<=Printing BOs for ATOM).*\n([\s\S]*?)(?=The sum of bond orders for this atom is SBO)',text) - - bonding_connections=[] - bonding_orders=[] - - for bond_block in bond_blocks: - - bonds=bond_block.strip().split('\n') - - bond_orders=[] - atom_indices=[] - # Catches cases where there are no bonds - if bonds[0]!='': - for bond in bonds: - - bond_order=float(re.findall('bond order\s=\s*([.0-9-]*)\s*',bond)[0]) - - # shift so index starts at 0 - atom_index=int(re.findall('translated image of atom number\s*([0-9]*)\s*',bond)[0]) -1 - - if bond_order >= bond_order_cutoff: - bond_orders.append(bond_order) - atom_indices.append(atom_index) - else: - pass - - bonding_connections.append(atom_indices) - bonding_orders.append(bond_orders) - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - bonding_orders=None - bonding_connections=None - - return bonding_connections,bonding_orders - -def parse_chargemol_atomic_moments(file): - """ - Parses the Chargemol atomic moments file and extracts the atomic moments. - - Args: - file (str): The path to the Chargemol atomic moments file. - - Returns: - list: A list containing the atomic moments. - """ - try: - with open(file,'r') as f: - text=f.read() - - - raw_atomic_moments_info=re.findall('Same information as above printed with atom number.*\n([\s\S]*)',text)[0].strip().split('\n') - atomic_moments=[] - for moment_info_line in raw_atomic_moments_info: - moment_info=moment_info_line.split() - moment= float(moment_info[5]) - atomic_moments.append(moment) - - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - atomic_moments=None - - return atomic_moments - - -def parse_chargemol_net_atomic_charges(file): - """ - Parses the Chargemol net atomic charges file and extracts the net atomic charges. - - Args: - file (str): The path to the Chargemol net atomic charges file. - - Returns: - list: A list containing the net atomic charges. - """ - try: - with open(file,'r') as f: - text=f.read() - - num_atoms=int(text.split('\n')[0]) - - parameter_names= ('Number of radial integration shells', - 'Cutoff radius \(pm\)', - 'Error in the integrated total number of electrons before renormalization \(e\)', - 'Charge convergence tolerance', - 'Minimum radius for electron cloud penetration fitting \(pm\)', - 'Minimum decay exponent for electron density of buried atom tails', - 'Maximum decay exponent for electron density of buried atom tails', - 'Number of iterations to convergence') - parameters=[] - for parameter_name in parameter_names: - reg_expression=parameter_name+'\s*=\s*([-E\d.]+)' - raw_parameter_name=parameter_name.replace('\\','') - parameter_value=float(re.findall(reg_expression,text)[0]) - parameters.append((raw_parameter_name,parameter_value)) - - - raw_info=re.findall('The following XYZ coordinates are in angstroms. The atomic dipoles and quadrupoles are in atomic units.*\n([\s\S]*)',text)[0].split('\n \n') - - moment_info=raw_info[0].split('\n') - moment_info_description='The following XYZ coordinates are in angstroms. The atomic dipoles and quadrupoles are in atomic units' - moment_info_names=[name.strip() for name in moment_info[0].split(',')] - moment_info_values=[] - for raw_values in moment_info[1:]: - values = [value if i == 1 else float(value) for i, value in enumerate(raw_values.split())] - moment_info_values.append(values) - - electron_density_fit_description='The sperically averaged electron density of each atom fit to a function of the form exp(a - br) for r >=rmin_cloud_penetration.' - electron_density_fit_info=re.findall('The sperically averaged electron density of each atom fit to a function of the form exp\(a \- br\) for r \>\=rmin_cloud_penetration.*\n([\s\S]*)',raw_info[1])[0].split('\n') - electron_density_fit_names=[name.strip() for name in electron_density_fit_info[0].split(',')] - electron_density_fit_values=[] - for raw_values in electron_density_fit_info[1:]: - values = [value if i == 1 else float(value) for i, value in enumerate(raw_values.split())] - electron_density_fit_values.append(values) - - - net_atomic_charges_info={ - 'moment_info_description':moment_info_description, - 'moment_info_names':moment_info_names, - 'moment_info_values':moment_info_values, - 'electron_density_fit_description':electron_density_fit_description, - 'electron_density_fit_names':electron_density_fit_names, - 'electron_density_fit_values':electron_density_fit_values, - 'computational_parameters':parameters - } - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - net_atomic_charges_info=None - - return net_atomic_charges_info - -def parse_chargemol_overlap_populations(file): - """ - Parses the Chargemol overlap populations file and extracts the overlap populations. - - Args: - file (str): The path to the Chargemol overlap populations file. - - Returns: - list: A list containing the overlap populations. - """ - - try: - with open(file,'r') as f: - text=f.read() - - num_atoms=int(text.split('\n')[0]) - overlap_populations_names='atom1, atom2, translation A, translation B, translation C, overlap population'.split(',') - raw_overlap_populations_values=re.findall('atom1, atom2, translation A, translation B, translation C, overlap population.*\n([\s\S]*)',text)[0].strip().split('\n') - - overlap_populations_values=[[ float(value) for value in overlap_population_line.split()] for overlap_population_line in raw_overlap_populations_values] - - overlap_populations_info={ - 'overlap_populations_names':overlap_populations_names, - 'overlap_populations_values':overlap_populations_values - } - except Exception as e: - LOGGER.error(f"Error processing file: {e}") - overlap_populations_info=None - return overlap_populations_info - - -if __name__=='__main__': - - - squared_moments_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/DDEC_atomic_Rsquared_moments.xyz' - cubed_moments_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/DDEC_atomic_Rcubed_moments.xyz' - fourth_moments_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/DDEC_atomic_Rfourth_moments.xyz' - bond_orders_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/DDEC6_even_tempered_bond_orders.xyz' - atomic_charges_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/DDEC6_even_tempered_net_atomic_charges.xyz' - overlap_population_file='/users/lllang/SCRATCH/projects/MatGraphDB/data/raw/materials_project_nelements_2/calculations/MaterialsData/mp-170/chargemol/overlap_populations.xyz' - - - # bond_order_info= parse_chargemol_bond_orders(file=bond_orders_file) - - # for bond_orders, neihbors in zip(*bond_order_info): - # print(bond_orders) - # print(neihbors) - # print('_'*200) - squared_moments_file='/gpfs20/scratch/lllang/projects/MatGraphDB/data/production/materials_project/calculations/MaterialsData/mp-1228566/chargemol/DDEC_atomic_Rsquared_moments.xyz' - print(squared_moments_file) - squared_moments_info=parse_chargemol_atomic_moments(file=squared_moments_file) - print(squared_moments_info) - # cubed_moments_info=parse_chargemol_atomic_moments(file=cubed_moments_file) - # fourth_moments_info=parse_chargemol_atomic_moments(file=fourth_moments_file) - - # print(squared_moments_info) - # print(cubed_moments_info) - # print(fourth_moments_info) - - # net_atomic_charges_info=parse_chargemol_net_atomic_charges(file=atomic_charges_file) - # print(net_atomic_charges_info) - - # overlap_population_info=parse_chargemol_overlap_populations(file=overlap_population_file) - # print(overlap_population_info) \ No newline at end of file diff --git a/matgraphdb/calculations/slurm_job_launcher.py b/matgraphdb/calculations/slurm_job_launcher.py deleted file mode 100644 index 38f899a..0000000 --- a/matgraphdb/calculations/slurm_job_launcher.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import subprocess - -from matgraphdb.data.manager import DBManager - -import subprocess - -def launch_calcs(slurm_scripts=[]): - """ - Launches SLURM calculations by submitting SLURM scripts. - - Args: - slurm_scripts (list): A list of SLURM script file paths. - - Returns: - None - """ - if slurm_scripts != []: - for slurm_script in slurm_scripts: - result = subprocess.run(['sbatch', slurm_script], capture_output=False, text=True) - -def launch_failed_chargemol_calcs(): - """ - Launches failed Chargemol calculations. - - This function retrieves the list of failed Chargemol calculations from the database - and launches the corresponding SLURM scripts for each failed calculation. - - Returns: - None - """ - db = DBManager() - success, failed = db.check_chargemol() - - slurm_scripts = [] - - print(f"About to launch {len(failed)} calculations") - for path in failed[:]: - slurm_script = os.path.join(path, 'run.slurm') - slurm_scripts.append(slurm_script) - - launch_calcs(slurm_scripts) - - - -if __name__=='__main__': - launch_failed_chargemol_calcs() \ No newline at end of file diff --git a/matgraphdb/core/__init__.py b/matgraphdb/core/__init__.py index cd96b05..6d3b94f 100644 --- a/matgraphdb/core/__init__.py +++ b/matgraphdb/core/__init__.py @@ -1,2 +1,2 @@ +from matgraphdb.core.material_store import MaterialStore from matgraphdb.core.matgraphdb import MatGraphDB -from matgraphdb.core.nodes import MaterialStore diff --git a/matgraphdb/core/datasets/__init__.py b/matgraphdb/core/datasets/__init__.py deleted file mode 100644 index aa0e707..0000000 --- a/matgraphdb/core/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from matgraphdb.core.datasets.mp_near_hull import MPNearHull diff --git a/matgraphdb/core/nodes/materials.py b/matgraphdb/core/material_store.py similarity index 82% rename from matgraphdb/core/nodes/materials.py rename to matgraphdb/core/material_store.py index 12a8581..fd29790 100644 --- a/matgraphdb/core/nodes/materials.py +++ b/matgraphdb/core/material_store.py @@ -1,23 +1,15 @@ -import json import logging -import os -from functools import partial -from glob import glob from typing import Callable, Dict, List, Tuple, Union import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc -import pyarrow.dataset as ds -import pyarrow.parquet as pq -import spglib -from parquetdb import NodeStore, node_generator +from parquetdb import NodeStore from parquetdb.core.parquetdb import LoadConfig, NormalizeConfig -from pymatgen.core import Composition, Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from pymatgen.core import Structure -from matgraphdb.utils.general_utils import set_verbosity +from matgraphdb.utils.log_utils import set_verbose_level from matgraphdb.utils.mp_utils import multiprocess_task logger = logging.getLogger(__name__) @@ -450,7 +442,6 @@ def delete_materials( ids: List[int] = None, columns: List[str] = None, normalize_config: NormalizeConfig = NormalizeConfig(), - verbose: int = 3, ): """ Deletes records from the database by ID. @@ -465,8 +456,6 @@ def delete_materials( A list of column names to delete from the database. normalize_config : NormalizeConfig, optional The normalize configuration to be applied to the data. This is the NormalizeConfig object from Parquet - verbose : int, optional - The verbosity level for logging (default is 3). Returns ------- @@ -480,7 +469,6 @@ def delete_materials( .. code-block:: python manager.delete(ids=[1, 2, 3]) """ - set_verbosity(verbose) logger.info(f"Deleting data {ids}") self.delete(ids=ids, columns=columns, normalize_config=normalize_config) @@ -526,107 +514,3 @@ def check_all_params_provided(**kwargs): ) -@node_generator -def material_lattice(material_store: NodeStore): - """ - Creates Lattice nodes if no file exists, otherwise loads them from a file. - """ - # Retrieve material nodes with lattice properties - try: - # material_nodes = NodeStore(material_store_path) - material_nodes = material_store - - table = material_nodes.read( - columns=[ - "structure.lattice.a", - "structure.lattice.b", - "structure.lattice.c", - "structure.lattice.alpha", - "structure.lattice.beta", - "structure.lattice.gamma", - "structure.lattice.volume", - "structure.lattice.pbc", - "structure.lattice.matrix", - "id", - "core.material_id", - ] - ) - - for i, column in enumerate(table.columns): - field = table.schema.field(i) - field_name = field.name - if "." in field_name: - field_name = field_name.split(".")[-1] - if "id" == field_name: - field_name = "material_node_id" - new_field = field.with_name(field_name) - table = table.set_column(i, new_field, column) - - except Exception as e: - logger.error(f"Error creating lattice nodes: {e}") - return None - - return table - - -@node_generator -def material_site(material_store: NodeStore): - try: - material_nodes = material_store - lattice_names = [ - "structure.lattice.a", - "structure.lattice.b", - "structure.lattice.c", - "structure.lattice.alpha", - "structure.lattice.beta", - "structure.lattice.gamma", - "structure.lattice.volume", - ] - id_names = ["id", "core.material_id"] - tmp_dict = {field: [] for field in id_names} - tmp_dict.update({field: [] for field in lattice_names}) - table = material_nodes.read( - columns=["structure.sites", *id_names, *lattice_names] - ) - # table=material_nodes.read(columns=['structure.sites', *id_names])#, *lattice_names]) - material_sites = table["structure.sites"].combine_chunks() - flatten_material_sites = pc.list_flatten(material_sites) - material_sites_length_list = pc.list_value_length(material_sites).to_numpy() - - for i, legnth in enumerate(material_sites_length_list): - for field_name in tmp_dict.keys(): - column = table[field_name].combine_chunks() - value = column[i] - tmp_dict[field_name].extend([value] * legnth) - table = None - - arrays = flatten_material_sites.flatten() - - if hasattr(flatten_material_sites, "names"): - names = flatten_material_sites.type.names - else: - names= [val.name for val in flatten_material_sites.type] - - flatten_material_sites = None - material_sites_length_list = None - - for name, column_values in tmp_dict.items(): - arrays.append(pa.array(column_values)) - names.append(name) - - table = pa.Table.from_arrays(arrays, names=names) - - for i, column in enumerate(table.columns): - field = table.schema.field(i) - field_name = field.name - if "." in field_name: - field_name = field_name.split(".")[-1] - if "id" == field_name: - field_name = "material_node_id" - new_field = field.with_name(field_name) - table = table.set_column(i, new_field, column) - - except Exception as e: - logger.error(f"Error creating site nodes: {e}") - raise e - return table diff --git a/matgraphdb/core/matgraphdb.py b/matgraphdb/core/matgraphdb.py index 8f43f7c..b32e274 100644 --- a/matgraphdb/core/matgraphdb.py +++ b/matgraphdb/core/matgraphdb.py @@ -1,11 +1,10 @@ import logging import os -from typing import Dict, List, Union +from typing import List -import pyarrow as pa from parquetdb import ParquetGraphDB -from matgraphdb.core.nodes import MaterialStore +from matgraphdb.core.material_store import MaterialStore logger = logging.getLogger(__name__) diff --git a/matgraphdb/core/nodes/__init__.py b/matgraphdb/core/nodes/__init__.py deleted file mode 100644 index d80f699..0000000 --- a/matgraphdb/core/nodes/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from matgraphdb.core.nodes.generators import ( - chemenv, - crystal_system, - element, - magnetic_state, - oxidation_state, - space_group, - wyckoff, -) -from matgraphdb.core.nodes.materials import ( - MaterialStore, - material_lattice, - material_site, -) diff --git a/matgraphdb/datasets/__init__.py b/matgraphdb/datasets/__init__.py new file mode 100644 index 0000000..8e031cd --- /dev/null +++ b/matgraphdb/datasets/__init__.py @@ -0,0 +1 @@ +from matgraphdb.datasets.mp_near_hull import MPNearHull diff --git a/matgraphdb/core/datasets/mp_near_hull.py b/matgraphdb/datasets/mp_near_hull.py similarity index 79% rename from matgraphdb/core/datasets/mp_near_hull.py rename to matgraphdb/datasets/mp_near_hull.py index c857231..db0cdf4 100644 --- a/matgraphdb/core/datasets/mp_near_hull.py +++ b/matgraphdb/datasets/mp_near_hull.py @@ -1,11 +1,11 @@ import logging import os +import shutil from huggingface_hub import snapshot_download +from matgraphdb import generators from matgraphdb.core import MatGraphDB -from matgraphdb.core.edges import * -from matgraphdb.core.nodes import * from matgraphdb.utils.config import config MPNEARHULL_PATH = os.path.join(config.data_dir, "datasets", "MPNearHull") @@ -58,19 +58,19 @@ def __init__( def initialize_nodes(self): node_generators = [ - {"generator_func": element}, - {"generator_func": chemenv}, - {"generator_func": crystal_system}, - {"generator_func": magnetic_state}, - {"generator_func": oxidation_state}, - {"generator_func": space_group}, - {"generator_func": wyckoff}, + {"generator_func": generators.element}, + {"generator_func": generators.chemenv}, + {"generator_func": generators.crystal_system}, + {"generator_func": generators.magnetic_state}, + {"generator_func": generators.oxidation_state}, + {"generator_func": generators.space_group}, + {"generator_func": generators.wyckoff}, { - "generator_func": material_site, + "generator_func": generators.material_site, "generator_args": {"material_store": self.node_stores["material"]}, }, { - "generator_func": material_lattice, + "generator_func": generators.material_lattice, "generator_args": {"material_store": self.node_stores["material"]}, }, ] @@ -87,53 +87,53 @@ def initialize_nodes(self): def initialize_edges(self): edge_generators = [ { - "generator_func": element_element_neighborsByGroupPeriod, + "generator_func": generators.element_element_neighborsByGroupPeriod, "generator_args": {"element_store": self.node_stores["element"]}, }, { - "generator_func": element_oxiState_canOccur, + "generator_func": generators.element_oxiState_canOccur, "generator_args": { "element_store": self.node_stores["element"], "oxiState_store": self.node_stores["oxidation_state"], }, }, { - "generator_func": material_chemenv_containsSite, + "generator_func": generators.material_chemenv_containsSite, "generator_args": { "material_store": self.node_stores["material"], "chemenv_store": self.node_stores["chemenv"], }, }, { - "generator_func": material_crystalSystem_has, + "generator_func": generators.material_crystalSystem_has, "generator_args": { "material_store": self.node_stores["material"], "crystal_system_store": self.node_stores["crystal_system"], }, }, { - "generator_func": material_element_has, + "generator_func": generators.material_element_has, "generator_args": { "material_store": self.node_stores["material"], "element_store": self.node_stores["element"], }, }, { - "generator_func": material_lattice_has, + "generator_func": generators.material_lattice_has, "generator_args": { "material_store": self.node_stores["material"], "lattice_store": self.node_stores["material_lattice"], }, }, { - "generator_func": material_spg_has, + "generator_func": generators.material_spg_has, "generator_args": { "material_store": self.node_stores["material"], "spg_store": self.node_stores["space_group"], }, }, { - "generator_func": element_chemenv_canOccur, + "generator_func": generators.element_chemenv_canOccur, "generator_args": { "element_store": self.node_stores["element"], "chemenv_store": self.node_stores["chemenv"], diff --git a/matgraphdb/generators/__init__.py b/matgraphdb/generators/__init__.py new file mode 100644 index 0000000..a85cf97 --- /dev/null +++ b/matgraphdb/generators/__init__.py @@ -0,0 +1,2 @@ +from matgraphdb.generators.edges import * +from matgraphdb.generators.nodes import * diff --git a/matgraphdb/core/edges.py b/matgraphdb/generators/edges.py similarity index 100% rename from matgraphdb/core/edges.py rename to matgraphdb/generators/edges.py diff --git a/matgraphdb/core/nodes/generators.py b/matgraphdb/generators/nodes.py similarity index 61% rename from matgraphdb/core/nodes/generators.py rename to matgraphdb/generators/nodes.py index 7f963c0..f2c2351 100644 --- a/matgraphdb/core/nodes/generators.py +++ b/matgraphdb/generators/nodes.py @@ -1,16 +1,13 @@ import logging import os -import shutil import warnings import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc -from parquetdb import ParquetDB, node_generator -from parquetdb.utils import pyarrow_utils +from parquetdb import NodeStore, node_generator -from matgraphdb.core.nodes import * from matgraphdb.utils.config import PKG_DIR logger = logging.getLogger(__name__) @@ -200,3 +197,110 @@ def wyckoff(): return None return df + + + +@node_generator +def material_lattice(material_store: NodeStore): + """ + Creates Lattice nodes if no file exists, otherwise loads them from a file. + """ + # Retrieve material nodes with lattice properties + try: + # material_nodes = NodeStore(material_store_path) + material_nodes = material_store + + table = material_nodes.read( + columns=[ + "structure.lattice.a", + "structure.lattice.b", + "structure.lattice.c", + "structure.lattice.alpha", + "structure.lattice.beta", + "structure.lattice.gamma", + "structure.lattice.volume", + "structure.lattice.pbc", + "structure.lattice.matrix", + "id", + "core.material_id", + ] + ) + + for i, column in enumerate(table.columns): + field = table.schema.field(i) + field_name = field.name + if "." in field_name: + field_name = field_name.split(".")[-1] + if "id" == field_name: + field_name = "material_node_id" + new_field = field.with_name(field_name) + table = table.set_column(i, new_field, column) + + except Exception as e: + logger.error(f"Error creating lattice nodes: {e}") + return None + + return table + + +@node_generator +def material_site(material_store: NodeStore): + try: + material_nodes = material_store + lattice_names = [ + "structure.lattice.a", + "structure.lattice.b", + "structure.lattice.c", + "structure.lattice.alpha", + "structure.lattice.beta", + "structure.lattice.gamma", + "structure.lattice.volume", + ] + id_names = ["id", "core.material_id"] + tmp_dict = {field: [] for field in id_names} + tmp_dict.update({field: [] for field in lattice_names}) + table = material_nodes.read( + columns=["structure.sites", *id_names, *lattice_names] + ) + # table=material_nodes.read(columns=['structure.sites', *id_names])#, *lattice_names]) + material_sites = table["structure.sites"].combine_chunks() + flatten_material_sites = pc.list_flatten(material_sites) + material_sites_length_list = pc.list_value_length(material_sites).to_numpy() + + for i, legnth in enumerate(material_sites_length_list): + for field_name in tmp_dict.keys(): + column = table[field_name].combine_chunks() + value = column[i] + tmp_dict[field_name].extend([value] * legnth) + table = None + + arrays = flatten_material_sites.flatten() + + if hasattr(flatten_material_sites, "names"): + names = flatten_material_sites.type.names + else: + names= [val.name for val in flatten_material_sites.type] + + flatten_material_sites = None + material_sites_length_list = None + + for name, column_values in tmp_dict.items(): + arrays.append(pa.array(column_values)) + names.append(name) + + table = pa.Table.from_arrays(arrays, names=names) + + for i, column in enumerate(table.columns): + field = table.schema.field(i) + field_name = field.name + if "." in field_name: + field_name = field_name.split(".")[-1] + if "id" == field_name: + field_name = "material_node_id" + new_field = field.with_name(field_name) + table = table.set_column(i, new_field, column) + + except Exception as e: + logger.error(f"Error creating site nodes: {e}") + raise e + return table diff --git a/matgraphdb/graph_kit/__init__.py b/matgraphdb/graph_kit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/graph_kit/data.py b/matgraphdb/graph_kit/data.py deleted file mode 100644 index 374e66c..0000000 --- a/matgraphdb/graph_kit/data.py +++ /dev/null @@ -1,932 +0,0 @@ -import io -import os - -import pandas as pd -import pyarrow.parquet as pq -import torch -from torch_geometric.data import HeteroData - -from matgraphdb.graph_kit.pyg.encoders import * -from matgraphdb.utils import get_child_logger - -logger=get_child_logger(__name__, console_out=False, log_level='debug') - -def get_parquet_field_metadata(path, columns=None): - """ - Retrieves the metadata for each field (column) in a Parquet file and returns it in a dictionary format. - - Args: - path (str): The file path to the Parquet file. - columns (list, optional): A list of column names to filter and return metadata for. If None, metadata for all fields is returned. - - Returns: - dict: A dictionary where keys are the column names, and values are dictionaries containing metadata for each column. - The metadata is extracted from the Arrow schema in the Parquet file and includes key-value pairs converted to UTF-8 strings. - - Example: - field_metadata = get_parquet_field_metadata('data.parquet', columns=['column1', 'column2']) - """ - parquet_file = pq.ParquetFile(path) - field_metadata={} - for field in parquet_file.metadata.schema.to_arrow_schema(): - name=field.name - - if columns and name not in columns: - continue - - field_metadata[name]={} - for key,value in field.metadata.items(): - field_metadata[name][key.decode('utf-8')]=value.decode('utf-8') - return field_metadata - -def load_node_parquet(path, feature_columns=[], target_columns=[], custom_encoders={}, filter={}, keep_nan=False): - """ - Loads a Parquet file containing node data and applies custom encoders and filters, returning feature and target tensors - along with a mapping of node names to indices. - - Args: - path (str): The file path to the Parquet file. - feature_columns (list, optional): List of column names to use as features. Defaults to an empty list. - target_columns (list, optional): List of column names to use as targets. Defaults to an empty list. - custom_encoders (dict, optional): A dictionary where keys are column names and values are custom encoders to apply to those columns. - If not provided, the default encoder from the Parquet metadata will be used. - filter (dict, optional): A dictionary where keys are column names and values are tuples specifying the (min, max) range for filtering rows. - keep_nan (bool, optional): If False, rows with NaN values will be dropped. Defaults to False. - - Returns: - tuple: A tuple containing: - - x (torch.Tensor): The feature tensor concatenated from all feature columns. - - target (torch.Tensor): The target tensor concatenated from all target columns. - - index_name_map (dict): A mapping of the original index to the name of the node. - - feature_names (list): A list of feature column names, including any columns expanded by the encoders. - - target_names (list): A list of target column names, including any columns expanded by the encoders. - - Example: - x, target, index_name_map, feature_names, target_names = load_node_parquet('nodes.parquet', feature_columns=['f1'], target_columns=['t1']) - """ - if target_columns is None: - target_columns=[] - if feature_columns is None: - feature_columns=[] - - all_columns=feature_columns+target_columns - - if 'name' not in all_columns: - all_columns.append('name') - - # Getting field metadata for all columns - field_metadata=get_parquet_field_metadata(path,columns=all_columns) - - # Reading all columns into single dataframe - df = pd.read_parquet(path, columns=all_columns) - names=df['name'] - df = df.drop(columns=['name']) - all_columns.remove('name') - column_names=list(df.columns) - - # Ensure all columns have no NaN values, otherwise drop them - if not keep_nan: - df.dropna(subset=df.columns, inplace=True) - - logger.info(f"Dataframe shape after removing NaN values: {df.shape}") - - # Apply data filter to the nodes - for key, (min_value, max_value) in filter.items(): - df = df[(df[key] >= min_value) & (df[key] <= max_value)] - - logger.info(f"Dataframe shape after applying node filter: {df.shape}") - logger.info(f"Dataframe shape: {df.shape}") - logger.info(f"Column names: {column_names}") - - # Applying encoders to the nodes - xs, ys = [], [] - feature_names, target_names = [], [] - for column_name in column_names: - tmp_names=[] - # Applying custom encoder if provided, - # otherwise use default encoder inside parquet feild metadata - if column_name in custom_encoders: - if isinstance(custom_encoders[column_name],str): - encoder=eval(custom_encoders[column_name]) - else: - encoder=custom_encoders[column_name] - else: - encoder=eval(field_metadata[column_name]['encoder']) - - encoded_values=encoder(df[column_name]) - - # Getting feature names from encoder. Some encoders return multiple feature - # columns due to the nature of the encoder - encoder_feature_names=None - if hasattr(encoder,'column_names'): - encoder_feature_names=encoder.column_names - - # Generating feature names - if encoder_feature_names: - for feature_name in encoder_feature_names: - tmp_names.append(f"{column_name}_{feature_name}") - elif encoded_values.shape[1]>1: - for i in range(encoded_values.shape[1]): - tmp_names.append(f"{column_name}_{i}") - else: - tmp_names.append(column_name) - - # Filtering values into features or targets - if column_name in feature_columns: - xs.append(encoded_values) - feature_names.extend(tmp_names) - if column_name in target_columns: - ys.append(encoded_values) - target_names.extend(tmp_names) - - # Concatenate features and targets - x=None - if xs: - x = torch.cat(xs, dim=-1) - target=None - if ys: - target = torch.cat(ys, dim=-1) - - # This maps the original index to the name of the node. - # This is needed since we filter out nodes that contain NaN. - index_name_map = {index: names[index] for i, index in enumerate(df.index.unique())} - return x, target, index_name_map, feature_names, target_names - -def load_relationship_parquet(path, node_id_mappings, - feature_columns=[], - target_columns=[], - custom_encoders={}, - filter={}): - """ - Loads a Parquet file containing relationship data (edges between nodes), applies custom encoders and filters, - and returns edge indices, edge attributes, and target tensors along with a mapping of edge indices. - - Args: - path (str): The file path to the Parquet file. - node_id_mappings (dict): A dictionary where keys are node types (e.g., 'src', 'dst') and values are mappings of node indices to IDs. - feature_columns (list, optional): List of column names to use as edge features. Defaults to an empty list. - target_columns (list, optional): List of column names to use as edge targets. Defaults to an empty list. - custom_encoders (dict, optional): A dictionary where keys are column names and values are custom encoders to apply to those columns. - If not provided, the default encoder from the Parquet metadata will be used. - filter (dict, optional): A dictionary where keys are column names and values are tuples specifying the (min, max) range for filtering rows. - - Returns: - tuple: A tuple containing: - - edge_index (torch.Tensor): The edge index tensor with source and destination node indices. - - edge_attr (torch.Tensor): The edge attributes tensor concatenated from all feature columns. - - target (torch.Tensor): The target tensor concatenated from all target columns. - - index_name_mapping (dict): A mapping of the original edge index to the filtered edge index. - - feature_names (list): A list of feature column names, including any columns expanded by the encoders. - - target_names (list): A list of target column names, including any columns expanded by the encoders. - - Example: - edge_index, edge_attr, target, index_name_mapping, feature_names, target_names = load_relationship_parquet( - 'edges.parquet', node_id_mappings={'src': {0: 'A'}, 'dst': {1: 'B'}}, feature_columns=['f1']) - """ - - all_columns=feature_columns+target_columns - - # Getting field metadata for all columns - field_metadata=get_parquet_field_metadata(path,columns=all_columns) - - # Getting relationship information - edge_name=os.path.basename(path).split('.')[0] - src_name,edge_type,dst_name=edge_name.split('-') - src_column_name=f'{src_name}-START_ID' - dst_column_name=f'{dst_name}-END_ID' - src_index_name_mapping=node_id_mappings[src_name] - dst_index_name_mapping=node_id_mappings[dst_name] - - # This maps the reduced index to the original index. - # Again this is because we filter out nodes that contain NaN. - src_index_translation={index:reduced_index for reduced_index,index in enumerate(src_index_name_mapping.keys())} - dst_index_translation={index:reduced_index for reduced_index,index in enumerate(dst_index_name_mapping.keys())} - - type_column_name='TYPE' - - # Reading all columns into single dataframe - df = pd.read_parquet(path) - - edges_in_graph_mask=df[src_column_name].isin(src_index_name_mapping.keys()) & df[dst_column_name].isin(dst_index_name_mapping.keys()) - df = pd.DataFrame(df[edges_in_graph_mask].to_dict()) - - edge_index=None - edge_attr = None - df['src_mapped'] = df[src_column_name].map(src_index_translation) - df['dst_mapped'] = df[dst_column_name].map(dst_index_translation) - - edge_index = torch.tensor([df['src_mapped'].tolist(), - df['dst_mapped'].tolist()]) - - df=df.drop(columns=[src_column_name,dst_column_name,type_column_name,'src_mapped','dst_mapped']) - - column_names=list(df.columns) - - logger.info(f"Dataframe shape: {df.shape}") - logger.info(f"Column names: {column_names}") - - # Ensure all columns have no NaN values, otherwise drop them - for col in column_names: - df=df.dropna(subset=[col]) - - logger.info(f"Dataframe shape after removing NaN values: {df.shape}") - - # Apply data filter to the nodes - if filter != {}: - for key, value in filter.items(): - for name in column_names: - if key in name: - column = name - break - min_value, max_value = value - df = df[(df[column] >= min_value) & (df[column] <= max_value)] - - logger.info(f"Dataframe shape after applying filter: {df.shape}") - - # Applying encoders to the nodes - xs=[] - ys=[] - feature_names=[] - target_names=[] - for column_name in column_names: - tmp_names=[] - # Applying custom encoder if provided, - # otherwise use default encoder inside parquet feild metadata - if column_name in custom_encoders: - if isinstance(custom_encoders[column_name],str): - encoder=eval(custom_encoders[column_name]) - else: - encoder=custom_encoders[column_name] - else: - encoder=eval(field_metadata[column_name]['encoder']) - - tmp_values=encoder(df[column_name]) - - # Getting feature names from encoder. Some encoders return multiple feature - # columns due to the nature of the encoder - encoder_feature_names=None - if hasattr(encoder,'column_names'): - encoder_feature_names=encoder.column_names - - # Generating feature names - if encoder_feature_names: - for feature_name in encoder_feature_names: - tmp_names.append(f"{column_name}_{feature_name}") - elif tmp_values.shape[1]>1: - for i in range(tmp_values.shape[1]): - tmp_names.append(f"{column_name}_{i}") - else: - tmp_names.append(column_name) - - # Filtering values into features or targets - if column_name in feature_columns: - xs.append(tmp_values) - feature_names.extend(tmp_names) - if column_name in target_columns: - ys.append(tmp_values) - target_names.extend(tmp_names) - - # Concatenate features and targets - edge_attr=None - if xs: - edge_attr = torch.cat(xs, dim=-1) - target=None - if ys: - target = torch.cat(ys, dim=-1) - - - # Create maping from original index to unqique index after removing posssible nan values - index_name_mapping = {index: i for i, index in enumerate(df.index.unique())} - - return edge_index, edge_attr, target, index_name_mapping, feature_names, target_names - -class DataGenerator: - """ - A class used to generate and manage heterogeneous and homogeneous graph data for machine learning models. - This class handles the loading of nodes and edges, conversion between heterogeneous and homogeneous graph formats, - and saving/loading of graph data. - - Attributes - ---------- - hetero_data : HeteroData - Stores the heterogeneous graph data, where each node and edge type can have its own features and labels. - node_id_mappings : dict - Maps node names to their index IDs, helping to manage relationships between nodes when edges are added. - _homo_data : torch_geometric.data.Data, optional - Stores the homogeneous graph data, if converted from the heterogeneous data. - - Methods - ------- - homo_data() - Getter property that converts the heterogeneous graph to homogeneous format if requested. - add_node_type(node_path, feature_columns=[], target_columns=[], custom_encoders={}, filter={}, keep_nan=False) - Adds a node type (with optional features and targets) to the heterogeneous graph. - add_edge_type(edge_path, feature_columns=[], target_columns=[], custom_encoders={}, filter={}, undirected=True) - Adds an edge type between two nodes in the heterogeneous graph, with optional edge features and labels. - to_homogeneous() - Converts the heterogeneous graph to a homogeneous format, where all node and edge types are merged. - save_graph(filepath, use_buffer=False, homogeneous=False) - Saves the graph data (either homogeneous or heterogeneous) to a specified file in .pt format. - load_graph(filepath, use_buffer=False, homogeneous=False) - Loads graph data (either homogeneous or heterogeneous) from a .pt file. - """ - def __init__(self): - """ - Initializes the DataGenerator class. - - Sets up an empty heterogeneous graph (`hetero_data`) and initializes mappings for node IDs. - Also initializes `_homo_data` to store the homogeneous graph, if needed. - """ - self.hetero_data = HeteroData() - - # This is need to map if nodes are filtered out. - self.node_id_mappings={} - - logger.info(f"Initializing DataGenerator") - self._homo_data = None - - @property - def homo_data(self): - - return self._homo_data - @homo_data.getter - def homo_data(self): - """ - Getter for the homogeneous graph data. - - Returns the homogeneous graph by converting the heterogeneous graph. Raises an exception if - the node features between types do not match. - - Returns - ------- - torch_geometric.data.Data - The homogeneous graph data. - - Raises - ------ - ValueError - If node features are inconsistent across types. - """ - try: - return self.hetero_data.to_homogeneous() - except Exception as e: - raise ValueError(f"Make sure to only upload a the nodes have the same amount of features: {e}") - @homo_data.setter - def homo_data(self, value): - """ - Setter for the homogeneous graph data. - - Parameters - ---------- - value : torch_geometric.data.Data - The homogeneous graph data to be assigned. - """ - self._homo_data = value - - def add_node_type(self, node_path, - feature_columns=[], - target_columns=[], - custom_encoders={}, - filter={}, - keep_nan=False): - """ - Adds a new node type to the heterogeneous graph, loading features and targets from a Parquet file. - - Parameters - ---------- - node_path : str - Path to the Parquet file containing node data. - feature_columns : list, optional - List of column names to be used as features for the nodes. - target_columns : list, optional - List of column names to be used as target labels for the nodes. - custom_encoders : dict, optional - Custom encoders for specific feature columns. - filter : dict, optional - A dictionary of filters to apply when loading the node data. - keep_nan : bool, optional - Whether to keep NaN values in the data (default is False). - - Returns - ------- - dict - A mapping of node IDs to their original index names. - - Raises - ------ - Exception - If the node cannot be added due to an error in loading the data. - """ - logger.info(f"Adding node type: {node_path}") - - - node_name=os.path.basename(node_path).split('.')[0] - - x,target,index_name_map,feature_names,target_names=load_node_parquet(node_path, - feature_columns=feature_columns, - target_columns=target_columns, - custom_encoders=custom_encoders, - filter=filter, - keep_nan=keep_nan) - - if x is not None: - logger.info(f"{node_name} feature shape: {x.shape}") - logger.info(f"{node_name} feature names: {len(feature_names)}") - - # logger.info(f"{node_name} index name map: {feature_names}") - logger.info(f"{node_name} target name map: {target_names}") - - - self.hetero_data[node_name].node_id=torch.arange(len(index_name_map)) - self.hetero_data[node_name].names=list(index_name_map.values()) - if x is not None: - self.hetero_data[node_name].x = x - self.hetero_data[node_name].feature_names=feature_names - else: - self.hetero_data[node_name].num_nodes=len(index_name_map) - - if target is not None: - - self.hetero_data[node_name].y_label_name=target_columns - out_channels=target.shape[1] - self.hetero_data[node_name].out_channels = out_channels - - if out_channels==1: - self.hetero_data[node_name].y=target - self.hetero_data[node_name].y_names=target_names - else: - self.hetero_data[node_name].y=torch.argmax(target, dim=1) - - logger.info(f"{node_name} target shape: {target.shape}") - logger.info(f"{node_name} out channels: {out_channels}") - - logger.info(f"Node {node_name} added to the graph") - - self.node_id_mappings.update({node_name:index_name_map}) - - return index_name_map - - def add_edge_type(self, edge_path, - feature_columns=[], - target_columns=[], - custom_encoders={}, - filter={}, - undirected=True): - """ - Adds a new edge type between two nodes in the heterogeneous graph, with optional edge features and targets. - - Parameters - ---------- - edge_path : str - Path to the Parquet file containing edge data. - feature_columns : list, optional - List of column names to be used as features for the edges. - target_columns : list, optional - List of column names to be used as target labels for the edges. - custom_encoders : dict, optional - Custom encoders for specific feature columns. - filter : dict, optional - A dictionary of filters to apply when loading the edge data. - undirected : bool, optional - Whether to treat the edge as undirected and add a reverse edge (default is True). - - Raises - ------ - Exception - If the source or destination node types do not exist in the node ID mappings. - """ - - edge_name=os.path.basename(edge_path).split('.')[0] - - src_name,edge_type,dst_name=edge_name.split('-') - - if src_name not in self.node_id_mappings: - raise Exception(f"Node {src_name} not found in node ID mappings. Call add_node_type first") - - if dst_name not in self.node_id_mappings: - raise Exception(f"Node {dst_name} not found in node ID mappings. Call add_node_type first") - - graph_dir=os.path.dirname(os.path.dirname(edge_path)) - node_dir=os.path.join(graph_dir,'nodes') - - logger.info(f"Edge name | {edge_name}") - logger.info(f"Edge type | {edge_type}") - logger.info(f"Edge src name | {src_name}") - logger.info(f"Edge dst name | {dst_name}") - logger.info(f"Graph dir | {graph_dir}") - logger.info(f"Node dir | {node_dir}") - - edge_index, edge_attr, target, index_name_map, feature_names, target_names = load_relationship_parquet(edge_path, - self.node_id_mappings, - feature_columns=feature_columns, - target_columns=target_columns, - custom_encoders=custom_encoders, - filter=filter) - - logger.info(f"Edge index shape: {edge_index.shape}") - if edge_attr is not None: - logger.info(f"Edge attr shape: {edge_attr.shape}") - logger.info(f"Feature names: {feature_names}") - logger.info(f"Target names: {target_names}") - - self.hetero_data[src_name,edge_type,dst_name].edge_index=edge_index - self.hetero_data[src_name,edge_type,dst_name].edge_attr=edge_attr - self.hetero_data[src_name,edge_type,dst_name].property_names=feature_names - - if target is not None: - logger.info(f"Target shape: {target.shape}") - self.hetero_data[src_name,edge_type,dst_name].y_label_name=target_names - out_channels=target.shape[1] - self.hetero_data[src_name,edge_type,dst_name].out_channels=out_channels - - if out_channels==1: - self.hetero_data[src_name,edge_type,dst_name].y=target - else: - self.hetero_data[src_name,edge_type,dst_name].y=torch.argmax(target, dim=1) - - - if undirected: - - row, col = edge_index - rev_edge_index = torch.stack([col, row], dim=0) - self.hetero_data[dst_name,f'rev_{edge_type}',src_name].edge_index=rev_edge_index - self.hetero_data[dst_name,f'rev_{edge_type}',src_name].edge_attr=edge_attr - - if target is not None: - self.hetero_data[dst_name,f'rev_{edge_type}',src_name].y_label_name=target_names - out_channels=target.shape[1] - self.hetero_data[dst_name,f'rev_{edge_type}',src_name].out_channels = out_channels - - - if out_channels==1: - self.hetero_data[src_name,edge_type,dst_name].y=target - else: - self.hetero_data[src_name,edge_type,dst_name].y=torch.argmax(target, dim=1) - - logger.info(f"Adding {edge_type} edge | {src_name} -> {dst_name}") - - - logger.info(f"undirected: {undirected}") - - def to_homogeneous(self): - """ - Converts the heterogeneous graph into a homogeneous graph, merging all node and edge types. - - Returns - ------- - torch_geometric.data.Data - The homogeneous graph data. - """ - logger.info(f"Converting to homogeneous graph") - self.homo_data=self.hetero_data.to_homogeneous() - - def save_graph(self, filepath, use_buffer=False, homogeneous=False): - """ - Saves the graph data (either homogeneous or heterogeneous) to a specified file in .pt format. - - Parameters - ---------- - filepath : str - The path to save the graph data, which must have a `.pt` extension. - use_buffer : bool, optional - Whether to save the graph data to a buffer instead of directly to a file (default is False). - homogeneous : bool, optional - Whether to save the homogeneous graph data (default is False, saving the heterogeneous graph). - - Raises - ------ - ValueError - If the file type is not `.pt`. - """ - file_type=filepath.split('.')[-1] - - if homogeneous==True: - data=self.homo_data - logger.info(f"Saving homogeneous graph") - else: - data=self.hetero_data - logger.info(f"Saving heterogeneous graph") - - if file_type!='pt': - raise ValueError("Only .pt files are supported") - - if use_buffer==True: - buffer = io.BytesIO() - torch.save(data, buffer) - else: - torch.save(data, filepath) - - def load_graph(self, filepath, use_buffer=False, homogeneous=False): - """ - Loads graph data (either homogeneous or heterogeneous) from a .pt file. - - Parameters - ---------- - filepath : str - The path to load the graph data from, which must have a `.pt` extension. - use_buffer : bool, optional - Whether to load the graph data from a buffer instead of directly from a file (default is False). - homogeneous : bool, optional - Whether to load the homogeneous graph data (default is False, loading the heterogeneous graph). - - Returns - ------- - torch_geometric.data.Data - The loaded graph data. - - Raises - ------ - ValueError - If the file type is not `.pt`. - """ - file_type=filepath.split('.')[-1] - if file_type!='pt': - raise ValueError("Only .pt files are supported") - - - if use_buffer==True: - with open(filepath, 'rb') as f: - buffer = io.BytesIO(f.read()) - data=torch.load(buffer, weights_only=False) - else: - data=torch.load(filepath, weights_only=False) - - if homogeneous==True: - self.homo_data=data - logger.info(f"Saving homogeneous graph") - else: - self.hetero_data=data - logger.info(f"Saving heterogeneous graph") - - return data - - -if __name__ == "__main__": - from matgraphdb.graph_kit.graph_manager import GraphManager - import pandas as pd - import os - material_graph=GraphManager(skip_init=False) - graph_dir = material_graph.graph_dir - nodes_dir = material_graph.node_dir - relationship_dir = material_graph.relationship_dir - - - node_names=material_graph.list_nodes() - relationship_names=material_graph.list_relationships() - - node_files=material_graph.get_node_filepaths() - relationship_files=material_graph.get_relationship_filepaths() - - print('-'*100) - print('Nodes') - print('-'*100) - for i,node_file in enumerate(node_files): - print(i,node_file) - print('-'*100) - print('Relationships') - print('-'*100) - for i,relationship_file in enumerate(relationship_files): - print(i,relationship_file) - print('-'*100) - - - node_path=node_files[2] - edge_path=relationship_files[3] - - # node_path=node_files[5] - # df=pd.read_parquet(node_path) - - node_path=node_files[0] - # df=pd.read_parquet(node_path) - # # for x in df.columns: - # # print(f"'{x}',") - # df.to_csv(node_path.replace('.parquet','.csv')) - - # df=pd.read_parquet(edge_path) - # df.to_csv(edge_path.replace('.parquet','.csv')) - - ############################################################################################################################## - # Example of using DataGenerator - ############################################################################################################################## - - node_properties={ - 'CHEMENV':[ - 'coordination', - 'name', - ], - 'ELEMENT':[ - 'abundance_universe', - 'abundance_solar', - 'abundance_meteor', - 'abundance_crust', - 'abundance_ocean', - 'abundance_human', - 'atomic_mass', - 'atomic_number', - 'block', - # 'boiling_point', - 'critical_pressure', - 'critical_temperature', - 'density_stp', - 'electron_affinity', - 'electronegativity_pauling', - 'extended_group', - 'heat_specific', - 'heat_vaporization', - 'heat_fusion', - 'heat_molar', - 'magnetic_susceptibility_mass', - 'magnetic_susceptibility_molar', - 'magnetic_susceptibility_volume', - 'melting_point', - 'molar_volume', - 'neutron_cross_section', - 'neutron_mass_absorption', - 'period', - 'phase', - 'radius_calculated', - 'radius_empirical', - 'radius_covalent', - 'radius_vanderwaals', - 'refractive_index', - 'speed_of_sound', - # 'valence_electrons', - 'conductivity_electric', - 'electrical_resistivity', - 'modulus_bulk', - 'modulus_shear', - 'modulus_young', - 'poisson_ratio', - 'coefficient_of_linear_thermal_expansion', - 'hardness_vickers', - 'hardness_brinell', - 'hardness_mohs', - 'superconduction_temperature', - 'is_actinoid', - 'is_alkali', - 'is_alkaline', - 'is_chalcogen', - 'is_halogen', - 'is_lanthanoid', - 'is_metal', - 'is_metalloid', - 'is_noble_gas', - 'is_post_transition_metal', - 'is_quadrupolar', - 'is_rare_earth_metal', - 'experimental_oxidation_states', - 'name', - 'type', - ], - 'MATERIAL':[ - 'nsites', - 'nelements', - 'volume', - 'density', - 'density_atomic', - 'crystal_system', - # 'space_group', - # 'point_group', - 'a', - 'b', - 'c', - 'alpha', - 'beta', - 'gamma', - 'unit_cell_volume', - # 'energy_per_atom', - # 'formation_energy_per_atom', - # 'energy_above_hull', - # 'band_gap', - # 'cbm', - # 'vbm', - # 'efermi', - # 'is_stable', - # 'is_gap_direct', - # 'is_metal', - # 'is_magnetic', - # 'ordering', - # 'total_magnetization', - # 'total_magnetization_normalized_vol', - # 'num_magnetic_sites', - # 'num_unique_magnetic_sites', - # 'e_total', - # 'e_ionic', - # 'e_electronic', - # 'sine_coulomb_matrix', - # 'element_fraction', - # 'element_property', - # 'xrd_pattern', - # 'uncorrected_energy_per_atom', - # 'equilibrium_reaction_energy_per_atom', - # 'n', - # 'e_ij_max', - # 'weighted_surface_energy_EV_PER_ANG2', - # 'weighted_surface_energy', - # 'weighted_work_function', - # 'surface_anisotropy', - # 'shape_factor', - # 'elasticity-k_vrh', - # 'elasticity-k_reuss', - # 'elasticity-k_voigt', - # 'elasticity-g_vrh', - # 'elasticity-g_reuss', - # 'elasticity-g_voigt', - # 'elasticity-sound_velocity_transverse', - # 'elasticity-sound_velocity_longitudinal', - # 'elasticity-sound_velocity_total', - # 'elasticity-sound_velocity_acoustic', - # 'elasticity-sound_velocity_optical', - # 'elasticity-thermal_conductivity_clarke', - # 'elasticity-thermal_conductivity_cahill', - # 'elasticity-young_modulus', - # 'elasticity-universal_anisotropy', - # 'elasticity-homogeneous_poisson', - # 'elasticity-debye_temperature', - # 'elasticity-state', - # 'name', - ] - } - - - relationship_properties={ - 'ELEMENT-CAN_OCCUR-CHEMENV':[ - 'weight', - ], - - 'ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT':[ - 'weight', - ], - 'MATERIAL-HAS-ELEMENT':[ - 'weight', - ] - } - - generator=DataGenerator() - generator.add_node_type(node_path=node_files[0], - feature_columns=node_properties['CHEMENV'], - target_columns=[]) - - generator.add_node_type(node_path=node_files[2], - feature_columns=node_properties['ELEMENT'], - target_columns=[]) - - generator.add_node_type(node_path=node_files[5], - feature_columns=node_properties['MATERIAL'], - target_columns=['elasticity-k_vrh'], - filter={'elasticity-k_vrh':(0,300)} - ) - - generator.add_edge_type(edge_path=edge_path, - feature_columns=relationship_properties['ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT'], - # target_columns=['weight'], - # custom_encoders={}, - # node_filter={}, - undirected=True) - generator.add_edge_type(edge_path='Z:/Research_Projects/crystal_generation_project/MatGraphDB/data/production/materials_project/graph_database/main/relationships/MATERIAL-HAS-ELEMENT.parquet', - feature_columns=relationship_properties['MATERIAL-HAS-ELEMENT'], - # target_columns=['weight'], - # custom_encoders={}, - # node_filter={}, - undirected=True) - - print(generator.hetero_data) - - print(generator.hetero_data['MATERIAL'].node_id) - - - - - # # generator.save_graph(filepath=os.path.join('data','raw','main.pt')) - - # generator.load_graph(filepath=os.path.join('data','raw','main.pt')) - - # print(generator.data) - - ############################################################################################################################## - # Creating graph node embeddings - ############################################################################################################################## - - # generator=DataGenerator() - # generator.add_node_type(node_path=node_files[2], - # feature_columns=node_properties['ELEMENT'], - # target_columns=[]) - # generator.add_edge_type(edge_path=relationship_files[8], - # feature_columns=relationship_properties['ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT'], - # # target_columns=['weight'], - # # custom_encoders={}, - # # node_filter={}, - # undirected=True) - # print(generator.hetero_data) - # data=generator.homo_data - # print(data) - - - - - # generator.save_graph(filepath=os.path.join('data','raw','main.pt')) - - # generator.load_graph(filepath=os.path.join('data','raw','main.pt')) - - # print(generator.data) - diff --git a/matgraphdb/graph_kit/graph_manager.py b/matgraphdb/graph_kit/graph_manager.py deleted file mode 100644 index 27dd58c..0000000 --- a/matgraphdb/graph_kit/graph_manager.py +++ /dev/null @@ -1,227 +0,0 @@ - -from glob import glob -import os -import warnings -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import logging - -from matgraphdb import config -from matgraphdb.utils.chem_utils.periodic import get_group_period_edge_index -# from matgraphdb.utils.chem_utils.coord_geometry import mp_coord_encoding - -from matgraphdb.graph_kit.metadata import get_node_schema,get_relationship_schema -from matgraphdb.graph_kit.metadata import NodeTypes, RelationshipTypes - -from matgraphdb.graph_kit.nodes import NodeManager -from matgraphdb.graph_kit.relationships import RelationshipManager -logger = logging.getLogger(__name__) - -# TODO: Add screen_graph method -class GraphManager: - def __init__(self, graph_dir, - output_format='pandas', - node_dirname='nodes', - relationship_dirname='relationships'): - """ - Initialize the GraphManager with node and relationship directories. - - Args: - graph_dir (str): The directory where the graph data (nodes and relationships) are stored. - output_format (str, optional): The format in which to retrieve nodes and relationships (default is 'pandas'). - node_dirname (str, optional): The directory name where node files are stored (default is 'nodes'). - relationship_dirname (str, optional): The directory name where relationship files are stored (default is 'relationships'). - - Attributes: - node_manager (NodeManager): An instance of the NodeManager responsible for managing nodes. - relationship_manager (RelationshipManager): An instance of the RelationshipManager responsible for managing relationships. - output_format (str): The format for outputting data (e.g., 'pandas'). - """ - node_dir=os.path.join(graph_dir,node_dirname) - relationship_dir=os.path.join(graph_dir,relationship_dirname) - - self.node_manager = NodeManager(node_dir, output_format) - self.relationship_manager = RelationshipManager(relationship_dir, output_format) - self.output_format = output_format - - def get_all_nodes(self): - """ - Retrieve all existing nodes using the NodeManager. - - Returns: - list: A list of all existing nodes. - """ - return self.node_manager.get_existing_nodes() - - def get_all_relationships(self): - """ - Retrieve all existing relationships using the RelationshipManager. - - Returns: - list: A list of all existing relationships. - """ - return self.relationship_manager.get_existing_relationships() - - def get_node(self, node_type): - """ - Retrieve a specific node by its type. - - Args: - node_type (str): The type of the node to retrieve. - - Returns: - pd.DataFrame or dict: The data associated with the requested node type, in the specified output format (e.g., 'pandas'). - """ - return self.node_manager.get_node(node_type, output_format=self.output_format) - - def get_relationship(self, relationship_type): - """ - Retrieve a specific relationship by its type. - - Args: - relationship_type (str): The type of the relationship to retrieve. - - Returns: - pd.DataFrame or dict: The data associated with the requested relationship type, in the specified output format (e.g., 'pandas'). - """ - return self.relationship_manager.get_relationship(relationship_type, output_format=self.output_format) - - def add_node(self, node_class): - """ - Add a new node using the NodeManager. - - Args: - node_class (Node): An instance of the Node class to add to the graph. - """ - self.node_manager.add_node(node_class) - - def add_relationship(self, relationship_class): - """ - Add a new relationship using the RelationshipManager. - - Args: - relationship_class (Relationship): An instance of the Relationship class to add to the graph. - """ - self.relationship_manager.add_relationship(relationship_class) - - def delete_node(self, node_type): - """ - Delete a node by its type using the NodeManager. - - Args: - node_type (str): The type of the node to delete. - """ - self.node_manager.delete_node(node_type) - - def delete_relationship(self, relationship_type): - """ - Delete a relationship by its type using the RelationshipManager. - - Args: - relationship_type (str): The type of the relationship to delete. - """ - self.relationship_manager.delete_relationship(relationship_type) - - def check_node_relationship_consistency(self): - """ - Check if all relationships refer to existing nodes. - - This method verifies if all nodes referenced in the relationships exist in the node manager. - It identifies any relationships that reference nodes that are missing from the graph. - - Returns: - set: A set of missing node types that are referenced in relationships but do not exist in the node manager. - """ - node_types = self.node_manager.get_existing_nodes() - relationships = self.relationship_manager.get_existing_relationships() - - # You could enhance this by implementing relationship-specific logic - # (for example, checking if all node types in relationships exist) - missing_nodes = set() - for rel in relationships: - rel_data = self.relationship_manager.get_relationship_dataframe(rel) - nodes_in_relationship = set(rel_data['source']).union(set(rel_data['target'])) - missing_nodes.update(nodes_in_relationship - node_types) - - if missing_nodes: - logger.warning(f"These nodes are missing in the relationships: {missing_nodes}") - else: - logger.info("All relationships are consistent with the nodes.") - - return missing_nodes - - def export_graph_to_neo4j(self, save_dir): - """ - Export both nodes and relationships to Neo4j CSV format. - - This method exports the graph data (nodes and relationships) to the specified directory in a format compatible - with Neo4j, allowing the graph to be imported into a Neo4j database. - - Args: - save_dir (str): The directory where the Neo4j CSV files will be saved. - """ - os.makedirs(save_dir, exist_ok=True) - logger.info("Converting all nodes to Neo4j CSV format.") - self.node_manager.convert_all_to_neo4j(save_dir) - - logger.info("Converting all relationships to Neo4j CSV format.") - self.relationship_manager.convert_all_to_neo4j(save_dir) - - logger.info(f"Graph successfully exported to Neo4j CSV in {save_dir}") - - def visualize_graph(self): - """ - Visualize the graph using NetworkX and Matplotlib. - - This method creates a visual representation of the graph by adding nodes and relationships - to a NetworkX graph and then displaying it using Matplotlib. - - If the required libraries (NetworkX, Matplotlib) are not installed, the method will log an error. - - Raises: - ImportError: If NetworkX or Matplotlib is not installed. - """ - try: - import networkx as nx - import matplotlib.pyplot as plt - - G = nx.Graph() - - # Add nodes to the graph - for node_type in self.node_manager.get_existing_nodes(): - node_df = self.node_manager.get_node_dataframe(node_type) - for index, row in node_df.iterrows(): - G.add_node(row['name'], type=node_type) - - # Add relationships to the graph - for rel_type in self.relationship_manager.get_existing_relationships(): - rel_df = self.relationship_manager.get_relationship_dataframe(rel_type) - for index, row in rel_df.iterrows(): - G.add_edge(row['source'], row['target'], relationship=rel_type) - - # Draw the graph - plt.figure(figsize=(10, 8)) - nx.draw(G, with_labels=True, node_size=700, node_color="skyblue", font_size=10, font_weight="bold") - plt.show() - - except ImportError: - logger.error("NetworkX and/or matplotlib is not installed. Please install them to visualize the graph.") - - def summary(self): - """ - Return a summary of the current state of the graph. - - The summary includes the total number of nodes and relationships currently present in the graph. - - Returns: - dict: A dictionary containing the counts of nodes and relationships in the graph. - """ - node_count = len(self.get_all_nodes()) - relationship_count = len(self.get_all_relationships()) - logger.info(f"Graph contains {node_count} nodes and {relationship_count} relationships.") - return { - "nodes": node_count, - "relationships": relationship_count - } diff --git a/matgraphdb/graph_kit/graph_tool/__init__.py b/matgraphdb/graph_kit/graph_tool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/graph_kit/graph_tool/graph_tool_analyzer.py b/matgraphdb/graph_kit/graph_tool/graph_tool_analyzer.py deleted file mode 100644 index bd34be0..0000000 --- a/matgraphdb/graph_kit/graph_tool/graph_tool_analyzer.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from glob import glob -from typing import List, Tuple, Union - -import pandas as pd -import graph_tool as gt - -from matgraphdb import DBManager -from matgraphdb.utils import (GRAPH_DIR,LOGGER) - - -class GraphToolAnalyzer: - def __init__(self,graphml_file,db_manager=DBManager()): - """ - Initializes the GraphToolAnalyzer object. - - Args: - db_manager (DBManager,optional): The database manager object. Defaults to DBManager(). - - """ - self.db_manager=db_manager - self.graph = gt.load_graph(graphml_file) - - def compute_graph_entropy(self,**kwargs): - """ - Computes the graph entropy for a given graphml file. - - Args: - graphml_file (str): The path to the graphml file. - from_scratch (bool,optional): If True, deletes the graph database and recreates it from scratch. Defaults to False. - - Returns: - float: The graph entropy. - """ - - return None - - def plot_graph(self,filename,**kwargs): - """ - Plots a graph using the graphml file. - - Args: - graphml_file (str): The path to the graphml file. - filename (str): The name of the file to save the plot. - - Returns: - None - """ - gt.draw.graph_draw(self.graph,output=filename,**kwargs) - return None - -if __name__=='__main__': - analyzer=GraphToolAnalyzer() - # analyzer.compute_graph_entropy(graphml_file="data/production/materials_project/graph_database/nelements-2-2/nelements-2-2.graphml") - - - analyzer.plot_graph(graphml_file="data/production/materials_project/graph_database/nelements-2-2/nelements-2-2.graphml", - filename="data/production/materials_project/graph_database/nelements-2-2/nelements-2-2.png") \ No newline at end of file diff --git a/matgraphdb/graph_kit/metadata.py b/matgraphdb/graph_kit/metadata.py deleted file mode 100644 index 7ae5a5e..0000000 --- a/matgraphdb/graph_kit/metadata.py +++ /dev/null @@ -1,500 +0,0 @@ -from enum import Enum - -import pyarrow as pa -from matgraphdb.data.utils import MATERIAL_PARQUET_SCHEMA - -class NodeTypes(Enum): - """ - An enumeration of different node types used in a heterogeneous graph for material science data. - - Each node type represents a specific category of data in the material science domain. These types are - used when creating or managing nodes in the graph and help distinguish different kinds of entities. - - Attributes - ---------- - ELEMENT : str - Represents chemical elements in the graph. - CHEMENV : str - Represents chemical environments, which describe the local coordination environment of an atom. - CRYSTAL_SYSTEM : str - Represents different crystal systems, which describe the symmetry of a crystal's lattice. - MAGNETIC_STATE : str - Represents the magnetic state of a material or element (e.g., ferromagnetic, paramagnetic). - SPACE_GROUP : str - Represents space groups, which describe the symmetry of the crystal structure. - OXIDATION_STATE : str - Represents the oxidation state of an element. - MATERIAL : str - Represents materials as a whole in the graph. - SPG_WYCKOFF : str - Represents Wyckoff positions, which describe specific atomic positions within space groups. - CHEMENV_ELEMENT : str - Represents the relationship between chemical environments and specific elements. - LATTICE : str - Represents the lattice structure of a material. - SITE : str - Represents atomic sites within a material's structure. - """ - - ELEMENT='ELEMENT' - CHEMENV='CHEMENV' - CRYSTAL_SYSTEM='CRYSTAL_SYSTEM' - MAGNETIC_STATE='MAGNETIC_STATE' - SPACE_GROUP='SPACE_GROUP' - OXIDATION_STATE='OXIDATION_STATE' - MATERIAL='MATERIAL' - SPG_WYCKOFF='SPG_WYCKOFF' - CHEMENV_ELEMENT='CHEMENV_ELEMENT' - LATTICE='LATTICE' - SITE='SITE' - - -class RelationshipTypes(Enum): - """ - An enumeration of different relationship types used in a heterogeneous graph for material science data. - - Each relationship type defines a connection between two different node types in the graph, representing - how different entities interact or relate to each other. These relationships are crucial for understanding - the structure, properties, and behavior of materials. - - Attributes - ---------- - MATERIAL_SPG : str - Represents the relationship between a material and its space group. - MATERIAL_CRYSTAL_SYSTEM : str - Represents the relationship between a material and its crystal system. - MATERIAL_LATTICE : str - Represents the relationship between a material and its lattice structure. - MATERIAL_SITE : str - Represents the relationship between a material and its atomic sites. - MATERIAL_CHEMENV : str - Represents the relationship between a material and its chemical environment. - MATERIAL_CHEMENV_ELEMENT : str - Represents the relationship between a material's chemical environment and its elements. - MATERIAL_ELEMENT : str - Represents the relationship between a material and its constituent elements. - ELEMENT_OXIDATION_STATE : str - Represents the relationship between an element and its possible oxidation states. - ELEMENT_CHEMENV : str - Represents the relationship between an element and its chemical environment. - ELEMENT_GEOMETRIC_CONNECTS_ELEMENT : str - Represents the geometric connection between two elements. - ELEMENT_ELECTRIC_CONNECTS_ELEMENT : str - Represents the electric connection between two elements. - ELEMENT_GEOMETRIC_ELECTRIC_CONNECTS_ELEMENT : str - Represents both geometric and electric connections between two elements. - ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT : str - Represents connections between elements based on group and period similarities in the periodic table. - CHEMENV_GEOMETRIC_CONNECTS_CHEMENV : str - Represents the geometric connection between two chemical environments. - CHEMENV_ELECTRIC_CONNECTS_CHEMENV : str - Represents the electric connection between two chemical environments. - CHEMENV_GEOMETRIC_ELECTRIC_CONNECTS_CHEMENV : str - Represents both geometric and electric connections between two chemical environments. - """ - MATERIAL_SPG=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.SPACE_GROUP.value}' - MATERIAL_CRYSTAL_SYSTEM=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.CRYSTAL_SYSTEM.value}' - MATERIAL_LATTICE=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.LATTICE.value}' - MATERIAL_SITE=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.SITE.value}' - - MATERIAL_CHEMENV=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.CHEMENV.value}' - MATERIAL_CHEMENV_ELEMENT=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.CHEMENV_ELEMENT.value}' - MATERIAL_ELEMENT=f'{NodeTypes.MATERIAL.value}-HAS-{NodeTypes.ELEMENT.value}' - - ELEMENT_OXIDATION_STATE=f'{NodeTypes.ELEMENT.value}-CAN_OCCUR-{NodeTypes.OXIDATION_STATE.value}' - ELEMENT_CHEMENV=f'{NodeTypes.ELEMENT.value}-CAN_OCCUR-{NodeTypes.CHEMENV.value}' - - ELEMENT_GEOMETRIC_CONNECTS_ELEMENT=f'{NodeTypes.ELEMENT.value}-GEOMETRIC_CONNECTS-{NodeTypes.ELEMENT.value}' - ELEMENT_ELECTRIC_CONNECTS_ELEMENT=f'{NodeTypes.ELEMENT.value}-ELECTRIC_CONNECTS-{NodeTypes.ELEMENT.value}' - ELEMENT_GEOMETRIC_ELECTRIC_CONNECTS_ELEMENT=f'{NodeTypes.ELEMENT.value}-GEOMETRIC_ELECTRIC_CONNECTS-{NodeTypes.ELEMENT.value}' - ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT=f'{NodeTypes.ELEMENT.value}-GROUP_PERIOD_CONNECTS-{NodeTypes.ELEMENT.value}' - - CHEMENV_GEOMETRIC_CONNECTS_CHEMENV=f'{NodeTypes.CHEMENV.value}-GEOMETRIC_CONNECTS-{NodeTypes.CHEMENV.value}' - CHEMENV_ELECTRIC_CONNECTS_CHEMENV=f'{NodeTypes.CHEMENV.value}-ELECTRIC_CONNECTS-{NodeTypes.CHEMENV.value}' - CHEMENV_GEOMETRIC_ELECTRIC_CONNECTS_CHEMENV=f'{NodeTypes.CHEMENV.value}-GEOMETRIC_ELECTRIC_CONNECTS-{NodeTypes.CHEMENV.value}' - -t_string=pa.string() -t_int=pa.int64() -t_float=pa.float64() -t_bool=pa.bool_() - - -############################################################################################################################## -# Elements - -# Create parquet schem from the column names above -element_property_schema_list = [ - pa.field('long_name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('symbol', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('abundance_universe', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('abundance_solar', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('abundance_meteor', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('abundance_crust', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('abundance_ocean', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('abundance_human', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('adiabatic_index', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('allotropes', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('appearance', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('atomic_mass', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('atomic_number', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('block', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('boiling_point', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('classifications_cas_number', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('classifications_cid_number', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('classifications_rtecs_number', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('classifications_dot_numbers', t_string, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('classifications_dot_hazard_class', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('conductivity_thermal', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('cpk_hex', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('critical_pressure', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('critical_temperature', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('crystal_structure', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('density_stp', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('discovered_year', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('discovered_by', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('discovered_location', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('electron_affinity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('electron_configuration', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('electron_configuration_semantic', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('electronegativity_pauling', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('energy_levels', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('gas_phase', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('group', t_int, metadata={'encoder':'IntegerOneHotEncoder(dtype=torch.float32)'}), - pa.field('extended_group', t_int, metadata={'encoder':'IntegerOneHotEncoder(dtype=torch.float32)'}), - pa.field('half_life', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('heat_specific', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('heat_vaporization', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('heat_fusion', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('heat_molar', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('isotopes_known', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('isotopes_stable', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('isotopic_abundances', t_string, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('lattice_angles', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('lattice_constants', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('lifetime', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('magnetic_susceptibility_mass', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('magnetic_susceptibility_molar', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('magnetic_susceptibility_volume', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('oxidation_states', pa.list_(t_float), metadata={'encoder':'OxidationStatesEncoder(dtype=torch.float32)'}), - pa.field('magnetic_type', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('melting_point', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('molar_volume', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('neutron_cross_section', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('neutron_mass_absorption', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('period', t_int, metadata={'encoder':'IntegerOneHotEncoder(dtype=torch.float32)'}), - pa.field('phase', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('quantum_numbers', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('radius_calculated', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('radius_empirical', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('radius_covalent', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('radius_vanderwaals', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('refractive_index', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('series', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('source', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('space_group_name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('space_group_number', t_int, metadata={'encoder':'SpaceGroupOneHotEncoder(dtype=torch.float32)'}), - pa.field('speed_of_sound', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('summary', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('valence_electrons', t_int, metadata={'encoder':'IntegerOneHotEncoder(dtype=torch.float32)'}), - pa.field('conductivity_electric', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('electrical_resistivity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('electrical_type', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('modulus_bulk', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('modulus_shear', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('modulus_young', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('poisson_ratio', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('coefficient_of_linear_thermal_expansion', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('hardness_vickers', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('hardness_brinell', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('hardness_mohs', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('superconduction_temperature', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('is_actinoid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_alkali', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_alkaline', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_chalcogen', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_halogen', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_lanthanoid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_metalloid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_noble_gas', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_post_transition_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_quadrupolar', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('is_rare_earth_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), - pa.field('experimental_oxidation_states', pa.list_(t_int), metadata={'encoder':'OxidationStatesEncoder(dtype=torch.float32)'}), - pa.field('ionization_energies', pa.list_(t_float), metadata={'encoder':'IonizationEnergiesEncoder(dtype=torch.float32)'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'}) - ] - -############################################################################################################################## -# chemenv - -chemenv_property_schema_list = [ - pa.field('chemenv_name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('coordination', t_int, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - - -############################################################################################################################## -# Crystal System - -crystal_system_property_schema_list = [ - pa.field('crystal_system', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - -############################################################################################################################## -# Lattice - -lattice_property_schema_list = [ - pa.field('lattice', pa.list_(pa.list_(t_float))), - pa.field('a', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('b', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('c', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('alpha', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('beta', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('gamma', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('volume', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - -############################################################################################################################## -# Sites - -site_property_schema_list = [ - pa.field('species', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('frac_coords', pa.list_(t_float), metadata={'encoder':'ListIdentityEncoder(dtype=torch.float32)'}), - pa.field('lattice', pa.list_(pa.list_(t_float))), - pa.field('material_id', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - -############################################################################################################################## -# Magnetic States - -magnetic_state_property_schema_list = [ - pa.field('magnetic_state', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - -############################################################################################################################## -# Oxidation States - -oxidation_state_property_schema_list = [ - pa.field('oxidation_state', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - -############################################################################################################################## -# SPG - -spg_property_schema_list = [ - pa.field('spg', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - - -############################################################################################################################## -# spg_wyckoff - -spg_wyckoff_property_schema_list = [ - pa.field('spg_wyckoff', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - - - -############################################################################################################################## -# relationship schema - - -def get_relationship_schema(relationship_type:RelationshipTypes): - """ - Generates a schema for a given relationship type in a material science graph, defining the fields that - describe the relationship between two nodes. The schema is used to specify the properties of the relationship - such as the start node, end node, type of relationship, and the weight of the connection. - - Parameters - ---------- - relationship_type : RelationshipTypes - An instance of the `RelationshipTypes` Enum that defines the specific relationship between two nodes - in the graph. - - Returns - ------- - pyarrow.Schema - A PyArrow schema object that defines the fields for the relationship properties, including: - - `START_ID`: The ID of the source node in the relationship. - - `END_ID`: The ID of the target node in the relationship. - - `TYPE`: The type of the relationship (encoded as a classification). - - `weight`: The weight of the relationship (encoded as an integer). - - Example - ------- - To generate a schema for a specific relationship type: - - >>> schema = get_relationship_schema(RelationshipTypes.MATERIAL_SPG) - >>> print(schema) - """ - if not isinstance(relationship_type,RelationshipTypes): - raise ValueError("relationship_type must be an instance of RelationshipTypes.{}") - node_a_name,connection_type,node_b_name=relationship_type.value.split('-') - - relationship_property_schema_list = [ - pa.field(f'{node_a_name}-START_ID', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field(f'{node_b_name}-END_ID', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), - pa.field('TYPE', t_string, metadata={'encoder':'ClassificationEncoder()'}), - pa.field('weight', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'})] - - return pa.schema(relationship_property_schema_list) - -############################################################################################################################## - -LATTICE_PARQUET_SCHEMA = pa.schema(lattice_property_schema_list) -SITE_PARQUET_SCHEMA = pa.schema(site_property_schema_list) -ELEMENT_PARQUET_SCHEMA = pa.schema(element_property_schema_list) -CHEMENV_PARQUET_SCHEMA = pa.schema(chemenv_property_schema_list) -MAGNETIC_STATE_PARQUET_SCHEMA = pa.schema(magnetic_state_property_schema_list) -CRYSTAL_SYSTEM_PARQUET_SCHEMA = pa.schema(crystal_system_property_schema_list) -OXIDATION_STATE_PARQUET_SCHEMA = pa.schema(oxidation_state_property_schema_list) -SPG_PARQUET_SCHEMA = pa.schema(spg_property_schema_list) -SPG_WYCKOFF_PARQUET_SCHEMA = pa.schema(spg_wyckoff_property_schema_list) - - -def get_node_schema(node_type:NodeTypes): - """ - Retrieves the schema for a given node type in a material science graph, where each node type corresponds to - a specific category of data in the material science domain. The schema defines the structure and properties - of the node data for that specific node type. - - Parameters - ---------- - node_type : NodeTypes - An instance of the `NodeTypes` Enum that specifies the type of node for which the schema is requested. - This parameter determines which pre-defined schema will be returned for the corresponding node type. - - Returns - ------- - pyarrow.Schema - A PyArrow schema object that defines the fields for the node data, corresponding to the given `node_type`. - The returned schema can be one of the following based on the node type: - - `MATERIAL_PARQUET_SCHEMA`: For material-related nodes. - - `ELEMENT_PARQUET_SCHEMA`: For chemical elements or pre-imputed elements. - - `CHEMENV_PARQUET_SCHEMA`: For chemical environment nodes. - - `CRYSTAL_SYSTEM_PARQUET_SCHEMA`: For crystal system nodes. - - `MAGNETIC_STATE_PARQUET_SCHEMA`: For magnetic state nodes. - - `SPG_PARQUET_SCHEMA`: For space group nodes. - - `OXIDATION_STATE_PARQUET_SCHEMA`: For oxidation state nodes. - - `SPG_WYCKOFF_PARQUET_SCHEMA`: For Wyckoff position nodes. - - `LATTICE_PARQUET_SCHEMA`: For lattice structure nodes. - - `SITE_PARQUET_SCHEMA`: For atomic site nodes. - - Example - ------- - To get the schema for a specific node type: - - >>> schema = get_node_schema(NodeTypes.MATERIAL) - >>> print(schema) - - If an unsupported node type is passed: - - >>> get_node_schema("invalid_node") - ValueError: node_type must be an instance of NodeTypes. - """ - if not isinstance(node_type,NodeTypes): - raise ValueError("node_type must be an instance of NodeTypes.{}") - node_name=node_type.value.lower() - - if node_name=='material': - schema=MATERIAL_PARQUET_SCHEMA - elif node_name=='element' or node_name=='pre_imputed_element': - schema=ELEMENT_PARQUET_SCHEMA - elif node_name=='chemenv': - schema=CHEMENV_PARQUET_SCHEMA - elif node_name=='crystal_system': - schema=CRYSTAL_SYSTEM_PARQUET_SCHEMA - elif node_name=='magnetic_state': - schema=MAGNETIC_STATE_PARQUET_SCHEMA - elif node_name=='space_group': - schema=SPG_PARQUET_SCHEMA - elif node_name=='oxidation_state': - schema=OXIDATION_STATE_PARQUET_SCHEMA - elif node_name=='spg_wyckoff': - schema=SPG_WYCKOFF_PARQUET_SCHEMA - elif node_name=='lattice': - schema=LATTICE_PARQUET_SCHEMA - elif node_name=='site': - schema=SITE_PARQUET_SCHEMA - - return schema - - - - - - -# element_property_schema_list = [ -# pa.field('group', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('row', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('Z', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('symbol', t_string, metadata={'encoder':'ClassificationEncoder()'}), -# pa.field('long_name', t_string, metadata={'encoder':'ClassificationEncoder()'}), -# pa.field('A', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('atomic_radius_calculated', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('van_der_waals_radius', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('mendeleev_no', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('electrical_resistivity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('velocity_of_sound', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('reflectivity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('refractive_index', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('poissons_ratio', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('molar_volume', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('thermal_conductivity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('boiling_point', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('melting_point', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('critical_temperature', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('superconduction_temperature', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('liquid_range', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('bulk_modulus', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('youngs_modulus', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('rigidity_modulus', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('vickers_hardness', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('density_of_solid', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('coefficient_of_linear_thermal_expansion', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('ionization_energies', pa.list_(t_float), metadata={'encoder':'ListIdentityEncoder(dtype=torch.float32)'}), -# pa.field('block', t_string, metadata={'encoder':'ClassificationEncoder()'}), -# pa.field('common_oxidation_states', pa.list_(t_int), metadata={'encoder':'ListIdentityEncoder(dtype=torch.float32)'}), -# pa.field('electron_affinity', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('X', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('atomic_mass', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('atomic_mass_number', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('atomic_radius', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('average_anionic_radius', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('average_cationic_radius', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('average_ionic_radius', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('ground_state_term_symbol', t_string, metadata={'encoder':'ClassificationEncoder()'}), - -# pa.field('icsd_oxidation_states', pa.list_(t_int), metadata={'encoder':'IdentityEncoder(dtype=torch.int64)'}), -# pa.field('is_actinoid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_alkali', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_alkaline', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_chalcogen', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_halogen', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_lanthanoid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_metalloid', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_noble_gas', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_post_transition_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_quadrupolar', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_rare_earth', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_rare_earth_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('is_transition_metal', t_bool, metadata={'encoder':'BooleanEncoder(dtype=torch.int64)'}), -# pa.field('iupac_ordering', t_float, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('max_oxidation_state', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('min_oxidation_state', t_int, metadata={'encoder':'IdentityEncoder(dtype=torch.float32)'}), -# pa.field('oxidation_states', pa.list_(t_int), metadata={'encoder':'ListIdentityEncoder(dtype=torch.int64)'}), -# pa.field('valence', pa.list_(t_int), metadata={'encoder':'ListIdentityEncoder(dtype=torch.int64)'}), -# pa.field('name', t_string, metadata={'encoder':'ClassificationEncoder()'}), -# pa.field('type', t_string, metadata={'encoder':'ClassificationEncoder()'})] - diff --git a/matgraphdb/graph_kit/neo4j/README.md b/matgraphdb/graph_kit/neo4j/README.md deleted file mode 100644 index 639eaf4..0000000 --- a/matgraphdb/graph_kit/neo4j/README.md +++ /dev/null @@ -1,3 +0,0 @@ -This modulue is used to control the neo4j database. The user, password, location, and database can be defined in __init__.py - -To switch between dev you must activate the dev dbms in Neo4j Desktop. Once you have made changes to dev dbms deactivate the dev dbms, then clone the dbms, and the remove the dev from the name for it to become the production db. diff --git a/matgraphdb/graph_kit/neo4j/__init__.py b/matgraphdb/graph_kit/neo4j/__init__.py deleted file mode 100644 index 0c888b9..0000000 --- a/matgraphdb/graph_kit/neo4j/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from matgraphdb.graph_kit.neo4j.neo4j_gds_manager import Neo4jGDSManager -from matgraphdb.graph_kit.neo4j.neo4j_manager import Neo4jManager \ No newline at end of file diff --git a/matgraphdb/graph_kit/neo4j/neo4j_analysis.py b/matgraphdb/graph_kit/neo4j/neo4j_analysis.py deleted file mode 100644 index 074712f..0000000 --- a/matgraphdb/graph_kit/neo4j/neo4j_analysis.py +++ /dev/null @@ -1,55 +0,0 @@ -import math - -from matgraphdb.graph_kit.neo4j.neo4j_manager import Neo4jManager -from matgraphdb.graph_kit.neo4j.neo4j_gds_manager import Neo4jGDSManager -from matgraphdb.graph_kit.neo4j.utils import format_dictionary - -def shannon_entropy(probabilities): - """ - Calculates the Shannon entropy of a list of probabilities. - - Args: - probabilities (list): A list of probabilities. - - Returns: - float: The Shannon entropy. - """ - probabilities = [p for p in probabilities if p > 0] - total_probability = sum(probabilities) - entropy=0 - for p in probabilities: - entropy += - p * math.log(p, 2) - return entropy - -class Neo4jAnalyzer: - def __init__(self,neo4j_manager=Neo4jManager()): - - self.neo4j_manager=neo4j_manager - - def get_node_degrees(self,database_name): - with self.neo4j_manager as session: - results=session.query(f"MATCH (n) RETURN n.name AS name",database_name) - - name=results[0]['name'] - prop_dict={'name':name} - degrees_dict={} - node_count=len(results) - for result in results: - name=result['name'] - prop_dict={'name':name} - cypher_statement=f"MATCH (n {format_dictionary(prop_dict)})-[r]-() RETURN n.name, COUNT(r) AS degree" - degree_results=session.query(cypher_statement,database_name) - if len(degree_results)==0: - degree=0 - else: - degree=degree_results[0]['degree'] - - tmp_dict={'degree':degree,'degree_probability':degree/node_count} - degrees_dict[name]=tmp_dict - return degrees_dict - -if __name__=='__main__': - analyzer=Neo4jAnalyzer() - degrees_dict=analyzer.get_node_degrees('nelements-1-2') - degree_probabilities=[degrees_dict[name]['degree_probability'] for name in degrees_dict] - print("Shannon Entropy: ",shannon_entropy(degree_probabilities)) \ No newline at end of file diff --git a/matgraphdb/graph_kit/neo4j/neo4j_experiment_manager.py b/matgraphdb/graph_kit/neo4j/neo4j_experiment_manager.py deleted file mode 100644 index 130311c..0000000 --- a/matgraphdb/graph_kit/neo4j/neo4j_experiment_manager.py +++ /dev/null @@ -1,323 +0,0 @@ - - -from matgraphdb.utils import LOGGER -from matgraphdb.graph_kit.metadata import NodeTypes -from matgraphdb.graph_kit.neo4j.neo4j_manager import Neo4jManager -from matgraphdb.graph_kit.neo4j.neo4j_gds_manager import Neo4jGDSManager - - -class GraphProjection: - def __init__(self, - name, - node_projections, - relationship_projections, - write_property, - use_weights, - use_node_properties): - self.name = name - self.node_projections = node_projections - self.relationship_projections = relationship_projections - self.write_property = write_property - self.use_weights = use_weights - self.use_node_properties = use_node_properties - node_properties=[] - if use_node_properties: - for key,value in self.node_projections.items(): - if 'properties' in value: - node_properties.append(value['properties']) - self.node_properties=node_properties - - @classmethod - def ec_element_chemenv(cls, use_weights=True, use_node_properties=True,algorithm_name='fastRP'): - relationship_projections={ - "`CHEMENV-ELECTRIC_CONNECTS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`ELEMENT-ELECTRIC_CONNECTS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`CHEMENV-CAN_OCCUR-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'} - } - node_projections={ - "Chemenv": {"label":'Chemenv'}, - "Element": {"label":'Element', - "properties":{ - 'atomic_number':{'default_value':0.0}, - 'X':{'default_value':0.0}, - 'atomic_radius':{'default_value':0.0}, - 'group':{'default_value':0}, - 'row':{'default_value':0}, - 'atomic_mass':{'default_value':0.0} - } - }, - "Material":{"label":'Material'} - } - - write_property=f'{algorithm_name}-embedding-ec-element-chemenv' - - if use_weights: - write_property+='-weighted' - if use_node_properties: - write_property+='-node_properties' - return cls( - name="EC Element ChemEnv", - node_projections=node_projections, - relationship_projections=relationship_projections, - write_property=write_property, - use_weights=use_weights, - use_node_properties=use_node_properties - ) - - @classmethod - def gc_element_chemenv(cls, use_weights=True, use_node_properties=True,algorithm_name='fastRP'): - relationship_projections={ - "`CHEMENV-GEOMETRIC_CONNECTS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`ELEMENT-GEOMETRIC_CONNECTS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`CHEMENV-CAN_OCCUR-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - } - node_projections={ - "Chemenv": {"label":'Chemenv'}, - "Element": {"label":'Element', - "properties":{ - 'atomic_number':{'default_value':0.0}, - 'X':{'default_value':0.0}, - 'atomic_radius':{'default_value':0.0}, - 'group':{'default_value':0}, - 'row':{'default_value':0}, - 'atomic_mass':{'default_value':0.0} - } - }, - "Material":{"label":'Material'} - } - write_property=f'{algorithm_name}-embedding-gc-element-chemenv' - - if use_weights: - write_property+='-weighted' - if use_node_properties: - write_property+='-node_properties' - - - return cls( - name="GC Element ChemEnv", - node_projections=node_projections, - relationship_projections=relationship_projections, - write_property=write_property, - use_weights=use_weights, - use_node_properties=use_node_properties - ) - - @classmethod - def gec_element_chemenv(cls, use_weights=True, use_node_properties=True,algorithm_name='fastRP'): - relationship_projections={ - "`CHEMENV-GEOMETRIC_ELECTRIC_CONNECTS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`ELEMENT-GEOMETRIC_ELECTRIC_CONNECTS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`CHEMENV-CAN_OCCUR-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-ELEMENT`": {"orientation": 'UNDIRECTED', "properties": 'weight'}, - "`MATERIAL-HAS-CHEMENV`": {"orientation": 'UNDIRECTED', "properties": 'weight'} - } - node_projections={ - "Chemenv": {"label":'Chemenv'}, - "Element": {"label":'Element', - "properties":{ - 'atomic_number':{'default_value':0.0}, - 'X':{'default_value':0.0}, - 'atomic_radius':{'default_value':0.0}, - 'group':{'default_value':0}, - 'row':{'default_value':0}, - 'atomic_mass':{'default_value':0.0} - } - }, - "Material":{"label":'Material'} - } - write_property=f'{algorithm_name}-embedding-ec-element-chemenv' - - if use_weights: - write_property+='-weighted' - if use_node_properties: - write_property+='-node_properties' - - return cls( - name="EC Element ChemEnv", - node_projections=node_projections, - relationship_projections=relationship_projections, - write_property=write_property, - use_weights=use_weights, - use_node_properties=use_node_properties - ) - - @classmethod - def gec_element_chemenv_material_properties(cls, use_weights=True, use_node_properties=True): - relationship_projections={ - "`CHEMENV-ELECTRIC_CONNECTS-CHEMENV`": {"orientation": 'UNDIRECTED'}, - "`ELEMENT-ELECTRIC_CONNECTS-ELEMENT`": {"orientation": 'UNDIRECTED'}, - "`CHEMENV-CAN_OCCUR-ELEMENT`": {"orientation": 'UNDIRECTED'}, - "`MATERIAL-HAS-ELEMENT`": {"orientation": 'UNDIRECTED'}, - "`MATERIAL-HAS-CHEMENV`": {"orientation": 'UNDIRECTED'} - } - node_projections={ - "Chemenv": {"label":'Chemenv'}, - "Element": {"label":'Element' , - "properties":{ - 'atomic_number':{'default_value':0.0}, - 'X':{'default_value':0.0}, - 'atomic_radius':{'default_value':0.0}, - 'group':{'default_value':0.0}, - 'row':{'default_value':0.0}, - 'atomic_mass':{'default_value':0.0} - } - }, - "Material":{"label":'Material', "properties":['band_gap','formation_energy_per_atom','energy_per_atom','energy_above_hull','k_vrh','g_vrh']} - } - write_property='fastrp-embedding-gec-element-chemenv' - - if use_weights: - for key, value in relationship_projections.items(): - relationship_projections[key]['properties'] = 'weight' - write_property+='-weighted' - - if use_node_properties: - write_property+='-node_properties' - - return cls( - name="GEC Element ChemEnv", - node_projections=node_projections, - relationship_projections=relationship_projections, - write_property=write_property, - use_weights=use_weights, - use_node_properties=use_node_properties - ) - - def get_config(self): - return { - "node_projections": self.node_projections, - "relationship_projections": self.relationship_projections, - "write_property": self.write_property, - "use_weights": self.use_weights, - "use_node_properties": self.use_node_properties - } - - -# ['nsites','nelements','volume','density','density_atomic','composition_reduced','formula_pretty' - # 'e_electronic','e_ionic','e_total','energy_per_atom','energy_above_hull','formation_energy_per_atom', - # 'band_gap','vbm','cbm','efermi', - # 'crystal_system','space_group','point_group','hall_symbol' - # 'is_gap_direct','is_metal','is_magnetic','is_stable', - # 'ordering','total_magnetization','total_magnetization_normalized_vol','num_magnetic_sites','num_unique_magnetic_sites' - # 'g_ruess','g_voigt','g_vrh','k_reuss','k_voigt','k_vrh','homogeneous_poisson','universal_anisotropy'] - -class Neo4jExperimentManager(): - def __init__(self): - pass - - def run_fastRP_algorithm(self, - database_name, - graph_name, - node_projections, - relationship_projections, - algorithm_params): - with Neo4jManager() as neo4j_manager: - manager=Neo4jGDSManager(neo4j_manager) - - try: - manager.load_graph_into_memory(database_name=database_name, - graph_name=graph_name, - node_projections=node_projections, - relationship_projections=relationship_projections) - - print(manager.is_graph_in_memory(database_name,graph_name)) - results=manager.run_fastRP_algorithm(database_name=database_name, - graph_name=graph_name, - **algorithm_params) - manager.drop_graph(database_name=database_name,graph_name=graph_name) - print(manager.is_graph_in_memory(database_name,graph_name)) - return results - except Exception as e: - print(f"Database {database_name} : {e}") - pass - - def run_hashGNN_algorithm(self, - database_name, - graph_name, - node_projections, - relationship_projections, - algorithm_params): - with Neo4jManager() as neo4j_manager: - manager=Neo4jGDSManager(neo4j_manager) - node_property=algorithm_params['mutate_property'] - try: - manager.load_graph_into_memory(database_name=database_name, - graph_name=graph_name, - node_projections=node_projections, - relationship_projections=relationship_projections) - - print(manager.is_graph_in_memory(database_name,graph_name)) - results=manager.run_hashGNN_algorithm(database_name=database_name, - graph_name=graph_name, - **algorithm_params) - - manager.write_graph(database_name=database_name, - graph_name=graph_name, - node_properties=node_property, - concurrency=4) - manager.drop_graph(database_name=database_name,graph_name=graph_name) - - - print(manager.is_graph_in_memory(database_name,graph_name)) - return results - except Exception as e: - print(f"Database {database_name} : {e}") - pass - - def generate_all_hashGNN(self, - database_names): - use_weights_list=[True,False] - use_node_properties_list=[True,False] - - for database_name in database_names: - LOGGER.info(f"Generating node embeddings for {database_name}") - - - - gc_graph_projection=GraphProjection.gc_element_chemenv(use_weights=False, - use_node_properties=False, - algorithm_name='hashGNN') - ec_graph_projection=GraphProjection.ec_element_chemenv(use_weights=False, - use_node_properties=False, - algorithm_name='hashGNN') - gec_graph_projection=GraphProjection.gec_element_chemenv(use_weights=False, - use_node_properties=False, - algorithm_name='hashGNN') - projections=[gc_graph_projection,ec_graph_projection,gec_graph_projection] - for graph_projection in projections: - - LOGGER.info(f"Using {graph_projection.name} with {False} weights and {False} node properties") - - relationship_projections=graph_projection.relationship_projections - node_projections=graph_projection.node_projections - write_property=graph_projection.write_property - - algorithm_params={'algorithm_mode': 'mutate', - 'embedding_density':128, - 'iterations':3, - 'heterogeneous':True, - 'mutate_property': write_property} - # if use_node_properties: - # algorithm_params['feature_properties']=graph_projection.node_properties - # algorithm_params['binarize_features']={'dimension': 128, 'densityLevel': 2} - # else: - algorithm_params['generate_features']={'dimension': 128, 'densityLevel': 2} - - for key,value in algorithm_params.items(): - LOGGER.info(f"Algorithm param: {key} : {value}") - - self.run_hashGNN_algorithm(database_name=database_name, - graph_name='main', - node_projections=node_projections, - relationship_projections=relationship_projections, - algorithm_params=algorithm_params) - - # self.run_fastRP_algorithm(database_name=database_name, - # graph_name='main', - # node_projections=node_projections, - # relationship_projections=relationship_projections, - # algorithm_params=algorithm_params) \ No newline at end of file diff --git a/matgraphdb/graph_kit/neo4j/neo4j_gds_manager.py b/matgraphdb/graph_kit/neo4j/neo4j_gds_manager.py deleted file mode 100644 index 2e5d018..0000000 --- a/matgraphdb/graph_kit/neo4j/neo4j_gds_manager.py +++ /dev/null @@ -1,2749 +0,0 @@ -import os -import json -from typing import List, Tuple, Union -from glob import glob - -from neo4j import GraphDatabase -import pandas as pd - -from matgraphdb.utils import (PASSWORD, USER, LOCATION, DBMSS_DIR, GRAPH_DIR, LOGGER,MP_DIR) -from matgraphdb.utils.general_utils import get_os -from matgraphdb.graph_kit.neo4j.utils import format_projection, format_dictionary, format_list, format_string - - -class Neo4jGDSManager: - - def __init__(self, neo4j_manager): - self.neo4j_manager = neo4j_manager - self.algorithm_modes=['stream','stats','write','mutate'] - self.link_prediction_algorithms=['adamicAdar','commonNeighbors','preferentialAttachment', - 'resourceAllocation','sameCommunity','totalNeighbors'] - if self.neo4j_manager.driver is None: - raise Exception("Graph database is not connected. Please ccreate a driver") - - def list_graphs(self,database_name): - """ - Lists the graphs in a database. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-list/ - - Args: - database_name (str): The name of the database. - Returns: - list: A list of the graphs in the database. - """ - cypher_statement=f""" - CALL gds.graph.list() - YIELD graphName - RETURN graphName; - """ - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - graph_names=[result['graphName'] for result in results] - return graph_names - - def is_graph_in_memory(self,database_name,graph_name): - """ - Checks if the graph exists in memory. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-list/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - - Returns: - bool: True if the graph exists in memory, False otherwise. - """ - - cypher_statement=f""" - CALL gds.graph.list("{graph_name}") - YIELD graphName - RETURN graphName; - """ - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - if len(results)!=0: - return True - return False - - def get_graph_info(self,database_name,graph_name): - """ - Gets the graph information. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-info/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - - Returns: - dict: The graph information. - """ - - cypher_statement=f"CALL gds.graph.list(\"{graph_name}\")" - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={ key:value for key, value in result.items()} - outputs.append(output) - return outputs - - def drop_graph(self,database_name,graph_name): - """ - Drops a graph from a database. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-drop/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - - Returns: - None - """ - cypher_statement=f""" - CALL gds.graph.drop("{graph_name}") - """ - self.neo4j_manager.query(cypher_statement,database_name=database_name) - return None - - def list_graph_data_science_algorithms(self,database_name, save=False): - """ - Lists the algorithms in a database. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-list-algorithms/ - - Args: - database_name (str): The name of the database. - - Returns: - list: A list of the algorithms in the database. - """ - cypher_statement=f""" - CALL gds.list() - YIELD name,description - RETURN name,description; - """ - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - algorithm_names={result['name']:result['description'] for result in results} - if save: - print("Saving algorithms to : ",os.path.join(MP_DIR,'neo4j_graph_data_science_algorithms.txt')) - with open(os.path.join(MP_DIR,'neo4j_graph_data_science_algorithms.json'), "w") as file: - json.dump(algorithm_names, file, indent=4) - # Decide on a fixed width for the name column - with open(os.path.join(MP_DIR,'neo4j_graph_data_science_algorithms.txt'), "w") as file: - fmt = '{0:75s}{1:200s}\n' - for result in results: - file.write(fmt.format(result['name'],result['description'])) - file.write('_'*300) - file.write('\n') - return algorithm_names - - def get_graph_data_science_algorithms_docs_url(self): - """ - Lists the algorithms in a database. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-list-algorithms/ - - Args: - database_name (str): The name of the database. - - Returns: - list: A list of the algorithms in the database. - """ - website_url=f"https://neo4j.com/docs/graph-data-science/current/algorithms/" - return website_url - - def load_graph_into_memory(self,database_name:str, - graph_name:str, - node_projections:Union[str,List,dict], - relationship_projections:Union[str,List,dict], - config:dict=None): - """ - Loads a graph into memory. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/graph-project/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - nodes_names (list): A list of node names. - relationships_names (list): A list of relationship names. - - Returns: - None - """ - if self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - LOGGER.info(f"Graph {graph_name} already in memory") - return None - formated_node_str=format_projection(projections=node_projections) - formated_relationship_str=format_projection(projections=relationship_projections) - - cypher_statement=f"""CALL gds.graph.project( - "{graph_name}", - {formated_node_str}, - {formated_relationship_str}""" - - if config: - cypher_statement+=", " - cypher_statement+=format_dictionary(projections=config) - cypher_statement+=")" - - self.neo4j_manager.query(cypher_statement,database_name=database_name) - return None - - def read_graph(self, - database_name:str, - graph_name:str, - node_properties:Union[str,List[str]], - node_labels:str=None, - concurrency:int=4, - list_node_labels:bool=False, - ): - """ - Reads a graph from memory. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/read-graph/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - node_properties (str): The name of the node properties. - algorithm_mode (str, optional): The algorithm mode. Defaults to 'stream'. - node_labels (str, optional): The name of the node labels. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - list_node_labels (bool, optional): The list node labels. Defaults to False. - - Returns: - None - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - LOGGER.info(f"Graph {graph_name} not in memory") - return None - config={} - config['concurrency']=concurrency - config['listNodeLabels']=list_node_labels - - cypher_statement=f"CALL gds.graph.nodeProperties.stream(" - cypher_statement+=f"{format_string(graph_name)}," - if isinstance(node_properties,str): - cypher_statement+=f"{format_string(node_properties)}" - elif isinstance(node_properties,list): - cypher_statement+=f"{format_list(node_properties)}" - if node_labels: - if isinstance(node_labels,str): - cypher_statement+=f", {format_string(node_labels)}" - elif isinstance(node_labels,list): - cypher_statement+=f", {format_list(node_labels)}" - cypher_statement+=f")" - print(cypher_statement) - results=self.neo4j_manager.query(cypher_statement,database_name=database_name) - return results - - def write_graph(self, - database_name:str, - graph_name:str, - node_properties:Union[str,List,dict]=None, - node_labels:Union[str,List[str]]=None, - relationship_type:str=None, - relationship_properties:Union[str,List[str]]=None, - concurrency:int=4): - """ - Write graph to neo4j database. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/graph-project/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - node_properties (str): The node properties. - node_labels (str): The node labels. - relationship_properties (str): The relationship properties. - relationship_labels (str): The relationship labels. - concurrency (int, optional): The concurrency. Defaults to 4. - - Returns: - list: Returns results of the algorithm. - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - raise Exception(f"Graph {graph_name} is not in memory") - node_outputs=[] - relationship_outputs=[] - - config={} - config['concurrency']=concurrency - config['writeConcurrency']=concurrency - if node_properties: - cypher_statement=f"CALL gds.graph.nodeProperties.write(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_projection(node_properties)}," - if node_labels: - cypher_statement+=f"{format_list(node_labels)}," - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+=")" - - results=self.neo4j_manager.query(cypher_statement,database_name=database_name) - for result in results: - node_output={property_name:property_value for property_name,property_value in result.items()} - node_outputs.append(node_output) - - if relationship_type: - cypher_statement=f"CALL gds.graph.relationshipProperties.write(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_string(relationship_type)}," - if relationship_properties: - cypher_statement+=f"{format_list(relationship_properties)}," - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+=")" - - results=self.neo4j_manager.query(cypher_statement,database_name=database_name) - for result in results: - relationship_output={property_name:property_value for property_name,property_value in result.items()} - relationship_outputs.append(relationship_output) - - return node_outputs,relationship_outputs - - def export_graph(self, - database_name:str, - graph_name:str, - db_name:str, - concurrency:int=4, - batch_size:int=100, - default_relationship_type:str='__ALL__', - additional_node_properties:Union[str,List,dict]=None, - ): - """ - Export graph to neo4j database. - - Useful urls. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/graph-project/ - - Useful scenarios. - - Avoid heavy write load on the operational system by exporting the data instead of writing back. - - Create an analytical view of the operational system that can be used as a basis for running algorithms. - - Produce snapshots of analytical results and persistent them for archiving and inspection. - - Share analytical results within the organization. - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - db_name (str): The name of the database. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - batch_size (int, optional): The batch size. Defaults to 100. - default_relationship_type (str, optional): The default relationship type. Defaults to '__ALL__'. - additional_node_properties (str, optional): The additional node properties. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - raise Exception(f"Graph {graph_name} is not in memory") - if db_name not in self.neo4j_manager.list_databases(): - raise Exception(f"Database {db_name} does not exist") - - config={} - config['dbName']=db_name - config['concurrency']=concurrency - config['writeConcurrency']=concurrency - config['batchSize']=batch_size - config['defaultRelationshipType']=default_relationship_type - if additional_node_properties: - config['additionalNodeProperties']=format_projection(additional_node_properties) - cypher_statement=f"CALL gds.graph.export(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+=")" - - - - results=self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={property_name:property_value for property_name,property_value in result.items()} - outputs.append(output) - return outputs - - def export_graph_csv(self, - database_name:str, - graph_name:str, - export_name:str, - concurrency:int=4, - default_relationship_type:str='__ALL__', - additional_node_properties:Union[str,List,dict]=None, - use_label_mapping:bool=False, - sampling_factor:float=0.001, - estimate_memory:bool=False, - ): - """ - Export graph to neo4j database. - - Useful urls. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/graph-project/ - - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - export_name (str): The export name. Absolute directory path is required. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - default_relationship_type (str, optional): The default relationship type. Defaults to '__ALL__'. - additional_node_properties (str, optional): The additional node properties. Defaults to None. - use_label_mapping (bool, optional): The use label mapping. Defaults to False. - - Returns: - list: Returns results of the algorithm. - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - raise Exception(f"Graph {graph_name} is not in memory") - if export_name is None: - raise Exception("export_name must be provided") - - config={} - if estimate_memory: - config['exportName']=export_name - config['samplingFactor']=sampling_factor - config['writeConcurrency']=concurrency - config['defaultRelationshipType']=default_relationship_type - cypher_statement=f"CALL gds.graph.export.csv.estimate(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+=")" - else: - config['exportName']=export_name - config['writeConcurrency']=concurrency - config['defaultRelationshipType']=default_relationship_type - if additional_node_properties: - config['additionalNodeProperties']=format_projection(additional_node_properties) - config['useLabelMapping']=use_label_mapping - - cypher_statement=f"CALL gds.graph.export.csv(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+=")" - - - results=self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={property_name:property_value for property_name,property_value in result.items()} - outputs.append(output) - return outputs - - def estimate_memeory_for_algorithm(self,database_name:str, - graph_name:str, - algorithm_name:str, - algorithm_mode:str='stream', - algorithm_config:dict=None): - """ - Estimates the memory required for a given algorithm. - f"https://neo4j.com/docs/graph-data-science/current/algorithms/" - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-estimate-memory/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_name (str): The name of the algorithm. - - Returns: - float: The estimated memory required for the algorithm. - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - raise Exception(f"Graph {graph_name} is not in memory") - cypher_statement=f"CALL gds.{algorithm_name}.{algorithm_mode}.estimate(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_dictionary(algorithm_config)}" - cypher_statement+=")" - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={property_name:property_value for property_name,property_value in result.items()} - outputs.append(output) - - return outputs - - def run_algorithm(self,database_name:str, - graph_name:str, - algorithm_name:str, - algorithm_mode:str='stream', - algorithm_config:dict=None): - """ - Estimates the memory required for a given algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/algorithms/ - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-estimate-memory/ - neo4j.com/docs/graph-data-science/current/operations-reference/algorithm-references/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_name (str): The name of the algorithm. - - Returns: - float: The estimated memory required for the algorithm. - """ - if not self.is_graph_in_memory(database_name=database_name,graph_name=graph_name): - raise Exception(f"Graph {graph_name} is not in memory") - cypher_statement=f"CALL gds.{algorithm_name}.{algorithm_mode}(" - cypher_statement+=f"{format_string(graph_name)}," - cypher_statement+=f"{format_dictionary(algorithm_config)}" - cypher_statement+=")" - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={property_name:property_value for property_name,property_value in result.items()} - outputs.append(output) - return outputs - - def run_fastRP_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str='stream', - embedding_dimension:int=128, - random_seed:int=42, - concurrency:int=4, - property_ratio:float=0.0, - feature_properties:List[str]=[], - iteration_weights:List[float]=[0.0,1.0,1.0], - node_self_influence:float=0.0, - normalization_strength:float=0.0, - relationship_weight_property:str=None, - write_property:str=None, - mutate_property:str=None, - ): - """ - Estimates the memory required for a given algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/node-embeddings/fastrp/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - embedding_dimension (int, optional): The dimension of the embedding. Defaults to 128. - random_seed (int, optional): The random seed. Defaults to 42. - concurrency (int, optional): The concurrency. Defaults to 4. - property_ratio (float, optional): The property ratio. Defaults to 0.0. - feature_properties (list, optional): The feature properties. Defaults to []. - iteration_weights (list, optional): The iteration weights. Defaults to [0.0,1.0,1.0]. - node_self_influence (float, optional): The node self influence. Defaults to 0.0. - normalization_strength (float, optional): The normalization strength. Defaults to 0.0. - relationship_weight_property (str, optional): The relationship weight property. Defaults to 'null'. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - if embedding_dimension is None: - raise Exception("Embedding_dimension must be provided") - algorithm_config={} - algorithm_config['embeddingDimension']=embedding_dimension - algorithm_config['randomSeed']=random_seed - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - algorithm_config['propertyRatio']=property_ratio - algorithm_config['featureProperties']=feature_properties - # algorithm_config['iterationWeights']=iteration_weights - algorithm_config['nodeSelfInfluence']=node_self_influence - algorithm_config['normalizationStrength']=normalization_strength - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(database_name=database_name, - graph_name=graph_name, - algorithm_name='fastRP', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_node2vec_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str='stream', - embedding_dimension:int=128, - embedding_initializer:str='NORMALIZED', - walk_length:int=80, - walks_per_node:int=10, - in_out_factor:float=1.0, - return_factor:float=1.0, - relationship_weight_property:str='null', - window_size:int=10, - negative_sample_rate:int=5, - positive_sampling_factor:float=0.001, - negative_sampling_exponent:float=0.75, - iterations:int=1, - initial_learning_rate:float=0.01, - min_learning_rate:float=0.0001, - walk_buffer_size:int=100, - random_seed:int=42, - concurrency:int=4, - write_property:str='n/a', - mutate_property:str=None, - ): - """ - Runs the node2vec algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/node-embeddings/node2vec/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - embedding_dimension (int, optional): The dimension of the embedding. Defaults to 128. - embedding_initializer (str, optional): The initializer of the embedding. Defaults to 'NORMALIZED'. - walk_length (int, optional): The length of the walks. Defaults to 80. - walks_per_node (int, optional): The number of walks per node. Defaults to 10. - in_out_factor (float, optional): The factor for the in-out ratio. Defaults to 1.0. - return_factor (float, optional): The factor for the return ratio. Defaults to 1.0. - relationship_weight_property (str, optional): The relationship weight property. Defaults to 'null'. - window_size (int, optional): The window size. Defaults to 10. - negative_sample_rate (int, optional): The negative sample rate. Defaults to 5. - positive_sampling_factor (float, optional): The positive sampling factor. Defaults to 0.001. - negative_sampling_exponent (float, optional): The negative sampling exponent. Defaults to 0.75. - iterations (int, optional): The number of iterations. Defaults to 1. - initial_learning_rate (float, optional): The initial learning rate. Defaults to 0.01. - min_learning_rate (float, optional): The minimum learning rate. Defaults to 0.0001. - walk_buffer_size (int, optional): The walk buffer size. Defaults to 100. - random_seed (int, optional): The random seed. Defaults to 42. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to 'n/a'. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - if embedding_dimension is None: - raise Exception("Embedding_dimension must be provided") - - algorithm_config={} - algorithm_config['embeddingDimension']=embedding_dimension - algorithm_config['embeddingInitializer']=embedding_initializer - algorithm_config['randomSeed']=random_seed - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - algorithm_config['walkLength']=walk_length - algorithm_config['walksPerNode']=walks_per_node - algorithm_config['inOutFactor']=in_out_factor - algorithm_config['returnFactor']=return_factor - - algorithm_config['windowSize']=window_size - algorithm_config['negativeSampleRate']=negative_sample_rate - algorithm_config['positiveSamplingFactor']=positive_sampling_factor - algorithm_config['negativeSamplingExponent']=negative_sampling_exponent - algorithm_config['iterations']=iterations - algorithm_config['initialLearningRate']=initial_learning_rate - algorithm_config['minLearningRate']=min_learning_rate - algorithm_config['walkBufferSize']=walk_buffer_size - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if write_property: - algorithm_config['write_property']=write_property - if mutate_property: - algorithm_config['mutate_property']=mutate_property - - results=self.run_algorithm( - database_name=database_name, - graph_name=graph_name, - algorithm_name='node2vec', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - - return results - - def run_hashGNN_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str='stream', - feature_properties:List[str]=[], - iterations:int=None, - embedding_density:int=None, - heterogeneous:bool=False, - neighbor_influence:float=1.0, - binarize_features:dict=None, - generate_features:dict=None, - output_dimension:int=None, - random_seed:int=42, - mutate_property:str=None): - """ - Runs the hashGNN algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/hashgnn/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - algorithm_config (dict, optional): The configuration of the algorithm. Defaults to None. - embedding_density (int, optional): The embedding density. Defaults to None. - heterogeneous (bool, optional): Whether the graph is heterogeneous. Defaults to False. - neighbor_influence (float, optional): The neighbor influence. Defaults to 1.0. - binarize_features (dict, optional): The binarize features. Defaults to None. - generate_features (dict, optional): The generate features. Defaults to None. - output_dimension (int, optional): The output dimension. Defaults to None. - random_seed (int, optional): The random seed. Defaults to 42. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - if iterations is None: - raise Exception("Iterations must be provided") - if embedding_density is None: - raise Exception("Embedding density must be provided") - if generate_features and feature_properties is not None: - raise Exception("Feature properties must be None when generate_features is given") - - algorithm_config={} - algorithm_config['featureProperties']=feature_properties - algorithm_config['iterations']=iterations - algorithm_config['embeddingDensity']=embedding_density - algorithm_config['heterogeneous']=heterogeneous - algorithm_config['neighborInfluence']=neighbor_influence - algorithm_config['randomSeed']=random_seed - if binarize_features: - algorithm_config['binarizeFeatures']=binarize_features - if generate_features: - algorithm_config['generateFeatures']=generate_features - if output_dimension: - algorithm_config['outputDimension']=output_dimension - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm( - database_name=database_name, - graph_name=graph_name, - algorithm_name='hashgnn', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_graphSAGE_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str='stream', - model_name:str=None, - concurrency:int=4, - batch_size:int=100, - mutate_property:str=None, - write_property:str=None - ): - """ - Runs the graphSAGE algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/graph-sage/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - model_name (str, optional): The model name. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - batch_size (int, optional): The batch size. Defaults to 100. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - if model_name is None: - raise Exception("Model name must be provided") - if write_property: - raise Exception("write_property must be None") - - algorithm_config={} - algorithm_config['modelName']=model_name - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - algorithm_config['batchSize']=batch_size - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='beta.graphSage', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_topological_link_prediction_algorithm(self, - database_name:str, - algorithm_name:str, - node_a_name:str, - node_a_type:str, - node_b_name:str, - node_b_type:str, - relationship_query:str=None, - direction:str='BOTH', - comunity_property:str='community', - ): - """ - Runs the topological link prediction algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/topological-link-prediction/ - - Args: - database_name (str): The name of the database. - node_a_name (str): The name of the node A. - node_a_type (str): The type of the node A. - node_b_name (str): The name of the node B. - node_b_type (str): The type of the node B. - relationship_query (str): The relationship query. - direction (str, optional): The direction. Defaults to 'BOTH'. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_name not in self.link_prediction_algorithms: - raise Exception(f"Algorithm name {algorithm_name} is not supported. Use one of {self.link_prediction_algorithms}") - - config={} - config['direction']=direction - if relationship_query: - config['relationshipQuery']=relationship_query - if algorithm_name=='SameCommunity': - config['communityProperty']=comunity_property - - cypher_statement=f"MATCH (p1:{node_a_type}" + "{name:" + f"{format_string(node_a_name)}" + "})\n" - cypher_statement=f"MATCH (p1:{node_b_type}" + "{name:" + f"{format_string(node_b_name)}" + "})\n" - cypher_statement=f"RETURN gds.alpha.linkprediction.{algorithm_name}(p1,p2,{format_dictionary(config)})" - results = self.neo4j_manager.query(cypher_statement,database_name=database_name) - outputs=[] - for result in results: - output={property_name:property_value for property_name,property_value in result.items()} - outputs.append(output) - return outputs - - def run_dijkstra_source_target_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - source_node:int, - target_node:int, - relationship_weight_property:str='weight', - concurrency:int=1, - write_node_ids:bool=False, - write_costs:bool=False, - write_relationship_type:str=None, - mutate_relationship_type:str=None, - ): - """ - Runs the dijkstra source target algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/dijkstra/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - source_node (int, optional): The source node. Defaults to None. - target_node (int, optional): The target node. Defaults to None. - relationship_weight_property (str, optional): The relationship weight property. Defaults to 'weight'. - concurrency (int, optional): The concurrency. Defaults to 4. - write_node_ids (bool, optional): The write node ids. Defaults to False. - write_costs (bool, optional): The write costs. Defaults to False. - write_relationship_type (bool, optional): The write relationship type. Defaults to False. - mutate_relationship_type (str, optional): The mutate relationship type. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_relationship_type is None: - raise Exception("write_relationship_type must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_relationship_type is None: - raise Exception("mutate_relationship_type must be provided when algorithm_mode is mutate") - if not isinstance(source_node,int): - raise Exception("Source node must be an integer") - if not isinstance(target_node,int): - raise Exception("Target node must be an integer") - - algorithm_config={} - algorithm_config['sourceNode']=source_node - algorithm_config['targetNode']=target_node - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_node_ids: - algorithm_config['writeNodeIds']=write_node_ids - if write_costs: - algorithm_config['writeCosts']=write_costs - if write_relationship_type: - algorithm_config['writeRelationshipType']=write_relationship_type - if mutate_relationship_type: - algorithm_config['mutateRelationshipType']=mutate_relationship_type - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='shortestPath.dijkstra', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_a_star_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - source_node:int, - target_node:int, - lattitude_property:float, - longitude_property:float, - write_relationship_type:str=None, - relationship_weight_property:str='weight', - concurrency:int=1, - write_node_ids:bool=False, - write_costs:bool=False, - mutate_relationship_type:str=None, - ): - """ - Runs the a star algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/a-star/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - source_node (int, optional): The source node. Defaults to None. - target_node (int, optional): The target node. Defaults to None. - lattitude_property (float, optional): The lattitude property. Defaults to None. - longitude_property (float, optional): The longitude property. Defaults to None. - relationship_weight_property (str, optional): The relationship weight property. Defaults to 'weight'. - concurrency (int, optional): The concurrency. Defaults to 4. - write_node_ids (bool, optional): The write node ids. Defaults to False. - write_costs (bool, optional): The write costs. Defaults to False. - write_relationship_type (bool, optional): The write relationship type. Defaults to False. - mutate_relationship_type (str, optional): The mutate relationship type. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_relationship_type is None: - raise Exception("write_relationship_type must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_relationship_type is None: - raise Exception("mutate_relationship_type must be provided when algorithm_mode is mutate") - if not isinstance(source_node,int): - raise Exception("Source node must be an integer") - if not isinstance(target_node,int): - raise Exception("Target node must be an integer") - if not isinstance(lattitude_property,float): - raise Exception("Lattitude property must be a float") - if not isinstance(longitude_property,float): - raise Exception("Longitude property must be a float") - - algorithm_config={} - algorithm_config['sourceNode']=source_node - algorithm_config['targetNode']=target_node - algorithm_config['lattitudeProperty']=lattitude_property - algorithm_config['longitudeProperty']=longitude_property - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_node_ids: - algorithm_config['writeNodeIds']=write_node_ids - if write_costs: - algorithm_config['writeCosts']=write_costs - if write_relationship_type: - algorithm_config['writeRelationshipType']=write_relationship_type - if mutate_relationship_type: - algorithm_config['mutateRelationshipType']=mutate_relationship_type - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='shortestPath.astar', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_yens_shortest_path_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - source_node:int, - target_node:int, - k:int=1, - relationship_weight_property:str='weight', - concurrency:int=1, - write_relationship_type:str=None, - write_node_ids:bool=False, - write_costs:bool=False, - mutate_relationship_type:str=None, - ): - """ - Runs the yen's shortest path algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/yen-shortest-path/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - source_node (int, optional): The source node. Defaults to None. - target_node (int, optional): The target node. Defaults to None. - k (int, optional): The k. Defaults to 1. - relationship_weight_property (str, optional): The relationship weight property. Defaults to 'weight'. - concurrency (int, optional): The concurrency. Defaults to 4. - write_node_ids (bool, optional): The write node ids. Defaults to False. - write_costs (bool, optional): The write costs. Defaults to False. - mutate_relationship_type (str, optional): The mutate relationship type. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_relationship_type is None: - raise Exception("write_relationship_type must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_relationship_type is None: - raise Exception("mutate_relationship_type must be provided when algorithm_mode is mutate") - if not isinstance(source_node,int): - raise Exception("Source node must be an integer") - if not isinstance(target_node,int): - raise Exception("Target node must be an integer") - - algorithm_config={} - algorithm_config['sourceNode']=source_node - algorithm_config['targetNode']=target_node - algorithm_config['k']=k - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_node_ids: - algorithm_config['writeNodeIds']=write_node_ids - if write_costs: - algorithm_config['writeCosts']=write_costs - if mutate_relationship_type: - algorithm_config['mutateRelationshipType']=mutate_relationship_type - if write_relationship_type: - algorithm_config['writeRelationshipType']=write_relationship_type - - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='shortestPath.yens', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_node_similarity_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str='stream', - similarity_cutoff:float=10**-42, - degree_cutoff:int=1, - upper_degree_cutoff:int=2147483647, - topK:int=10, - bottomK:int=10, - topN:int=0, - bottomN:int=0, - relationship_weight_property:str=None, - similarity_metric:str='JACCARD', - use_components:bool=False, - write_relationship_type:str=None, - write_property:str=None, - mutate_relationship_type:str=None, - mutate_property:str=None, - ): - """ - Runs the node similarity algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/node-similarity/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - similarity_cutoff (float, optional): The similarity cutoff. Defaults to 10**-42. - degree_cutoff (int, optional): The degree cutoff. Defaults to 1. - upper_degree_cutoff (int, optional): The upper degree cutoff. Defaults to 2147483647. - topK (int, optional): The top K. Defaults to 10. - bottomK (int, optional): The bottom K. Defaults to 10. - topN (int, optional): The top N. Defaults to 0. - bottomN (int, optional): The bottom N. Defaults to 0. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - similarity_metric (str, optional): The similarity metric. Defaults to 'JACCARD'. - use_components (bool, optional): The use components. Defaults to False. - write_relationship_type (str, optional): The write relationship type. Defaults to None. - write_property (str, optional): The write property. Defaults to None. - mutate_relationship_type (str, optional): The mutate relationship type. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and (write_property is None or write_relationship_type is None): - raise Exception("write_property and write_relationship_type must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and (mutate_property is None or mutate_relationship_type is None): - raise Exception("mutate_property and mutate_relationship_type must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['similarityCutoff']=similarity_cutoff - algorithm_config['degreeCutoff']=degree_cutoff - algorithm_config['upperDegreeCutoff']=upper_degree_cutoff - algorithm_config['topK']=topK - algorithm_config['bottomK']=bottomK - algorithm_config['topN']=topN - algorithm_config['bottomN']=bottomN - algorithm_config['similarityMetric']=similarity_metric - algorithm_config['useComponents']=use_components - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if write_relationship_type: - algorithm_config['writeRelationshipType']=write_relationship_type - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_relationship_type: - algorithm_config['mutateRelationshipType']=mutate_relationship_type - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='nodeSimilairty', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_k_nearest_neighbors_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - node_properties:Union[str,List,dict], - topK:int=10, - sample_rate:float=0.5, - delta_threshold:float=0.001, - max_iterations:int=100, - random_joins:int=10, - initial_sampler:str='uniform', - random_seed:int=42, - similarity_cutoff:float=0, - perturbation_rate:float=0.0, - concurrency:int=1, - write_relationship_type:str=None, - write_property:str=None, - write_node_ids:bool=False, - write_costs:bool=False, - mutate_relationship_type:str=None, - mutate_property:str=None, - ): - """ - Runs the k nearest neighbors algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/k-nearest-neighbors/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - node_properties (Union[str,List,dict], optional): The node properties. Defaults to None. - topK (int, optional): The top K. Defaults to 10. - sample_rate (float, optional): The sample rate. Defaults to 0.5. - delta_threshold (float, optional): The delta threshold. Defaults to 0.001. - max_iterations (int, optional): The max iterations. Defaults to 100. - random_joins (int, optional): The random joins. Defaults to 10. - initial_sampler (str, optional): The initial sampler. Defaults to 'uniform'. - random_seed (int, optional): The random seed. Defaults to 42. - similarity_cutoff (float, optional): The similarity cutoff. Defaults to 0. - perturbation_rate (float, optional): The perturbation rate. Defaults to 0.0. - concurrency (int, optional): The concurrency. Defaults to 1. - write_relationship_type (str, optional): The write relationship type. Defaults to None. - write_property (str, optional): The write property. Defaults to None. - write_node_ids (bool, optional): The write node ids. Defaults to False. - write_costs (bool, optional): The write costs. Defaults to False. - mutate_relationship_type (str, optional): The mutate relationship type. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and (write_property is None or write_relationship_type is None): - raise Exception("write_property and write_relationship_type must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and (mutate_property is None or mutate_relationship_type is None): - raise Exception("mutate_property and mutate_relationship_type must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['nodeProperties']=node_properties - algorithm_config['topK']=topK - algorithm_config['sampleRate']=sample_rate - algorithm_config['deltaThreshold']=delta_threshold - algorithm_config['maxIterations']=max_iterations - algorithm_config['randomJoins']=random_joins - algorithm_config['initialSampler']=initial_sampler - algorithm_config['randomSeed']=random_seed - algorithm_config['similarityCutoff']=similarity_cutoff - algorithm_config['perturbationRate']=perturbation_rate - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_node_ids: - algorithm_config['writeNodeIds']=write_node_ids - if write_costs: - algorithm_config['writeCosts']=write_costs - if write_relationship_type: - algorithm_config['writeRelationshipType']=write_relationship_type - if mutate_relationship_type: - algorithm_config['mutateRelationshipType']=mutate_relationship_type - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='knn', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_conductance_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - relationship_weight_property:str=None, - community_property:str=None, - concurrency:int=1, - ): - """ - Runs the conductance algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/conductance/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - community_property (str, optional): The community property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 1. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - - algorithm_config={} - algorithm_config['concurrency']=concurrency - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if community_property: - algorithm_config['communityProperty']=community_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='conductance', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_k_core_decomposition_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - concurrency:int=1, - mutate_property:str=None, - write_property:str=None): - """ - Runs the k core decomposition algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/k-core-decomposition/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - concurrency (int, optional): The concurrency. Defaults to 1. - write_property (str, optional): The write property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='kcore', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_k1_coloring_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - min_community_size:int=0, - max_iterations:int=10, - concurrency:int=1, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the k1 coloring algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/k1-coloring/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - min_community_size (int, optional): The min community size. Defaults to 0. - max_iterations (int, optional): The max iterations. Defaults to 10. - concurrency (int, optional): The concurrency. Defaults to 1. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and (write_property is None or mutate_property is None): - raise Exception("write_property and mutate_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and (mutate_property is None or write_property is None): - raise Exception("mutate_property and write_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['minCommunitySize']=min_community_size - algorithm_config['maxIterations']=max_iterations - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='k1coloring', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_label_propagation_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_iterations:int=10, - node_weight_property:str=None, - relationship_weight_property:str=None, - seed_property:str=None, - consecutive_ids:bool=False, - min_community_size:int=0, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the label propagation algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/label-propagation/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_iterations (int, optional): The max iterations. Defaults to 10. - node_weight_property (str, optional): The node weight property. Defaults to None. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - seed_property (str, optional): The seed property. Defaults to None. - consecutive_ids (bool, optional): The consecutive ids. Defaults to False. - min_community_size (int, optional): The min community size. Defaults to 0. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property and mutate_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property and write_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['maxIterations']=max_iterations - if node_weight_property: - algorithm_config['nodeWeightProperty']=node_weight_property - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['seedProperty']=seed_property - algorithm_config['consecutiveIds']=consecutive_ids - algorithm_config['minCommunitySize']=min_community_size - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='labelPropagation', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_leiden_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_levels:int=10, - gamma:float=1.0, - theta:float=0.01, - tolerance:float=0.0001, - include_intermediate_communities:bool=False, - seed_property:str=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the leiden algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/leiden/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_levels (int, optional): The max levels. Defaults to 10. - gamma (float, optional): The gamma. Defaults to 1.0. - theta (float, optional): The theta. Defaults to 0.01. - tolerance (float, optional): The tolerance. Defaults to 0.0001. - include_intermediate_communities (bool, optional): The include intermediate communities. Defaults to False. - seed_property (str, optional): The seed property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['maxLevels']=max_levels - algorithm_config['gamma']=gamma - algorithm_config['theta']=theta - algorithm_config['tolerance']=tolerance - algorithm_config['includeIntermediateCommunities']=include_intermediate_communities - algorithm_config['seedProperty']=seed_property - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='leiden', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_local_clustering_coefficient_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - triangle_count_property:str=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the local clustering coefficient algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/local-clustering-coefficient/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - triangle_count_property (str, optional): The triangle count property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['triangleCountProperty']=triangle_count_property - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='localClusteringCoefficient', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_louvain_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - relationship_weight_property:str=None, - seed_property:str=None, - max_levels:int=10, - tolerance:float=0.0001, - include_intermediate_communities:bool=False, - consecutive_ids:bool=False, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the louvain algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/louvain/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - seed_property (str, optional): The seed property. Defaults to None. - max_levels (int, optional): The max levels. Defaults to 10. - tolerance (float, optional): The tolerance. Defaults to 0.0001. - include_intermediate_communities (bool, optional): The include intermediate communities. Defaults to False. - consecutive_ids (bool, optional): The consecutive ids. Defaults to False. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['seedProperty']=seed_property - algorithm_config['maxLevels']=max_levels - algorithm_config['tolerance']=tolerance - algorithm_config['includeIntermediateCommunities']=include_intermediate_communities - algorithm_config['consecutiveIds']=consecutive_ids - algorithm_config['concurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='louvain', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_modularity_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - relationship_weight_property:str=None, - community_property:str=None, - concurrency:int=4): - """ - Runs the modularity algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/modularity/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - community_property (str, optional): The community property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if community_property is None: - raise Exception("Community property must be provided") - - algorithm_config={} - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if community_property: - algorithm_config['communityProperty']=community_property - algorithm_config['concurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='modularity', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_modularity_optimization_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_iterations:int=10, - tolerance:float=0.0001, - seed_property:str=None, - consecutive_ids:bool=False, - relationship_weight_property:str=None, - min_community_size:int=0, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the modularity optimization algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/modularity-optimization/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_iterations (int, optional): The max iterations. Defaults to 10. - tolerance (float, optional): The tolerance. Defaults to 0.0001. - seed_property (str, optional): The seed property. Defaults to None. - consecutive_ids (bool, optional): The consecutive ids. Defaults to False. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - min_community_size (int, optional): The min community size. Defaults to 0. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['maxIterations']=max_iterations - algorithm_config['tolerance']=tolerance - algorithm_config['seedProperty']=seed_property - algorithm_config['consecutiveIds']=consecutive_ids - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['minCommunitySize']=min_community_size - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='modularityOptimization', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_strongly_connected_components_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - consecutive_ids:bool=False, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the strongly connected components algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/strongly-connected-components/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - consecutive_ids (bool, optional): The consecutive ids. Defaults to False. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['consecutiveIds']=consecutive_ids - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='scc', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_triangle_count_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_degree:int=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the triangle count algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/triangle-count/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_degree (int, optional): The max degree. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - if max_degree: - algorithm_config['maxDegree']=max_degree - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='triangleCount', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_weakly_connected_components_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - relationship_weight_property:str=None, - seed_property:str=None, - threshold:float=0.001, - consecutive_ids:bool=False, - min_component_size:int=0, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the weakly connected components algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/weakly-connected-components/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - seed_property (str, optional): The seed property. Defaults to None. - threshold (float, optional): The threshold. Defaults to 0.001. - consecutive_ids (bool, optional): The consecutive ids. Defaults to False. - min_component_size (int, optional): The min component size. Defaults to 0. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['seedProperty']=seed_property - algorithm_config['threshold']=threshold - algorithm_config['consecutiveIds']=consecutive_ids - algorithm_config['minComponentSize']=min_component_size - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='weaklyConnectedComponents', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_approximate_max_k_cut_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - k:int=2, - iterations:int=8, - vns_max_neighborhood_order:int=None, - random_seed:int=42, - relationship_weight_property:str=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the approximate max k cut algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/approximate-max-k-cut/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - k (int, optional): The k. Defaults to 2. - iterations (int, optional): The iterations. Defaults to 8. - vns_max_neighborhood_order (int, optional): The vns max neighborhood order. Defaults to None. - random_seed (int, optional): The random seed. Defaults to 42. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['k']=k - algorithm_config['iterations']=iterations - if vns_max_neighborhood_order: - algorithm_config['vnsMaxNeighborhoodOrder']=vns_max_neighborhood_order - algorithm_config['randomSeed']=random_seed - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='maxkcut', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_speaker_listener_label_propagation_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_iterations:int=None, - min_assocation_strength:float=0.2, - partitioning:str="RANGE", - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the speaker listener label propagation algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/speaker-listener-label-propagation/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_iterations (int, optional): The max iterations. Defaults to None. - min_assocation_strength (float, optional): The min assocation strength. Defaults to 0.2. - partitioning (str, optional): The partitioning. Defaults to "RANGE". - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - if max_iterations is None: - raise Exception("Max iterations must be provided") - - algorithm_config={} - algorithm_config['maxIterations']=max_iterations - algorithm_config['minAssocationStrength']=min_assocation_strength - algorithm_config['partitioning']=partitioning - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='sllpa', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_article_ranking_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - damping_factor:float=0.85, - max_iterations:int=20, - tolerance:float=0.0000001, - relationship_weight_property:str=None, - source_nodes:List=None, - scaler:Union[str,dict]=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the article ranking algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/article-ranking/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - damping_factor (float, optional): The damping factor. Defaults to 0.85. - max_iterations (int, optional): The max iterations. Defaults to 20. - tolerance (float, optional): The tolerance. Defaults to 0.0000001. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - source_nodes (List, optional): The source nodes. Defaults to None. - scaler (Union[str,dict], optional): The scaler. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['dampingFactor']=damping_factor - algorithm_config['maxIterations']=max_iterations - algorithm_config['tolerance']=tolerance - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if source_nodes: - algorithm_config['sourceNodes']=source_nodes - if scaler: - algorithm_config['scaler']=scaler - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='articleRank', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_betweenness_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - sampling_size:int=None, - sampling_seed:int=None, - relationship_weight_property:str=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the betweenness centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/betweenness-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - sampling_size (int, optional): The sampling size. Defaults to None. - sampling_seed (int, optional): The sampling seed. Defaults to None. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - if sampling_size: - algorithm_config['samplingSize']=sampling_size - if sampling_seed: - algorithm_config['samplingSeed']=sampling_seed - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='betweenness', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_celf_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - seed_set_size:int=None, - monte_carlo_simulations:int=100, - propagation_probability:float=0.1, - random_seed:int=None, - concurrency:int=4,): - """ - Runs the celf algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/celf/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - seed_set_size (int, optional): The seed set size. Defaults to None. - monte_carlo_simulations (int, optional): The monte carlo simulations. Defaults to 100. - propagation_probability (float, optional): The propagation probability. Defaults to 0.1. - random_seed (int, optional): The random seed. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - - algorithm_config={} - if seed_set_size: - algorithm_config['seedSetSize']=seed_set_size - if monte_carlo_simulations: - algorithm_config['monteCarloSimulations']=monte_carlo_simulations - if propagation_probability: - algorithm_config['propagationProbability']=propagation_probability - if random_seed: - algorithm_config['randomSeed']=random_seed - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='influenceMaximization', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_closeness_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - use_wasserman_faust:bool=False, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the closeness centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/closeness-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - use_wasserman_faust (bool, optional): The use wasserman faust. Defaults to False. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - - algorithm_config['useWassermanFaust']=use_wasserman_faust - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='closeness', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_degree_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - orientation:str='NATURAL', - relationship_weight_property:str=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the degree centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/degree-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - orientation (str, optional): The orientation. Defaults to 'NATURAL'. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['orientation']=orientation - algorithm_config['concurrency']=concurrency - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='degree', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_eigenvector_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - max_iterations:int=20, - tolerance:float=0.0000001, - relationship_weight_property:str=None, - source_nodes:List=None, - scaler:Union[str,dict]=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the eigenvector centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/eigenvector-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - max_iterations (int, optional): The max iterations. Defaults to 20. - tolerance (float, optional): The tolerance. Defaults to 0.0000001. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - source_nodes (List, optional): The source nodes. Defaults to None. - scaler (Union[str,dict], optional): The scaler. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['maxIterations']=max_iterations - algorithm_config['tolerance']=tolerance - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if source_nodes: - algorithm_config['sourceNodes']=source_nodes - if scaler: - algorithm_config['scaler']=scaler - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='eigenvector', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_page_rank_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - damping_factor:float=0.85, - max_iterations:int=20, - tolerance:float=0.0000001, - relationship_weight_property:str=None, - source_nodes:List=None, - scaler:Union[str,dict]=None, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the page rank algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/page-rank/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - damping_factor (float, optional): The damping factor. Defaults to 0.85. - max_iterations (int, optional): The max iterations. Defaults to 20. - tolerance (float, optional): The tolerance. Defaults to 0.0000001. - relationship_weight_property (str, optional): The relationship weight property. Defaults to None. - source_nodes (List, optional): The source nodes. Defaults to None. - scaler (Union[str,dict], optional): The scaler. Defaults to None. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['dampingFactor']=damping_factor - algorithm_config['maxIterations']=max_iterations - algorithm_config['tolerance']=tolerance - if relationship_weight_property: - algorithm_config['relationshipWeightProperty']=relationship_weight_property - if source_nodes: - algorithm_config['sourceNodes']=source_nodes - if scaler: - algorithm_config['scaler']=scaler - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='pageRank', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_harmonic_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the harmonic centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/harmonic-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='closeness.harmonic', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - - def run_hits_centrality_algorithm(self, - database_name:str, - graph_name:str, - algorithm_mode:str, - hits_iterations:int=20, - auth_property:str=None, - hub_property:str=None, - partitioning:str='AUTO', - concurrency:int=4, - write_property:str=None, - mutate_property:str=None, - ): - """ - Runs the hits centrality algorithm. - Useful urls: - https://neo4j.com/docs/graph-data-science/current/machine-learning/hits-centrality/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - algorithm_mode (str, optional): The mode of the algorithm. Defaults to 'stream'. - hits_iterations (int, optional): The hits iterations. Defaults to 20. - auth_property (str, optional): The auth property. Defaults to None. - hub_property (str, optional): The hub property. Defaults to None. - partitioning (str, optional): The partitioning. Defaults to 'AUTO'. - concurrency (int, optional): The concurrency. Defaults to 4. - write_property (str, optional): The write property. Defaults to None. - mutate_property (str, optional): The mutate property. Defaults to None. - - Returns: - list: Returns results of the algorithm. - """ - - if algorithm_mode not in self.algorithm_modes: - raise Exception(f"Algorithm mode {algorithm_mode} is not supported. Use one of {self.algorithm_modes}") - if algorithm_mode=='write' and write_property is None: - raise Exception("write_property must be provided when algorithm_mode is write") - if algorithm_mode=='mutate' and mutate_property is None: - raise Exception("mutate_property must be provided when algorithm_mode is mutate") - - algorithm_config={} - algorithm_config['hitsIterations']=hits_iterations - algorithm_config['authProperty']=auth_property - algorithm_config['hubProperty']=hub_property - algorithm_config['partitioning']=partitioning - algorithm_config['concurrency']=concurrency - if algorithm_mode=='write': - algorithm_config['writeConcurrency']=concurrency - if write_property: - algorithm_config['writeProperty']=write_property - if mutate_property: - algorithm_config['mutateProperty']=mutate_property - - results=self.run_algorithm(self, - database_name=database_name, - graph_name=graph_name, - algorithm_name='hits', - algorithm_mode=algorithm_mode, - algorithm_config=algorithm_config) - return results - -if __name__ == "__main__": - - from matgraphdb.graph_kit.neo4j.neo4j_manager import Neo4jManager - - # from matgraphdb.graph.graph_generator import GraphGenerator - # generator=GraphGenerator(from_scratch=False) - # main_graph_dir=generator.main_graph_dir - # sub_graph_names=generator.list_sub_graphs() - # print(sub_graph_names) - # sub_graph_paths=[os.path.join(main_graph_dir,'sub_graphs',sub_graph_name) for sub_graph_name in sub_graph_names] - - # with Neo4jGraphDatabase(from_scratch=True) as manager: - # for path in sub_graph_paths: - # # print(manager.list_databases()) - # results=manager.load_graph_database_into_neo4j(path) - # print(results) - - - # with Neo4jGraphDatabase() as matgraphdb: - # database_name='nelements-no-fe' - # graph_dir=os.path.join('data','production','materials_project','graph_database') - # # settings={'apoc.export.file.enabled':'true'} - # # matgraphdb.set_apoc_environment_variables(settings=settings) - # matgraphdb.export_database(graph_dir,database_name=database_name) - # with Neo4jGraphDatabase() as matgraphdb: - # # database_name='nelements-1-2' - # manager=Neo4jGDSManager(matgraphdb) - # # print(manager.list_graphs(database_name)) - # # print(manager.is_graph_in_memory(database_name,'materials_chemenvElements')) - # # # print(manager.list_graph_data_science_algorithms(database_name)) - # database_name='elements-no-fe' - # graph_name='materials_chemenvElements' - # # node_projections=['ChemenvElement','Material'] - # # relationship_projections={ - # # "GEOMETRIC_ELECTRIC_CONNECTS": { - # # "orientation": 'UNDIRECTED', - # # "properties": 'weight' - # # }, - # # "COMPOSED_OF": { - # # "orientation": 'UNDIRECTED', - # # "properties": 'weight' - # # } - # # } - - # # # print(format_dictionary(relationship_projections)) - # # manager.load_graph_into_memory(database_name=database_name, - # # graph_name=graph_name, - # # node_projections=node_projections, - # # relationship_projections=relationship_projections) - # # print(manager.get_graph_info(database_name=database_name,graph_name=graph_name)) - # print(manager.list_graphs(database_name)) - # print(manager.is_graph_in_memory(database_name,graph_name)) - # print(manager.drop_graph(database_name,graph_name)) - # print(manager.list_graphs(database_name)) - # print(manager.is_graph_in_memory(database_name,graph_name)) - - # print(manager.get_graph_info(database_name=database_name,graph_name=graph_name)) - # results=manager.read_graph(database_name=database_name, - # graph_name=graph_name, - # node_properties=["fastrp-embedding"], - # node_labels=['Material']) - - # print(results) - - - # result=manager.estimate_memeory_for_algorithm(database_name=database_name, - # graph_name=graph_name, - # algorithm_name='fastRP', - # model='stream', - # algorithm_config={'embeddingDimension':128}) - # result=manager.run_algorithm(database_name=database_name, - # graph_name=graph_name, - # algorithm_name='fastRP', - # algorithm_mode='stats', - # algorithm_config={'embeddingDimension':128}) - # print(result) - - - - - - - # with GraphDatabase() as session: - # # result = matgraphdb.execute_query(query, parameters) - # schema_list=session.list_schema() - - # results=session.read_material(material_ids=['mp-1000','mp-1001'], - # elements=['Te','Ba']) - # results=session.read_material(material_ids=['mp-1000','mp-1001'], - # elements=['Te','Ba'], - # crystal_systems=['cubic']) - # results=session.read_material(material_ids=['mp-1000','mp-1001'], - # elements=['Te','Ba'], - # crystal_systems=['hexagonal']) - # results=session.read_material(material_ids=['mp-1000','mp-1001'], - # elements=['Te','Ba'], - # hall_symbols=['Fm-3m']) - # results=session.read_material(material_ids=['mp-1000','mp-1001'], - # elements=['Te','Ba'], - # band_gap=[(1.0,'>')]) - - # results=session.create_material(composition="BaTe") - - # print(results) - # print(schema_list) -# prompt = "What are materials similar to the composition TiAu" -# # prompt = "What are materials with TiAu" -# # prompt = "What are materials with TiAu" - -# # prompt = "What are some cubic materials" -# # # prompt = "What are some materials with a large band gap?" -# prompt = "What are materials with a band_gap greater than 1.0?" -# results=session.execute_llm_query(prompt,n_results=10) - -# for result in results: -# print(result['sm']["name"]) -# print(result['sm']["formula_pretty"]) -# print(result['sm']["symmetry"]) -# print(result['score']) -# # print(results['sm']["band_gap"]) -# print("_"*200) - - - - - ################################################################################################## - # # Loading and unloading graphs into GDS - ################################################################################################## - with Neo4jManager() as matgraphdb: - database_name='elements-no-fe' - graph_name='materials_chemenvElements' - manager=Neo4jGDSManager(matgraphdb) - print(manager.list_graphs(database_name)) - print(manager.is_graph_in_memory(database_name,'materials_chemenvElements')) - print(manager.drop_graph(database_name,graph_name)) - # graph_name='materials_chemenvElements' - # node_projections={ - # "ChemenvElement":{ - # "label":'ChemenvElement', - # }, - # "Material":{ - # "label":'Material', - # "properties":['band_gap','formation_energy_per_atom','energy_per_atom','energy_above_hull','k_vrh','g_vrh'] - # } - # } - # relationship_projections = { - # "GEOMETRIC_ELECTRIC_CONNECTS": { - # "orientation": 'UNDIRECTED', - # "properties": 'weight' - # }, - # "COMPOSED_OF": { - # "orientation": 'UNDIRECTED', - # "properties": 'weight' - # } - # } - # manager.load_graph_into_memory(database_name=database_name, - # graph_name=graph_name, - # node_projections=node_projections, - # relationship_projections=relationship_projections) - - ################################################################################################## - # # Testing Neo4jDLManager - ################################################################################################## - # with MatGraphDB() as matgraphdb: - # database_name='nelements-1-2' - # manager=Neo4jDLManager(matgraphdb) - # print(manager.list_graphs(database_name)) - # print(manager.is_graph_in_memory(database_name,'materials_chemenvElements')) - # results=manager.list_graph_data_science_algorithms(database_name,save=True) - # for result in results: - # # print the reuslts in two columns - # print(result[0],'|||||||||||||||',result[1]) \ No newline at end of file diff --git a/matgraphdb/graph_kit/neo4j/neo4j_manager.py b/matgraphdb/graph_kit/neo4j/neo4j_manager.py deleted file mode 100644 index a7a2c12..0000000 --- a/matgraphdb/graph_kit/neo4j/neo4j_manager.py +++ /dev/null @@ -1,889 +0,0 @@ -import os -import json -from typing import List, Tuple, Union -from glob import glob - -from neo4j import GraphDatabase -import pandas as pd - -from matgraphdb.data.manager import DBManager -from matgraphdb.utils import (PASSWORD,USER,LOCATION,DBMSS_DIR, GRAPH_DIR, LOGGER,MP_DIR) -from matgraphdb.utils.general_utils import get_os -from matgraphdb.graph_kit.neo4j.utils import get_similarity_query, format_projection,format_dictionary,format_list,format_string - - -# TODO: Think of way to store new node and relationship properties. -# TODO: For material nodes, we can use DB Manager to store properties back into json database -# TODO: Created method to export node and relationship properties to csv file it does not have the exact same format as the original file -# TODO: FIX FastRP algorithm - - -class Neo4jManager: - - def __init__(self, graph_dir=None, db_manager=DBManager(), uri=LOCATION, user=USER, password=PASSWORD, from_scratch=False): - """ - Initializes a MatGraphDB object. - - Args: - graph_dir (str): The directory where the graph is located. - uri (str): The URI of the graph database. - user (str): The username for authentication. - password (str): The password for authentication. - from_scratch (bool): Whether to create a new database or use an existing one. - """ - - self.uri = uri - self.user = user - self.password = password - self.driver = None - self.dbms_dir = None - self.dbms_json = None - self.db_manager = db_manager - self.graph_dir=graph_dir - self.from_scratch=from_scratch - self.get_dbms_dir() - self.neo4j_admin_path=None - self.neo4j_cypher_shell_path=None - self.get_neo4j_tools_path() - if graph_dir: - self.load_graph_database_into_neo4j(graph_dir) - - def __enter__(self): - """ - Enter method for using the graph database as a context manager. - - This method is called when entering a `with` statement and is responsible for setting up any necessary resources - or connections. In this case, it creates a driver for the graph database and returns the current instance. - - Returns: - self: The current instance of the `GraphDatabase` class. - - Example: - with GraphDatabase() as graph_db: - # Perform operations on the graph database - """ - self.create_driver() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context manager exit method. - - This method is called when exiting the context manager. It is responsible for closing the database connection. - - Args: - exc_type (type): The type of the exception raised, if any. - exc_val (Exception): The exception raised, if any. - exc_tb (traceback): The traceback object associated with the exception, if any. - """ - self.close() - - def create_driver(self): - """ - Creates a driver object for connecting to the graph database. - - Returns: - None - """ - self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) - - def close(self): - """ - Closes the database connection. - - This method closes the connection to the database if it is open. - """ - if self.driver: - self.driver.close() - self.driver=None - - def list_schema(self): - """ - Retrieves the schema of the graph database. - - Returns: - schema_list (list): A list of strings representing the schema of the graph database. - """ - schema_list=[] - - # Query for node labels and properties - labels_query = "CALL db.schema.nodeTypeProperties()" - labels_results = self.execute_query(labels_query) - node_and_properties = {} - for record in labels_results: - - node_type = record["nodeType"] - propert_name = record["propertyName"] - - if isinstance(record["propertyTypes"], list): - property_type = record["propertyTypes"][0] - else: - property_type = record["propertyTypes"] - property_type=property_type.replace("`",'').replace("String",'str').replace('Integer','int').replace('Float','float') - if node_type not in node_and_properties: - node_and_properties[node_type] = {propert_name : property_type} - else: - node_and_properties[node_type].update({propert_name : property_type}) - - # Query for relationship types - labels_query = "CALL db.schema.visualization()" - labels_results = self.execute_query(labels_query) - - for i,record in enumerate(labels_results): - - # Getting node types and names - try: - node_1_type=record["nodes"][0]['name'] - node_2_type=record["nodes"][1]['name'] - node_1_name=f':`{node_1_type}`' - node_2_name=f':`{node_2_type}`' - except: - raise Exception("Only one node in this graph gb") - - # Adding indexes and contraints - # node_and_properties[node_1_name].update({'indexes' : type(record["nodes"][0]._properties['indexes']).__name__}) - # node_and_properties[node_2_name].update({'indexes' : type(record["nodes"][1]._properties['indexes']).__name__}) - - # node_and_properties[node_1_name].update({'constraints' : type(record["nodes"][0]._properties['constraints']).__name__}) - # node_and_properties[node_2_name].update({'constraints' : type(record["nodes"][1]._properties['constraints']).__name__}) - - node_and_properties[node_1_name].update({'indexes' : record["nodes"][0]._properties['indexes']}) - node_and_properties[node_2_name].update({'indexes' : record["nodes"][1]._properties['indexes']}) - - node_and_properties[node_1_name].update({'constraints' : record["nodes"][0]._properties['constraints']}) - node_and_properties[node_2_name].update({'constraints' : record["nodes"][1]._properties['constraints']}) - - # Get relationship infor for all relationships - for relationship in record["relationships"]: - - # Get start and end node names - start_node=relationship.start_node - end_node=relationship.end_node - - start_node_name=f':`{start_node._properties["name"]}`' - end_node_name=f':`{end_node._properties["name"]}`' - - # Get relationship type - relationship_type = relationship.type - - # Get the relationship properties - query_relationship=f'MATCH ({start_node_name})-[r:`{relationship_type}`]-({end_node_name}) RETURN r LIMIT 1' - try: - relationship = self.execute_query(query_relationship)[0][0] - - relationship_properties = {} - for key, value in relationship._properties.items(): - relationship_properties[key] = type(value).__name__ - - # Create the final schema - query_relationship=f'({start_node_name} {node_and_properties[node_1_name]} )-[r:`{relationship_type}` {relationship_properties}]-({end_node_name} {node_and_properties[node_2_name]}) ' - schema_list.append(query_relationship) - except: - continue - - return schema_list - - def get_dbms_dir(self): - """ - Returns the directory where the database management system (DBMS) is located. - - Returns: - str: The directory where the DBMS is located. - """ - dbmss_dirs=glob(os.path.join(DBMSS_DIR,'*')) - for dbms_dir in dbmss_dirs: - relate_json=os.path.join(dbms_dir,'relate.dbms.json') - with open(relate_json,'r') as f: - dbms_info=json.loads(f.read()) - dbms_name=dbms_info['name'].split()[0] - if dbms_name=='MatGraphDB': - self.dbms_dir=dbms_dir - self.dbms_json=relate_json - return self.dbms_dir - if self.dbms_dir is None: - raise Exception("MatGraphDB DBMS is not found. Please create a new DBMS with the name 'MatGraphDB'") - - def get_neo4j_tools_path(self): - """ - Returns the path to the Neo4j tools. - - Returns: - str: The path to the Neo4j tools. - """ - if get_os()=='Windows': - self.neo4j_admin_path=os.path.join(self.dbms_dir,'bin','neo4j-admin.bat') - self.neo4j_cypher_shell_path=os.path.join(self.dbms_dir,'bin','cypher-shell.bat') - else: - self.neo4j_admin_path=os.path.join(self.dbms_dir,'bin','neo4j-admin') - self.neo4j_cypher_shell_path=os.path.join(self.dbms_dir,'bin','cypher-shell') - return self.neo4j_admin_path,self.neo4j_cypher_shell_path - - def get_load_statments(self,database_path): - """ - Returns the load statement. - - Returns: - str: The load statement for nodes. - """ - node_statement=" --nodes" - node_files=glob(os.path.join(database_path,'nodes','*.csv')) - for node_file in node_files: - node_statement+=f" \"{node_file}\"" - relationship_files=glob(os.path.join(database_path,'relationships','*.csv')) - relationship_statement=" --relationships" - for relationship_file in relationship_files: - relationship_statement+=f" \"{relationship_file}\"" - statement=node_statement+relationship_statement - return statement - - def list_databases(self): - """ - Returns a list of databases in the graph database. - - Returns: - bool: True if the graph database exists, False otherwise. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - results=self.execute_query("SHOW DATABASES",database_name='system') - names=[] - for result in results: - graph_name=result['name'] - if graph_name not in ['neo4j','system']: - names.append(graph_name.lower()) - return names - - def create_database(self,database_name): - """ - Creates a new database in the graph database. - - Args: - database_name (str): The name of the database to create. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - self.execute_query(f"CREATE DATABASE `{database_name}`",database_name='system') - - def remove_database(self,database_name): - """ - Removes a database from the graph database. - - Args: - database_name (str): The name of the database to remove. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - self.execute_query(f"DROP DATABASE `{database_name}`",database_name='system') - - def remove_node(self,database_name,node_type,node_properties=None): - """ - Removes a node from the graph database. - - Args: - node_type (str): The type of the node to remove. - node_properties (dict, optional): The properties of the node to remove. Defaults to None. - """ - - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - - if node_properties is None: - node_properties="" - else: - node_properties=format_dictionary(node_properties) - - cypher_statement=f"MATCH (n:{node_type} {node_properties}) " - cypher_statement+=f"DETACH DELETE n" - results=self.query(cypher_statement,database_name=database_name) - return results - - def remove_relationship(self,database_name,relationship_type,relationship_properties=None): - """ - Removes a relationship from the graph database. - - Args: - relationship_type (str): The type of the relationship to remove. - relationship_properties (dict, optional): The properties of the relationship to remove. Defaults to None. - """ - - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - - if relationship_properties is None: - relationship_properties="" - else: - relationship_properties=format_dictionary(relationship_properties) - - cypher_statement=f"MATCH ()-[r:{relationship_type} {relationship_properties}]-() " - cypher_statement+=f"DETACH DELETE r" - results=self.query(cypher_statement,database_name=database_name) - return results - - def load_graph_database_into_neo4j(self,database_path,new_database_name=None): - """ - Loads a graph database into Neo4j. - - Args: - graph_datbase_path (str): The path to the graph database to load. - """ - if new_database_name: - database_name=new_database_name - else: - database_name=os.path.basename(database_path) - database_path=os.path.join(database_path) - db_names=self.list_databases() - database_name=database_name.lower() - if self.from_scratch and database_name in db_names: - print(f"Removing database {database_name}") - self.remove_database(database_name) - db_names=self.list_databases() - - if database_name in db_names: - raise Exception(f"Graph database {database_name} already exists. " - "It must be removed before loading. " - "Set from_scratch=True to force a new database to be created.") - - - import_statment=f'{self.neo4j_admin_path} database import full' - load_statment=self.get_load_statments(database_path) - import_statment+=load_statment - import_statment+=f" --overwrite-destination {database_name}" - # Execute the import statement - os.system(import_statment) - - self.create_database(database_name) - return None - - def does_property_exist(self,database_name,property_name,node_type=None,relationship_type=None): - """ - Checks if a property exists in a graph database. - - Args: - database_name (str): The name of the database. - node_type (str): The type of the node. - property_name (str): The name of the property. - - Returns: - bool: True if the property exists, False otherwise. - """ - if node_type and relationship_type: - raise Exception("Both node_type and relationship_type cannot be provided at the same time") - if node_type is None and relationship_type is None: - raise Exception("Either node_type or relationship_type must be provided") - if node_type: - cypher_statement=f"MATCH (n:`{node_type}`)\n" - cypher_statement+=f"WHERE n.`{property_name}` IS NOT NULL\n" - cypher_statement+=f"RETURN n LIMIT 1" - if relationship_type: - cypher_statement=f"MATCH ()-[r:`{relationship_type}`]-()\n" - cypher_statement+=f"WHERE r.`{property_name}` IS NOT NULL\n" - cypher_statement+=f"RETURN r LIMIT 1" - results=self.query(cypher_statement,database_name=database_name) - if len(results)==0: - return False - return True - - def remove_property(self,database_name,property_name,node_type=None,relationship_type=None): - """ - Removes a property from a graph database. - - Args: - database_name (str): The name of the database. - node_type (str): The type of the node. - property_name (str): The name of the property. - - Returns: - None - """ - if node_type and relationship_type: - raise Exception("Both node_type and relationship_type cannot be provided at the same time") - if node_type is None and relationship_type is None: - raise Exception("Either node_type or relationship_type must be provided") - - if node_type: - cypher_statement=f"MATCH (n:{node_type})" - cypher_statement+=f"DETACH DELETE n.`{property_name}`" - cypher_statement+=f"RETURN n" - if relationship_type: - cypher_statement=f"MATCH ()-[r:{relationship_type}]-()" - cypher_statement+=f"DETACH DELETE r.`{property_name}`" - cypher_statement+=f"RETURN r" - self.query(cypher_statement,database_name=database_name) - return None - - def execute_query(self, query, database_name, parameters=None): - """ - Executes a query on the graph database. - - Args: - query (str): The Cypher query to execute. - database_name (str): The name of the database to execute the query on. Defaults to None. - parameters (dict, optional): Parameters to pass to the query. Defaults to None. - - Returns: - list: A list of records returned by the query. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - - with self.driver.session(database=database_name) as session: - results = session.run(query, parameters) - return [record for record in results] - - def query(self, query, database_name, parameters=None): - """ - Executes a query on the graph database. - - Args: - query (str): The Cypher query to execute. - database_name (str): The name of the database to execute the query on. Defaults to None. - parameters (dict, optional): Parameters to pass to the query. Defaults to None. - - Returns: - list: A list of records returned by the query. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - with self.driver.session(database=database_name) as session: - results = session.run(query, parameters) - return [record for record in results] - - def execute_llm_query(self, prompt, database_name, n_results=5): - """ - Executes a query in the graph database using the LLM (Language-Modeling) approach. - - Args: - prompt (str): The prompt for the query. - database_name (str): The name of the database to execute the query on. - n_results (int, optional): The number of results to return. Defaults to 5. - - Returns: - list: A list of records returned by the query. - """ - if self.driver is None: - raise Exception("Graph database is not connected. Please connect to the database first.") - embedding, execute_statement = get_similarity_query(prompt) - parameters = {"embedding": embedding, "nresults": n_results} - - with self.driver.session(database=database_name) as session: - results = session.run(execute_statement, parameters) - - return [record for record in results] - - def read_material(self, - material_ids:List[str]=None, - elements:List[str]=None, - crystal_systems:List[str]=None, - magentic_states:List[str]=None, - hall_symbols:List[str]=None, - point_groups:List[str]=None, - band_gap:List[Tuple[float,str]]=None, - vbm:List[Tuple[float,str]]=None, - k_vrh:List[Tuple[float,str]]=None, - k_voigt:List[Tuple[float,str]]=None, - k_reuss:List[Tuple[float,str]]=None, - - g_vrh:List[Tuple[float,str]]=None, - g_voigt:List[Tuple[float,str]]=None, - g_reuss:List[Tuple[float,str]]=None, - universal_anisotropy:List[Tuple[float,str]]=None, - - density_atomic:List[Tuple[float,str]]=None, - density:List[Tuple[float,str]]=None, - e_ionic:List[Tuple[float,str]]=None, - - e_total:List[Tuple[float,str]]=None, - - energy_per_atom:List[Tuple[float,str]]=None, - - compositons:List[str]=None): - """ - Retrieves materials from the database based on specified criteria. - - Args: - material_ids (List[str], optional): List of material IDs to filter the results. Defaults to None. - elements (List[str], optional): List of elements to filter the results. Defaults to None. - crystal_systems (List[str], optional): List of crystal systems to filter the results. Defaults to None. - magentic_states (List[str], optional): List of magnetic states to filter the results. Defaults to None. - hall_symbols (List[str], optional): List of Hall symbols to filter the results. Defaults to None. - point_groups (List[str], optional): List of point groups to filter the results. Defaults to None. - band_gap (List[Tuple[float,str]], optional): List of tuples representing the band gap values and comparison operators to filter the results. Defaults to None. - vbm (List[Tuple[float,str]], optional): List of tuples representing the valence band maximum values and comparison operators to filter the results. Defaults to None. - k_vrh (List[Tuple[float,str]], optional): List of tuples representing the K_vrh values and comparison operators to filter the results. Defaults to None. - k_voigt (List[Tuple[float,str]], optional): List of tuples representing the K_voigt values and comparison operators to filter the results. Defaults to None. - k_reuss (List[Tuple[float,str]], optional): List of tuples representing the K_reuss values and comparison operators to filter the results. Defaults to None. - g_vrh (List[Tuple[float,str]], optional): List of tuples representing the G_vrh values and comparison operators to filter the results. Defaults to None. - g_voigt (List[Tuple[float,str]], optional): List of tuples representing the G_voigt values and comparison operators to filter the results. Defaults to None. - g_reuss (List[Tuple[float,str]], optional): List of tuples representing the G_reuss values and comparison operators to filter the results. Defaults to None. - universal_anisotropy (List[Tuple[float,str]], optional): List of tuples representing the universal anisotropy values and comparison operators to filter the results. Defaults to None. - density_atomic (List[Tuple[float,str]], optional): List of tuples representing the atomic density values and comparison operators to filter the results. Defaults to None. - density (List[Tuple[float,str]], optional): List of tuples representing the density values and comparison operators to filter the results. Defaults to None. - e_ionic (List[Tuple[float,str]], optional): List of tuples representing the ionic energy values and comparison operators to filter the results. Defaults to None. - e_total (List[Tuple[float,str]], optional): List of tuples representing the total energy values and comparison operators to filter the results. Defaults to None. - energy_per_atom (List[Tuple[float,str]], optional): List of tuples representing the energy per atom values and comparison operators to filter the results. Defaults to None. - compositons (List[str], optional): List of compositions to filter the results. Defaults to None. - - Returns: - results: The materials that match the specified criteria. - """ - - query = f"MATCH (m:Material) WHERE " - conditional_query = "" - - if material_ids: - conditional_query += f"m.material_id IN {material_ids}" - if elements: - for i,element in enumerate(elements): - if len(conditional_query)!=0: - conditional_query += " AND " - conditional_query += f"'{element}' IN m.elements" - if crystal_systems: - if len(conditional_query)!=0: - conditional_query += " AND " - conditional_query += f"m.crystal_system IN {crystal_systems}" - if magentic_states: - if len(conditional_query)!=0: - conditional_query += " AND " - conditional_query += f"m.ordering IN {magentic_states}" - if hall_symbols: - if len(conditional_query)!=0: - conditional_query += " AND " - conditional_query += f"m.hall_symbol IN {hall_symbols}" - if point_groups: - if len(conditional_query)!=0: - conditional_query += " AND " - conditional_query += f"m.point_group IN {point_groups}" - - if band_gap: - for bg in band_gap: - if len(conditional_query)!=0: - conditional_query += " AND " - value=bg[0] - comparison_operator=bg[1] - condition_string=f"m.band_gap {comparison_operator} {value}" - conditional_query += condition_string - - query += conditional_query - query +=" RETURN m" - results = self.execute_query(query) - return results - - def create_vector_index(self, - database_name:str, - property_dimensions:int, - similarity_function:str, - node_type:str=None, - node_property_name:str=None, - relationship_type:str=None, - relationship_property_name:str=None, - index_name:str=None, - ): - """ - Creates a vector index on a graph for either a node or a relationship. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/vector-index/ - - Args: - database_name (str): The name of the database. - property_dimensions (int): The number of dimensions in the vector. - similarity_function (str): The similarity function to use. - node_type (str,optional): The type of the nodes. This is optional and if not provided, - node_property_name (str, optional): The name of the node property. This is optional and if not provided, - relationship_type (str, optional): The type of the relationships. This is optional and if not provided, - relationship_property_name (str, optional): The name of the relationship property. This is optional and if not provided, - index_name (str, optional): The name of the index. This is optional and if not provided, - the default property name is used. - - Returns: - None - """ - - if node_type is None and node_property_name is None and relationship_type is None and relationship_property_name is None: - raise Exception("Either node_type and node_property_name or relationship_type and relationship_property_name must be provided") - if node_type is None and node_property_name is not None: - raise Exception("node_type must be provided if node_property_name is provided") - if relationship_type is None and relationship_property_name is not None: - raise Exception("relationship_type must be provided if relationship_property_name is provided") - if node_property_name is None and node_type is not None: - raise Exception("node_property_name must be provided if node_type is provided") - if relationship_property_name is None and relationship_type is not None: - raise Exception("relationship_property_name must be provided if relationship_type is provided") - - config={} - config['vector.dimensions']=property_dimensions - config['vector.similarity_function']=similarity_function - - cypher_statement=f"CREATE VECTOR INDEX {format_string(index_name)} IF NOT EXISTS" - - if node_type is not None: - cypher_statement+=f"FOR (n :{format_string(node_type)}) ON (n.{format_string(node_property_name)})" - - if relationship_type is not None: - cypher_statement+=f"FOR ()-[r :{format_string(relationship_type)}]-() ON (r.{format_string(relationship_property_name)})" - - cypher_statement+=" OPTIONS {indexConfig:" - cypher_statement+=f"{format_dictionary(config)}" - cypher_statement+="}" - - results=self.query(cypher_statement,database_name=database_name) - - outputs=[] - for result in results: - output={ key:value for key, value in result.items()} - outputs.append(output) - return outputs - - def check_vector_index(self, - database_name:str, - index_name:str, - ): - """ - Checks if a vector index exists on a graph. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/vector-index/ - - Args: - database_name (str): The name of the database. - index_name (str): The name of the index. - - Returns: - bool: True if the index exists, False otherwise. - """ - cypher_statement=f"CALL db.index.list('{format_string(index_name)}')" - results=self.query(cypher_statement,database_name=database_name) - if len(results)!=0: - return True - return False - - def drop_vector_index(self, - database_name:str, - index_name:str, - ): - """ - Drops a vector index on a graph. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/vector-index/ - - Args: - database_name (str): The name of the database. - index_name (str): The name of the index. - - Returns: - None - """ - cypher_statement=f"CALL db.index.drop('{format_string(index_name)}')" - self.query(cypher_statement,database_name=database_name) - return None - - def query_vector_index(self, - database_name:str, - graph_name:str, - index_name:str, - nearest_neighbors:int=10, - node_type:str=None, - node_property_name:str=None, - node_properties:dict=None, - relationship_type:str=None, - relationship_property_name:str=None, - relationship_properties:dict=None, - ): - """ - Queries a vector index on a graph for either a node or a relationship. - https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/vector-index/ - - Args: - database_name (str): The name of the database. - graph_name (str): The name of the graph. - index_name (str): The name of the index. - nearest_neighbors (int, optional): The number of nearest neighbors to return. Defaults to 10. - node_type (str,optional): The type of the nodes. This is optional and if not provided, - node_property_name (str, optional): The name of the node property for which the vector index is queried. - relationship_type (str, optional): The type of the relationships. This is optional and if not provided, - relationship_property_name (str, optional): The name of the relationship property. This is optional and if not provided, - - Returns: - list: A list of tuples representing the query results. - """ - if node_type is None and node_property_name is None and relationship_type is None and relationship_property_name is None: - raise Exception("Either node_type and node_property_name or relationship_type and relationship_property_name must be provided") - if node_type is None and node_property_name is not None: - raise Exception("node_type must be provided if node_property_name is provided") - if relationship_type is None and relationship_property_name is not None: - raise Exception("relationship_type must be provided if relationship_property_name is provided") - if node_property_name is None and node_type is not None: - raise Exception("node_property_name must be provided if node_type is provided") - if relationship_property_name is None and relationship_type is not None: - raise Exception("relationship_property_name must be provided if relationship_type is provided") - if node_type is not None and node_properties is None: - raise Exception("node_properties must be provided. This is to find a specific node") - if relationship_type is not None and relationship_properties is None: - raise Exception("relationship_properties must be provided. This is to find a specific relationship") - - if node_type: - cypher_statement=f"MATCH (m :{format_string(node_type)} {format_dictionary(node_properties)})" - cypher_statement+=f"CALL db.index.vector.queryNodes({format_string(index_name)}, {nearest_neighbors}, m.{format_string(node_property_name)})" - if relationship_type: - cypher_statement=f"MATCH ()-[r:{format_string(relationship_type)} {format_dictionary(relationship_properties)}]-()" - cypher_statement+=f"CALL db.index.vector.queryRelationships({format_string(index_name)}, {nearest_neighbors}, r.{format_string(relationship_property_name)})" - - outputs=[] - for result in self.query(cypher_statement,database_name=database_name): - output={ key:value for key, value in result.items()} - outputs.append(output) - return outputs - - def format_exported_node_file(self,file): - """ - Formats an exported node file. - - Args: - file (str): The file to format. - - Returns: - str: The formatted file. - """ - df=pd.read_csv(file) - # df.drop(df.columns[0], axis=1, inplace=True) - # column_names=list(df.columns) - # id_col_index=0 - # label_index=0 - # for i,col_name in enumerate(column_names): - # if "Id" in col_name: - # id_col_index=i - # if ':LABEL' in col_name: - # label_index=i - - # # Move the id column to the first position and the label to the second position in the dataframe - # # Get the 'Id' and ':LABEL' columns - # id_col = df.pop(column_names[id_col_index]) - # label_col = df.pop(column_names[label_index]) # Adjusting index because we've already popped id_col - # # Insert 'Id' column at the first position - # df.insert(0, column_names[id_col_index], id_col) - # # Insert ':LABEL' column at the second position - # df.insert(1, column_names[label_index], label_col) - # # Rename id column to 'id' - # node_name=column_names[id_col_index].split('Id')[0] - # df.rename(columns={column_names[id_col_index]:f'{column_names[id_col_index]}:ID({node_name}-ID)'}, inplace=True) - - return df - - def format_exported_relationship_file(self,file): - """ - Formats an exported relationship file. - - Args: - file (str): The file to format. - - Returns: - str: The formatted file. - """ - df=pd.read_csv(file) - # df.drop(df.columns[0], axis=1, inplace=True) - # column_names=list(df.columns) - # id_col_index=0 - # label_index=0 - # for i,col_name in enumerate(column_names): - # if "Id" in col_name: - # id_col_index=i - # if ':LABEL' in col_name: - # label_index=i - - return df - - def export_database(self, - graph_dir:str, - database_name:str, - batch_size:int=20000, - delimiter:str=',', - array_delimiter:str=';', - quotes:str="always", - ): - """ - Exports a database to a file. - https://neo4j.com/docs/apoc/current/export/csv/ - - Args: - graph_dir (str): The directory of the original graph. - database_name (str): The name of the database. - batch_size (int, optional): The batch size. Defaults to 20000. - delimiter (str, optional): The delimiter. Defaults to ','. - array_delimiter (str, optional): The array delimiter. Defaults to ';'. - quotes (str, optional): The quotes. Defaults to 'always'. - - Returns: - None - """ - mutated_graphs_dir=os.path.join(graph_dir,'mutated_neo4j_graphs') - mutated_graph_dir=os.path.join(mutated_graphs_dir,database_name,'neo4j_csv') - node_dir=os.path.join(mutated_graph_dir,'nodes') - relationship_dir=os.path.join(mutated_graph_dir,'relationships') - os.makedirs(node_dir,exist_ok=True) - os.makedirs(relationship_dir,exist_ok=True) - - config={} - config['bulkImport']=True - config['useTypes']=True - config['batchSize']=batch_size - config['delimiter']=delimiter - config['arrayDelimiter']=array_delimiter - config['quotes']=quotes - - cypher_statement=f"""CALL apoc.export.csv.all("tmp.csv", {format_dictionary(config)})""" - results=self.query(cypher_statement,database_name=database_name) - - import_dir=os.path.join(self.dbms_dir,"import") - files=glob(os.path.join(import_dir,'*.csv')) - for file in files: - filename=os.path.basename(file) - _, graph_element, element_type,_=filename.split('.') - element_type=element_type.lower() - if 'nodes' == graph_element: - df=self.format_exported_node_file(file) - new_file=os.path.join(node_dir,f'{graph_element}.csv') - df.to_csv(new_file) - elif 'relationships' == graph_element: - df=self.format_exported_node_file(file) - new_file=os.path.join(node_dir,f'{graph_element}.csv') - df.to_csv(new_file) - - for file in files: - os.remove(file) - - return results - - def set_apoc_environment_variables(self,settings:dict=None, overwrite:bool=False): - """ - Sets the apoc export file. - https://neo4j.com/docs/apoc/current/export/csv/ - - Args: - database_name (str): The name of the database. - - Returns: - None - """ - - conf_file=os.path.join(self.dbms_dir,'conf','apoc.conf') - # Reading the existing config - with open(conf_file, 'r') as file: - lines = file.readlines() - - # Modifying the config - with open(conf_file, 'w') as file: - for key, value in settings.items(): - new_line = f'{key}={value}\n' - for line in lines: - if key in line and overwrite: - new_line = f'{key}={value}\n' - else: - new_line = line - - file.write(new_line) - return None - - -# if __name__=='__main__': - # with Neo4jManager() as manager: - # results=manager.does_property_exist('elements-no-fe','Material','fastrp-embedding') - # print(results) - - # with Neo4jManager() as manager: - # results=manager.does_property_exist('elements-no-fe','Material','fastrp-embedding') - # print(results) \ No newline at end of file diff --git a/matgraphdb/graph_kit/neo4j/utils.py b/matgraphdb/graph_kit/neo4j/utils.py deleted file mode 100644 index 6ae2f9e..0000000 --- a/matgraphdb/graph_kit/neo4j/utils.py +++ /dev/null @@ -1,147 +0,0 @@ - -from typing import List, Union - -import os -import json - -import openai -import tiktoken - -from dotenv import load_dotenv -load_dotenv() - -from matgraphdb.utils import OPENAI_API_KEY - -def format_list(prop_list): - """ - Formats a list into a string for use in Cypher queries. - - Args: - prop_list (list): A list containing the properties. - - Returns: - str: A string representation of the properties. - """ - return [f"{prop}" for prop in prop_list] - -def format_string(prop_string): - """ - Formats a string into a string for use in Cypher queries. - - Args: - prop_string (str): A string containing the properties. - - Returns: - str: A string representation of the properties. - """ - return f"'{prop_string}'" - -def format_dictionary(prop_dict): - """ - Formats a dictionary into a string for use in Cypher queries. - - Args: - prop_dict (dict): A dictionary containing the properties. - - Returns: - str: A string representation of the properties. - """ - formatted_properties="{" - n_props=len(prop_dict) - for i,(prop_name,prop_params) in enumerate(prop_dict.items()): - if isinstance(prop_params,str): - formatted_properties+=f"{prop_name}: {format_string(prop_params)}" - elif isinstance(prop_params,int): - formatted_properties+=f"{prop_name}: {prop_params}" - elif isinstance(prop_params,float): - formatted_properties+=f"{prop_name}: {prop_params}" - elif isinstance(prop_params,List): - formatted_properties+=f"{prop_name}: {format_list(prop_params)}" - elif isinstance(prop_params,dict): - formatted_properties+=f"{prop_name}: {format_dictionary(prop_params)}" - - if i!=n_props-1: - formatted_properties+=", " - - formatted_properties+="}" - return formatted_properties - -def format_projection(projections:Union[str,List,dict]): - formatted_projections="" - if isinstance(projections,List): - formatted_projections=format_list(projections) - elif isinstance(projections,dict): - formatted_projections=format_dictionary(projections) - elif isinstance(projections,str): - formatted_projections=format_string(projections) - return formatted_projections - - - - - -def num_tokens_from_string(string: str, encoding_name: str) -> int: - """ - Returns the number of tokens in a text string. - - Parameters: - string (str): The input text string. - encoding_name (str): The name of the encoding to use. - - Returns: - int: The number of tokens in the text string. - """ - encoding = tiktoken.get_encoding(encoding_name) - num_tokens = len(encoding.encode(string)) - return num_tokens - -def get_embedding(text, client, model="text-embedding-3-small"): - """ - Get the embedding for a given text using the specified model. - - Args: - text (str): The input text to be embedded. - client: The client object used for embedding. - model (str, optional): The name of the model to use for embedding. Defaults to "text-embedding-3-small". - - Returns: - list: The embedding vector for the input text. - """ - text = text.replace("\n", " ") - return client.embeddings.create(input=[text], model=model).data[0].embedding -def get_embedding(text, client, model="text-embedding-3-small"): - - text = text.replace("\n", " ") - return client.embeddings.create(input = [text], model=model).data[0].embedding - -def get_similarity_query(prompt): - """ - Retrieves the similarity query for a given prompt. - - Args: - prompt (str): The prompt for which the similarity query is generated. - - Returns: - tuple: A tuple containing the embedding and the execute statement. - - Example: - >>> prompt = "What is the melting point of gold?" - >>> embedding, execute_statement = get_similarity_query(prompt) - """ - models=["text-embedding-3-small","text-embedding-3-large","ada v2"] - cost_per_token=[0.00000002,0.00000013,0.00000010] - model_index=0 - - MODEL=models[model_index] - - client = openai.OpenAI() - - embedding=get_embedding(prompt,client, model=MODEL) - - execute_statement=""" - CALL db.index.vector.queryNodes('material-text-embedding-3-small-embeddings', $nresults, $embedding) - YIELD node as sm, score - RETURN sm, score - """ - - return embedding, execute_statement \ No newline at end of file diff --git a/matgraphdb/graph_kit/nodes.py b/matgraphdb/graph_kit/nodes.py deleted file mode 100644 index 36d7ba1..0000000 --- a/matgraphdb/graph_kit/nodes.py +++ /dev/null @@ -1,853 +0,0 @@ -from glob import glob -import os -import warnings -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import logging - -from matgraphdb import config -# from matgraphdb.utils.chem_utils.coord_geometry import mp_coord_encoding -from matgraphdb.graph_kit.metadata import get_node_schema -from matgraphdb.graph_kit.metadata import NodeTypes - -logger = logging.getLogger(__name__) - -class Nodes: - """ - A base class to manage node operations, including creating, loading, and saving nodes as Parquet files, - with options to format data as either Pandas or PyArrow DataFrames. Subclasses should implement custom - logic for node creation and schema generation. - """ - - def __init__(self, node_type, node_dir, output_format='pandas'): - """ - Initializes a Nodes object with the given node type, directory, and output format. - - Parameters: - ----------- - node_type : str - The type of node to manage. - node_dir : str - Directory where node files will be stored. - output_format : str, optional - Format for loading data, either 'pandas' (default) or 'pyarrow'. Must be one of these two options. - - Raises: - ------- - ValueError - If output_format is not 'pandas' or 'pyarrow'. - """ - if output_format not in ['pandas','pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - self.node_type = node_type - self.node_dir = node_dir - os.makedirs(self.node_dir,exist_ok=True) - - self.output_format = output_format - self.file_type = 'parquet' - self.filepath = os.path.join(self.node_dir, f'{self.node_type}.{self.file_type}') - self.schema = self.create_schema() - - self.get_dataframe() - - def get_dataframe(self, columns=None, include_cols=True, from_scratch=False, **kwargs): - """ - Loads or creates a node dataframe. If the node file exists and from_scratch is False, - the existing file will be loaded. Otherwise, it will call the create_nodes() method to - create the nodes and save them. - - Parameters: - ----------- - columns : list of str, optional - A list of column names to load. If None, all columns will be loaded. - include_cols : bool, optional - If True (default), the specified columns will be included. If False, - they will be excluded from the loaded dataframe. - from_scratch : bool, optional - If True, forces the creation of a new node dataframe even if a file exists. Default is False. - **kwargs : dict - Additional arguments passed to the create_nodes() method. - - Returns: - -------- - pandas.DataFrame or pyarrow.Table - The loaded or newly created node data. - - Raises: - ------- - ValueError - If the 'name' field is missing from the created nodes dataframe. - """ - - if os.path.exists(self.filepath) and not from_scratch: - logger.info(f"Trying to load {self.node_type} nodes from {self.filepath}") - df = self.load_dataframe(filepath=self.filepath, columns=columns, include_cols=include_cols, **kwargs) - return df - - logger.info(f"No node file found. Attempting to create {self.node_type} nodes") - df = self.create_nodes(**kwargs) # Subclasses will define this - - # Ensure the 'name' field is present - if 'name' not in df.columns: - raise ValueError(f"The 'name' field must be defined for {self.node_type} nodes. Define this in the create_nodes.") - df['name'] = df['name'] # Ensure 'name' is set - df['type'] = self.node_type # Ensure 'type' is set - if columns: - df = df[columns] - - if not self.schema: - logger.error(f"No schema set for {self.node_type} nodes") - return None - - self.save_dataframe(df, self.filepath) - return df - - def get_property_names(self): - """ - Retrieves and logs the column names (properties) of the node data from the Parquet file. - - Returns: - -------- - list of str - A list of column names in the node file. - """ - properties = Nodes.get_column_names(self.filepath) - for property in properties: - logger.info(f"Property: {property}") - return properties - - def create_nodes(self, **kwargs): - """ - Abstract method for creating nodes. Must be implemented by subclasses to define the logic - for creating nodes specific to the node type. - - Raises: - ------- - NotImplementedError - If this method is not implemented in a subclass. - """ - if self.__class__.__name__ != 'Nodes': - raise NotImplementedError("Subclasses must implement this method.") - else: - pass - - def create_schema(self, **kwargs): - """ - Abstract method for creating a Parquet schema. Must be implemented by subclasses to define - the schema for the node data. - - Raises: - ------- - NotImplementedError - If this method is not implemented in a subclass. - """ - if self.__class__.__name__ != 'Nodes': - raise NotImplementedError("Subclasses must implement this method.") - else: - pass - - def load_dataframe(self, filepath, columns=None, include_cols=True, **kwargs): - """ - Loads node data from a Parquet file, optionally filtering by columns. - - Parameters: - ----------- - filepath : str - Path to the Parquet file. - columns : list of str, optional - A list of column names to load. If None, all columns will be loaded. - include_cols : bool, optional - If True (default), the specified columns will be included. If False, - they will be excluded from the loaded dataframe. - **kwargs : dict - Additional arguments for reading the Parquet file. - - Returns: - -------- - pandas.DataFrame or pyarrow.Table - The loaded node data. - """ - if not include_cols: - metadata = pq.read_metadata(filepath) - all_columns = [] - for filed_schema in metadata.schema: - - # Only want top column names - max_defintion_level=filed_schema.max_definition_level - if max_defintion_level!=1: - continue - - all_columns.append(filed_schema.name) - - columns = [col for col in all_columns if col not in columns] - - try: - if self.output_format=='pandas': - df = pd.read_parquet(filepath, columns=columns) - elif self.output_format=='pyarrow': - df = pq.read_table(filepath, columns=columns) - - return df - except Exception as e: - logger.error(f"Error loading {self.node_type} nodes from {filepath}: {e}") - return None - - def save_dataframe(self, df, filepath): - """ - Saves the given dataframe to a Parquet file at the specified filepath. - - Parameters: - ----------- - df : pandas.DataFrame or pyarrow.Table - The node data to save. - filepath : str - The path where the Parquet file should be saved. - - Raises: - ------- - Exception - If there is an error during the save process. - """ - try: - parquet_table = pa.Table.from_pandas(df, self.schema) - pq.write_table(parquet_table, filepath) - logger.info(f"Finished saving {self.node_type} nodes to {filepath}") - except Exception as e: - logger.error(f"Error converting dataframe to parquet table for saving: {e}") - - def to_neo4j(self, save_dir): - """ - Converts the node data to a CSV file for importing into Neo4j. Saves the file in the given directory. - - Parameters: - ----------- - save_dir : str - Directory where the CSV file will be saved. - """ - logger.info(f"Converting node to Neo4j : {self.filepath}") - node_type=os.path.basename(self.filepath).split('.')[0] - - logger.debug(f"Node type: {node_type}") - - metadata = pq.read_metadata(self.filepath) - column_types = {} - neo4j_column_name_mapping={} - for filed_schema in metadata.schema: - # Only want top column names - type=filed_schema.physical_type - - field_path=filed_schema.path.split('.') - name=field_path[0] - - is_list=False - if len(field_path)>1: - is_list=field_path[1] == 'list' - - column_types[name] = {} - column_types[name]['type']=type - column_types[name]['is_list']=is_list - - if type=='BYTE_ARRAY': - neo4j_type ='string' - if type=='BOOLEAN': - neo4j_type='boolean' - if type=='DOUBLE': - neo4j_type='float' - if type=='INT64': - neo4j_type='int' - - if is_list: - neo4j_type+='[]' - - column_types[name]['neo4j_type'] = f'{name}:{neo4j_type}' - column_types[name]['neo4j_name'] = f'{name}:{neo4j_type}' - - neo4j_column_name_mapping[name]=f'{name}:{neo4j_type}' - - neo4j_column_name_mapping['type']=':LABEL' - - df=self.load_nodes(filepath=self.filepath) - df.rename(columns=neo4j_column_name_mapping, inplace=True) - df.index.name = f'{node_type}:ID({node_type}-ID)' - - os.makedirs(save_dir,exist_ok=True) - - save_file=os.path.join(save_dir,f'{node_type}.csv') - logger.info(f"Saving {node_type} nodes to {save_file}") - - - df.to_csv(save_file, index=True) - - logger.info(f"Finished converting node to Neo4j : {node_type}") - - @staticmethod - def get_column_names(filepath): - """ - Extracts and returns the top-level column names from a Parquet file. - - This method reads the metadata of a Parquet file and extracts the names of the top-level columns. - It filters out nested columns or columns with a `max_definition_level` other than 1, ensuring that - only primary, non-nested columns are included in the output. - - Args: - filepath (str): The file path to the Parquet (.parquet) file. - - Returns: - list of str: A list containing the names of the top-level columns in the Parquet file. - - Example: - columns = Nodes.get_column_names('data/example.parquet') - print(columns) - # Output: ['column1', 'column2', 'column3'] - - """ - metadata = pq.read_metadata(filepath) - all_columns = [] - for filed_schema in metadata.schema: - - # Only want top column names - max_defintion_level=filed_schema.max_definition_level - if max_defintion_level!=1: - continue - - all_columns.append(filed_schema.name) - return all_columns - - -class MaterialNodes(Nodes): - """ - A specialized class for handling Material nodes within the node management system. - - This class inherits from the `Nodes` base class and is designed to manage nodes of type 'Material'. - It defines the schema for Material nodes and provides functionality to create these nodes from - an external Parquet file. - - Attributes: - node_dir (str): The directory where the Material node data is stored. - output_format (str): The format for returning node data, defaulting to 'pandas'. Options might include - 'pandas' for a DataFrame or other formats supported by the system. - - Methods: - create_schema(): - Defines and returns the schema for Material nodes. The schema is fetched based on the Material node type. - - create_nodes(**kwargs): - Reads the material node data from a Parquet file, processes it, and returns the nodes as a pandas DataFrame. - In case of an error during the reading process, it logs the error and returns `None`. - - Example: - material_nodes = MaterialNodes(node_dir='path/to/material_nodes') - """ - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.MATERIAL.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Material nodes - return get_node_schema(NodeTypes.MATERIAL) - - def create_nodes(self, **kwargs): - # The logic for creating material nodes - try: - df = pd.read_parquet(MATERIAL_PARQUET_FILE) - df['name'] = df.index + 1 # Assign 'name' field - except Exception as e: - logger.error(f"Error reading material parquet file: {e}") - return None - return df - -class ElementNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - """ - Initializes the `ElementNodes` class. - - Args: - node_dir (str): The directory where the node files are located. - output_format (str): The format in which the nodes will be outputted. Default is 'pandas'. - - Inherits the initialization from the parent `Nodes` class, setting the node type - as 'Element' and configuring the output format for the node data. - """ - super().__init__(node_type=NodeTypes.ELEMENT.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - """ - Defines and returns the schema for the `ElementNodes` class. - - Returns: - dict: A dictionary representing the schema for element nodes, - which includes all necessary fields for storing element information - such as oxidation states, ionization energies, and other properties. - - This method uses the `get_node_schema` function to generate the schema - specific to the 'Element' node type. - """ - # Define and return the schema for Element nodes - return get_node_schema(NodeTypes.ELEMENT) - - def create_nodes(self, base_element_csv='imputed_periodic_table_values.csv', **kwargs): - """ - Reads the element data from a CSV file and processes it for node creation. - - Args: - base_element_csv (str): The filename of the CSV containing element data. - Defaults to 'imputed_periodic_table_values.csv'. - **kwargs: Additional arguments for flexibility, if needed. - - Returns: - pandas.DataFrame: A DataFrame containing the processed element data, - ready to be used as nodes in the application. - """ - # Ensure the CSV file exists - csv_files = glob(os.path.join(PKG_DIR, 'utils', "*.csv")) - csv_filenames = [os.path.basename(file) for file in csv_files] - if base_element_csv not in csv_filenames: - raise ValueError(f"base_element_csv must be one of the following: {csv_filenames}") - - # Suppress warnings during node creation - warnings.filterwarnings("ignore", category=UserWarning) - - try: - df = pd.read_csv(os.path.join(PKG_DIR, 'utils', base_element_csv), index_col=0) - df['oxidation_states']=df['oxidation_states'].apply(lambda x: x.replace(']', '').replace('[', '')) - df['oxidation_states']=df['oxidation_states'].apply(lambda x: ','.join(x.split()) ) - df['oxidation_states']=df['oxidation_states'].apply(lambda x: eval('['+x+']') ) - df['experimental_oxidation_states']=df['experimental_oxidation_states'].apply(lambda x: eval(x) ) - df['ionization_energies']=df['ionization_energies'].apply(lambda x: eval(x) ) - - df['name'] = df['symbol'] # Assign 'name' field - except Exception as e: - logger.error(f"Error reading element CSV file: {e}") - return None - - return df - - -class CrystalSystemNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.CRYSTAL_SYSTEM.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for CrystalSystem nodes - return get_node_schema(NodeTypes.CRYSTAL_SYSTEM) - - def create_nodes(self, **kwargs): - """ - Creates Crystal System nodes if no file exists, otherwise loads them from a file. - """ - try: - crystal_systems = ['triclinic', 'monoclinic', 'orthorhombic', 'tetragonal', 'trigonal', 'hexagonal', 'cubic'] - crystal_systems_properties = [{"crystal_system": cs} for cs in crystal_systems] - - df = pd.DataFrame(crystal_systems_properties) - df['name'] = df['crystal_system'] # Assign 'name' field - except Exception as e: - logger.error(f"Error creating crystal system nodes: {e}") - return None - - return df - -class MagneticStatesNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.MAGNETIC_STATE.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Magnetic State nodes - return get_node_schema(NodeTypes.MAGNETIC_STATE) - - def create_nodes(self, **kwargs): - """ - Creates Magnetic State nodes if no file exists, otherwise loads them from a file. - """ - # Define magnetic states - try: - magnetic_states = ['NM', 'FM', 'FiM', 'AFM', 'Unknown'] - magnetic_states_properties = [{"magnetic_state": ms} for ms in magnetic_states] - - df = pd.DataFrame(magnetic_states_properties) - df['name'] = df['magnetic_state'] # Assign 'name' field - except Exception as e: - logger.error(f"Error creating magnetic state nodes: {e}") - return None - return df - -class OxidationStatesNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.OXIDATION_STATE.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Oxidation State nodes - return get_node_schema(NodeTypes.OXIDATION_STATE) - - def create_nodes(self, **kwargs): - """ - Creates Oxidation State nodes if no file exists, otherwise loads them from a file. - """ - - # Retrieve material nodes with possible valences - try: - # material_df = self.get_material_nodes(columns=['oxidation_states-possible_valences']) - # possible_oxidation_state_names = [] - # possible_oxidation_state_valences = [] - - # # Iterate through the material DataFrame to collect possible valences - # for _, row in material_df.iterrows(): - # possible_valences = row['oxidation_states-possible_valences'] - # if possible_valences is None: - # continue - # for possible_valence in possible_valences: - # oxidation_state_name = f'ox_{possible_valence}' - # if oxidation_state_name not in possible_oxidation_state_names: - # possible_oxidation_state_names.append(oxidation_state_name) - # possible_oxidation_state_valences.append(possible_valence) - - # # Create DataFrame with the collected oxidation state names and valences - # data = { - # 'oxidation_state': possible_oxidation_state_names, - # 'valence': possible_oxidation_state_valences - # } - - oxidation_states = np.arange(-9, 10) - oxidation_states_names = [f'ox_{i}' for i in oxidation_states] - data={ - 'oxidation_state': oxidation_states_names, - 'value': oxidation_states - } - df = pd.DataFrame(data) - df['name'] = df['oxidation_state'] # Assign 'name' field - except Exception as e: - logger.error(f"Error creating oxidation state nodes: {e}") - return None - return df - -class SpaceGroupNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.SPACE_GROUP.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Space Group nodes - return get_node_schema(NodeTypes.SPACE_GROUP) - - def create_nodes(self, **kwargs): - """ - Creates Space Group nodes if no file exists, otherwise loads them from a file. - """ - - # Generate space group numbers from 1 to 230 - try: - space_groups = [f'spg_{i}' for i in np.arange(1, 231)] - space_groups_properties = [{"spg": int(space_group.split('_')[1])} for space_group in space_groups] - - # Create DataFrame with the space group properties - df = pd.DataFrame(space_groups_properties) - df['name'] = df['spg'].astype(str) # Assign 'name' field as string version of 'spg' - except Exception as e: - logger.error(f"Error creating space group nodes: {e}") - return None - - return df - -class ChemEnvNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.CHEMENV.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for ChemEnv nodes - return get_node_schema(NodeTypes.CHEMENV) - - def create_nodes(self, **kwargs): - """ - Creates ChemEnv nodes if no file exists, otherwise loads them from a file. - """ - # Get the chemical environment names from a dictionary (mp_coord_encoding) - try: - chemenv_names = list(mp_coord_encoding.keys()) - chemenv_names_properties = [] - - # Create a list of dictionaries with 'chemenv_name' and 'coordination' - for chemenv_name in chemenv_names: - coordination = int(chemenv_name.split(':')[1]) - chemenv_names_properties.append({ - "chemenv_name": chemenv_name, - "coordination": coordination - }) - - # Create DataFrame with the chemical environment names and coordination numbers - df = pd.DataFrame(chemenv_names_properties) - df['name'] = df['chemenv_name'].str.replace(':', '_') # Replace ':' with '_' for 'name' field - except Exception as e: - logger.error(f"Error creating chemical environment nodes: {e}") - return None - - return df - -class WyckoffPositionsNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.SPG_WYCKOFF.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Wyckoff Positions nodes - return get_node_schema(NodeTypes.SPG_WYCKOFF) - - def create_nodes(self, **kwargs): - """ - Creates Wyckoff Position nodes if no file exists, otherwise loads them from a file. - """ - logger.info(f"No node file found. Attempting to create {self.node_type} nodes") - - # Generate space group names from 1 to 230 - try: - space_groups = [f'spg_{i}' for i in np.arange(1, 231)] - # Define Wyckoff letters - wyckoff_letters = ['a', 'b', 'c', 'd', 'e', 'f'] - - # Create a list of space group-Wyckoff position combinations - spg_wyckoffs = [f"{spg}_{wyckoff_letter}" for wyckoff_letter in wyckoff_letters for spg in space_groups] - - # Create a list of dictionaries with 'spg_wyckoff' - spg_wyckoff_properties = [{"spg_wyckoff": spg_wyckoff} for spg_wyckoff in spg_wyckoffs] - - # Create DataFrame with Wyckoff positions - df = pd.DataFrame(spg_wyckoff_properties) - df['name'] = df['spg_wyckoff'] # Assign 'name' field - except Exception as e: - logger.error(f"Error creating Wyckoff position nodes: {e}") - return None - - return df - -class MaterialLatticeNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.LATTICE.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Lattice nodes - return get_node_schema(NodeTypes.LATTICE) - - def create_nodes(self,**kwargs): - """ - Creates Lattice nodes if no file exists, otherwise loads them from a file. - """ - - # Retrieve material nodes with lattice properties - try: - df = pd.read_parquet(MATERIAL_PARQUET_FILE,columns=['material_id', 'lattice', 'a', 'b', 'c', - 'alpha', 'beta', 'gamma', 'crystal_system', 'volume']) - - # Set the 'name' field as 'material_id' - df['name'] = df['material_id'] - - except Exception as e: - logger.error(f"Error creating lattice nodes: {e}") - return None - - return df - -class MaterialSiteNodes(Nodes): - def __init__(self, node_dir, output_format='pandas'): - super().__init__(node_type=NodeTypes.SITE.value, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for Site nodes - return get_node_schema(NodeTypes.SITE) - - def create_nodes(self, **kwargs): - """ - Creates Site nodes if no file exists, otherwise loads them from a file. - """ - - # Retrieve material nodes with relevant site properties - try: - df = pd.read_parquet(MATERIAL_PARQUET_FILE,columns=['material_id', 'lattice', 'frac_coords', 'species']) - - all_species = [] - all_coords = [] - all_lattices = [] - all_ids = [] - - # Iterate through each row of the DataFrame - for irow, row in df.iterrows(): - if irow % 10000 == 0: - logger.info(f"Processing row {irow}") - if row['species'] is None: - continue - - # Collect species, fractional coordinates, lattices, and material IDs - for frac_coord, specie in zip(row['frac_coords'], row['species']): - all_species.append(specie) - all_coords.append(frac_coord) - all_lattices.append(row['lattice']) - all_ids.append(row['material_id']) - - # Create DataFrame for Site nodes - df = pd.DataFrame({ - 'species': all_species, - 'frac_coords': all_coords, - 'lattice': all_lattices, - 'material_id': all_ids - }) - - df['name'] = df['material_id'] # Assign 'name' field as 'material_id' - except Exception as e: - logger.error(f"Error creating site nodes: {e}") - return None - return df - - -class NodeManager: - def __init__(self, node_dir, output_format='pandas'): - """ - Initialize the NodesManager with the directory where nodes are stored. - """ - if output_format not in ['pandas','pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - self.node_dir = node_dir - os.makedirs(self.node_dir, exist_ok=True) - self.file_type = 'parquet' - self.get_existing_nodes() - - def get_existing_nodes(self): - self.nodes = set(self.list_nodes()) - return self.nodes - - def list_nodes(self): - """ - List all node files available in the node directory. - """ - node_files = [f for f in os.listdir(self.node_dir) if f.endswith(f'.{self.file_type}')] - node_types = [os.path.splitext(f)[0] for f in node_files] # Extract file names without extension - logger.info(f"Found the following node types: {node_types}") - return node_types - - def get_node(self, node_type, output_format='pandas'): - """ - Load a node dataframe by its type (which corresponds to the filename without extension). - """ - if output_format not in ['pandas','pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - filepath = os.path.join(self.node_dir, f'{node_type}.{self.file_type}') - - if not os.path.exists(filepath): - logger.error(f"No node file found for type: {node_type}") - return None - - nodes=Nodes(node_type=node_type, node_dir=self.node_dir,output_format=output_format) - - return nodes - - def get_node_dataframe(self, node_type, columns=None, include_cols=True, output_format='pandas', **kwargs): - """ - Return the node dataframe if it has already been loaded; otherwise, load it from file. - """ - if output_format not in ['pandas','pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - return self.get_node(node_type, output_format=output_format).get_dataframe(columns=columns, include_cols=include_cols) - - def add_node(self, node_class): - """ - Add a new node by providing a custom node class (must inherit from the base Node class). - The node class must implement its own creation logic. - """ - if not issubclass(node_class, Nodes): - raise TypeError("The provided class must inherit from the Nodes class.") - - node = node_class(node_dir=self.node_dir) # Initialize the node class - node.get_dataframe() # Get or create the node dataframe - - self.get_existing_nodes() - - def delete_node(self, node_type): - """ - Delete a node type. This method will remove the parquet file and the node from the self.nodes set. - """ - filepath = os.path.join(self.node_dir, f'{node_type}.{self.file_type}') - - if os.path.exists(filepath): - try: - os.remove(filepath) - self.nodes.discard(node_type) # Remove from the set of nodes - logger.info(f"Deleted node of type {node_type} and removed it from the node set.") - except Exception as e: - logger.error(f"Error deleting node of type {node_type}: {e}") - else: - logger.warning(f"No node file found for type {node_type} to delete.") - - def convert_all_to_neo4j(self, save_dir): - """ - Convert all Parquet node files in the node directory to Neo4j CSV format. - """ - os.makedirs(save_dir, exist_ok=True) - for node_type in self.nodes: - logger.info(f"Converting {node_type} to Neo4j CSV format.") - try: - node = self.get_node(node_type) # Load the node - if node is not None: - node.to_neo4j(save_dir) # Convert to Neo4j format - logger.info(f"Successfully converted {node_type} to Neo4j CSV.") - else: - logger.warning(f"Skipping {node_type} as it could not be loaded.") - except Exception as e: - logger.error(f"Error converting {node_type} to Neo4j CSV: {e}") - - - - -if __name__ == "__main__": - node_dir = os.path.join('data','raw','nodes') - node=ElementNodes(node_dir=node_dir) - print(node.get_property_names()) - - nodes=SpaceGroupNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=MagneticStatesNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=MaterialNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=CrystalSystemNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=OxidationStatesNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=ChemEnvNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=WyckoffPositionsNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=MaterialLatticeNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - nodes=MaterialSiteNodes(node_dir=node_dir) - print(nodes.get_property_names()) - - - - - # node=Nodes(node_type='ELEMENT', node_dir=node_dir) - - # print(node.get_property_names()) - - - - # df = node.get_nodes() - # print(df.head()) - - - # node=MaterialNodes(node_dir=node_dir) - # df = node.get_nodes() - # print(df.head()) - - manager=NodeManager(node_dir=node_dir) - - print(manager.nodes) - - df=manager.get_node_dataframe('ELEMENT') - print(df.head()) diff --git a/matgraphdb/graph_kit/pyg/__init__.py b/matgraphdb/graph_kit/pyg/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/graph_kit/pyg/algo.py b/matgraphdb/graph_kit/pyg/algo.py deleted file mode 100644 index 5a46489..0000000 --- a/matgraphdb/graph_kit/pyg/algo.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -from torch_geometric.data import Data -from torch_geometric.transforms import FeaturePropagation - - -def feature_propagation(data=None, x=None, edge_index=None, **kwargs): - """ - Perform feature propagation on a graph to impute missing values. - - This function applies the FeaturePropagation transformation to a given - graph's node features, filling in missing values based on the graph - structure and available features. The function accepts either a PyTorch - Geometric `Data` object, or separate `x` (node features) and `edge_index` - (graph edges) arrays. - - Parameters: - ----------- - data : torch_geometric.data.Data, optional - A PyTorch Geometric Data object containing the node features `x` and - the adjacency information `edge_index`. If provided, `x` and - `edge_index` are extracted from this object. - - x : torch.Tensor, optional - A tensor containing the node features. Must be provided if `data` - is not given. - - edge_index : torch.Tensor, optional - A tensor defining the edge connections in the graph, represented - as a list of source and target node indices. Must be provided if - `data` is not given. - - **kwargs : dict - Additional keyword arguments passed to the `FeaturePropagation` - class, which controls the behavior of the feature propagation - algorithm (e.g., the propagation method, number of iterations, etc.). - - Returns: - -------- - numpy.ndarray - The imputed node features as a NumPy array, where missing values - have been filled through feature propagation. - - Raises: - ------- - ValueError - If neither `data` nor both `x` and `edge_index` are provided. - - Example: - -------- - >>> imputed_features = feature_propagation(x=my_node_features, - edge_index=my_edge_index) - """ - if data is None and x is None and edge_index is None: - raise ValueError("Either data or x and edge_index must be provided") - if data is None: - data = Data(x=x, edge_index=edge_index) - else: - x=data.x - transform = FeaturePropagation(missing_mask=torch.isnan(x), **kwargs) - homo_graph_transformed = transform(data) - return homo_graph_transformed.x.numpy() diff --git a/matgraphdb/graph_kit/pyg/callbacks.py b/matgraphdb/graph_kit/pyg/callbacks.py deleted file mode 100644 index e2b77a2..0000000 --- a/matgraphdb/graph_kit/pyg/callbacks.py +++ /dev/null @@ -1,219 +0,0 @@ -import os -import copy -import json - -import torch - -from matgraphdb.graph_kit.pyg.metrics import RegressionMetrics, ClassificationMetrics - - - -class EarlyStopping(): - def __init__(self, patience=5, min_delta=0, restore_best_weights=True): - """The early stopping callback - - Parameters - ---------- - patience : int, optional - The number of epochs to wait to see improvement on loss, by default 5 - min_delta : float, optional - The difference theshold for determining if a result is better, by default 0 - restore_best_weights : bool, optional - Boolean to restore weights, by default True - """ - self.patience = patience - self.min_delta = min_delta - self.restore_best_weights = restore_best_weights - self.best_model = None - self.best_loss = None - self.best_mape_loss = None - self.best_mae_loss = None - self.counter = 0 - self.status = 0 - - def __call__(self, model, test_loss:float, mae_loss): - """The class calling method - - Parameters - ---------- - model : torch.nn.Module - The pytorch model - test_loss : float - The validation loss - mape_val_loss : float - The map_val_loss - - Returns - ------- - _type_ - _description_ - """ - if self.best_loss == None: - self.best_loss = test_loss - self.best_mae_loss = mae_loss - - self.best_model = copy.deepcopy(model) - elif self.best_loss - test_loss > self.min_delta: - self.best_loss = test_loss - self.best_mae_loss = mae_loss - - self.counter = 0 - self.best_model.load_state_dict(model.state_dict()) - elif self.best_loss - test_loss < self.min_delta: - self.counter +=1 - if self.counter >= self.patience: - self.status = f'Stopped on {self.counter}' - if self.restore_best_weights: - model.load_state_dict(self.best_model.state_dict()) - return True - self.status = f"{self.counter}/{self.patience}" - # print( self.status ) - return False - - -class MetricsTacker(): - def __init__(self,save_path, is_regression=False): - self.metrics_dict={} - self.n_metrics=0 - self.is_regression=is_regression - - self.metrics_dict['train']={} - self.metrics_dict['test']={} - self.save_path=save_path - - if self.is_regression: - self.get_regression_metrics(split='train') - self.get_regression_metrics(split='test') - else: - self.get_classification_metrics(split='train') - self.get_classification_metrics(split='test') - - def get_regression_metrics(self,split): - self.metrics_dict[split]['mse']=[] - self.metrics_dict[split]['mae']=[] - self.metrics_dict[split]['rmse']=[] - self.metrics_dict[split]['msle']=[] - self.metrics_dict[split]['r2']=[] - self.metrics_dict[split]['adjusted_r2']=[] - self.metrics_dict[split]['explained_variance_score']=[] - self.metrics_dict[split]['mape']=[] - self.metrics_dict[split]['huber_loss']=[] - self.metrics_dict[split]['batch_loss']=[] - self.metrics_dict[split]['epoch']=[] - - self.metric_names=list(self.metrics_dict.keys()) - - def get_classification_metrics(self,split): - self.metrics_dict[split]['accuracy']=[] - self.metrics_dict[split]['class_weights']=[] - self.metrics_dict[split]['confusion_matrix']=[] - self.metrics_dict[split]['batch_loss']=[] - self.metrics_dict[split]['epoch']=[] - - self.metric_names=list(self.metrics_dict.keys()) - - def calculate_metrics(self,y_pred,y_true, batch_loss, epoch, n_features,num_classes, split): - """ - Calculates the metrics for a given set of predictions and true values. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - batch_loss (torch.Tensor): The trained loss. - epoch (int): The current epoch. - n_features (int): The number of features. - split (str): The split for which the metrics are being calculated. - - Returns: - None - """ - if self.is_regression: - self.metrics_dict[split]['mse'].append(RegressionMetrics.mean_squared_error(y_pred, y_true)) - self.metrics_dict[split]['mae'].append(RegressionMetrics.mean_absolute_error(y_pred, y_true)) - self.metrics_dict[split]['rmse'].append(RegressionMetrics.root_mean_squared_error(y_pred, y_true)) - self.metrics_dict[split]['mape'].append(RegressionMetrics.mean_absolute_percentage_error(y_pred, y_true)) - self.metrics_dict[split]['msle'].append(RegressionMetrics.mean_squared_logarithmic_error(y_pred, y_true)) - self.metrics_dict[split]['r2'].append(RegressionMetrics.r_squared(y_pred, y_true)) - self.metrics_dict[split]['adjusted_r2'].append(RegressionMetrics.adjusted_r_squared(y_pred, y_true,n_features)) - self.metrics_dict[split]['explained_variance_score'].append(RegressionMetrics.explained_variance_score(y_pred, y_true)) - self.metrics_dict[split]['huber_loss'].append(RegressionMetrics.huber_loss(y_pred, y_true)) - self.metrics_dict[split]['batch_loss'].append(batch_loss) - else: - self.metrics_dict[split]['accuracy'].append(ClassificationMetrics.accuracy(y_pred, y_true)) - confusion_matrix=ClassificationMetrics.confusion_matrix(y_pred, y_true, num_classes) - self.metrics_dict[split]['confusion_matrix'].append(confusion_matrix) - self.metrics_dict[split]['batch_loss'].append(batch_loss) - self.metrics_dict[split]['class_weights'].append(ClassificationMetrics.class_weights(y_true)) - self.metrics_dict[split]['epoch'].append(epoch) - self.n_metrics+=1 - - def get_metrics_dict(self): - return self.metrics_dict - - def get_metric_names(self): - return self.metric_names - - def format_for_json(self): - """ - Formats all metrics for JSON serialization, converting tensors to lists. - """ - formatted_dict = {} - for key, value in self.metrics_dict.items(): - print(key) - if isinstance(value, dict): - formatted_dict[key] = {k: self._tensor_to_list(v) for k, v in value.items()} - else: - formatted_dict[key] = self._tensor_to_list(value) - return formatted_dict - - def _tensor_to_list(self, item): - """ - Converts a tensor to a list or returns the item if it's not a tensor. - """ - if isinstance(item, torch.Tensor): - return item.tolist() # Convert tensors to lists - elif isinstance(item, list): - return [self._tensor_to_list(x) for x in item] # Recursively process lists - else: - return item # Return the item as is if not a tensor or list - - def save_metrics(self): - """ - Saves the metrics to a file. - - Args: - path (str): The path to save the metrics to. - """ - formatted_data = self.format_for_json() - with open(os.path.join(self.save_path,'metrics.json'),'w') as f: - json.dump(formatted_data, f) - - -class Checkpointer: - def __init__(self, save_path, verbose=1): - """ - Initializes the ModelCheckpoint callback. - - Args: - save_path (str): Directory where the model checkpoints will be saved. - verbose (int): Verbosity mode, 0 or 1. - - """ - self.save_path = save_path - self.verbose = verbose - - - def save_model(self, model, epoch, checkpoint_name=None): - """ - Saves the model to the specified path. - """ - if not os.path.exists(self.save_path): - os.makedirs(self.save_path) - if checkpoint_name is None: - filename = f'model_epoch_{epoch:04d}.pth' - else: - filename = f'{checkpoint_name}.pth' - filepath = os.path.join(self.save_path, filename) - torch.save(model.state_dict(), filepath) - if self.verbose: - print(f"Model checkpoint saved: {filepath}") \ No newline at end of file diff --git a/matgraphdb/graph_kit/pyg/datasets/__init__.py b/matgraphdb/graph_kit/pyg/datasets/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/matgraphdb/graph_kit/pyg/encoders.py b/matgraphdb/graph_kit/pyg/encoders.py deleted file mode 100644 index 3e58fbe..0000000 --- a/matgraphdb/graph_kit/pyg/encoders.py +++ /dev/null @@ -1,240 +0,0 @@ -import math -import os - -import torch - -import numpy as np -import pandas as pd - -class CategoricalEncoder: - def __init__(self, sep='|'): - self.sep = sep - - def __call__(self, df): - genres = set(g for col in df.values for g in col.split(self.sep)) - mapping = {genre: i for i, genre in enumerate(genres)} - - x = torch.zeros(len(df), len(mapping)) - for i, col in enumerate(df.values): - for genre in col.split(self.sep): - x[i, mapping[genre]] = 1 - return x - -class ClassificationEncoder: - """Converts a column of of unique itentities into a torch tensor. One hot encoding""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - # Find unique values in the column - unique_values = df.unique() - # Create a dictionary mapping unique values to integers - value_to_index = {value: i for i, value in enumerate(unique_values)} - tensor=torch.zeros(len(df),len(unique_values)) - - for irow,elements in enumerate(df): - tensor[irow,value_to_index[elements]]=1 - return tensor - -class BooleanEncoder: - """Converts a column of boolean values into a torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - # Convert boolean values to integers (True to 1, False to 0) - boolean_integers = df.astype(int) - # Create a Torch tensor from the numpy array, ensure it has the correct dtype - return torch.from_numpy(boolean_integers.values).view(-1, 1).type(self.dtype) - -class IdentityEncoder: - """Converts a column of numbers into torch tensor.""" - def __init__(self, dtype=torch.float32): - self.dtype = dtype - - def __call__(self, df): - tensor=torch.from_numpy(df.values).view(-1, 1).to(self.dtype) - return tensor - -class ListIdentityEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - values=[] - for irow,row in enumerate(df): - values.append(row) - values=np.array(values) - - tensor=torch.from_numpy(values).to(self.dtype) - return tensor - -class IonizationEnergiesEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - self.column_names=['mean_ionization_energy', - 'standard_deviation_ionization_energy', - 'min_ionization_energy', - 'max_ionization_energy', - 'median_ionization_energy'] - - def __call__(self, df): - values=[] - for irow,row in enumerate(df): - - if len(row)==0: - embedding=[0,0,0,0,0] - continue - - mean=calculate_mean(row) - std=calculate_standard_deviation(row) - min_val=calculate_min(row) - max_val=calculate_max(row) - median=calculate_median(row) - embedding=[mean,std,min_val,max_val,median] - values.append(embedding) - - values=np.array(values) - - tensor=torch.from_numpy(values).to(self.dtype) - return tensor - -class OxidationStatesEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - self.column_names=['mean_ionization_energy', - 'standard_deviation_ionization_energy', - 'min_ionization_energy', - 'max_ionization_energy', - 'median_ionization_energy'] - - def __call__(self, df): - values=[] - for irow,row in enumerate(df): - - if len(row)==0: - embedding=[0,0,0,0,0] - values.append(embedding) - continue - - mean=calculate_mean(row) - std=calculate_standard_deviation(row) - min_val=calculate_min(row) - max_val=calculate_max(row) - median=calculate_median(row) - embedding=[mean,std,min_val,max_val,median] - values.append(embedding) - - values=np.array(values) - - tensor=torch.from_numpy(values).to(self.dtype) - return tensor - -class ElementsEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - from matgraphdb.utils.chem_utils.periodic import atomic_symbols - tensor=torch.zeros(len(df),118) - element_to_z={element:i-1 for i,element in enumerate(atomic_symbols)} - for irow,elements in enumerate(df): - elemnt_indices=[element_to_z[e] for e in elements.split(';')] - tensor[irow,elemnt_indices]+=1 - return tensor - -class CompositionEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - from matgraphdb.utils.chem_utils.periodic import atomic_symbols - import ast - tensor=torch.zeros(len(df),118) - element_to_z={element:i-1 for i,element in enumerate(atomic_symbols)} - for irow,comp_string in enumerate(df): - comp_mapping=ast.literal_eval(comp_string) - for element,comp_val in comp_mapping.items(): - element_index=element_to_z[element] - tensor[irow,element_index]+=comp_val - # Normalize tensor by row - tensor=tensor/tensor.sum(axis=1, keepdims=True) - return tensor - -class SpaceGroupOneHotEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - tensor=torch.zeros(len(df),230) - for irow,space_group in enumerate(df): - tensor[irow,space_group-1]+=1 - return tensor - -class IntegerOneHotEncoder: - """Converts a column of list of numbers into torch tensor.""" - def __init__(self, dtype=None): - self.dtype = dtype - - def __call__(self, df): - - possible_values=[] - for irow,value in enumerate(df): - possible_values.append(value) - - tensor=torch.zeros(len(df),len(possible_values)) - for irow,value in enumerate(df): - index_value=value-1 - tensor[irow,index_value]+=1 - return tensor - - -def calculate_mean(numbers): - return sum(numbers) / len(numbers) - -def calculate_standard_deviation(numbers): - mean = calculate_mean(numbers) - variance = sum((x - mean) ** 2 for x in numbers) / len(numbers) - return math.sqrt(variance) - -def calculate_min(numbers): - return min(numbers) - -def calculate_max(numbers): - return max(numbers) - -def calculate_median(numbers): - sorted_numbers = sorted(numbers) - n = len(sorted_numbers) - mid = n // 2 - if n % 2 == 0: - return (sorted_numbers[mid - 1] + sorted_numbers[mid]) / 2 - else: - return sorted_numbers[mid] - -# if __name__ == "__main__": - # import pandas as pd - # import os - # import matplotlib.pyplot as plt - # from matgraphdb.graph.material_graph import MaterialGraph - # from matgraphdb.mlcore.transforms import min_max_normalize, standardize_tensor - - # material_graph=MaterialGraph() - # graph_dir = material_graph.graph_dir - # nodes_dir = material_graph.node_dir - # relationship_dir = material_graph.relationship_dir - - - # node_names=material_graph.list_nodes() - # relationship_names=material_graph.list_relationships() - - # node_files=material_graph.get_node_filepaths() - # relationship_files=material_graph.get_relationship_filepaths() diff --git a/matgraphdb/graph_kit/pyg/graph_models.py b/matgraphdb/graph_kit/pyg/graph_models.py deleted file mode 100644 index bceb993..0000000 --- a/matgraphdb/graph_kit/pyg/graph_models.py +++ /dev/null @@ -1,301 +0,0 @@ -# Creating a GraphSAGE model -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import GCNConv, SAGEConv, GINConv,to_hetero -from torch_geometric.nn import global_mean_pool, global_add_pool -from torch_geometric.data import DataLoader -from torch_geometric.utils import negative_sampling -from torch_geometric.datasets import Planetoid -from torch_geometric.loader import LinkNeighborLoader - - -from matgraphdb.graph_kit.pyg.models import MultiLayerPerceptron -# Define the model -class StackedSAGELayers(nn.Module): - def __init__(self, in_channels, hidden_channels, num_layers, dropout=0.2 ): - super(StackedSAGELayers, self).__init__() - self.dropout=dropout - self.conv1 = SAGEConv(in_channels, hidden_channels) - self.convs = nn.ModuleList([ - SAGEConv(hidden_channels, hidden_channels) - for _ in range(num_layers - 2) - ]) - - - def forward(self, x: torch.Tensor, edge_index: torch.Tensor,training=False): - x = F.relu(self.conv1(x, edge_index)) - # x = F.dropout(x, p=self.dropout,training=training) - for conv in self.convs: - x = F.relu(conv(x, edge_index)) - # x = F.dropout(x, p=self.dropout,training=training) - - return x - - - - -class SupervisedHeteroSAGEModel(nn.Module): - def __init__(self, data, - hidden_channels:int, - out_channels:int, - pred_node_type:str, - num_layers, - device='cuda:0'): - super(SupervisedHeteroSAGEModel, self).__init__() - - self.embs = nn.ModuleDict() - self.data_lins = nn.ModuleDict() - self.node_type=pred_node_type - for node_type in data.node_types: - num_nodes = data[node_type].num_nodes - num_features = data[node_type].num_node_features - self.embs[node_type]=nn.Embedding(num_nodes,hidden_channels,device=device) - if num_features != 0: - self.data_lins[node_type]=nn.Linear(num_features, hidden_channels,device=device) - - self.output_layer = nn.Linear(hidden_channels, out_channels) - - # Initialize and convert GraphSAGE to heterogeneous - self.graph_sage = StackedSAGELayers(hidden_channels,hidden_channels,num_layers) - self.graph_sage = to_hetero(self.graph_sage, metadata=data.metadata()) - - def forward(self, data): - x_dict={} - for node_type, emb_layer in self.embs.items(): - # Handling nodes based on feature availability - if node_type in self.data_lins: - x_dict[node_type] = self.data_lins[node_type](data[node_type].x) + emb_layer(data[node_type].node_id) - else: - x_dict[node_type] = emb_layer(data[node_type].node_id) - - x_dict=self.graph_sage(x_dict, data.edge_index_dict) - - out=self.output_layer(x_dict[self.node_type]) - return out - - - - -def get_node_dataloaders(data,shuffle=False): - input_loaders={} - for node_item in data.node_items(): - node_type=node_item[0] - node_dict=node_item[1] - node_ids=node_dict['node_id'] - input_nodes=(node_type, node_ids) - test_loader = NeighborLoader( - data, - # Sample 15 neighbors for each node and each edge type for 2 iterations: - num_neighbors=[15] * 2, - replace=False, - subgraph_type="bidirectional", - disjoint=False, - weight_attr = None, - transform=None, - transform_sampler_output = None, - - input_nodes=input_nodes, - shuffle=shuffle, - batch_size=128, - ) - input_loaders[node_type]=test_loader - - return input_loaders - -def train(model, optimizer, dataloader_dict, loss_fn=nn.CrossEntropyLoss()): - model.train() - # optimizer.zero_grad() - # node_train_loss = 0.0 - - - - - # train_loss.backward() - # optimizer.step() - # optimizer.zero_grad() - - # batch_train_loss = batch_train_loss / num_batches - - # print(f"Loss: {batch_train_loss}") - # node_train_loss += batch_train_loss - # return batch_train_loss - -def evaluate(model, dataloader_dict): - model.eval() - with torch.no_grad(): - for node_type,dataloader in dataloader_dict.items(): - for data in dataloader: - data.to(device) - z_dict = model(data) - z_dict = model(data) - loss = negative_sampling_hetero_loss(z_dict, data.edge_index_dict) - return loss.item() - -if __name__ == "__main__": - import torch - import torch.nn as nn - import torch_geometric.transforms as T - from torch_geometric.sampler import NegativeSampling - from matgraphdb.mlcore.datasets import MaterialGraphDataset - from matgraphdb.mlcore.loss import negative_sampling_hetero_loss,positive_sampling_hetero_random_walk_loss - from torch_geometric.loader import NeighborLoader - - - - - graph_dataset=MaterialGraphDataset.ec_element_chemenv( - use_weights=True, - use_node_properties=True, - properties=['atomic_number','group','row','atomic_mass'] - #,properties=['group','atomic_number'] - ) - print(graph_dataset.data) - # print(dir(graph_dataset.data)) - - - rev_edge_types=[] - edge_types=[] - for edge_type in graph_dataset.data.edge_types: - rel_type=edge_type[1] - if 'rev' in rel_type: - rev_edge_types.append(edge_type) - else: - edge_types.append(edge_type) - print(edge_types) - print(rev_edge_types) - transform = T.RandomLinkSplit( - num_val=0.1, - num_test=0.1, - disjoint_train_ratio=0.3, - neg_sampling_ratio=2.0, - add_negative_train_samples=False, - edge_types=edge_types, - rev_edge_types=rev_edge_types, - ) - - # transform = T.RandomNodeSplit(split="random") # Or another appropriate method - - train_data, val_data, test_data = transform(graph_dataset.data) - print(dir(test_data)) - # training_graph, _ = InMemoryDataset.collate(train_data_list) - # test_graph, _ = InMemoryDataset.collate(test_data_list) - - # print(test_data['material', 'has', 'element'].edge_label_index) - # print(test_data['material', 'has', 'element'].edge_label) - # print(test_data['material', 'has', 'element'].edge_label.shape) - # print(test_data['material', 'has', 'chemenv'].edge_label) - # print(test_data['material', 'has', 'chemenv'].edge_label.shape) - - - # print(test_data['element', 'electric_connects', 'element'].edge_label) - # print(test_data['element', 'electric_connects', 'element'].edge_label.shape) - # for x in test_data['material', 'has', 'chemenv'].edge_label: - # print(x) - # print(test_data.edge_index) - # edge_label_index = train_data["user", "rates", "movie"].edge_label_index - # edge_label = train_data["user", "rates", "movie"].edge_la - - # train_loader = LinkNeighborLoader( - # data=train_data, - # num_neighbors=[20, 10], - # neg_sampling_ratio=2.0, - # edge_label_index=(("user", "rates", "movie"), edge_label_index), - # edge_label=edge_label, - # batch_size=128, - # shuffle=True, - # ) - - - # print("Train Data") - # print("-"*200) - # print(train_data) - # print("Val Data") - # print("-"*200) - # print(val_data) - # print("Test Data") - # print("-"*200) - # print(test_data) - - - device= "cuda:0" if torch.cuda.is_available() else torch.device("cpu") - print(device) - # Define the model - model = SupervisedHeteroSAGEModel(test_data, - hidden_channels=128, - out_channels=1, - num_layers=1, - device=device) - model.to(device) - - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - - train_loader_dict=get_node_dataloaders(train_data,shuffle=True) - val_loader_dict=get_node_dataloaders(val_data,shuffle=False) - test_loader_dict=get_node_dataloaders(test_data,shuffle=False) - - - - batch=next(iter(test_loader_dict['material'])) - batch.to(device) - z_dict=model(batch) - # print(batch) - # print(z_dict) - - loss_fn=nn.CrossEntropyLoss() - - # batch=next(iter(test_loader_dict['element'])) - # print(batch) - - # batch=next(iter(test_loader_dict['chemenv'])) - # print(batch) - - - - # print(batch['material'].node_id) - # print(model(batch)) - # Train and evaluate model - # for epoch in range(1): - # train_loss = train(model, optimizer, train_loader_dict, loss_fn=loss_fn) - # test_loss = train(model, optimizer, test_loader_dict, loss_fn=loss_fn) - # val_loss = evaluate(model, val_loader_dict) - - # print(f'Epoch {epoch+1}: Train Loss: {train_loss}, Validation Loss: {val_loss}') - - - # Evaluate on test data - # test_loss = evaluate(model, test_data) - # print(f'Test Loss: {test_loss}') - - - - - - - - - - - # for epoch in range(1): - # model.train() - # optimizer.zero_grad() - - # # - # x_dict = model(data) - # loss = negative_sampling_hetero_loss(x_dict, data.edge_index_dict) - # loss.backward() - # optimizer.step() - - # for epoch in range(1, 6): - # total_loss = total_examples = 0 - # for sampled_data in tqdm.tqdm(train_loader): - # optimizer.zero_grad() - # sampled_data.to(device) - # pred = model(sampled_data) - # ground_truth = sampled_data["user", "rates", "movie"].edge_label - # loss = F.binary_cross_entropy_with_logits(pred, ground_truth) - # loss.backward() - # optimizer.step() - # total_loss += float(loss) * pred.numel() - # total_examples += pred.numel() - # print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}") \ No newline at end of file diff --git a/matgraphdb/graph_kit/pyg/graph_trainer.py b/matgraphdb/graph_kit/pyg/graph_trainer.py deleted file mode 100644 index 96ce717..0000000 --- a/matgraphdb/graph_kit/pyg/graph_trainer.py +++ /dev/null @@ -1,303 +0,0 @@ -import os -import copy - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -import torch_geometric.transforms as T - -from matgraphdb.graph_kit.pyg.callbacks import EarlyStopping, MetricsTacker, Checkpointer -from matgraphdb.graph_kit.pyg.models import WeightedRandomClassifier, MajorityClassClassifier -from matgraphdb.utils import ML_DIR,ML_SCRATCH_RUNS_DIR -from matgraphdb.utils import LOGGER - -class Trainer: - def __init__(self,train_dataset, test_dataset, model, loss_fn, optimizer, device, - run_path=ML_SCRATCH_RUNS_DIR, - run_name='scratch', - batch_size=64, - max_iters=100, - early_stopping_patience=5, - eval_interval=4 - ): - self.train_loader=DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) - self.test_loader=DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) - self.model=model - self.loss_fn=loss_fn - self.optimizer=optimizer - self.device=device - self.eval_interval=eval_interval - self.max_iters=max_iters - self.early_stopping_patience=early_stopping_patience - - self.run_path=run_path - self.run_dir=os.path.join(self.run_path,run_name) - os.makedirs(self.run_dir,exist_ok=True) - - self.es=EarlyStopping(patience = early_stopping_patience) - self.metrics_tacker=MetricsTacker(save_path=self.run_dir,is_regression=self.is_regression()) - self.checkpointer=Checkpointer(save_path=self.run_dir,verbose=1) - - self.meta_data={} - self.best_loss=None - self.best_model=None - - - # self.model.to(device) - - def initialize_meta_data(self): - self.meta_data['model_name']=self.model.__class__.__name__ - self.meta_data['loss_fn']=self.loss_fn.__class__.__name__ - self.meta_data['optimizer']=self.optimizer.__class__.__name__ - self.meta_data['eval_interval']=self.eval_interval - self.meta_data['max_iters']=self.max_iters - self.meta_data['early_stopping_patience']=self.early_stopping_patience - self.meta_data['eval_iterval']=self.eval_interval - self.meta_data['device']=self.device - - def get_class_weights(self,number_of_classes): - class_counts = torch.zeros(number_of_classes) # Replace `number_of_classes` with your actual number of classes - - for _,labels in self.train_loader: - labels=labels.to(self.device) - for label in labels: - class_counts[label] += 1 - - - # Prevent division by zero in case some class is not present at all - class_weights = 1. / (class_counts + 1e-5) - # Normalize weights so that the smallest weight is 1.0 - self.class_weights = class_weights / class_weights.min() - return class_counts,class_weights - - def is_regression(self): - return isinstance(self.loss_fn,nn.MSELoss) - - def get_num_classes(self): - if self.is_regression(): - return 1 - else: - return self.model.output_dim - - def calculate_loss(self,logits,y_true): - if self.is_regression(): - logits = logits[:,0] - train_loss = self.loss_fn(logits, y_true) - else: - train_loss = self.loss_fn(logits, y_true) - return train_loss - - def train_step(self,dataloader): - """ - Trains the model on the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to train on. - - Returns: - float: The average loss per batch on the training data. - """ - num_batches = len(dataloader) - batch_train_loss = 0.0 - for i_batch, (X,y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - - train_loss = self.calculate_loss(logits, y) - batch_train_loss += train_loss.item() - - # Backpropagation - self.optimizer.zero_grad(set_to_none=True) - train_loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() - - batch_train_loss = batch_train_loss / num_batches - return batch_train_loss - - def test_step(self,dataloader): - """ - Tests the model on the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to test on. - - Returns: - float: The average loss per batch on the test data. - """ - num_batches = len(dataloader) - self.model.eval() - batch_test_loss = 0.0 - with torch.no_grad(): - for i_batch,(X, y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - batch_test_loss = self.calculate_loss(logits, y) - - batch_test_loss /= num_batches - return batch_test_loss - - def predict(self, dataloader, return_probabilities=False): - """ - Predicts the labels for the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to predict on. - return_probabilities (bool, optional): Whether to return the probabilities of the predictions. Defaults to False. - - Returns: - list: A list of predicted labels. - """ - total_samples = len(dataloader.dataset) - batch_size = dataloader.batch_size - num_classes=self.get_num_classes() - - # Determine the size and type of the predictions tensor - if return_probabilities and not self.is_regression(): - predictions = torch.zeros(total_samples, num_classes, dtype=torch.float, device=self.device) - actual=torch.zeros(total_samples, num_classes, dtype=torch.float, device=self.device) - elif not self.is_regression(): - predictions = torch.zeros(total_samples, dtype=torch.long, device=self.device) - actual=torch.zeros(total_samples, dtype=torch.long, device=self.device) - else: - predictions = torch.zeros(total_samples, dtype=torch.float, device=self.device) - actual=torch.zeros(total_samples, dtype=torch.float, device=self.device) - - self.model.eval() - with torch.no_grad(): - sample_idx = 0 - for i_batch,(X, y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - - if self.is_regression(): - batch_predictions = logits.squeeze() - else: - probailities = torch.sigmoid(logits) - if return_probabilities: - batch_predictions=probailities - else: - batch_predictions=probailities.argmax(1) - - # Calculate the number of predictions to store - batch_size_actual = X.size(0) - predictions[sample_idx:sample_idx + batch_size_actual] = batch_predictions - actual[sample_idx:sample_idx + batch_size_actual] = y - sample_idx += batch_size_actual - return actual,predictions - - def train(self): - LOGGER.info(f"___Starting training___") - - if self.model.__class__.__name__ in ['WeightedRandomClassifier','MajorityClassClassifier']: - batch_test_loss=self.test_step(self.test_loader) - self.max_iters=1 - # Main training loop - for iter in range(self.max_iters): - # Train step - batch_train_loss=self.train_step(self.train_loader) - # Test step - batch_test_loss=self.test_step(self.test_loader) - - - if iter%self.eval_interval==0: - - train_actual,train_predictions=self.predict(self.train_loader,return_probabilities=False) - test_actual,test_predictions=self.predict(self.test_loader,return_probabilities=False) - self.metrics_tacker.calculate_metrics(y_pred=train_predictions,y_true=train_actual, - batch_loss=batch_train_loss, - epoch=iter, - n_features=self.model.input_dim, - num_classes=self.model.output_dim, - split='train') - self.metrics_tacker.calculate_metrics(y_pred=test_predictions,y_true=test_actual, - batch_loss=batch_test_loss, - epoch=iter, - n_features=self.model.input_dim, - num_classes=self.model.output_dim, - split='test') - LOGGER.info(f"Train Loss: {batch_train_loss} | Test Loss: {batch_test_loss}") - - if self.es is not None: - if self.es(model=self.model, val_loss=batch_test_loss): - self.best_loss=batch_test_loss - self.best_model=copy.deepcopy(self.model) - LOGGER.info("Early stopping") - LOGGER.info('_______________________') - LOGGER.info(f'Stopping : {iter - self.es.counter}') - LOGGER.info(f'Best loss: {self.es.best_loss}') - - self.checkpointer.save_model(model=self.best_model,epoch=iter) - break - elif iter==self.max_iters-1: - self.checkpointer.save_model(model=self.model,epoch=iter) - LOGGER.info(f"___Ending training___") - - LOGGER.info(f"___Starting Saving metrics___") - self.metrics_tacker.save_metrics() - LOGGER.info(f"___Ending saving metrics___") - - -if __name__=='__main__': - from matgraphdb.mlcore.models import MultiLayerPerceptron - - - from matgraphdb.mlcore.datasets import MaterialGraphDataset - - graph_dataset=MaterialGraphDataset.ec_element_chemenv( - use_weights=False, - use_node_properties=True, - #,properties=['group','atomic_number'] - ) - print(graph_dataset.data) - - - rev_edge_types=[] - edge_types=[] - for edge_type in graph_dataset.data.edge_types: - rel_type=edge_type[1] - if 'rev' in rel_type: - rev_edge_types.append(edge_type) - else: - edge_types.append(edge_type) - print(edge_types) - print(rev_edge_types) - transform = T.RandomLinkSplit( - num_val=0.1, - num_test=0.1, - disjoint_train_ratio=0.3, - neg_sampling_ratio=2.0, - add_negative_train_samples=False, - edge_types=edge_types, - rev_edge_types=rev_edge_types, - ) - train_data, val_data, test_data = transform(graph_dataset.data) - print("Train Data") - print("-"*200) - print(train_data) - print("Val Data") - print("-"*200) - print(val_data) - print("Test Data") - print("-"*200) - print(test_data) - - - # device= "cuda:0" if torch.cuda.is_available() else torch.device("cpu") - - # model=MultiLayerPerceptron(input_dim=28*28,output_dim=10,num_layers=1,n_embd=128) - # loss_fn=nn.CrossEntropyLoss() - - # optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) - # trainer=Trainer(train_dataset,test_dataset,model,loss_fn,optimizer,device, - # eval_interval=2, - # max_iters=4, - # batch_size=128, - # early_stopping_patience=5) - # trainer.train() - # # print(trainer.get_num_classes()) - # # predictions=trainer.predict(test_loader,return_probabilities=False) - # # print(predictions.shape) diff --git a/matgraphdb/graph_kit/pyg/loss.py b/matgraphdb/graph_kit/pyg/loss.py deleted file mode 100644 index be827f8..0000000 --- a/matgraphdb/graph_kit/pyg/loss.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -from torch_geometric.data import Data -from torch_geometric.nn import SAGEConv, global_mean_pool -from torch_geometric.utils import negative_sampling -from torch_cluster import random_walk -import torch.nn.functional as F - -import random - -def negative_sampling_loss(z, edge_index, num_neg_samples=10): - src, pos = edge_index[0], edge_index[1] - neg = torch.randint(0, z.size(0), (num_neg_samples * src.size(0),), dtype=torch.long) - - pos_loss = -torch.log(torch.sigmoid((z[src] * z[pos]).sum(dim=-1))).mean() - neg_loss = -torch.log(1 - torch.sigmoid((z[src] * z[neg]).sum(dim=-1))).mean() - - return pos_loss + neg_loss - - -def positive_sampling_hetero_random_walk_loss(z_dict, edge_index_dict, num_walks=10, walk_length=3, weighted=True,device=None): - type_loss=torch.zeros(size=(len(z_dict),),device=device) - - for i_node_type,(node_type, z_src) in enumerate(z_dict.items()): - for node in z_src[:1,:]: - - - embd_size=node.shape[0] - neighbors=torch.zeros(size=(num_walks,embd_size),device=device) - for i_walk in range(num_walks): - # Resetting the current node type and node to initial node - current_node_type=node_type - current_node=node - - # Iterating over the walk length - for i_step in range(walk_length): - edge_types=[] - weights=[] - for edge_type, edge_index in edge_index_dict.items(): - src_type, rel_type, dst_type = edge_type[0],edge_type[1], edge_type[2] - if current_node_type == src_type: - edge_types.append(edge_type) - weights.append(edge_index.shape[1]) - - # Randomly selecting the step type - if weighted: - weights = torch.tensor(weights, dtype=torch.float) - else: - weights = torch.ones(len(weights), dtype=torch.float) - random_step_type_index = torch.multinomial(weights, 1).item() - step_edge_type = edge_types[random_step_type_index] - src_type, rel_type, dst_type = step_edge_type[0],step_edge_type[1], step_edge_type[2] - step_edge_index = edge_index_dict[step_edge_type] - - - # Randomly selecting the step from the selected edge type - random_step_weights=torch.ones(size=(step_edge_index.shape[1],)) - random_step_index = torch.multinomial(random_step_weights, 1).item() - src_step_index,dst_step_index=step_edge_index[:,random_step_index] - - current_node_type=dst_type - current_node=z_dict[current_node_type][dst_step_index] - - neighbors[i_walk,:]=current_node - - matmul=torch.matmul(node,neighbors.T) - matmul = torch.nan_to_num(matmul, nan=0.0) - - type_loss[i_node_type] += -torch.log(torch.sigmoid( matmul.mean() ) ) - pos_loss=type_loss.sum() - return pos_loss - - - -def negative_sampling_hetero_loss(z_dict, edge_index_dict, num_neg_samples=10, method = "sparse"): - total_loss = 0 - - for node_type, z_src in z_dict.items(): - # pos_batch = random_walk(row, col, batch, - # walk_length=1, - # coalesced=False)[:, 1] - - for edge_type, edge_index in edge_index_dict.items(): - src_type, rel_type, dst_type = edge_type[0],edge_type[1], edge_type[2] - if node_type == src_type: - pos_src, pos_dst = edge_index[0], edge_index[1] - # print(pos_src.shape) - - print(z_src.shape) - - pass - # for edge_type, edge_index in edge_index_dict.items(): - # pos_src, pos_dst = edge_index[0], edge_index[1] - # src_type, rel_type, dst_type = edge_type[0],edge_type[1], edge_type[2] - # # src_z, dst_z = z_dict[src_type][pos_src], z_dict[dst_type][pos_dst] - # # print(src_z.shape) - # # print(dst_z.shape) - # # torch.einsum('aj,bj->i', src_z, dst_z) - # # result=src_z* dst_z - # # print(result.shape) - # neg_sample_edge_index=negative_sampling( - # edge_index=edge_index, - # # num_nodes = test_data['material', 'has', 'element'].num_nodes, - # num_neg_samples = num_neg_samples, - # # method = method, - # # force_undirected = True, - # ) - # neg_dst, neg_src = neg_sample_edge_index[0], neg_sample_edge_index[1] - # neg_src_z, neg_dst_z = z_dict[src_type][neg_src], z_dict[dst_type][neg_dst] - # print(neg_dst_z.shape) - - - # neg_dst = torch.randint(0, dst_z.size(0), (num_neg_samples * src.size(0),), dtype=torch.long) - - # pos_loss = -torch.log(torch.sigmoid((src_z * dst_z).sum(dim=-1))).mean() - # neg_loss = -torch.log(1 - torch.sigmoid((src_z * dst_z[neg_dst]).sum(dim=-1))).mean() - # total_loss += pos_loss + neg_loss - # return total_loss / len(edge_index_dict) \ No newline at end of file diff --git a/matgraphdb/graph_kit/pyg/metrics.py b/matgraphdb/graph_kit/pyg/metrics.py deleted file mode 100644 index 71a970f..0000000 --- a/matgraphdb/graph_kit/pyg/metrics.py +++ /dev/null @@ -1,261 +0,0 @@ -import torch -import torcheval.metrics.functional as FEVAL - -class RegressionMetrics(): - def mean_absolute_error(y_pred, y_true): - """ - Calculates the mean absolute error. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The mean absolute error. - """ - return torch.mean(torch.abs(y_pred - y_true)) - - def root_mean_squared_error(y_pred, y_true): - """ - Calculates the root mean squared error. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The root mean squared error. - """ - return torch.sqrt(torch.mean((y_pred - y_true) ** 2)) - - def mean_squared_logarithmic_error(y_pred, y_true): - """ - Calculates the mean squared logarithmic error. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The mean squared logarithmic error. - """ - return torch.mean((torch.log1p(y_pred) - torch.log1p(y_true)) ** 2) - - def r_squared(y_pred, y_true): - """ - Calculates the R-squared. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The R-squared. - """ - ss_res = torch.sum((y_true - y_pred) ** 2) - ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2) - return 1 - ss_res / ss_tot - - def adjusted_r_squared(y_pred, y_true, n_features): - """ - Calculates the adjusted R-squared. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - n_features (int): The number of features. - - Returns: - torch.Tensor: The adjusted R-squared. - """ - n = len(y_true) - r2 = RegressionMetrics.r_squared(y_pred, y_true) - return 1 - (1 - r2) * (n - 1) / (n - n_features - 1) - - def mean_absolute_percentage_error(y_pred, y_true): - """ - Calculates the mean absolute percentage error. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The mean absolute percentage error. - """ - return torch.mean(torch.abs((y_true - y_pred) / y_true)) * 100 - - def median_absolute_error(y_pred, y_true): - """ - Calculates the median absolute error. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The median absolute error. - """ - return torch.median(torch.abs(y_pred - y_true)) - - def explained_variance_score(y_pred, y_true): - """ - Calculates the explained variance score. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - - Returns: - torch.Tensor: The explained variance score. - """ - variance_y_true = torch.var(y_true) - variance_y_pred = torch.var(y_pred) - return 1 - (variance_y_true - variance_y_pred) / variance_y_true - - def huber_loss(y_pred, y_true, delta=1.0): - """ - Calculates the Huber loss. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - delta (float, optional): The delta value. Defaults to 1.0. The smaller the delta, the more the loss is penalized for large errors. - - Returns: - torch.Tensor: The Huber loss. - """ - error = y_true - y_pred - is_small_error = torch.abs(error) < delta - squared_loss = 0.5 * error ** 2 - linear_loss = delta * (torch.abs(error) - 0.5 * delta) - return torch.where(is_small_error, squared_loss, linear_loss).mean() - - def quantile_loss(y_pred, y_true, quantile=0.5): - """ - Calculates the quantile loss. - - Args: - y_pred (torch.Tensor): The predicted values. - y_true (torch.Tensor): The true values. - quantile (float, optional): The quantile value. Defaults to 0.5. - - Returns: - torch.Tensor: The quantile loss. - """ - error = y_true - y_pred - loss = torch.max((quantile - 1) * error, quantile * error) - return torch.mean(loss) - - - -class ClassificationMetrics(): - - def class_weights( y_true): - class_counts = torch.bincount(y_true) - class_weights = 1. / class_counts.float() # Convert to float to perform division - return class_weights - - def accuracy(y_pred, y_true): - """ Computes the accuracy of the classifier. """ - correct = y_pred.eq(y_true).sum() - return correct.float() / y_true.numel() - - def precision(y_pred, y_true): - """ Computes the precision of the classifier for binary classification. """ - true_positives = (y_pred * y_true).sum().float() - predicted_positives = y_pred.sum().float() - return true_positives / predicted_positives if predicted_positives != 0 else 0.0 - - def multi_class_accuracy(confusion_matrix=None,y_pred=None, y_true=None, num_classes=None): - """ Computes the accuracy of the classifier for multi-class classification. """ - # Precision: Diagonal elements / sum of respective column elements - if confusion_matrix is None: - conf_matrix=ClassificationMetrics.confusion_matrix(y_pred, y_true, num_classes) - else: - conf_matrix=confusion_matrix - - # Calculate correct predictions per class - correct_predictions = torch.diag(conf_matrix) - - # Total actual instances for each class (sum over each row) - total_true = conf_matrix.sum(dim=1) - per_class_acc = correct_predictions / total_true - - return per_class_acc - - - def multiclass_precision(confusion_matrix=None,y_pred=None, y_true=None, num_classes=None): - """ Computes the precision of the classifier for multi-class classification. """ - # Precision: Diagonal elements / sum of respective column elements - if confusion_matrix is None: - conf_matrix=ClassificationMetrics.confusion_matrix(y_pred, y_true, num_classes) - else: - conf_matrix=confusion_matrix - precision = torch.diag(conf_matrix) / conf_matrix.sum(0) - precision[torch.isnan(precision)] = 0 # handle NaNs due to division by zero - return precision - - def multiclass_recall(confusion_matrix=None,y_pred=None, y_true=None, num_classes=None): - """ Computes the recall of the classifier for multi-class classification. """ - if confusion_matrix is None: - conf_matrix=ClassificationMetrics.confusion_matrix(y_pred, y_true, num_classes) - else: - conf_matrix=confusion_matrix - recall = torch.diag(conf_matrix) / conf_matrix.sum(1) - recall[torch.isnan(recall)] = 0 # handle NaNs - return recall - - def multiclass_f1_score(confusion_matrix=None,y_pred=None, y_true=None, num_classes=None): - """ Computes the F1 score of the classifier for multi-class classification. """ - if confusion_matrix is None: - conf_matrix=ClassificationMetrics.confusion_matrix(y_pred, y_true, num_classes) - else: - conf_matrix=confusion_matrix - precision = ClassificationMetrics.multiclass_precision(conf_matrix) - recall = ClassificationMetrics.multiclass_recall(conf_matrix) - # F1 Score: Harmonic mean of precision and recall - f1 = 2 * (precision * recall) / (precision + recall) - f1[torch.isnan(f1)] = 0 # handle NaNs - return f1 - - def recall(y_pred, y_true): - """ Computes the recall of the classifier for binary classification. """ - true_positives = (y_pred * y_true).sum().float() - actual_positives = y_true.sum().float() - return true_positives / actual_positives if actual_positives != 0 else 0.0 - - def f1_score(y_pred, y_true): - """ Computes the F1 score of the classifier for binary classification. """ - prec = ClassificationMetrics.precision(y_pred, y_true) - rec = ClassificationMetrics.recall(y_pred, y_true) - return 2 * (prec * rec) / (prec + rec) if (prec + rec) != 0 else 0.0 - - def confusion_matrix(y_pred, y_true, num_classes): - - return FEVAL.multiclass_confusion_matrix(y_pred, y_true, num_classes) - - - def roc_auc_score(y_pred, y_true): - """ Computes ROC AUC score for binary classification. """ - # This requires sklearn, as PyTorch does not have native AUC computation - from sklearn.metrics import roc_auc_score - y_true_np = y_true.cpu().detach().numpy() - y_scores_np = y_pred.cpu().detach().numpy()[:, 1] # Probabilities for the positive class - return roc_auc_score(y_true_np, y_scores_np) - - def log_loss(y_pred, y_true): - """ Computes the log loss (cross-entropy loss), assuming y_pred are probabilities. """ - return torch.nn.functional.binary_cross_entropy(y_pred, y_true) - - def matthews_corrcoef(y_pred, y_true): - """ Computes the Matthews correlation coefficient for binary classification. """ - conf_matrix = ClassificationMetrics.confusion_matrix(y_pred, y_true, 2) - tp = conf_matrix[1, 1].float() - tn = conf_matrix[0, 0].float() - fp = conf_matrix[0, 1].float() - fn = conf_matrix[1, 0].float() - - numerator = (tp * tn - fp * fn) - denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) - return numerator / denominator if denominator != 0 else 0.0 diff --git a/matgraphdb/graph_kit/pyg/models.py b/matgraphdb/graph_kit/pyg/models.py deleted file mode 100644 index f51e0aa..0000000 --- a/matgraphdb/graph_kit/pyg/models.py +++ /dev/null @@ -1,201 +0,0 @@ -# Define the MLP class -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import SAGEConv -class FeedFoward(nn.Module): - """ a simple linear layer followed by a non-linearity """ - - def __init__(self, n_embd, dropout=0.0): - super().__init__() - - - if dropout>0.0: - self.net = nn.Sequential( - nn.Linear(n_embd, 4 * n_embd), - nn.ReLU(), - nn.Linear(4 * n_embd, n_embd), - nn.Dropout(dropout), - ) - else: - self.net = nn.Sequential( - nn.Linear(n_embd, 4 * n_embd), - nn.ReLU(), - nn.Linear(4 * n_embd, n_embd), - ) - self.ln=nn.LayerNorm(n_embd) - - def forward(self, x): - return self.net(self.ln(x)) - -class InputLayer(nn.Module): - def __init__(self,input_dim,n_embd): - super().__init__() - self.flatten = nn.Flatten() - self.proj=nn.Linear(input_dim, n_embd) - - def forward(self, x): - out=self.flatten(x) - return self.proj(out) - -class MultiLayerPerceptron(nn.Module): - - def __init__(self, input_dim, output_dim, num_layers, n_embd): - super().__init__() - self.input_dim=input_dim - self.output_dim=output_dim - - self.input_layer=InputLayer(input_dim,n_embd) - self.layers = nn.ModuleList([FeedFoward(n_embd) for _ in range(num_layers)]) - - self.ln_f=nn.LayerNorm(n_embd) - self.output_layer=nn.Linear(n_embd,self.output_dim) - - - def forward(self, x): - out=self.input_layer(x) - for layer in self.layers: - out = out + layer(out) - out=self.ln_f(out) - out=self.output_layer(out) - return out - - - - - -class Block(nn.Module): - """Conv Block: communication followed by computation""" - def __init__(self, in_channels, out_channels, - graph_conv=SAGEConv, - ffwd_params={ - 'dropout': 0.0, - }, - conv_params={ - 'aggr': 'add', - 'normalize': False, - 'root_weight': False, - 'project': False, - 'bias': True}, - ): - super().__init__() - - self.comm=graph_conv(in_channels, out_channels, **conv_params) - self.ffwd = FeedFoward(out_channels, **ffwd_params) - - def forward(self, x): - x = self.comm(x, x) - x = self.ffwd(x) - return x - -class StackedConv(nn.Module): - def __init__(self, n_embd, - num_layers=1, - graph_conv=SAGEConv, - conv_params={ - 'aggr': 'add', - 'normalize': False, - 'root_weight': False, - 'project': False, - 'bias': True}, - ): - super().__init__() - - self.activtion=F.relu - self.convs = nn.ModuleList([ - graph_conv(n_embd, n_embd, **conv_params) - for _ in range(num_layers - 2) - ]) - - def forward(self, x, edge_index, **kwargs): - for conv in self.convs: - x = self.activtion(conv(x, edge_index, **kwargs)) - return x - - -class LinearRegressor(nn.Module): - def __init__(self, input_dim, output_dim): - super(LinearRegressor, self).__init__() - # Define the parameters / weights of the model - self.linear = nn.Linear( input_dim, output_dim) # Assuming x and y are single-dimensional - - def forward(self, x): - return self.linear(x) - -class WeightedRandomClassifier(nn.Module): - def __init__(self, class_counts): - super().__init__() - self.class_counts = class_counts - - def forward(self, x): - # Generate random guesses according to the class weights for each example in the batch - random_guesses = torch.multinomial(self.class_counts, x.size(0), replacement=True) - # Convert indices to one-hot encoding - - return F.one_hot(random_guesses, num_classes=len(self.class_counts)).to(torch.float32) - - -class MajorityClassClassifier(nn.Module): - def __init__(self, majority_class, num_classes): - super().__init__() - self.majority_class = majority_class - self.num_classes = num_classes - - def forward(self, x): - # Return the majority class for each example in the batch - majority_class_tensor = torch.full((x.size(0),), self.majority_class, dtype=torch.long) - # Convert indices to one-hot encoding - return F.one_hot(majority_class_tensor, num_classes=self.num_classes).to(torch.float32) - - -def test_baseline_classifiers(): - from matgraphdb.graph_kit.pyg.metrics import ClassificationMetrics - # Assuming class_weights for 3 classes - class_counts= torch.tensor([2.0,4.0,6.0]) - - # Assuming the majority class is class index 1 - majority_class = torch.argmax(class_counts) - - # Initialize classifiers - weighted_random_classifier = WeightedRandomClassifier(class_counts) - majority_class_classifier = MajorityClassClassifier(majority_class,num_classes=len(class_counts)) - - # Dummy input (batch size of 10, features size of 5) - dummy_input = torch.randn(10, 5) - - # Get predictions - weighted_random_logits = weighted_random_classifier(dummy_input) - majority_class_logits = majority_class_classifier(dummy_input) - - - weighted_random_probailities = torch.sigmoid(weighted_random_logits) - majority_class_probailities = torch.sigmoid(majority_class_logits) - - weighted_random_predictions = weighted_random_probailities.argmax(1) - majority_class_predictions = majority_class_probailities.argmax(1) - - - - print("Random Predictions:", weighted_random_predictions) - print("Majority Class Predictions:", majority_class_predictions) - - -def pytorch_geometric_test(): - # Check to see if pytorch geometric gpu is available - if torch.cuda.is_available(): - print("GPU is available") - device = torch.device("cuda") - else: - print("GPU is not available") - device = torch.device("cpu") - - - -if __name__ == "__main__": - - test_baseline_classifiers() - pytorch_geometric_test() - - - - \ No newline at end of file diff --git a/matgraphdb/graph_kit/pyg/plot_metrics.py b/matgraphdb/graph_kit/pyg/plot_metrics.py deleted file mode 100644 index 94f3170..0000000 --- a/matgraphdb/graph_kit/pyg/plot_metrics.py +++ /dev/null @@ -1,88 +0,0 @@ -import os - -import matplotlib.pyplot as plt -import numpy as np - -def plot_multiclass_metrics(save_dir,metrics_data,metric,epochs): - num_classes = len(metrics_data['train'][metric][0]) # Number of classes inferred from the length of the first list - plt.figure(figsize=(10, 6)) - for class_idx in range(num_classes): - - train_values = [epoch[class_idx] for epoch in metrics_data['train'][metric]] - test_values = [epoch[class_idx] for epoch in metrics_data['test'][metric]] - - plt.plot(epochs, train_values, label=f'Train {metric} Class {class_idx}') - plt.plot(epochs, test_values, label=f'Test {metric} Class {class_idx}') - - plt.title(f'{metric.capitalize()}') - plt.xlabel('Epochs') - plt.ylabel(metric.capitalize()) - plt.legend() - plt.grid(True) - - plot_filename = os.path.join(save_dir, f"{metric}.png") - plt.savefig(plot_filename) - plt.close() - - - -def plot_metrics(metrics_data, save_dir): - """ - Plots the training and test loss curves for each metric in the given metrics data, - and saves the plots to the specified directory. - - Args: - metrics_data (dict): A dictionary containing 'train' and 'test' metrics data. - save_dir (str): Path to the directory where the plots will be saved. - """ - if not os.path.exists(save_dir): - os.makedirs(save_dir) # Create the directory if it does not exist - - # Assuming there are epochs data and the same metrics in train and test - epochs = range(len(metrics_data['train']['accuracy'])) # Example to generate epoch data if not provided - - # Handling multiclass metrics - multiclass_metrics = ['precision', 'recall', 'f1'] - - for metric in metrics_data['train']: - if metric not in ['confusion_matrix', 'epoch']: # Exclude non-numeric metrics - # Single value metrics like accuracy and batch_loss - train_values = np.array(metrics_data['train'][metric]) - test_values = np.array(metrics_data['test'][metric]) - - std_train=None - std_test=None - if metric in multiclass_metrics: - # Compute the mean and standard deviation for each class - std_train = np.std(train_values, axis=1) - train_values = np.mean(train_values, axis=1) - std_test = np.std(test_values, axis=1) - test_values = np.mean(test_values, axis=1) - - # Single value metrics like accuracy and batch_loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, train_values, label=f'Train {metric}') - plt.plot(epochs, test_values, label=f'Test {metric}') - if std_train is not None: - plt.errorbar(epochs, train_values, yerr=std_train, fmt='o', color='blue', ecolor='lightgray', elinewidth=3, capsize=0) - plt.errorbar(epochs, test_values, yerr=std_test, fmt='o', color='blue', ecolor='lightgray', elinewidth=3, capsize=0) - - plt.title(metric.capitalize()) - plt.xlabel('Epochs') - plt.ylabel(metric.capitalize()) - plt.legend() - plt.grid(True) - - plot_filename = os.path.join(save_dir, f"{metric}_plot.png") - plt.savefig(plot_filename) - plt.close() - -if __name__=='__main__': - import json - - save_dir='data/production/materials_project/ML/scratch_runs/scratch' - with open(os.path.join(save_dir,'metrics.json'), 'r') as f: - metrics_data = json.load(f) - plot_metrics(metrics_data,os.path.join(save_dir,'plots')) - - \ No newline at end of file diff --git a/matgraphdb/graph_kit/pyg/trainer.py b/matgraphdb/graph_kit/pyg/trainer.py deleted file mode 100644 index 274aa75..0000000 --- a/matgraphdb/graph_kit/pyg/trainer.py +++ /dev/null @@ -1,285 +0,0 @@ -import os -import copy - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader - -from matgraphdb.graph_kit.pyg.callbacks import EarlyStopping, MetricsTacker, Checkpointer -from matgraphdb.graph_kit.pyg.models import WeightedRandomClassifier, MajorityClassClassifier -from matgraphdb.utils import ML_DIR,ML_SCRATCH_RUNS_DIR -from matgraphdb.utils import LOGGER - -class Trainer: - def __init__(self,train_dataset, test_dataset, model, loss_fn, optimizer, device, run_path=ML_SCRATCH_RUNS_DIR, run_name='scratch', - batch_size=64, - max_iters=100, - early_stopping_patience=5, - eval_interval=4 - ): - self.train_loader=DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) - self.test_loader=DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) - self.model=model - self.loss_fn=loss_fn - self.optimizer=optimizer - self.device=device - self.eval_interval=eval_interval - self.max_iters=max_iters - self.early_stopping_patience=early_stopping_patience - - self.run_path=run_path - self.run_dir=os.path.join(self.run_path,run_name) - os.makedirs(self.run_dir,exist_ok=True) - - self.es=EarlyStopping(patience = early_stopping_patience) - self.metrics_tacker=MetricsTacker(save_path=self.run_dir,is_regression=self.is_regression()) - self.checkpointer=Checkpointer(save_path=self.run_dir,verbose=1) - - - self.meta_data={} - self.best_loss=None - self.best_model=None - - - # self.model.to(device) - - def initialize_meta_data(self): - self.meta_data['model_name']=self.model.__class__.__name__ - self.meta_data['loss_fn']=self.loss_fn.__class__.__name__ - self.meta_data['optimizer']=self.optimizer.__class__.__name__ - self.meta_data['eval_interval']=self.eval_interval - self.meta_data['max_iters']=self.max_iters - self.meta_data['early_stopping_patience']=self.early_stopping_patience - self.meta_data['eval_iterval']=self.eval_interval - self.meta_data['device']=self.device - - def get_class_weights(self,number_of_classes): - class_counts = torch.zeros(number_of_classes) # Replace `number_of_classes` with your actual number of classes - - for _,labels in self.train_loader: - labels=labels.to(self.device) - for label in labels: - class_counts[label] += 1 - - - # Prevent division by zero in case some class is not present at all - class_weights = 1. / (class_counts + 1e-5) - # Normalize weights so that the smallest weight is 1.0 - self.class_weights = class_weights / class_weights.min() - return class_counts,class_weights - - - - - if not self.is_regression(): - pass - - def is_regression(self): - return isinstance(self.loss_fn,nn.MSELoss) - - def get_num_classes(self): - if self.is_regression(): - return 1 - else: - return self.model.output_dim - - def calculate_loss(self,logits,y_true): - if self.is_regression(): - logits = logits[:,0] - train_loss = self.loss_fn(logits, y_true) - else: - train_loss = self.loss_fn(logits, y_true) - return train_loss - - def train_step(self,dataloader): - """ - Trains the model on the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to train on. - - Returns: - float: The average loss per batch on the training data. - """ - num_batches = len(dataloader) - batch_train_loss = 0.0 - for i_batch, (X,y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - - train_loss = self.calculate_loss(logits, y) - batch_train_loss += train_loss.item() - - # Backpropagation - self.optimizer.zero_grad(set_to_none=True) - train_loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() - - batch_train_loss = batch_train_loss / num_batches - return batch_train_loss - - def test_step(self,dataloader): - """ - Tests the model on the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to test on. - - Returns: - float: The average loss per batch on the test data. - """ - num_batches = len(dataloader) - self.model.eval() - batch_test_loss = 0.0 - with torch.no_grad(): - for i_batch,(X, y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - batch_test_loss = self.calculate_loss(logits, y) - - batch_test_loss /= num_batches - return batch_test_loss - - def predict(self, dataloader, return_probabilities=False): - """ - Predicts the labels for the given dataloader. - - Args: - dataloader (DataLoader): The dataloader to predict on. - return_probabilities (bool, optional): Whether to return the probabilities of the predictions. Defaults to False. - - Returns: - list: A list of predicted labels. - """ - total_samples = len(dataloader.dataset) - batch_size = dataloader.batch_size - num_classes=self.get_num_classes() - - # Determine the size and type of the predictions tensor - if return_probabilities and not self.is_regression(): - predictions = torch.zeros(total_samples, num_classes, dtype=torch.float, device=self.device) - actual=torch.zeros(total_samples, num_classes, dtype=torch.float, device=self.device) - elif not self.is_regression(): - predictions = torch.zeros(total_samples, dtype=torch.long, device=self.device) - actual=torch.zeros(total_samples, dtype=torch.long, device=self.device) - else: - predictions = torch.zeros(total_samples, dtype=torch.float, device=self.device) - actual=torch.zeros(total_samples, dtype=torch.float, device=self.device) - - self.model.eval() - with torch.no_grad(): - sample_idx = 0 - for i_batch,(X, y) in enumerate(dataloader): - X, y = X.to(self.device), y.to(self.device) - - logits = self.model(X) - - if self.is_regression(): - batch_predictions = logits.squeeze() - else: - probailities = torch.sigmoid(logits) - if return_probabilities: - batch_predictions=probailities - else: - batch_predictions=probailities.argmax(1) - - # Calculate the number of predictions to store - batch_size_actual = X.size(0) - predictions[sample_idx:sample_idx + batch_size_actual] = batch_predictions - actual[sample_idx:sample_idx + batch_size_actual] = y - sample_idx += batch_size_actual - return actual,predictions - - def train(self): - LOGGER.info(f"___Starting training___") - - if self.model.__class__.__name__ in ['WeightedRandomClassifier','MajorityClassClassifier']: - batch_test_loss=self.test_step(self.test_loader) - self.max_iters=1 - # Main training loop - for iter in range(self.max_iters): - # Train step - batch_train_loss=self.train_step(self.train_loader) - # Test step - batch_test_loss=self.test_step(self.test_loader) - - - if iter%self.eval_interval==0: - - train_actual,train_predictions=self.predict(self.train_loader,return_probabilities=False) - test_actual,test_predictions=self.predict(self.test_loader,return_probabilities=False) - self.metrics_tacker.calculate_metrics(y_pred=train_predictions,y_true=train_actual, - batch_loss=batch_train_loss, - epoch=iter, - n_features=self.model.input_dim, - num_classes=self.model.output_dim, - split='train') - self.metrics_tacker.calculate_metrics(y_pred=test_predictions,y_true=test_actual, - batch_loss=batch_test_loss, - epoch=iter, - n_features=self.model.input_dim, - num_classes=self.model.output_dim, - split='test') - LOGGER.info(f"Train Loss: {batch_train_loss} | Test Loss: {batch_test_loss}") - - if self.es is not None: - if self.es(model=self.model, val_loss=batch_test_loss): - self.best_loss=batch_test_loss - self.best_model=copy.deepcopy(self.model) - LOGGER.info("Early stopping") - LOGGER.info('_______________________') - LOGGER.info(f'Stopping : {iter - self.es.counter}') - LOGGER.info(f'Best loss: {self.es.best_loss}') - - self.checkpointer.save_model(model=self.best_model,epoch=iter) - break - elif iter==self.max_iters-1: - self.checkpointer.save_model(model=self.model,epoch=iter) - LOGGER.info(f"___Ending training___") - - LOGGER.info(f"___Starting Saving metrics___") - self.metrics_tacker.save_metrics() - LOGGER.info(f"___Ending saving metrics___") - - -if __name__=='__main__': - from matgraphdb.mlcore.models import MultiLayerPerceptron - from matgraphdb.mlcore.trainer import Trainer - - from torchvision import transforms - import torchvision - from torch.utils.data import DataLoader - - - import torch - import torch.nn as nn - import torch.optim as optim - - - - - # Data preprocessing and loading - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) - - train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, transform=transform, download=True) - test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, transform=transform) - - - device= "cuda:0" if torch.cuda.is_available() else torch.device("cpu") - - model=MultiLayerPerceptron(input_dim=28*28,output_dim=10,num_layers=1,n_embd=128) - loss_fn=nn.CrossEntropyLoss() - - optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) - trainer=Trainer(train_dataset,test_dataset,model,loss_fn,optimizer,device, - eval_interval=2, - max_iters=4, - batch_size=128, - early_stopping_patience=5) - trainer.train() - # print(trainer.get_num_classes()) - # predictions=trainer.predict(test_loader,return_probabilities=False) - # print(predictions.shape) diff --git a/matgraphdb/graph_kit/pyg/transforms.py b/matgraphdb/graph_kit/pyg/transforms.py deleted file mode 100644 index 97072b2..0000000 --- a/matgraphdb/graph_kit/pyg/transforms.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from torch_geometric.data import HeteroData - -def min_max_normalize(tensor:torch.Tensor, normalization_range=(0,1), tensor_min=None, tensor_max=None): - if tensor_min is None: - tensor_min = tensor.min() - if tensor_max is None: - tensor_max = tensor.max() - - - values_minus_min = tensor - tensor_min - old_range_diff = tensor_max - tensor_min - new_range_diff = normalization_range[1] - normalization_range[0] - new_min = normalization_range[0] - return ( values_minus_min / old_range_diff ) * new_range_diff + new_min, tensor_min, tensor_max - -def standardize_tensor(tensor, epsilon=1e-8, mean=None, std=None): - if mean is None: - mean = torch.mean(tensor) - if std is None: - std = torch.std(tensor) - return (tensor - mean) / (std + epsilon), mean, std - - -def robust_scale(tensor, q_min=0.25, q_max=0.75, epsilon=1e-8, median=None, iqr=None): - q1 = torch.quantile(tensor, q_min) - q3 = torch.quantile(tensor, q_max) - if iqr is None: - iqr = q3 - q1 - if median is None: - median = torch.median(tensor) - return (tensor - median) / (iqr + epsilon), median, iqr \ No newline at end of file diff --git a/matgraphdb/graph_kit/relationships.py b/matgraphdb/graph_kit/relationships.py deleted file mode 100644 index 6f2d2d4..0000000 --- a/matgraphdb/graph_kit/relationships.py +++ /dev/null @@ -1,1400 +0,0 @@ -from glob import glob -import os - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import logging - -from matgraphdb.utils.chem_utils.periodic import get_group_period_edge_index -from matgraphdb.graph_kit.metadata import get_relationship_schema -from matgraphdb.graph_kit.metadata import RelationshipTypes -from matgraphdb.graph_kit.nodes import NodeManager - -logger = logging.getLogger(__name__) - -class Relationships: - """ - A class for managing and creating relationships between nodes in a graph database. This class handles - relationship creation, loading, validation, and exporting relationships into different formats, such as - Parquet and Neo4j CSV. It utilizes a NodeManager instance to handle node-related operations and can be - extended to implement custom relationship creation logic. - - Attributes - ---------- - relationship_type : str - A string representing the relationship type in the format 'start_node-relationship-end_node'. - relationship_dir : str - The directory where the relationships are stored. - node_manager : NodeManager - An instance of NodeManager responsible for handling node operations. - output_format : str - The format used for reading and writing data ('pandas' or 'pyarrow'). Default is 'pandas'. - file_type : str - The file format used for saving relationships. Currently, 'parquet' is used. - filepath : str - The full path of the file where the relationships are stored. - schema : object - The schema definition for the relationships, created by the `create_schema` method. - """ - def __init__(self, relationship_type, relationship_dir, node_dir, output_format='pandas'): - """ - Initializes the Relationships class with the given relationship type, directories, and output format. - - Parameters - ---------- - relationship_type : str - The type of relationship, formatted as 'start_node-relationship-end_node'. - relationship_dir : str - Directory where the relationship files will be stored. - node_dir : str - Directory where node information is stored, managed by the NodeManager. - output_format : str, optional - Format for reading and writing data. Options are 'pandas' or 'pyarrow' (default is 'pandas'). - - Raises - ------ - ValueError - If `output_format` is not 'pandas' or 'pyarrow'. - """ - if output_format not in ['pandas', 'pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - self.relationship_type = relationship_type - self.relationship_dir = relationship_dir - self.node_manager = NodeManager(node_dir=node_dir) # Store the NodeManager instance - os.makedirs(self.relationship_dir, exist_ok=True) - - self.output_format = output_format - self.file_type = 'parquet' - self.filepath = os.path.join(self.relationship_dir, f'{self.relationship_type}.{self.file_type}') - self.schema = self.create_schema() - - self.get_dataframe() - - def get_dataframe(self, columns=None, include_cols=True, from_scratch=False, remove_duplicates=True, **kwargs): - """ - Loads or creates the relationship data based on the specified parameters. - - Parameters - ---------- - columns : list, optional - A list of columns to include or exclude from the DataFrame. - include_cols : bool, optional - Whether to include or exclude the specified columns (default is True). - from_scratch : bool, optional - If True, forces the creation of the relationship data from scratch (default is False). - remove_duplicates : bool, optional - Whether to remove duplicate relationships (default is True). - - Returns - ------- - pd.DataFrame or pyarrow.Table - The loaded or newly created relationship DataFrame. - - Raises - ------ - ValueError - If required fields ('start_node_id' and 'end_node_id') are missing in the relationship DataFrame. - """ - - start_node_type,connection_name,end_node_type=self.relationship_type.split('-') - start_node_name=f'{start_node_type}-START_ID' - end_node_name=f'{end_node_type}-END_ID' - - - if os.path.exists(self.filepath) and not from_scratch: - logger.info(f"Trying to load {self.relationship_type} relationships from {self.filepath}") - df = self.load_dataframe(filepath=self.filepath, columns=columns, include_cols=include_cols, **kwargs) - return df - - logger.info(f"No relationship file found. Attempting to create {self.relationship_type} relationships") - df = self.create_relationships(**kwargs) # Subclasses will define this - - # Ensure the 'start_node_id' and 'end_node_id' fields are present - if start_node_name not in df.columns or end_node_name not in df.columns: - raise ValueError(f"'{start_node_name}' and '{end_node_name}' fields must be defined for {self.relationship_type} relationships.") - - # If 'weight' is not in the dataframe, add it or if remove_duplicates is True, remove duplicates - if 'weight' not in df.columns: - if remove_duplicates: - df=self.remove_duplicate_relationships(df) - - df['TYPE'] = self.relationship_type - - if columns: - df = df[columns] - - if not self.schema: - logger.error(f"No schema set for {self.relationship_type} relationships") - return None - - self.save_dataframe(df, self.filepath) - return df - - def get_property_names(self): - """ - Retrieves and logs the names of properties (columns) in the relationship file. - - Returns - ------- - list - A list of property names in the relationship file. - """ - properties = Relationships.get_column_names(self.filepath) - for property in properties: - logger.info(f"Property: {property}") - return properties - - def create_relationships(self, **kwargs): - """ - Abstract method for creating relationships. Must be implemented by subclasses. - - Raises - ------ - NotImplementedError - If this method is called from the base class instead of a subclass. - """ - if self.__class__.__name__ != 'Relationships': - raise NotImplementedError("Subclasses must implement this method.") - else: - pass - - def create_schema(self, **kwargs): - """ - Abstract method for creating the schema for relationships. Must be implemented by subclasses. - - Raises - ------ - NotImplementedError - If this method is called from the base class instead of a subclass. - """ - if self.__class__.__name__ != 'Relationships': - raise NotImplementedError("Subclasses must implement this method.") - else: - pass - - def load_dataframe(self, filepath, columns=None, include_cols=True, **kwargs): - """ - Loads a DataFrame from a parquet file. - - Parameters - ---------- - filepath : str - The path to the parquet file. - columns : list, optional - List of columns to include or exclude when loading the file. - include_cols : bool, optional - Whether to include or exclude the specified columns (default is True). - - Returns - ------- - pd.DataFrame or pyarrow.Table - The loaded DataFrame or table, depending on the output format. - - Raises - ------ - Exception - If an error occurs while loading the file. - """ - try: - if self.output_format == 'pandas': - df = pd.read_parquet(filepath, columns=columns) - elif self.output_format == 'pyarrow': - df = pq.read_table(filepath, columns=columns) - return df - except Exception as e: - logger.error(f"Error loading {self.relationship_type} relationships from {filepath}: {e}") - return None - - def save_dataframe(self, df, filepath): - """ - Saves a DataFrame to a parquet file. - - Parameters - ---------- - df : pd.DataFrame - The DataFrame to save. - filepath : str - The path where the DataFrame will be saved. - - Raises - ------ - Exception - If an error occurs while saving the DataFrame to a parquet file. - """ - try: - parquet_table = pa.Table.from_pandas(df, self.schema) - pq.write_table(parquet_table, filepath) - logger.info(f"Finished saving {self.relationship_type} relationships to {filepath}") - except Exception as e: - logger.error(f"Error converting dataframe to parquet table for saving: {e}") - - def to_neo4j(self, save_dir): - """ - Converts the relationship data to Neo4j-compatible CSV format. - - Parameters - ---------- - save_dir : str - The directory where the Neo4j CSV file will be saved. - """ - logger.info(f"Converting relationship to Neo4j : {self.filepath}") - - relationship_type=os.path.basename(self.filepath).split('.')[0] - node_a_type,connection_name,node_b_type=relationship_type.split('-') - - logger.debug(f"Relationship type: {relationship_type}") - - metadata = pq.read_metadata(self.filepath) - column_types = {} - neo4j_column_name_mapping={} - for filed_schema in metadata.schema: - # Only want top column names - type=filed_schema.physical_type - - field_path=filed_schema.path.split('.') - name=field_path[0] - - is_list=False - if len(field_path)>1: - is_list=field_path[1] == 'list' - - column_types[name] = {} - column_types[name]['type']=type - column_types[name]['is_list']=is_list - - if type=='BYTE_ARRAY': - neo4j_type ='string' - if type=='BOOLEAN': - neo4j_type='boolean' - if type=='DOUBLE': - neo4j_type='float' - if type=='INT64': - neo4j_type='int' - - if is_list: - neo4j_type+='[]' - - column_types[name]['neo4j_type'] = f'{name}:{neo4j_type}' - column_types[name]['neo4j_name'] = f'{name}:{neo4j_type}' - - neo4j_column_name_mapping[name]=f'{name}:{neo4j_type}' - - neo4j_column_name_mapping['TYPE']=':LABEL' - - neo4j_column_name_mapping[f'{node_a_type}-START_ID']=f':START_ID({node_a_type}-ID)' - neo4j_column_name_mapping[f'{node_b_type}-END_ID']=f'END_ID({node_a_type}-ID)' - - df=self.load_relationships(filepath=self.filepath) - - - df.rename(columns=neo4j_column_name_mapping, inplace=True) - - os.makedirs(save_dir,exist_ok=True) - - - save_file=os.path.join(save_dir,f'{relationship_type}.csv') - - logger.debug(f"Saving {relationship_type} relationship_path to {save_file}") - - df.to_csv(save_file, index=False) - - logger.info(f"Finished converting relationship to Neo4j : {relationship_type}") - - def validate_nodes(self): - """ - Validate that the nodes used in the relationships exist in the node manager. - """ - start_nodes = set(self.get_dataframe()['start_node'].unique()) - end_nodes = set(self.get_dataframe()['end_node'].unique()) - - existing_nodes = self.node_manager.get_existing_nodes() - missing_start_nodes = start_nodes - existing_nodes - missing_end_nodes = end_nodes - existing_nodes - - if missing_start_nodes: - logger.warning(f"Missing start nodes: {missing_start_nodes}") - if missing_end_nodes: - logger.warning(f"Missing end nodes: {missing_end_nodes}") - - return not missing_start_nodes and not missing_end_nodes - - @staticmethod - def get_column_names(filepath): - metadata = pq.read_metadata(filepath) - all_columns = [] - for filed_schema in metadata.schema: - - # Only want top column names - max_defintion_level=filed_schema.max_definition_level - if max_defintion_level!=1: - continue - - all_columns.append(filed_schema.name) - return all_columns - - @staticmethod - def remove_duplicate_relationships(df): - """Expects only two columns with the that represent the id of the nodes - - Parameters - ---------- - df : pandas.DataFrame - """ - column_names=list(df.columns) - - df['id_tuple'] = df.apply(lambda x: tuple(sorted([x[column_names[0]], x[column_names[1]]])), axis=1) - # Group by the sorted tuple and count occurrences - grouped = df.groupby('id_tuple') - weights = grouped.size().reset_index(name='weight') - - # Drop duplicates based on the id_tuple - df_weighted = df.drop_duplicates(subset='id_tuple') - - # Merge with weights - df_weighted = pd.merge(df_weighted, weights, on='id_tuple', how='left') - - # Drop the id_tuple column - df_weighted = df_weighted.drop(columns='id_tuple') - return df_weighted - - -# Example subclass for specific relationship types -class MaterialSPGRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_SPG.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_SPG) - - def create_relationships(self, **kwargs): - # The logic for creating material relationships - try: - relationship_type=RelationshipTypes.MATERIAL_SPG.value - start_node_type,connection_name,end_node_type=relationship_type.split('-') - - # Example: Create relationships between materials from nodes in node_manager - start_nodes_df = self.node_manager.get_node_dataframe(start_node_type, columns=['space_group']) - end_nodes_df = self.node_manager.get_node_dataframe(end_node_type, columns=['name']) - - # Mapping name to index - name_to_index_mapping_b = {int(name): index for index, name in end_nodes_df['name'].items()} - - # Creating dataframe - df = start_nodes_df.copy() - - # Removing NaN values - df = df.dropna() - - # Making current index a column and reindexing - df = df.reset_index().rename(columns={'index': start_node_type+'-START_ID'}) - - # Adding node b ID with the mapping - df[end_node_type+'-END_ID'] = df['space_group'].map(name_to_index_mapping_b).astype(int) - - df.drop(columns=['space_group'], inplace=True) - - df['weight'] = 1.0 - - except Exception as e: - logger.error(f"Error creating material relationships: {e}") - return None - return df - -class MaterialCrystalSystemRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_CRYSTAL_SYSTEM.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material-crystal system relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_CRYSTAL_SYSTEM) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.MATERIAL_CRYSTAL_SYSTEM.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for material and crystal system - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['crystal_system']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - - # Mapping name to index for the crystal system nodes - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Creating relationships dataframe by copying node_a dataframe - df = node_a_df.copy() - - # Converting the 'crystal_system' column to lowercase to standardize - df['crystal_system'] = df['crystal_system'].str.lower() - - # Removing rows with missing 'crystal_system' values - df = df.dropna() - - # Resetting the index and renaming it to follow the START_ID convention - df = df.reset_index().rename(columns={'index': node_a_type + '-START_ID'}) - - # Adding the END_ID for the crystal system nodes by mapping the 'crystal_system' to the index - df[node_b_type + '-END_ID'] = df['crystal_system'].map(name_to_index_mapping_b) - - # Dropping the 'crystal_system' column as it's no longer needed - df.drop(columns=['crystal_system'], inplace=True) - - df['weight'] = 1.0 - except Exception as e: - logger.error(f"Error creating material relationships: {e}") - return None - - - return df - -class MaterialLatticeRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_LATTICE.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material-lattice relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_LATTICE) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.MATERIAL_LATTICE.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for material and lattice - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - - # Mapping name to index for the lattice nodes - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Creating relationships dataframe by copying node_a dataframe - df = node_a_df.copy() - - # Removing rows with missing 'name' values - df = df.dropna() - - # Resetting the index and renaming it to follow the START_ID convention - df = df.reset_index().rename(columns={'index': node_a_type + '-START_ID'}) - - # Adding the END_ID for the lattice nodes by mapping the 'name' to the index - df[node_b_type + '-END_ID'] = df['name'].map(name_to_index_mapping_b).astype(int) - - # Dropping the 'name' column as it's no longer needed - df.drop(columns=['name'], inplace=True) - - # Setting the relationship type and a default weight of 1.0 - df['weight'] = 1.0 - - except Exception as e: - logger.error(f"Error creating material lattice relationships: {e}") - return None - - return df - -class MaterialSiteRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_SITE.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material-site relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_SITE) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.MATERIAL_SITE.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for material and site - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - - # Mapping name to index for the material nodes - name_to_index_mapping_a = {name: index for index, name in node_a_df['name'].items()} - - # Creating relationships dataframe by copying node_b dataframe - df = node_b_df.copy() - - # Removing rows with missing 'name' values - df = df.dropna() - - # Resetting the index and renaming it to follow the START_ID convention - df = df.reset_index().rename(columns={'index': node_a_type + '-START_ID'}) - - # Adding the END_ID for the site nodes by mapping the 'name' to the material nodes index - df[node_b_type + '-END_ID'] = df['name'].map(name_to_index_mapping_a) - - # Dropping the 'name' column as it's no longer needed - df.drop(columns=['name'], inplace=True) - - df['weight'] = 1.0 - - except Exception as e: - logger.error(f"Error creating material site relationships: {e}") - return None - - return df - -class ElementOxidationStateRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_OXIDATION_STATE.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-oxidation state relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_OXIDATION_STATE) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_OXIDATION_STATE.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements and oxidation states - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name', 'experimental_oxidation_states']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe("MATERIAL", columns=['name', 'species', 'oxidation_states-possible_valences']) - - # Mapping names to indices for element and oxidation state nodes - name_to_index_mapping_a = {name: index for index, name in node_a_df['name'].items()} - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Connecting oxidation states to elements derived from material nodes - oxidation_state_names = [] - element_names = [] - for _, row in node_material_df.iterrows(): - possible_valences = row['oxidation_states-possible_valences'] - elements = row['species'] - if possible_valences is None or elements is None: - continue - for possible_valence, element in zip(possible_valences, elements): - oxidation_state_name = f'ox_{possible_valence}' - oxidation_state_names.append(oxidation_state_name) - element_names.append(element) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': element_names, - f'{node_b_type}-END_ID': oxidation_state_names - } - df = pd.DataFrame(data) - - # Convert element names to indices and oxidation state names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_a) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_b) - - - except Exception as e: - logger.error(f"Error creating element oxidation state relationships: {e}") - return None - - return df - -class MaterialElementRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_ELEMENT.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material-element relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_ELEMENT) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.MATERIAL_ELEMENT.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for materials and elements - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name', 'species']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - - # Mapping names to indices for materials and element nodes - name_to_index_mapping_a = {name: index for index, name in node_a_df['name'].items()} - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Connecting materials to elements derived from material nodes - material_names = [] - element_names = [] - for _, row in node_a_df.iterrows(): - elements = row['species'] - material_name = row['name'] - if elements is None: - continue - - # Append the material name for each element in the species list - material_names.extend([material_name] * len(elements)) - element_names.extend(elements) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': material_names, - f'{node_b_type}-END_ID': element_names - } - df = pd.DataFrame(data) - - # Convert material names to indices and element names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_a) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_b) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - - except Exception as e: - logger.error(f"Error creating material-element relationships: {e}") - return None - - return df - -class MaterialChemEnvRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.MATERIAL_CHEMENV.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for material-chemenv relationships - return get_relationship_schema(RelationshipTypes.MATERIAL_CHEMENV) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.MATERIAL_CHEMENV.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for materials and chemenv - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name', 'coordination_environments_multi_weight']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - - # Mapping names to indices for materials and chemenv nodes - name_to_index_mapping_a = {name: index for index, name in node_a_df['name'].items()} - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Connecting materials to chemenv derived from material nodes - material_names = [] - chemenv_names = [] - for _, row in node_a_df.iterrows(): - bond_connections = row['coordination_environments_multi_weight'] - material_name = row['name'] - if bond_connections is None: - continue - - # Extract chemenv name from bond connections - for coord_env in bond_connections: - try: - chemenv_name = coord_env[0]['ce_symbol'].replace(':', '_') - except: - continue - - material_names.append(material_name) - chemenv_names.append(chemenv_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': material_names, - f'{node_b_type}-END_ID': chemenv_names - } - df = pd.DataFrame(data) - - # Convert material and chemenv names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_a) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_b) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - - except Exception as e: - logger.error(f"Error creating material-chemenv relationships: {e}") - return None - - return df - -class ElementChemEnvRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_CHEMENV.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-chemenv relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_CHEMENV) - - def create_relationships(self, columns=None, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_CHEMENV.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements and chemenv - node_a_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_b_df = self.node_manager.get_node_dataframe(node_b_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'species', 'coordination_environments_multi_weight']) - - # Mapping names to indices for element and chemenv nodes - name_to_index_mapping_a = {name: index for index, name in node_a_df['name'].items()} - name_to_index_mapping_b = {name: index for index, name in node_b_df['name'].items()} - - # Connecting materials to chemenv derived from material nodes - element_names = [] - chemenv_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['coordination_environments_multi_weight'] - elements = row['species'] - if bond_connections is None: - continue - - # Extract chemenv name and corresponding element - for i, coord_env in enumerate(bond_connections): - try: - chemenv_name = coord_env[0]['ce_symbol'].replace(':', '_') - except: - continue - element_name = elements[i] - - chemenv_names.append(chemenv_name) - element_names.append(element_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': element_names, - f'{node_b_type}-END_ID': chemenv_names - } - df = pd.DataFrame(data) - - # Convert element names and chemenv names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_a) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_b) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - except Exception as e: - logger.error(f"Error creating element-chemenv relationships: {e}") - return None - - return df - -class ElementGeometricElectricElementRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_GEOMETRIC_ELECTRIC_CONNECTS_ELEMENT.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-geometric electric element relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_GEOMETRIC_ELECTRIC_CONNECTS_ELEMENT) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_GEOMETRIC_ELECTRIC_CONNECTS_ELEMENT.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements and material data - element_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'species', 'geometric_electric_consistent_bond_connections']) - - # Mapping names to indices for element nodes - name_to_index_mapping_element = {name: index for index, name in element_df['name'].items()} - - # Connecting materials to elements based on bond connections - site_element_names = [] - neighbor_element_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['geometric_electric_consistent_bond_connections'] - elements = row['species'] - - if bond_connections is None: - continue - - for i, site_connections in enumerate(bond_connections): - site_element_name = elements[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_element_name = elements[i_neighbor_element] - - site_element_names.append(site_element_name) - neighbor_element_names.append(neighbor_element_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_element_names, - f'{node_b_type}-END_ID': neighbor_element_names - } - df = pd.DataFrame(data) - - # Convert element names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_element) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_element) - - # Removing NaN values - df = df.dropna() - - - except Exception as e: - logger.error(f"Error creating element-geometric electric element relationships: {e}") - return None - - return df - -class ElementGeometricElementRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_GEOMETRIC_CONNECTS_ELEMENT.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-geometric element relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_GEOMETRIC_CONNECTS_ELEMENT) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_GEOMETRIC_CONNECTS_ELEMENT.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements and material data - element_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'species', 'geometric_consistent_bond_connections']) - - # Mapping names to indices for element nodes - name_to_index_mapping_element = {name: index for index, name in element_df['name'].items()} - - # Connecting materials to elements based on bond connections - site_element_names = [] - neighbor_element_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['geometric_consistent_bond_connections'] - elements = row['species'] - - if bond_connections is None: - continue - - for i, site_connections in enumerate(bond_connections): - site_element_name = elements[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_element_name = elements[i_neighbor_element] - - site_element_names.append(site_element_name) - neighbor_element_names.append(neighbor_element_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_element_names, - f'{node_b_type}-END_ID': neighbor_element_names - } - df = pd.DataFrame(data) - - # Convert element names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_element) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_element) - - # Removing NaN values - df = df.dropna() - - except Exception as e: - logger.error(f"Error creating element-geometric element relationships: {e}") - return None - - return df - - -class ElementElectricElementRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_ELECTRIC_CONNECTS_ELEMENT.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-electric element relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_ELECTRIC_CONNECTS_ELEMENT) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_ELECTRIC_CONNECTS_ELEMENT.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements and material data - element_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'species', 'electric_consistent_bond_connections']) - - # Mapping names to indices for element nodes - name_to_index_mapping_element = {name: index for index, name in element_df['name'].items()} - - # Connecting materials to elements based on electric bond connections - site_element_names = [] - neighbor_element_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['electric_consistent_bond_connections'] - elements = row['species'] - - if bond_connections is None: - continue - - for i, site_connections in enumerate(bond_connections): - site_element_name = elements[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_element_name = elements[i_neighbor_element] - - site_element_names.append(site_element_name) - neighbor_element_names.append(neighbor_element_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_element_names, - f'{node_b_type}-END_ID': neighbor_element_names - } - df = pd.DataFrame(data) - - # Convert element names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_element) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_element) - - # Removing NaN values - df = df.dropna() - - except Exception as e: - logger.error(f"Error creating element-electric element relationships: {e}") - return None - - return df - -class ChemEnvGeometricElectricChemEnvRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.CHEMENV_GEOMETRIC_ELECTRIC_CONNECTS_CHEMENV.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for chemenv-geometric electric chemenv relationships - return get_relationship_schema(RelationshipTypes.CHEMENV_GEOMETRIC_ELECTRIC_CONNECTS_CHEMENV) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.CHEMENV_GEOMETRIC_ELECTRIC_CONNECTS_CHEMENV.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for chemenv and material data - chemenv_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'coordination_environments_multi_weight', 'geometric_electric_consistent_bond_connections']) - - # Mapping names to indices for chemenv nodes - name_to_index_mapping_chemenv = {name: index for index, name in chemenv_df['name'].items()} - - # Connecting materials to chemenv based on bond connections - site_chemenv_names = [] - neighbor_chemenv_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['geometric_electric_consistent_bond_connections'] - chemenv_info = row['coordination_environments_multi_weight'] - - if bond_connections is None or chemenv_info is None: - continue - - # Extract chemenv names from the coordination environments - chemenv_names = [] - for coord_env in chemenv_info: - try: - chemenv_name = coord_env[0]['ce_symbol'].replace(':', '_') - chemenv_names.append(chemenv_name) - except: - continue - - # Creating connections between site chemenv and neighbor chemenv - for i, site_connections in enumerate(bond_connections): - site_chemenv_name = chemenv_names[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_chemenv_name = chemenv_names[i_neighbor_element] - - site_chemenv_names.append(site_chemenv_name) - neighbor_chemenv_names.append(neighbor_chemenv_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_chemenv_names, - f'{node_b_type}-END_ID': neighbor_chemenv_names - } - df = pd.DataFrame(data) - - # Convert chemenv names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_chemenv) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_chemenv) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - - except Exception as e: - logger.error(f"Error creating chemenv-geometric electric chemenv relationships: {e}") - return None - - return df - - -class ChemEnvGeometricChemEnvRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.CHEMENV_GEOMETRIC_CONNECTS_CHEMENV.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for chemenv-geometric chemenv relationships - return get_relationship_schema(RelationshipTypes.CHEMENV_GEOMETRIC_CONNECTS_CHEMENV) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.CHEMENV_GEOMETRIC_CONNECTS_CHEMENV.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for chemenv and material data - chemenv_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'coordination_environments_multi_weight', 'geometric_consistent_bond_connections']) - - # Mapping names to indices for chemenv nodes - name_to_index_mapping_chemenv = {name: index for index, name in chemenv_df['name'].items()} - - # Connecting materials to chemenv based on bond connections - site_chemenv_names = [] - neighbor_chemenv_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['geometric_consistent_bond_connections'] - chemenv_info = row['coordination_environments_multi_weight'] - - if bond_connections is None or chemenv_info is None: - continue - - # Extract chemenv names from the coordination environments - chemenv_names = [] - for coord_env in chemenv_info: - try: - chemenv_name = coord_env[0]['ce_symbol'].replace(':', '_') - chemenv_names.append(chemenv_name) - except: - continue - - # Creating connections between site chemenv and neighbor chemenv - for i, site_connections in enumerate(bond_connections): - site_chemenv_name = chemenv_names[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_chemenv_name = chemenv_names[i_neighbor_element] - - site_chemenv_names.append(site_chemenv_name) - neighbor_chemenv_names.append(neighbor_chemenv_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_chemenv_names, - f'{node_b_type}-END_ID': neighbor_chemenv_names - } - df = pd.DataFrame(data) - - # Convert chemenv names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_chemenv) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_chemenv) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - except Exception as e: - logger.error(f"Error creating chemenv-geometric chemenv relationships: {e}") - return None - - return df - -class ChemEnvElectricChemEnvRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.CHEMENV_ELECTRIC_CONNECTS_CHEMENV.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for chemenv-electric chemenv relationships - return get_relationship_schema(RelationshipTypes.CHEMENV_ELECTRIC_CONNECTS_CHEMENV) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.CHEMENV_ELECTRIC_CONNECTS_CHEMENV.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for chemenv and material data - chemenv_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name']) - node_material_df = self.node_manager.get_node_dataframe('material', columns=['name', 'coordination_environments_multi_weight', 'electric_consistent_bond_connections']) - - # Mapping names to indices for chemenv nodes - name_to_index_mapping_chemenv = {name: index for index, name in chemenv_df['name'].items()} - - # Connecting materials to chemenv based on electric bond connections - site_chemenv_names = [] - neighbor_chemenv_names = [] - for _, row in node_material_df.iterrows(): - bond_connections = row['electric_consistent_bond_connections'] - chemenv_info = row['coordination_environments_multi_weight'] - - if bond_connections is None or chemenv_info is None: - continue - - # Extract chemenv names from the coordination environments - chemenv_names = [] - for coord_env in chemenv_info: - try: - chemenv_name = coord_env[0]['ce_symbol'].replace(':', '_') - chemenv_names.append(chemenv_name) - except: - continue - - # Creating connections between site chemenv and neighbor chemenv - for i, site_connections in enumerate(bond_connections): - site_chemenv_name = chemenv_names[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_chemenv_name = chemenv_names[i_neighbor_element] - - site_chemenv_names.append(site_chemenv_name) - neighbor_chemenv_names.append(neighbor_chemenv_name) - - # Creating the relationships dataframe - data = { - f'{node_a_type}-START_ID': site_chemenv_names, - f'{node_b_type}-END_ID': neighbor_chemenv_names - } - df = pd.DataFrame(data) - - # Convert chemenv names to indices - df[f'{node_a_type}-START_ID'] = df[f'{node_a_type}-START_ID'].map(name_to_index_mapping_chemenv) - df[f'{node_b_type}-END_ID'] = df[f'{node_b_type}-END_ID'].map(name_to_index_mapping_chemenv) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - except Exception as e: - logger.error(f"Error creating chemenv-electric chemenv relationships: {e}") - return None - - return df - -class ElementGroupPeriodRelationships(Relationships): - def __init__(self, relationship_dir, node_dir, output_format='pandas'): - super().__init__(relationship_type=RelationshipTypes.ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT.value, relationship_dir=relationship_dir, node_dir=node_dir, output_format=output_format) - - def create_schema(self): - # Define and return the schema for element-group-period relationships - return get_relationship_schema(RelationshipTypes.ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT) - - def create_relationships(self, columns=None, remove_duplicates=True, **kwargs): - # Defining node types and relationship type - try: - relationship_type = RelationshipTypes.ELEMENT_GROUP_PERIOD_CONNECTS_ELEMENT.value - node_a_type, connection_name, node_b_type = relationship_type.split('-') - - logger.info(f"Getting {relationship_type} relationships") - - # Loading nodes for elements - element_df = self.node_manager.get_node_dataframe(node_a_type, columns=['name', 'atomic_number', 'extended_group', 'period', 'symbol']) - - # Mapping names to indices for elements - name_to_index_mapping = {name: index for index, name in element_df['name'].items()} - - # Getting group-period edge index - edge_index = get_group_period_edge_index(element_df) - - # Creating the relationships dataframe - df = pd.DataFrame(edge_index, columns=[f'{node_a_type}-START_ID', f'{node_b_type}-END_ID']) - - # Removing NaN values and converting to int - df = df.dropna().astype(int) - - - except Exception as e: - logger.error(f"Error creating element-group-period relationships: {e}") - return None - - return df - - -class RelationshipManager: - def __init__(self, relationship_dir, output_format='pandas'): - """ - Initialize the RelationshipManager with the directory where relationships are stored. - """ - if output_format not in ['pandas', 'pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - self.relationship_dir = relationship_dir - os.makedirs(self.relationship_dir, exist_ok=True) - - self.file_type = 'parquet' - self.get_existing_relationships() - - def get_existing_relationships(self): - self.relationships = set(self.list_relationships()) - return self.relationships - - def list_relationships(self): - """ - List all relationship files available in the relationship directory. - """ - relationship_files = [f for f in os.listdir(self.relationship_dir) if f.endswith(f'.{self.file_type}')] - relationship_types = [os.path.splitext(f)[0] for f in relationship_files] # Extract file names without extension - logger.info(f"Found the following relationship types: {relationship_types}") - return relationship_types - - def get_relationship(self, relationship_type, output_format='pandas'): - """ - Load a relationship dataframe by its type (which corresponds to the filename without extension). - """ - if output_format not in ['pandas', 'pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - - filepath = os.path.join(self.relationship_dir, f'{relationship_type}.{self.file_type}') - - if not os.path.exists(filepath): - logger.error(f"No relationship file found for type: {relationship_type}") - return None - - relationship = Relationships(relationship_type=relationship_type, relationship_dir=self.relationship_dir, output_format=output_format) - return relationship - - def get_relationship_dataframe(self, relationship_type, columns=None, include_cols=True, output_format='pandas', **kwargs): - """ - Return the relationship dataframe if it has already been loaded; otherwise, load it from file. - """ - if output_format not in ['pandas', 'pyarrow']: - raise ValueError("output_format must be either 'pandas' or 'pyarrow'") - return self.get_relationship(relationship_type, output_format=output_format).get_dataframe(columns=columns, include_cols=include_cols) - - def add_relationship(self, relationship_class): - """ - Add a new relationship by providing a custom relationship class (must inherit from the base Relationships class). - The relationship class must implement its own creation logic. - """ - if not issubclass(relationship_class, Relationships): - raise TypeError("The provided class must inherit from the Relationships class.") - - relationship = relationship_class(relationship_dir=self.relationship_dir) # Initialize the relationship class - relationship.get_dataframe() # Get or create the relationship dataframe - - self.get_existing_relationships() - - def delete_relationship(self, relationship_type): - """ - Delete a relationship type. This method will remove the parquet file and the relationship from the self.relationships set. - """ - filepath = os.path.join(self.relationship_dir, f'{relationship_type}.{self.file_type}') - - if os.path.exists(filepath): - try: - os.remove(filepath) - self.relationships.discard(relationship_type) # Remove from the set of relationships - logger.info(f"Deleted relationship of type {relationship_type} and removed it from the relationship set.") - except Exception as e: - logger.error(f"Error deleting relationship of type {relationship_type}: {e}") - else: - logger.warning(f"No relationship file found for type {relationship_type} to delete.") - - def convert_all_to_neo4j(self, save_dir): - """ - Convert all Parquet relationship files in the relationship directory to Neo4j CSV format. - """ - os.makedirs(save_dir, exist_ok=True) - for relationship_type in self.relationships: - logger.info(f"Converting {relationship_type} to Neo4j CSV format.") - try: - relationship = self.get_relationship(relationship_type) # Load the relationship - if relationship is not None: - relationship.to_neo4j(save_dir) # Convert to Neo4j format - logger.info(f"Successfully converted {relationship_type} to Neo4j CSV.") - else: - logger.warning(f"Skipping {relationship_type} as it could not be loaded.") - except Exception as e: - logger.error(f"Error converting {relationship_type} to Neo4j CSV: {e}") - - -if __name__ == "__main__": - - node_dir = os.path.join('data','raw','nodes') - relationship_dir = os.path.join('data','raw','relationships') - - # relationships=ElementOxidationStateRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - # print(relationships.get_dataframe().head()) - # print(relationships.get_property_names()) - - relationships=ElementGroupPeriodRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=ElementGeometricElectricElementRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=ElementElectricElementRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=ElementGeometricElementRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=ChemEnvGeometricElectricChemEnvRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=ChemEnvGeometricChemEnvRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships = ChemEnvElectricChemEnvRelationships(node_dir=node_dir, relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - - - relationships= MaterialChemEnvRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships = MaterialElementRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships = MaterialLatticeRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=MaterialSiteRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=MaterialCrystalSystemRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - relationships=MaterialSPGRelationships(node_dir=node_dir,relationship_dir=relationship_dir) - print(relationships.get_dataframe().head()) - print(relationships.get_property_names()) - - - - diff --git a/matgraphdb/graph_kit/utils.py b/matgraphdb/graph_kit/utils.py deleted file mode 100644 index 0fc4e4e..0000000 --- a/matgraphdb/graph_kit/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List, Union - - -def is_in_range(val:Union[float, int],min_val:Union[float, int],max_val:Union[float, int], negation:bool=True): - """ - Screens a list of floats to keep only those that are within a given range. - - Args: - floats (Union[float, int]): A list of floats to be screened. - min_val (float): The minimum value to keep. - max_val (float): The maximum value to keep. - negation (bool, optional): If True, returns True if the value is within the range. - If False, returns True if the value is outside the range. - Defaults to True. - - Returns: - bool: A boolean indicating whether the value is within the given range. - """ - if negation: - return min_val <= val <= max_val - else: - return not (min_val <= val <= max_val) - -def is_in_list(val, string_list: List, negation: bool = True) -> bool: - """ - Checks if a value is (or is not, based on the inverse_check flag) in a given list. - - Args: - val: The value to be checked. - string_list (List): The list to check against. - negation (bool, optional): If True, returns True if the value is in the list. - If False, returns True if the value is not in the list. - Defaults to True. - - Returns: - bool: A boolean indicating whether the value is (or is not) in the list based on 'inverse_check'. - """ - return (val in string_list) if negation else (val not in string_list) \ No newline at end of file diff --git a/matgraphdb/utils/config.py b/matgraphdb/utils/config.py index b6b21bf..969a8c0 100644 --- a/matgraphdb/utils/config.py +++ b/matgraphdb/utils/config.py @@ -1,9 +1,12 @@ +import logging import os from pathlib import Path -import numpy as np from dotenv import load_dotenv -from variconfig import LoggingConfig +from platformdirs import user_config_dir +from variconfig import ConfigDict + +logger = logging.getLogger(__name__) load_dotenv() @@ -12,15 +15,33 @@ UTILS_DIR = str(FILE.parents[0]) DATA_DIR = os.getenv("DATA_DIR") -config = LoggingConfig.from_yaml(os.path.join(UTILS_DIR, "config.yml")) +DEFAULT_CFG = os.path.join(UTILS_DIR, "config.yml") + + +def load_config() -> ConfigDict: + """ + Load and merge configuration files, highest-priority first. + """ + user_cfg_dir = Path(user_config_dir("parquetdb")) + user_cfg = user_cfg_dir / "config.yml" + if not user_cfg.exists(): + user_cfg_dir.mkdir(parents=True, exist_ok=True) + if Path(DEFAULT_CFG).exists() and not user_cfg.exists(): + import shutil + + shutil.copy2(DEFAULT_CFG, user_cfg) + logger.info(f"Created user config file at {user_cfg}") + + user_cfg = Path(user_config_dir("parquetdb")) / "config.yml" + + logger.info(f"Config file: {user_cfg}") + cfg = ConfigDict.from_yaml(str(user_cfg)) -if DATA_DIR: - config.data_dir = DATA_DIR + if DATA_DIR: + logger.info(f"Setting data_dir to {DATA_DIR}") + cfg.data_dir = DATA_DIR + return cfg -# if config.log_dir: -# os.makedirs(config.log_dir, exist_ok=True) -# if config.data_dir: -# os.makedirs(config.data_dir, exist_ok=True) -np.set_printoptions(**config.numpy_config.np_printoptions.to_dict()) +config = load_config() diff --git a/matgraphdb/utils/config.yml b/matgraphdb/utils/config.yml index 1f689b9..3c79f32 100644 --- a/matgraphdb/utils/config.yml +++ b/matgraphdb/utils/config.yml @@ -2,7 +2,7 @@ root_dir: "." data_dir: "{{ root_dir }}/data" external_data_dir: "{{ data_dir }}/external" log_dir: "{{ root_dir }}/logs" -db_name: 'MatGraphDB' +tests_dir: "{{ root_dir }}/tests" n_cores: use_multiprocessing: True @@ -11,78 +11,3 @@ numpy_config: np_printoptions: linewidth: 400 precision: 3 - -matgraphdb_config: - normalize_kwargs: - load_kwargs: - batch_readahead: 16 - fragment_readahead: 4 - batch_size: 131072 - use_threads: True - max_partitions: 1024 - max_open_files: 1024 - max_rows_per_file: 1000000 - min_rows_per_group: 0 - max_rows_per_group: 1000000 - existing_data_behavior: overwrite_or_ignore - create_dir: True - load_kwargs: - batch_readahead: 16 - fragment_readahead: 4 - batch_size: 131072 - - - -logging_config: - version: 1 - disable_existing_loggers: False - - formatters: - simple: - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - datefmt: '%Y-%m-%d %H:%M:%S' - - user: - format: '%(message)s' - datefmt: '%Y-%m-%d %H:%M:%S' - - handlers: - console: - class: logging.StreamHandler - formatter: simple - stream: ext://sys.stdout - - user_console: - class: logging.StreamHandler - formatter: user - stream: ext://sys.stdout - - # file: - # class: logging.FileHandler - # formatter: simple - # filename: "{{ log_dir }}/matgraphdb.log" - # mode: a - - loggers: - matgraphdb: - level: CRITICAL - handlers: [console] - propagate: no - parquetdb: - level: CRITICAL - handlers: [console] - propagate: no - timing: - level: DEBUG - handlers: [console] - propagate: no - tests: - level: DEBUG - handlers: [console] - propagate: no - - - # root: - # level: INFO - # handlers: [console] - # propagate: no diff --git a/matgraphdb/utils/general_utils.py b/matgraphdb/utils/general_utils.py index e1eb148..459011c 100644 --- a/matgraphdb/utils/general_utils.py +++ b/matgraphdb/utils/general_utils.py @@ -78,35 +78,6 @@ def get_function_args(func: Callable): return args, kwargs -def set_verbosity(verbose: int): - """ - Sets the verbosity level for the logger. - - Args: - verbose (int): The verbosity level. 0 is no logging, 1 is INFO level logging, and 2 is DEBUG level logging. - """ - if not isinstance(verbose, int): - raise TypeError( - "Verbose must be an integer. The higher the number, the more verbose the logging." - ) - if verbose == 0: - config.logging_config.loggers.matgraphdb.level = logging.CRITICAL - elif verbose == 1: - config.logging_config.loggers.matgraphdb.level = logging.ERROR - elif verbose == 2: - config.logging_config.loggers.matgraphdb.level = logging.WARNING - elif verbose == 3: - config.logging_config.loggers.matgraphdb.level = logging.INFO - elif verbose == 4: - config.logging_config.loggers.matgraphdb.level = logging.DEBUG - config.logging_config.loggers.parquetdb.level = logging.DEBUG - else: - raise ValueError( - "Verbose must be an integer between 0 and 4. The higher the number, the more verbose the logging." - ) - config.apply() - - def download_test_data(save_path: str): import requests diff --git a/matgraphdb/utils/log_utils.py b/matgraphdb/utils/log_utils.py new file mode 100644 index 0000000..bf1098e --- /dev/null +++ b/matgraphdb/utils/log_utils.py @@ -0,0 +1,89 @@ +import logging +import logging.config +import os +from datetime import datetime + + +def set_verbose_level(verbose: int): + user_logger = logging.getLogger("user") + package_logger = logging.getLogger("matgraphdb") + + if verbose == 0: + user_logger.setLevel(logging.CRITICAL) + package_logger.setLevel(logging.CRITICAL) + elif verbose == 1: + user_logger.setLevel(logging.DEBUG) + package_logger.setLevel(logging.CRITICAL) + elif verbose >= 2: + user_logger.setLevel(logging.DEBUG) + package_logger.setLevel(logging.DEBUG) + + +class UserFriendlyFormatter(logging.Formatter): + """Custom formatter that makes warnings and errors more noticeable to users""" + + # ANSI color codes for terminal output + YELLOW = "\033[93m" # Warning + RED = "\033[91m" # Error/Critical + BOLD = "\033[1m" # Bold text + RESET = "\033[0m" # Reset formatting + + def format(self, record): + # Default format for regular messages + self._style._fmt = "%(message)s" + + # Special formatting for warnings and errors + if record.levelno >= logging.ERROR: + self._style._fmt = f"{self.RED}{self.BOLD}ERROR: %(message)s{self.RESET}" + elif record.levelno >= logging.WARNING: + self._style._fmt = ( + f"{self.YELLOW}{self.BOLD}WARNING: %(message)s{self.RESET}" + ) + + return super().format(record) + + +logging_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "simple": { + "format": "[%(levelname)s] %(asctime)s - %(name)s[%(lineno)d][%(funcName)s] - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + "user": { + "()": UserFriendlyFormatter, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "simple", + "stream": "ext://sys.stdout", + }, + "user_console": { + "class": "logging.StreamHandler", + "formatter": "user", + "stream": "ext://sys.stdout", + }, + "file": { + "class": "logging.FileHandler", + "formatter": "simple", + "filename": "parquetdb.log", + "mode": "a", + }, + }, + "loggers": { + "matgraphdb": { + "level": "INFO", + "handlers": ["console"], # , "file"], + "propagate": True, + }, + "user": {"level": "DEBUG", "handlers": ["user_console"], "propagate": False}, + "tests": {"level": "DEBUG", "handlers": ["console"], "propagate": False}, + }, +} + + +def setup_logging(): + logging.config.dictConfig(logging_config) diff --git a/tests/test_data/material/material_0.parquet b/tests/test_data/material/material_0.parquet index 47d3de0..4d52118 100644 Binary files a/tests/test_data/material/material_0.parquet and b/tests/test_data/material/material_0.parquet differ diff --git a/tests/test_material_nodes.py b/tests/test_material_nodes.py index dff2094..184c064 100644 --- a/tests/test_material_nodes.py +++ b/tests/test_material_nodes.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from matgraphdb.core.nodes.materials import MaterialStore +from matgraphdb.core.material_store import MaterialStore TEMP_DIR = Path(tempfile.mkdtemp()) diff --git a/tests/test_matgraphdb.py b/tests/test_matgraphdb.py index 8fdcb1f..a21d60a 100644 --- a/tests/test_matgraphdb.py +++ b/tests/test_matgraphdb.py @@ -3,12 +3,12 @@ import tempfile from pathlib import Path +import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest -from matgraphdb.core.edges import * -from matgraphdb.core.matgraphdb import MatGraphDB -from matgraphdb.core.nodes import * +from matgraphdb import MaterialStore, MatGraphDB, generators current_dir = Path(__file__).parent TEST_DATA_DIR = current_dir / "test_data" @@ -55,19 +55,19 @@ def material_store(): @pytest.fixture def node_generator_data(matgraphdb): node_generators = [ - {"generator_func": element}, - {"generator_func": chemenv}, - {"generator_func": crystal_system}, - {"generator_func": magnetic_state}, - {"generator_func": oxidation_state}, - {"generator_func": space_group}, - {"generator_func": wyckoff}, + {"generator_func": generators.element}, + {"generator_func": generators.chemenv}, + {"generator_func": generators.crystal_system}, + {"generator_func": generators.magnetic_state}, + {"generator_func": generators.oxidation_state}, + {"generator_func": generators.space_group}, + {"generator_func": generators.wyckoff}, { - "generator_func": material_site, + "generator_func": generators.material_site, "generator_args": {"material_store": matgraphdb.node_stores["material"]}, }, { - "generator_func": material_lattice, + "generator_func": generators.material_lattice, "generator_args": {"material_store": matgraphdb.node_stores["material"]}, }, ] @@ -89,53 +89,53 @@ def edge_generator_data(node_generator_data): edge_generators = [ { - "generator_func": element_element_neighborsByGroupPeriod, + "generator_func": generators.element_element_neighborsByGroupPeriod, "generator_args": {"element_store": matgraphdb.node_stores["element"]}, }, { - "generator_func": element_oxiState_canOccur, + "generator_func": generators.element_oxiState_canOccur, "generator_args": { "element_store": matgraphdb.node_stores["element"], "oxiState_store": matgraphdb.node_stores["oxidation_state"], }, }, { - "generator_func": material_chemenv_containsSite, + "generator_func": generators.material_chemenv_containsSite, "generator_args": { "material_store": matgraphdb.node_stores["material"], "chemenv_store": matgraphdb.node_stores["chemenv"], }, }, { - "generator_func": material_crystalSystem_has, + "generator_func": generators.material_crystalSystem_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "crystal_system_store": matgraphdb.node_stores["crystal_system"], }, }, { - "generator_func": material_element_has, + "generator_func": generators.material_element_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "element_store": matgraphdb.node_stores["element"], }, }, { - "generator_func": material_lattice_has, + "generator_func": generators.material_lattice_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "lattice_store": matgraphdb.node_stores["material_lattice"], }, }, { - "generator_func": material_spg_has, + "generator_func": generators.material_spg_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "spg_store": matgraphdb.node_stores["space_group"], }, }, { - "generator_func": element_chemenv_canOccur, + "generator_func": generators.element_chemenv_canOccur, "generator_args": { "element_store": matgraphdb.node_stores["element"], "chemenv_store": matgraphdb.node_stores["chemenv"], @@ -335,53 +335,53 @@ def test_moving_matgraphdb(edge_generator_data): matgraphdb = MatGraphDB(storage_path=new_dir) edge_generators = [ { - "generator_func": element_element_neighborsByGroupPeriod, + "generator_func": generators.element_element_neighborsByGroupPeriod, "generator_args": {"element_store": matgraphdb.node_stores["element"]}, }, { - "generator_func": element_oxiState_canOccur, + "generator_func": generators.element_oxiState_canOccur, "generator_args": { "element_store": matgraphdb.node_stores["element"], "oxiState_store": matgraphdb.node_stores["oxidation_state"], }, }, { - "generator_func": material_chemenv_containsSite, + "generator_func": generators.material_chemenv_containsSite, "generator_args": { "material_store": matgraphdb.node_stores["material"], "chemenv_store": matgraphdb.node_stores["chemenv"], }, }, { - "generator_func": material_crystalSystem_has, + "generator_func": generators.material_crystalSystem_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "crystal_system_store": matgraphdb.node_stores["crystal_system"], }, }, { - "generator_func": material_element_has, + "generator_func": generators.material_element_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "element_store": matgraphdb.node_stores["element"], }, }, { - "generator_func": material_lattice_has, + "generator_func": generators.material_lattice_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "lattice_store": matgraphdb.node_stores["material_lattice"], }, }, { - "generator_func": material_spg_has, + "generator_func": generators.material_spg_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "spg_store": matgraphdb.node_stores["space_group"], }, }, { - "generator_func": element_chemenv_canOccur, + "generator_func": generators.element_chemenv_canOccur, "generator_args": { "element_store": matgraphdb.node_stores["element"], "chemenv_store": matgraphdb.node_stores["chemenv"], @@ -421,7 +421,7 @@ def test_dependency_updates(matgraphdb, node_generator_data): edge_generators = [ { - "generator_func": material_crystalSystem_has, + "generator_func": generators.material_crystalSystem_has, "generator_args": { "material_store": matgraphdb.node_stores["material"], "crystal_system_store": matgraphdb.node_stores["crystal_system"],