diff --git a/.gitignore b/.gitignore index 5330055..0b149e5 100644 --- a/.gitignore +++ b/.gitignore @@ -167,7 +167,7 @@ poly_graphs_lib/cfg/mp_api.yml /matgl/ /config.yml /private_config.yml -tmp +tmp.py /logs/ /figures/ @@ -187,7 +187,7 @@ dev_scripts/ /notes/ /my_tests/ /data -*data/ + _version.py diff --git a/README.md b/README.md index b89f4c6..4709ad3 100644 --- a/README.md +++ b/README.md @@ -23,18 +23,49 @@ Check out the [docs](https://romerogroup.github.io/MatGraphDB/) ## Installing -### Installing via pip +### Regular install + +#### Install via pip + ```bash pip install matgraphdb ``` -### Installing from github + +#### Install from github + ```bash git clone https://github.com/romerogroup/MatGraphDB.git cd MatGraphDB pip install -e . ``` + +### Install with ML dependencies + +You may want to install the package with its ML dependencies. This will install the latest version of PyTorch and the PyTorch Geometric package. This will be dependent on the CUDA version you have installed. + +#### Easy install (cpu) + +The easiest way to install the package with ML dependencies is to use the `[ml]` flag. +```bash +pip install matgraphdb[ml] +``` + +#### Manual install (gpu) + +Here is an example of how to install the package with GPU support with CUDA 11.8. If you have a different version of CUDA installed, you can replace the version numbers `cu118` with the appropriate version for your system. + + +```bash +pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118 + +pip install torch_geometric + +pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu118.html +``` + + ## Usage ### Interacting with the materials database. diff --git a/docs/source/01_tutorials/01 - Getting Started.ipynb b/docs/source/01_tutorials/01 - Getting Started.ipynb index 1b179da..209f256 100644 --- a/docs/source/01_tutorials/01 - Getting Started.ipynb +++ b/docs/source/01_tutorials/01 - Getting Started.ipynb @@ -1006,7 +1006,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.9.21", + "nbsphinx": { + "execute": "never" + } } }, "nbformat": 4, diff --git a/docs/source/01_tutorials/02 - Managing Graphs in MatGraphDB.ipynb b/docs/source/01_tutorials/02 - Managing Graphs in MatGraphDB.ipynb index 0b39af1..fd8869e 100644 --- a/docs/source/01_tutorials/02 - Managing Graphs in MatGraphDB.ipynb +++ b/docs/source/01_tutorials/02 - Managing Graphs in MatGraphDB.ipynb @@ -969,7 +969,10 @@ "pygments_lexer": "ipython3", "version": "3.9.21" }, - "name": "Example 2 - Managing Graphs in MatGraphDB" + "name": "Example 2 - Managing Graphs in MatGraphDB", + "nbsphinx": { + "execute": "never" + } }, "nbformat": 4, "nbformat_minor": 4 diff --git a/docs/source/01_tutorials/03 - Graph Generators in MatgraphDB.ipynb b/docs/source/01_tutorials/03 - Graph Generators in MatgraphDB.ipynb index 6450eeb..8e635a1 100644 --- a/docs/source/01_tutorials/03 - Graph Generators in MatgraphDB.ipynb +++ b/docs/source/01_tutorials/03 - Graph Generators in MatgraphDB.ipynb @@ -483,7 +483,7 @@ ], "source": [ "import pandas as pd\n", - "from matgraphdb import node_generator\n", + "from parquetdb import node_generator\n", "from matgraphdb.utils.config import PKG_DIR\n", "\n", "BASE_ELEMENT_FILE = os.path.join(\n", @@ -1193,7 +1193,7 @@ "metadata": {}, "outputs": [], "source": [ - "from matgraphdb import edge_generator\n", + "from parquetdb import edge_generator\n", "import pyarrow as pa\n", "\n", "\n", @@ -1905,6 +1905,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", + "nbsphinx": { + "execute": "never" + }, "pygments_lexer": "ipython3", "version": "3.9.21" } diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.EdgeStore.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.EdgeStore.rst deleted file mode 100644 index 0357b49..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.EdgeStore.rst +++ /dev/null @@ -1,92 +0,0 @@ -matgraphdb.core.edges.EdgeStore -=============================== - -.. currentmodule:: matgraphdb.core.edges - -.. autoclass:: EdgeStore - - - .. automethod:: __init__ - - - .. rubric:: Methods - - .. autosummary:: - - ~EdgeStore.__init__ - ~EdgeStore.backup_database - ~EdgeStore.construct_table - ~EdgeStore.copy_dataset - ~EdgeStore.create - ~EdgeStore.create_edges - ~EdgeStore.dataset_exists - ~EdgeStore.delete - ~EdgeStore.delete_edges - ~EdgeStore.drop_dataset - ~EdgeStore.export_dataset - ~EdgeStore.export_partitioned_dataset - ~EdgeStore.get_current_files - ~EdgeStore.get_field_metadata - ~EdgeStore.get_field_names - ~EdgeStore.get_file_sizes - ~EdgeStore.get_metadata - ~EdgeStore.get_n_rows_per_row_group_per_file - ~EdgeStore.get_number_of_row_groups_per_file - ~EdgeStore.get_number_of_rows_per_file - ~EdgeStore.get_parquet_column_metadata_per_file - ~EdgeStore.get_parquet_file_metadata_per_file - ~EdgeStore.get_parquet_file_row_group_metadata_per_file - ~EdgeStore.get_row_group_sizes_per_file - ~EdgeStore.get_schema - ~EdgeStore.get_serialized_metadata_size_per_file - ~EdgeStore.import_dataset - ~EdgeStore.is_empty - ~EdgeStore.merge_datasets - ~EdgeStore.normalize - ~EdgeStore.normalize_edges - ~EdgeStore.preprocess_table - ~EdgeStore.process_data_with_python_objects - ~EdgeStore.read - ~EdgeStore.read_edges - ~EdgeStore.rename_dataset - ~EdgeStore.rename_fields - ~EdgeStore.restore_database - ~EdgeStore.set_field_metadata - ~EdgeStore.set_metadata - ~EdgeStore.setup - ~EdgeStore.sort_fields - ~EdgeStore.summary - ~EdgeStore.to_nested - ~EdgeStore.transform - ~EdgeStore.update - ~EdgeStore.update_edges - ~EdgeStore.update_schema - ~EdgeStore.validate_edges - - - - - - .. rubric:: Attributes - - .. autosummary:: - - ~EdgeStore.basename_template - ~EdgeStore.columns - ~EdgeStore.dataset_name - ~EdgeStore.db_path - ~EdgeStore.edge_metadata_keys - ~EdgeStore.edge_type - ~EdgeStore.n_columns - ~EdgeStore.n_edges - ~EdgeStore.n_features - ~EdgeStore.n_files - ~EdgeStore.n_row_groups_per_file - ~EdgeStore.n_rows - ~EdgeStore.n_rows_per_file - ~EdgeStore.n_rows_per_row_group_per_file - ~EdgeStore.required_fields - ~EdgeStore.serialized_metadata_size_per_file - ~EdgeStore.storage_path - - \ No newline at end of file diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.edge_generator.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.edge_generator.rst deleted file mode 100644 index 3aae209..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.edges.edge_generator.rst +++ /dev/null @@ -1,6 +0,0 @@ -matgraphdb.core.edges.edge\_generator -===================================== - -.. currentmodule:: matgraphdb.core.edges - -.. autofunction:: edge_generator \ No newline at end of file diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.generator_store.GeneratorStore.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.generator_store.GeneratorStore.rst deleted file mode 100644 index a684fcf..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.generator_store.GeneratorStore.rst +++ /dev/null @@ -1,91 +0,0 @@ -matgraphdb.core.generator\_store.GeneratorStore -=============================================== - -.. currentmodule:: matgraphdb.core.generator_store - -.. autoclass:: GeneratorStore - - - .. automethod:: __init__ - - - .. rubric:: Methods - - .. autosummary:: - - ~GeneratorStore.__init__ - ~GeneratorStore.backup_database - ~GeneratorStore.construct_table - ~GeneratorStore.copy_dataset - ~GeneratorStore.create - ~GeneratorStore.dataset_exists - ~GeneratorStore.delete - ~GeneratorStore.delete_generator - ~GeneratorStore.drop_dataset - ~GeneratorStore.export_dataset - ~GeneratorStore.export_partitioned_dataset - ~GeneratorStore.get_current_files - ~GeneratorStore.get_field_metadata - ~GeneratorStore.get_field_names - ~GeneratorStore.get_file_sizes - ~GeneratorStore.get_metadata - ~GeneratorStore.get_n_rows_per_row_group_per_file - ~GeneratorStore.get_number_of_row_groups_per_file - ~GeneratorStore.get_number_of_rows_per_file - ~GeneratorStore.get_parquet_column_metadata_per_file - ~GeneratorStore.get_parquet_file_metadata_per_file - ~GeneratorStore.get_parquet_file_row_group_metadata_per_file - ~GeneratorStore.get_row_group_sizes_per_file - ~GeneratorStore.get_schema - ~GeneratorStore.get_serialized_metadata_size_per_file - ~GeneratorStore.import_dataset - ~GeneratorStore.is_empty - ~GeneratorStore.is_in - ~GeneratorStore.list_generators - ~GeneratorStore.load_generator - ~GeneratorStore.load_generator_data - ~GeneratorStore.merge_datasets - ~GeneratorStore.normalize - ~GeneratorStore.preprocess_table - ~GeneratorStore.process_data_with_python_objects - ~GeneratorStore.read - ~GeneratorStore.rename_dataset - ~GeneratorStore.rename_fields - ~GeneratorStore.restore_database - ~GeneratorStore.run_generator - ~GeneratorStore.set_field_metadata - ~GeneratorStore.set_metadata - ~GeneratorStore.sort_fields - ~GeneratorStore.store_generator - ~GeneratorStore.summary - ~GeneratorStore.to_nested - ~GeneratorStore.transform - ~GeneratorStore.update - ~GeneratorStore.update_schema - - - - - - .. rubric:: Attributes - - .. autosummary:: - - ~GeneratorStore.basename_template - ~GeneratorStore.columns - ~GeneratorStore.dataset_name - ~GeneratorStore.db_path - ~GeneratorStore.generator_names - ~GeneratorStore.metadata_keys - ~GeneratorStore.n_columns - ~GeneratorStore.n_files - ~GeneratorStore.n_generators - ~GeneratorStore.n_row_groups_per_file - ~GeneratorStore.n_rows - ~GeneratorStore.n_rows_per_file - ~GeneratorStore.n_rows_per_row_group_per_file - ~GeneratorStore.required_fields - ~GeneratorStore.serialized_metadata_size_per_file - ~GeneratorStore.storage_path - - \ No newline at end of file diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.graph_db.GraphDB.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.graph_db.GraphDB.rst deleted file mode 100644 index 823c1a1..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.graph_db.GraphDB.rst +++ /dev/null @@ -1,70 +0,0 @@ -matgraphdb.core.graph\_db.GraphDB -================================= - -.. currentmodule:: matgraphdb.core.graph_db - -.. autoclass:: GraphDB - - - .. automethod:: __init__ - - - .. rubric:: Methods - - .. autosummary:: - - ~GraphDB.__init__ - ~GraphDB.add_edge_generator - ~GraphDB.add_edge_store - ~GraphDB.add_edge_type - ~GraphDB.add_edges - ~GraphDB.add_generator_dependency - ~GraphDB.add_node_generator - ~GraphDB.add_node_store - ~GraphDB.add_node_type - ~GraphDB.add_nodes - ~GraphDB.construct_table - ~GraphDB.delete_edges - ~GraphDB.delete_nodes - ~GraphDB.edge_exists - ~GraphDB.edge_is_empty - ~GraphDB.generator_consistency_check - ~GraphDB.get_edge_store - ~GraphDB.get_generator_dependency_graph - ~GraphDB.get_generator_type - ~GraphDB.get_node_store - ~GraphDB.get_nodes - ~GraphDB.list_edge_types - ~GraphDB.list_node_types - ~GraphDB.node_exists - ~GraphDB.node_is_empty - ~GraphDB.normalize_all_edges - ~GraphDB.normalize_all_nodes - ~GraphDB.normalize_edges - ~GraphDB.normalize_nodes - ~GraphDB.read_edges - ~GraphDB.read_nodes - ~GraphDB.remove_edge_store - ~GraphDB.remove_edge_type - ~GraphDB.remove_node_store - ~GraphDB.remove_node_type - ~GraphDB.run_edge_generator - ~GraphDB.run_node_generator - ~GraphDB.summary - ~GraphDB.update_edges - ~GraphDB.update_nodes - - - - - - .. rubric:: Attributes - - .. autosummary:: - - ~GraphDB.n_edge_types - ~GraphDB.n_edges_per_type - ~GraphDB.n_node_types - ~GraphDB.n_nodes_per_type - - \ No newline at end of file diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.NodeStore.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.NodeStore.rst deleted file mode 100644 index cba89bd..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.NodeStore.rst +++ /dev/null @@ -1,90 +0,0 @@ -matgraphdb.core.nodes.NodeStore -=============================== - -.. currentmodule:: matgraphdb.core.nodes - -.. autoclass:: NodeStore - - - .. automethod:: __init__ - - - .. rubric:: Methods - - .. autosummary:: - - ~NodeStore.__init__ - ~NodeStore.backup_database - ~NodeStore.construct_table - ~NodeStore.copy_dataset - ~NodeStore.create - ~NodeStore.create_nodes - ~NodeStore.dataset_exists - ~NodeStore.delete - ~NodeStore.delete_nodes - ~NodeStore.drop_dataset - ~NodeStore.export_dataset - ~NodeStore.export_partitioned_dataset - ~NodeStore.get_current_files - ~NodeStore.get_field_metadata - ~NodeStore.get_field_names - ~NodeStore.get_file_sizes - ~NodeStore.get_metadata - ~NodeStore.get_n_rows_per_row_group_per_file - ~NodeStore.get_number_of_row_groups_per_file - ~NodeStore.get_number_of_rows_per_file - ~NodeStore.get_parquet_column_metadata_per_file - ~NodeStore.get_parquet_file_metadata_per_file - ~NodeStore.get_parquet_file_row_group_metadata_per_file - ~NodeStore.get_row_group_sizes_per_file - ~NodeStore.get_schema - ~NodeStore.get_serialized_metadata_size_per_file - ~NodeStore.import_dataset - ~NodeStore.initialize - ~NodeStore.is_empty - ~NodeStore.merge_datasets - ~NodeStore.normalize - ~NodeStore.normalize_nodes - ~NodeStore.preprocess_table - ~NodeStore.process_data_with_python_objects - ~NodeStore.read - ~NodeStore.read_nodes - ~NodeStore.rename_dataset - ~NodeStore.rename_fields - ~NodeStore.restore_database - ~NodeStore.set_field_metadata - ~NodeStore.set_metadata - ~NodeStore.sort_fields - ~NodeStore.summary - ~NodeStore.to_nested - ~NodeStore.transform - ~NodeStore.update - ~NodeStore.update_nodes - ~NodeStore.update_schema - - - - - - .. rubric:: Attributes - - .. autosummary:: - - ~NodeStore.basename_template - ~NodeStore.columns - ~NodeStore.dataset_name - ~NodeStore.db_path - ~NodeStore.n_columns - ~NodeStore.n_features - ~NodeStore.n_files - ~NodeStore.n_nodes - ~NodeStore.n_row_groups_per_file - ~NodeStore.n_rows - ~NodeStore.n_rows_per_file - ~NodeStore.n_rows_per_row_group_per_file - ~NodeStore.name_column - ~NodeStore.node_metadata_keys - ~NodeStore.serialized_metadata_size_per_file - ~NodeStore.storage_path - - \ No newline at end of file diff --git a/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.node_generator.rst b/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.node_generator.rst deleted file mode 100644 index f7c21e0..0000000 --- a/docs/source/03_api/core/_autosummary/matgraphdb.core.nodes.node_generator.rst +++ /dev/null @@ -1,6 +0,0 @@ -matgraphdb.core.nodes.node\_generator -===================================== - -.. currentmodule:: matgraphdb.core.nodes - -.. autofunction:: node_generator \ No newline at end of file diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.core.MatGraphDB.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.core.MatGraphDB.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.core.MatGraphDB.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.core.MatGraphDB.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_chemenv_canOccur.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_chemenv_canOccur.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_chemenv_canOccur.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_chemenv_canOccur.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_element_bonds.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_element_bonds.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_element_bonds.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_element_bonds.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_element_neighborsByGroupPeriod.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_element_neighborsByGroupPeriod.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_element_neighborsByGroupPeriod.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_element_neighborsByGroupPeriod.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_oxiState_canOccur.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_oxiState_canOccur.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.element_oxiState_canOccur.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.element_oxiState_canOccur.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_chemenv_containsSite.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_chemenv_containsSite.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_chemenv_containsSite.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_chemenv_containsSite.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_crystalSystem_has.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_crystalSystem_has.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_crystalSystem_has.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_crystalSystem_has.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_element_has.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_element_has.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_element_has.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_element_has.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_lattice_has.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_lattice_has.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_lattice_has.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_lattice_has.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_spg_has.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_spg_has.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.material_spg_has.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.material_spg_has.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.spg_crystalSystem_isApart.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.spg_crystalSystem_isApart.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.edges.spg_crystalSystem_isApart.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.edges.spg_crystalSystem_isApart.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.chemenv.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.chemenv.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.chemenv.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.chemenv.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.crystal_system.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.crystal_system.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.crystal_system.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.crystal_system.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.element.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.element.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.element.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.element.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.magnetic_state.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.magnetic_state.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.magnetic_state.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.magnetic_state.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.oxidation_state.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.oxidation_state.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.oxidation_state.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.oxidation_state.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.space_group.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.space_group.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.space_group.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.space_group.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.wyckoff.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.wyckoff.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.generators.wyckoff.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.generators.wyckoff.rst diff --git a/docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.materials.MaterialStore.rst b/docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.materials.MaterialStore.rst similarity index 100% rename from docs/source/03_api/materials/_autosummary/matgraphdb.materials.nodes.materials.MaterialStore.rst rename to docs/source/03_api/core/_autosummary/matgraphdb.materials.nodes.materials.MaterialStore.rst diff --git a/docs/source/03_api/core/edge_generator.rst b/docs/source/03_api/core/edge_generator.rst deleted file mode 100644 index 3e1b9cd..0000000 --- a/docs/source/03_api/core/edge_generator.rst +++ /dev/null @@ -1,9 +0,0 @@ -EdgeGenerator -======================== - -- :func:`edge_generator ` - A decorator that validates the input arguments of a function and converts them into a dataframe. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.edges.edge_generator diff --git a/docs/source/03_api/core/edge_generators.rst b/docs/source/03_api/core/edge_generators.rst new file mode 100644 index 0000000..b156a27 --- /dev/null +++ b/docs/source/03_api/core/edge_generators.rst @@ -0,0 +1,38 @@ +Edge Generators +======================== + +- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. + +- :func:`element_element_bonds ` - A function that generates the bonds of an element. + +- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. + +- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. + +- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. + +- :func:`material_element_has ` - A function that generates the elements of a material. + +- :func:`material_lattice_has ` - A function that generates the lattice of a material. + +- :func:`material_spg_has ` - A function that generates the space group of a material. + +- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. + +- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. + + + +.. autosummary:: + :toctree: _autosummary + + matgraphdb.core.edges.element_element_neighborsByGroupPeriod + matgraphdb.core.edges.element_element_bonds + matgraphdb.core.edges.element_oxiState_canOccur + matgraphdb.core.edges.material_chemenv_containsSite + matgraphdb.core.edges.material_crystalSystem_has + matgraphdb.core.edges.material_element_has + matgraphdb.core.edges.material_lattice_has + matgraphdb.core.edges.material_spg_has + matgraphdb.core.edges.element_chemenv_canOccur + matgraphdb.core.edges.spg_crystalSystem_isApart diff --git a/docs/source/03_api/core/edge_store.rst b/docs/source/03_api/core/edge_store.rst deleted file mode 100644 index c98dfc3..0000000 --- a/docs/source/03_api/core/edge_store.rst +++ /dev/null @@ -1,9 +0,0 @@ -Edge Store -======================== - -- :class:`EdgeStore ` - The main interface class that provides database-like operations over Parquet files. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.edges.EdgeStore diff --git a/docs/source/03_api/core/generator_store.rst b/docs/source/03_api/core/generator_store.rst deleted file mode 100644 index 205741a..0000000 --- a/docs/source/03_api/core/generator_store.rst +++ /dev/null @@ -1,9 +0,0 @@ -GeneratorStore -======================== - -- :class:`GeneratorStore ` - A store for managing generator functions in a graph database. This class handles serialization, storage, and loading of functions that generate edges between nodes. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.generator_store.GeneratorStore diff --git a/docs/source/03_api/core/graphdb.rst b/docs/source/03_api/core/graphdb.rst deleted file mode 100644 index 18ae065..0000000 --- a/docs/source/03_api/core/graphdb.rst +++ /dev/null @@ -1,9 +0,0 @@ -GraphDB -======================== - -- :class:`GraphDB ` - A manager for a graph storing multiple node types and edge types. Each node type and edge type is backed by a separate ParquetDB instance (wrapped by NodeStore or EdgeStore). - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.graph_db.GraphDB diff --git a/docs/source/03_api/core/index.rst b/docs/source/03_api/core/index.rst index 2e0d17d..b422126 100644 --- a/docs/source/03_api/core/index.rst +++ b/docs/source/03_api/core/index.rst @@ -1,29 +1,62 @@ -.. _core-api-index: +.. _materials-api-index: -Core API +Materials API =================================== -The Core API provides the fundamental functionality of MatGraphDB, offering a robust interface for managing a graph database. This module contains the essential classes and methods that enable database-like operations +The Materials API provides the fundamental functionality of MatGraphDB, offering a robust interface for managing a graph database. This module contains the essential classes and methods that enable database-like operations The core components include: -- :class:`EdgeStore ` - The main interface class that provides database-like operations over Parquet files. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. +- :class:`MatGraphDB ` - The main interface class that provides database-like operations over Parquet files. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. -- :func:`edge_generator ` - A decorator that validates the input arguments of a function and converts them into a dataframe. +- :class:`MaterialStore ` - A store for managing materials in a graph database. -- :class:`GeneratorStore ` - A store for managing generator functions in a graph database. This class handles serialization, storage, and loading of functions that generate edges between nodes. -- :class:`NodeStore ` - A store for managing node features in a graph database. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. +Node Generators +======================== -- :func:`node_generator ` - A decorator that validates the input arguments of a function and converts them into a dataframe. +- :func:`element ` - A function that generates the elements of a material. + +- :func:`chemenv ` - A function that generates the chemical environments of a material. + +- :func:`crystal_system ` - A function that generates the crystal systems of a material. + +- :func:`magnetic_state ` - A function that generates the magnetic states of a material. + +- :func:`oxidation_state ` - A function that generates the oxidation states of a material. + +- :func:`space_group ` - A function that generates the space groups of a material. + +- :func:`wyckoff ` - A function that generates the wyckoffs of a material. + + +Edge Generators +======================== + +- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. + +- :func:`element_element_bonds ` - A function that generates the bonds of an element. + +- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. + +- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. + +- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. + +- :func:`material_element_has ` - A function that generates the elements of a material. + +- :func:`material_lattice_has ` - A function that generates the lattice of a material. + +- :func:`material_spg_has ` - A function that generates the space group of a material. + +- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. + +- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. -- :class:`GraphDB ` - A manager for a graph storing multiple node types and edge types. Each node type and edge type is backed by a separate ParquetDB instance (wrapped by NodeStore or EdgeStore). .. toctree:: :maxdepth: 2 - node_store - node_generator - edge_store - edge_generator - generator_store - graphdb + matgraphdb_base + edge_generators + node_generators + material_store diff --git a/docs/source/03_api/core/material_store.rst b/docs/source/03_api/core/material_store.rst new file mode 100644 index 0000000..cf7be9d --- /dev/null +++ b/docs/source/03_api/core/material_store.rst @@ -0,0 +1,9 @@ +MaterialStore +======================== + +- :class:`MaterialStore ` - A store for managing materials in a graph database. + +.. autosummary:: + :toctree: _autosummary + + matgraphdb.core.nodes.materials.MaterialStore diff --git a/docs/source/03_api/core/matgraphdb.rst b/docs/source/03_api/core/matgraphdb.rst new file mode 100644 index 0000000..3adf586 --- /dev/null +++ b/docs/source/03_api/core/matgraphdb.rst @@ -0,0 +1,9 @@ +MatGraphDB +======================== + +- :class:`MatGraphDB ` - An extension of the `ParquetGraphDB ` class that provides additional functionality for managing materials data. + +.. autosummary:: + :toctree: _autosummary + + matgraphdb.core.matgraphdb.MatGraphDB diff --git a/docs/source/03_api/core/node_generator.rst b/docs/source/03_api/core/node_generator.rst deleted file mode 100644 index acf5990..0000000 --- a/docs/source/03_api/core/node_generator.rst +++ /dev/null @@ -1,9 +0,0 @@ -NodeGenerator -======================== - -- :func:`node_generator ` - A decorator that validates the input arguments of a function and converts them into a dataframe. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.nodes.node_generator diff --git a/docs/source/03_api/core/node_generators.rst b/docs/source/03_api/core/node_generators.rst new file mode 100644 index 0000000..c699279 --- /dev/null +++ b/docs/source/03_api/core/node_generators.rst @@ -0,0 +1,29 @@ +Node Generators +======================== + + +- :func:`element ` - A function that generates the elements of a material. + +- :func:`chemenv ` - A function that generates the chemical environments of a material. + +- :func:`crystal_system ` - A function that generates the crystal systems of a material. + +- :func:`magnetic_state ` - A function that generates the magnetic states of a material. + +- :func:`oxidation_state ` - A function that generates the oxidation states of a material. + +- :func:`space_group ` - A function that generates the space groups of a material. + +- :func:`wyckoff ` - A function that generates the wyckoffs of a material. + + +.. autosummary:: + :toctree: _autosummary + + matgraphdb.core.nodes.generators.element + matgraphdb.core.nodes.generators.chemenv + matgraphdb.core.nodes.generators.crystal_system + matgraphdb.core.nodes.generators.magnetic_state + matgraphdb.core.nodes.generators.oxidation_state + matgraphdb.core.nodes.generators.space_group + matgraphdb.core.nodes.generators.wyckoff diff --git a/docs/source/03_api/core/node_store.rst b/docs/source/03_api/core/node_store.rst deleted file mode 100644 index fa6da09..0000000 --- a/docs/source/03_api/core/node_store.rst +++ /dev/null @@ -1,9 +0,0 @@ -NodeStore -======================== - -- :class:`NodeStore ` - A store for managing node features in a graph database. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.core.nodes.NodeStore diff --git a/docs/source/03_api/index.rst b/docs/source/03_api/index.rst index 6677719..e77e136 100644 --- a/docs/source/03_api/index.rst +++ b/docs/source/03_api/index.rst @@ -6,7 +6,7 @@ API Reference :hidden: core/index - materials/index + In this section, you can explore the API documentation for MatGraphDB. @@ -19,10 +19,3 @@ In this section, you can explore the API documentation for MatGraphDB. Learn more about MatGraphDB's core functionality. -.. card:: materials API - :link: materials-api-index - :link-type: ref - :class-title: matgraphdb-card-title - - Learn more about MatGraphDB's materials functionality. - diff --git a/docs/source/03_api/materials/core.rst b/docs/source/03_api/materials/core.rst deleted file mode 100644 index 3ccfbf5..0000000 --- a/docs/source/03_api/materials/core.rst +++ /dev/null @@ -1,9 +0,0 @@ -MatGraphDB -======================== - -- :class:`MatGraphDB ` - An extension of the `GraphDB ` class that provides additional functionality for managing materials data. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.materials.core.MatGraphDB diff --git a/docs/source/03_api/materials/edge_generators.rst b/docs/source/03_api/materials/edge_generators.rst deleted file mode 100644 index 295b517..0000000 --- a/docs/source/03_api/materials/edge_generators.rst +++ /dev/null @@ -1,38 +0,0 @@ -Edge Generators -======================== - -- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. - -- :func:`element_element_bonds ` - A function that generates the bonds of an element. - -- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. - -- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. - -- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. - -- :func:`material_element_has ` - A function that generates the elements of a material. - -- :func:`material_lattice_has ` - A function that generates the lattice of a material. - -- :func:`material_spg_has ` - A function that generates the space group of a material. - -- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. - -- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. - - - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.materials.edges.element_element_neighborsByGroupPeriod - matgraphdb.materials.edges.element_element_bonds - matgraphdb.materials.edges.element_oxiState_canOccur - matgraphdb.materials.edges.material_chemenv_containsSite - matgraphdb.materials.edges.material_crystalSystem_has - matgraphdb.materials.edges.material_element_has - matgraphdb.materials.edges.material_lattice_has - matgraphdb.materials.edges.material_spg_has - matgraphdb.materials.edges.element_chemenv_canOccur - matgraphdb.materials.edges.spg_crystalSystem_isApart diff --git a/docs/source/03_api/materials/index.rst b/docs/source/03_api/materials/index.rst deleted file mode 100644 index f57c057..0000000 --- a/docs/source/03_api/materials/index.rst +++ /dev/null @@ -1,62 +0,0 @@ -.. _materials-api-index: - -Materials API -=================================== - -The Materials API provides the fundamental functionality of MatGraphDB, offering a robust interface for managing a graph database. This module contains the essential classes and methods that enable database-like operations -The core components include: - -- :class:`MatGraphDB ` - The main interface class that provides database-like operations over Parquet files. This class handles data storage, retrieval, querying, schema evolution, and complex data type management through an intuitive API that wraps PyArrow's functionality. - -- :class:`MaterialStore ` - A store for managing materials in a graph database. - - -Node Generators -======================== - -- :func:`element ` - A function that generates the elements of a material. - -- :func:`chemenv ` - A function that generates the chemical environments of a material. - -- :func:`crystal_system ` - A function that generates the crystal systems of a material. - -- :func:`magnetic_state ` - A function that generates the magnetic states of a material. - -- :func:`oxidation_state ` - A function that generates the oxidation states of a material. - -- :func:`space_group ` - A function that generates the space groups of a material. - -- :func:`wyckoff ` - A function that generates the wyckoffs of a material. - - -Edge Generators -======================== - -- :func:`element_element_neighborsByGroupPeriod ` - A function that generates the neighbors of an element by group and period. - -- :func:`element_element_bonds ` - A function that generates the bonds of an element. - -- :func:`element_oxiState_canOccur ` - A function that generates the possible oxidation states of an element. - -- :func:`material_chemenv_containsSite ` - A function that generates the sites of a material. - -- :func:`material_crystalSystem_has ` - A function that generates the crystal system of a material. - -- :func:`material_element_has ` - A function that generates the elements of a material. - -- :func:`material_lattice_has ` - A function that generates the lattice of a material. - -- :func:`material_spg_has ` - A function that generates the space group of a material. - -- :func:`element_chemenv_canOccur ` - A function that generates the possible oxidation states of an element. - -- :func:`spg_crystalSystem_isApart ` - A function that generates the crystal system of a material. - - -.. toctree:: - :maxdepth: 2 - - core - edge_generators - node_generators - material_store diff --git a/docs/source/03_api/materials/material_store.rst b/docs/source/03_api/materials/material_store.rst deleted file mode 100644 index 8b40419..0000000 --- a/docs/source/03_api/materials/material_store.rst +++ /dev/null @@ -1,9 +0,0 @@ -MaterialStore -======================== - -- :class:`MaterialStore ` - A store for managing materials in a graph database. - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.materials.nodes.materials.MaterialStore diff --git a/docs/source/03_api/materials/node_generators.rst b/docs/source/03_api/materials/node_generators.rst deleted file mode 100644 index 1cd8c3c..0000000 --- a/docs/source/03_api/materials/node_generators.rst +++ /dev/null @@ -1,29 +0,0 @@ -Node Generators -======================== - - -- :func:`element ` - A function that generates the elements of a material. - -- :func:`chemenv ` - A function that generates the chemical environments of a material. - -- :func:`crystal_system ` - A function that generates the crystal systems of a material. - -- :func:`magnetic_state ` - A function that generates the magnetic states of a material. - -- :func:`oxidation_state ` - A function that generates the oxidation states of a material. - -- :func:`space_group ` - A function that generates the space groups of a material. - -- :func:`wyckoff ` - A function that generates the wyckoffs of a material. - - -.. autosummary:: - :toctree: _autosummary - - matgraphdb.materials.nodes.generators.element - matgraphdb.materials.nodes.generators.chemenv - matgraphdb.materials.nodes.generators.crystal_system - matgraphdb.materials.nodes.generators.magnetic_state - matgraphdb.materials.nodes.generators.oxidation_state - matgraphdb.materials.nodes.generators.space_group - matgraphdb.materials.nodes.generators.wyckoff diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst deleted file mode 100644 index 76a82ca..0000000 --- a/docs/source/examples/index.rst +++ /dev/null @@ -1,25 +0,0 @@ -Examples for the MatGraphDB package -========================================== - -Welcome to the MatGraphDB examples! This collection of notebooks demonstrates use cases and practical applications of MatGraphDB. - -These examples are automatically generated from the `examples -directory`_ of the package and showcase how to effectively use MatGraphDB's features for data storage, querying, and management. Feel free to download and run these notebooks to explore the functionality firsthand. - -.. _examples directory: https://github.com/romerogroup/MatGraphDB/tree/main/examples/notebooks - - -.. nblinkgallery:: - :caption: Example Gallery - :name: rst-link-gallery - - notebooks/01 - Creating MatGraphDB Instance - -Contents --------- - -.. toctree:: - :maxdepth: 3 - :caption: Example Gallery - - notebooks/01 - Creating MatGraphDB Instance diff --git a/docs/source/examples/notebooks/01 - Creating MatGraphDB Instance.ipynb b/docs/source/examples/notebooks/01 - Creating MatGraphDB Instance.ipynb deleted file mode 100644 index 42468bb..0000000 --- a/docs/source/examples/notebooks/01 - Creating MatGraphDB Instance.ipynb +++ /dev/null @@ -1,1220 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Building a MatGraphDB Example with MPNearHull Data\n", - "\n", - "In this notebook, we demonstrate how to build a materials graph database using the\n", - "[MatGraphDB](https://github.com/your/matgraphdb) framework with the MPNearHull dataset.\n", - "\n", - "The steps include:\n", - "1. Importing required libraries and setting up configuration paths.\n", - "2. Downloading and extracting the dataset (and raw materials data if needed).\n", - "3. Creating a MatGraphDB instance.\n", - "4. Initializing node generators.\n", - "5. Initializing edge generators.\n", - "6. Verifying the database setup.\n", - "\n", - "Follow along and run each cell to see how the database is constructed." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Library imports and paths are set.\n" - ] - } - ], - "source": [ - "import os\n", - "import shutil\n", - "import zipfile\n", - "import gdown\n", - "\n", - "# Get the data directory from the config. You can change this to your own data directory.\n", - "DATA_DIR = os.path.join(\"..\",\"..\",\"data\",\"examples\",\"01\")\n", - "# Define the path to store the raw materials data.\n", - "MATERIALS_PATH = os.path.join(DATA_DIR, \"material\")\n", - "\n", - "MATGRAPHDB_PATH = os.path.join(DATA_DIR, \"MatGraphDB\")\n", - "\n", - "# Define the dataset URLs.\n", - "DATASET_URL = \"https://drive.google.com/uc?id=1zSmEQbV8pNvjWdhFuCwOeoOzvfoS5XKP\"\n", - "\n", - "# Define the URL for the raw materials data.\n", - "RAW_DATASET_URL = \"https://drive.google.com/uc?id=14guJqEK242XgRGEZA-zIrWyg4b-gX5zk\" # (Not used below but available)\n", - "\n", - "# # Define the path to store the raw materials data.\n", - "# RAW_DATASET_ZIP = os.path.join(config.data_dir, \"raw\", \"MPNearHull_v0.0.1_raw.zip\")\n", - "\n", - "# # Define the path to store the dataset.\n", - "# DATASET_ZIP = os.path.join(config.data_dir, \"datasets\", \"MPNearHull_v0.0.1.zip\")\n", - "\n", - "print(\"Library imports and paths are set.\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define Function for Downloading and Extracting Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading raw materials data...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading...\n", - "From (original): https://drive.google.com/uc?id=1zSmEQbV8pNvjWdhFuCwOeoOzvfoS5XKP\n", - "From (redirected): https://drive.google.com/uc?id=1zSmEQbV8pNvjWdhFuCwOeoOzvfoS5XKP&confirm=t&uuid=5bcba796-ff8e-4bb3-bc09-39d3f1136dc1\n", - "To: c:\\Users\\lllang\\Desktop\\Current_Projects\\MatGraphDB\\examples\\notebooks\\materials\\MPNearHull_v0.0.1_raw.zip\n", - "100%|██████████| 632M/632M [00:11<00:00, 53.6MB/s] \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting raw materials data...\n", - "Raw materials data ready!\n" - ] - } - ], - "source": [ - "def download_raw_materials(mp_materials_path):\n", - " \"\"\"\n", - " Download and extract the raw materials data if it is not already present.\n", - " \"\"\"\n", - " if not os.path.exists(mp_materials_path):\n", - " \n", - " os.makedirs(mp_materials_path, exist_ok=True)\n", - " print(\"Downloading raw materials data...\")\n", - " \n", - " raw_dataset_zip = os.path.join(mp_materials_path, \"MPNearHull_v0.0.1_raw.zip\")\n", - " # Note: Here we use DATASET_URL as in the original code.\n", - " gdown.download(DATASET_URL, output=raw_dataset_zip, quiet=False)\n", - " \n", - " print(\"Extracting raw materials data...\")\n", - " with zipfile.ZipFile(raw_dataset_zip, \"r\") as zip_ref:\n", - " zip_ref.extractall(mp_materials_path)\n", - " \n", - " \n", - " files=os.listdir(mp_materials_path)\n", - " os.remove(raw_dataset_zip)\n", - " mp_nearhull_path = os.path.join(mp_materials_path, \"MPNearHull\")\n", - " tmp_materials_path = os.path.join(mp_nearhull_path, \"nodes\", \"material\")\n", - " materials_files = os.listdir(tmp_materials_path)\n", - " for file in materials_files:\n", - " shutil.move(os.path.join(tmp_materials_path, file), os.path.join(mp_materials_path, file))\n", - " \n", - " shutil.rmtree(mp_nearhull_path)\n", - " print(\"Raw materials data ready!\")\n", - " \n", - "# Optionally, download the raw materials data if you plan to initialize from raw files.\n", - "if not os.path.exists(MATERIALS_PATH):\n", - " download_raw_materials(MATERIALS_PATH)\n", - "else:\n", - " print(\"Raw materials data already exists.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialization" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initialize a Materials Store" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - "NODE STORE SUMMARY\n", - "============================================================\n", - "Node type: material\n", - "• Number of nodes: 80643\n", - "• Number of features: 136\n", - "Storage path: ..\\..\\data\\examples\\01\\material\n", - "\n", - "\n", - "############################################################\n", - "METADATA\n", - "############################################################\n", - "• class: MaterialStore\n", - "• class_module: matgraphdb.materials.nodes.materials\n", - "• node_type: material\n", - "• name_column: id\n", - "\n", - "############################################################\n", - "NODE DETAILS\n", - "############################################################\n", - "\n" - ] - } - ], - "source": [ - "from matgraphdb import MaterialStore\n", - "\n", - "materials_store = MaterialStore(storage_path=MATERIALS_PATH)\n", - "print(materials_store)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialize a MatGraphDB Instance" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - "GRAPH DATABASE SUMMARY\n", - "============================================================\n", - "Name: MatGraphDB\n", - "Storage path: ..\\..\\data\\examples\\01\\MatGraphDB\n", - "└── Repository structure:\n", - " ├── nodes/ (..\\..\\data\\examples\\01\\MatGraphDB\\nodes)\n", - " ├── edges/ (..\\..\\data\\examples\\01\\MatGraphDB\\edges)\n", - " ├── edge_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\edge_generators)\n", - " ├── node_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\node_generators)\n", - " └── graph/ (..\\..\\data\\examples\\01\\MatGraphDB\\graph)\n", - "\n", - "############################################################\n", - "NODE DETAILS\n", - "############################################################\n", - "Total node types: 1\n", - "------------------------------------------------------------\n", - "• Node type: material\n", - " - Number of nodes: 80643\n", - " - Number of features: 136\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE DETAILS\n", - "############################################################\n", - "Total edge types: 0\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "NODE GENERATOR DETAILS\n", - "############################################################\n", - "Total node generators: 0\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE GENERATOR DETAILS\n", - "############################################################\n", - "Total edge generators: 0\n", - "------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "\n", - "from matgraphdb import MatGraphDB\n", - "\n", - "if not os.path.exists(MATGRAPHDB_PATH):\n", - " shutil.rmtree(MATGRAPHDB_PATH)\n", - "mdb = MatGraphDB(storage_path=MATGRAPHDB_PATH,materials_store=materials_store)\n", - "\n", - "print(mdb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Adding Nodes\n", - "\n", - "In this section, we will add the nodes to the MatGraphDB instance. We will be using some of the built-in node generators to add the nodes to the MatGraphDB instance." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from matgraphdb.materials.nodes import (\n", - " element, chemenv, crystal_system, magnetic_state, \n", - " oxidation_state, space_group, wyckoff, material_site, material_lattice\n", - ")\n", - "\n", - "# Here we define the generator functions and arguments if they are needed. \n", - "# For instance, to get the materials sites and lattices, we need to pass the materials store to the generator function.\n", - "node_generators = [\n", - " {\"generator_func\": element},\n", - " {\"generator_func\": chemenv},\n", - " {\"generator_func\": crystal_system},\n", - " {\"generator_func\": magnetic_state},\n", - " {\"generator_func\": oxidation_state},\n", - " {\"generator_func\": space_group},\n", - " {\"generator_func\": wyckoff},\n", - " {\n", - " \"generator_func\": material_site,\n", - " \"generator_args\": {\"material_store\": mdb.node_stores[\"material\"]},\n", - " },\n", - " {\n", - " \"generator_func\": material_lattice,\n", - " \"generator_args\": {\"material_store\": mdb.node_stores[\"material\"]},\n", - " },\n", - "]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can add the node generators to the MatGraphDB instance. When we add the generator, it will immediately execute and add the nodes to the database." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Adding node generator: element\n", - "Adding node generator: chemenv\n", - "Adding node generator: crystal_system\n", - "Adding node generator: magnetic_state\n", - "Adding node generator: oxidation_state\n", - "Adding node generator: space_group\n", - "Adding node generator: wyckoff\n", - "Adding node generator: material_site\n", - "Adding node generator: material_lattice\n", - "Node generators have been initialized.\n", - "============================================================\n", - "GRAPH DATABASE SUMMARY\n", - "============================================================\n", - "Name: MatGraphDB\n", - "Storage path: ..\\..\\data\\examples\\01\\MatGraphDB\n", - "└── Repository structure:\n", - " ├── nodes/ (..\\..\\data\\examples\\01\\MatGraphDB\\nodes)\n", - " ├── edges/ (..\\..\\data\\examples\\01\\MatGraphDB\\edges)\n", - " ├── edge_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\edge_generators)\n", - " ├── node_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\node_generators)\n", - " └── graph/ (..\\..\\data\\examples\\01\\MatGraphDB\\graph)\n", - "\n", - "############################################################\n", - "NODE DETAILS\n", - "############################################################\n", - "Total node types: 10\n", - "------------------------------------------------------------\n", - "• Node type: material\n", - " - Number of nodes: 80643\n", - " - Number of features: 136\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "------------------------------------------------------------\n", - "• Node type: element\n", - " - Number of nodes: 118\n", - " - Number of features: 99\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - "------------------------------------------------------------\n", - "• Node type: chemenv\n", - " - Number of nodes: 67\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - "------------------------------------------------------------\n", - "• Node type: crystal_system\n", - " - Number of nodes: 7\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - "------------------------------------------------------------\n", - "• Node type: magnetic_state\n", - " - Number of nodes: 5\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\magnetic_state\n", - "------------------------------------------------------------\n", - "• Node type: oxidation_state\n", - " - Number of nodes: 19\n", - " - Number of features: 3\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\oxidation_state\n", - "------------------------------------------------------------\n", - "• Node type: space_group\n", - " - Number of nodes: 230\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "------------------------------------------------------------\n", - "• Node type: wyckoff\n", - " - Number of nodes: 1380\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\wyckoff\n", - "------------------------------------------------------------\n", - "• Node type: material_site\n", - " - Number of nodes: 2545026\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_site\n", - "------------------------------------------------------------\n", - "• Node type: material_lattice\n", - " - Number of nodes: 80643\n", - " - Number of features: 12\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_lattice\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE DETAILS\n", - "############################################################\n", - "Total edge types: 0\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "NODE GENERATOR DETAILS\n", - "############################################################\n", - "Total node generators: 9\n", - "------------------------------------------------------------\n", - "• Generator: element\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - " - generator_name: ['element']\n", - " - id: [0]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: chemenv\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - " - generator_name: ['chemenv']\n", - " - id: [1]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: crystal_system\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['crystal_system']\n", - " - id: [2]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: magnetic_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['magnetic_state']\n", - " - id: [3]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: oxidation_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['oxidation_state']\n", - " - id: [4]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: space_group\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['space_group']\n", - " - id: [5]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: wyckoff\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['wyckoff']\n", - " - id: [6]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_site\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_site']\n", - " - id: [7]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_lattice\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_lattice']\n", - " - id: [8]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE GENERATOR DETAILS\n", - "############################################################\n", - "Total edge generators: 0\n", - "------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "# Add each node generator to the database.\n", - "for generator in node_generators:\n", - " generator_func = generator.get(\"generator_func\")\n", - " generator_args = generator.get(\"generator_args\", None)\n", - " print(f\"Adding node generator: {generator_func.__name__}\")\n", - " mdb.add_node_generator(generator_func=generator_func, generator_args=generator_args)\n", - "\n", - "print(\"Node generators have been initialized.\")\n", - "\n", - "print(mdb)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Adding Edges\n", - "\n", - "In this section, we will add the edges to the MatGraphDB instance. We will be using some of the built-in edge generators to add the edges to the MatGraphDB instance." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Adding edge generator: element_element_neighborsByGroupPeriod\n", - "Adding edge generator: element_oxiState_canOccur\n", - "Adding edge generator: material_chemenv_containsSite\n", - "Adding edge generator: material_crystalSystem_has\n", - "Adding edge generator: material_element_has\n", - "Adding edge generator: material_lattice_has\n", - "Adding edge generator: material_spg_has\n", - "Adding edge generator: element_chemenv_canOccur\n", - "Adding edge generator: spg_crystalSystem_isApart\n", - "Adding edge generator: element_element_bonds\n", - "Edge generators have been initialized.\n", - "============================================================\n", - "GRAPH DATABASE SUMMARY\n", - "============================================================\n", - "Name: MatGraphDB\n", - "Storage path: ..\\..\\data\\examples\\01\\MatGraphDB\n", - "└── Repository structure:\n", - " ├── nodes/ (..\\..\\data\\examples\\01\\MatGraphDB\\nodes)\n", - " ├── edges/ (..\\..\\data\\examples\\01\\MatGraphDB\\edges)\n", - " ├── edge_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\edge_generators)\n", - " ├── node_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\node_generators)\n", - " └── graph/ (..\\..\\data\\examples\\01\\MatGraphDB\\graph)\n", - "\n", - "############################################################\n", - "NODE DETAILS\n", - "############################################################\n", - "Total node types: 10\n", - "------------------------------------------------------------\n", - "• Node type: material\n", - " - Number of nodes: 80643\n", - " - Number of features: 136\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "------------------------------------------------------------\n", - "• Node type: element\n", - " - Number of nodes: 118\n", - " - Number of features: 99\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - "------------------------------------------------------------\n", - "• Node type: chemenv\n", - " - Number of nodes: 67\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - "------------------------------------------------------------\n", - "• Node type: crystal_system\n", - " - Number of nodes: 7\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - "------------------------------------------------------------\n", - "• Node type: magnetic_state\n", - " - Number of nodes: 5\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\magnetic_state\n", - "------------------------------------------------------------\n", - "• Node type: oxidation_state\n", - " - Number of nodes: 19\n", - " - Number of features: 3\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\oxidation_state\n", - "------------------------------------------------------------\n", - "• Node type: space_group\n", - " - Number of nodes: 230\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "------------------------------------------------------------\n", - "• Node type: wyckoff\n", - " - Number of nodes: 1380\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\wyckoff\n", - "------------------------------------------------------------\n", - "• Node type: material_site\n", - " - Number of nodes: 2545026\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_site\n", - "------------------------------------------------------------\n", - "• Node type: material_lattice\n", - " - Number of nodes: 80643\n", - " - Number of features: 12\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_lattice\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE DETAILS\n", - "############################################################\n", - "Total edge types: 10\n", - "------------------------------------------------------------\n", - "• Edge type: element_element_neighborsByGroupPeriod\n", - " - Number of edges: 391\n", - " - Number of features: 14\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_element_neighborsByGroupPeriod\n", - "------------------------------------------------------------\n", - "• Edge type: element_oxiState_canOccur\n", - " - Number of edges: 162\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_oxiState_canOccur\n", - "------------------------------------------------------------\n", - "• Edge type: material_chemenv_containsSite\n", - " - Number of edges: 2542897\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_chemenv_containsSite\n", - "------------------------------------------------------------\n", - "• Edge type: material_crystalSystem_has\n", - " - Number of edges: 80643\n", - " - Number of features: 10\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_crystalSystem_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_element_has\n", - " - Number of edges: 270902\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_element_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_lattice_has\n", - " - Number of edges: 80643\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_lattice_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_spg_has\n", - " - Number of edges: 80643\n", - " - Number of features: 10\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_spg_has\n", - "------------------------------------------------------------\n", - "• Edge type: element_chemenv_canOccur\n", - " - Number of edges: 270474\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_chemenv_canOccur\n", - "------------------------------------------------------------\n", - "• Edge type: spg_crystalSystem_isApart\n", - " - Number of edges: 230\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\spg_crystalSystem_isApart\n", - "------------------------------------------------------------\n", - "• Edge type: element_element_bonds\n", - " - Number of edges: 3069943\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_element_bonds\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "NODE GENERATOR DETAILS\n", - "############################################################\n", - "Total node generators: 9\n", - "------------------------------------------------------------\n", - "• Generator: element\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - " - generator_name: ['element']\n", - " - id: [0]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: chemenv\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - " - generator_name: ['chemenv']\n", - " - id: [1]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: crystal_system\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['crystal_system']\n", - " - id: [2]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: magnetic_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['magnetic_state']\n", - " - id: [3]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: oxidation_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['oxidation_state']\n", - " - id: [4]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: space_group\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['space_group']\n", - " - id: [5]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: wyckoff\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['wyckoff']\n", - " - id: [6]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_site\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_site']\n", - " - id: [7]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_lattice\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_lattice']\n", - " - id: [8]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE GENERATOR DETAILS\n", - "############################################################\n", - "Total edge generators: 10\n", - "------------------------------------------------------------\n", - "• Generator: element_oxiState_canOccur\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - oxiState_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\oxidation_state\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_chemenv_containsSite\n", - "Generator Args:\n", - " - chemenv_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_crystalSystem_has\n", - "Generator Args:\n", - " - crystal_system_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_element_neighborsByGroupPeriod\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_element_has\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_lattice_has\n", - "Generator Args:\n", - " - lattice_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_lattice\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_spg_has\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - spg_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_chemenv_canOccur\n", - "Generator Args:\n", - " - chemenv_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: spg_crystalSystem_isApart\n", - "Generator Args:\n", - " - crystal_system_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - " - spg_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_element_bonds\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "from matgraphdb.materials.edges import (\n", - " material_element_has,\n", - " material_lattice_has,\n", - " material_spg_has,\n", - " element_element_neighborsByGroupPeriod,\n", - " element_element_bonds,\n", - " element_oxiState_canOccur,\n", - " material_chemenv_containsSite,\n", - " material_crystalSystem_has,\n", - " element_chemenv_canOccur,\n", - " spg_crystalSystem_isApart,\n", - ")\n", - "\n", - "\n", - "\n", - "# List of edge generator configurations.\n", - "edge_generators = [\n", - " {\n", - " \"generator_func\": element_element_neighborsByGroupPeriod,\n", - " \"generator_args\": {\"element_store\": mdb.node_stores[\"element\"]},\n", - " },\n", - " {\n", - " \"generator_func\": element_oxiState_canOccur,\n", - " \"generator_args\": {\n", - " \"element_store\": mdb.node_stores[\"element\"],\n", - " \"oxiState_store\": mdb.node_stores[\"oxidation_state\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": material_chemenv_containsSite,\n", - " \"generator_args\": {\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " \"chemenv_store\": mdb.node_stores[\"chemenv\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": material_crystalSystem_has,\n", - " \"generator_args\": {\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " \"crystal_system_store\": mdb.node_stores[\"crystal_system\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": material_element_has,\n", - " \"generator_args\": {\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " \"element_store\": mdb.node_stores[\"element\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": material_lattice_has,\n", - " \"generator_args\": {\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " \"lattice_store\": mdb.node_stores[\"material_lattice\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": material_spg_has,\n", - " \"generator_args\": {\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " \"spg_store\": mdb.node_stores[\"space_group\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": element_chemenv_canOccur,\n", - " \"generator_args\": {\n", - " \"element_store\": mdb.node_stores[\"element\"],\n", - " \"chemenv_store\": mdb.node_stores[\"chemenv\"],\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": spg_crystalSystem_isApart,\n", - " \"generator_args\": {\n", - " \"spg_store\": mdb.node_stores[\"space_group\"],\n", - " \"crystal_system_store\": mdb.node_stores[\"crystal_system\"],\n", - " },\n", - " },\n", - " {\n", - " \"generator_func\": element_element_bonds,\n", - " \"generator_args\": {\n", - " \"element_store\": mdb.node_stores[\"element\"],\n", - " \"material_store\": mdb.node_stores[\"material\"],\n", - " },\n", - " },\n", - "]\n", - "\n", - "\n", - "# Add each edge generator to the database and run them immediately.\n", - "for generator in edge_generators:\n", - " generator_func = generator.get(\"generator_func\")\n", - " generator_args = generator.get(\"generator_args\", None)\n", - " print(f\"Adding edge generator: {generator_func.__name__}\")\n", - " mdb.add_edge_generator(generator_func=generator_func, generator_args=generator_args, run_immediately=True)\n", - "\n", - "print(\"Edge generators have been initialized.\")\n", - "print(mdb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Verifying the Database\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - "GRAPH DATABASE SUMMARY\n", - "============================================================\n", - "Name: MatGraphDB\n", - "Storage path: ..\\..\\data\\examples\\01\\MatGraphDB\n", - "└── Repository structure:\n", - " ├── nodes/ (..\\..\\data\\examples\\01\\MatGraphDB\\nodes)\n", - " ├── edges/ (..\\..\\data\\examples\\01\\MatGraphDB\\edges)\n", - " ├── edge_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\edge_generators)\n", - " ├── node_generators/ (..\\..\\data\\examples\\01\\MatGraphDB\\node_generators)\n", - " └── graph/ (..\\..\\data\\examples\\01\\MatGraphDB\\graph)\n", - "\n", - "############################################################\n", - "NODE DETAILS\n", - "############################################################\n", - "Total node types: 10\n", - "------------------------------------------------------------\n", - "• Node type: material\n", - " - Number of nodes: 80643\n", - " - Number of features: 136\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "------------------------------------------------------------\n", - "• Node type: element\n", - " - Number of nodes: 118\n", - " - Number of features: 99\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - "------------------------------------------------------------\n", - "• Node type: chemenv\n", - " - Number of nodes: 67\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - "------------------------------------------------------------\n", - "• Node type: crystal_system\n", - " - Number of nodes: 7\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - "------------------------------------------------------------\n", - "• Node type: magnetic_state\n", - " - Number of nodes: 5\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\magnetic_state\n", - "------------------------------------------------------------\n", - "• Node type: oxidation_state\n", - " - Number of nodes: 19\n", - " - Number of features: 3\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\oxidation_state\n", - "------------------------------------------------------------\n", - "• Node type: space_group\n", - " - Number of nodes: 230\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "------------------------------------------------------------\n", - "• Node type: wyckoff\n", - " - Number of nodes: 1380\n", - " - Number of features: 2\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\wyckoff\n", - "------------------------------------------------------------\n", - "• Node type: material_site\n", - " - Number of nodes: 2545026\n", - " - Number of features: 15\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_site\n", - "------------------------------------------------------------\n", - "• Node type: material_lattice\n", - " - Number of nodes: 80643\n", - " - Number of features: 12\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_lattice\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE DETAILS\n", - "############################################################\n", - "Total edge types: 10\n", - "------------------------------------------------------------\n", - "• Edge type: element_element_neighborsByGroupPeriod\n", - " - Number of edges: 391\n", - " - Number of features: 14\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_element_neighborsByGroupPeriod\n", - "------------------------------------------------------------\n", - "• Edge type: element_oxiState_canOccur\n", - " - Number of edges: 162\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_oxiState_canOccur\n", - "------------------------------------------------------------\n", - "• Edge type: material_chemenv_containsSite\n", - " - Number of edges: 2542897\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_chemenv_containsSite\n", - "------------------------------------------------------------\n", - "• Edge type: material_crystalSystem_has\n", - " - Number of edges: 80643\n", - " - Number of features: 10\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_crystalSystem_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_element_has\n", - " - Number of edges: 270902\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_element_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_lattice_has\n", - " - Number of edges: 80643\n", - " - Number of features: 8\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_lattice_has\n", - "------------------------------------------------------------\n", - "• Edge type: material_spg_has\n", - " - Number of edges: 80643\n", - " - Number of features: 10\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\material_spg_has\n", - "------------------------------------------------------------\n", - "• Edge type: element_chemenv_canOccur\n", - " - Number of edges: 270474\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_chemenv_canOccur\n", - "------------------------------------------------------------\n", - "• Edge type: spg_crystalSystem_isApart\n", - " - Number of edges: 230\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\spg_crystalSystem_isApart\n", - "------------------------------------------------------------\n", - "• Edge type: element_element_bonds\n", - " - Number of edges: 3069943\n", - " - Number of features: 7\n", - " - db_path: ..\\..\\data\\examples\\01\\MatGraphDB\\edges\\element_element_bonds\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "NODE GENERATOR DETAILS\n", - "############################################################\n", - "Total node generators: 9\n", - "------------------------------------------------------------\n", - "• Generator: element\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - " - generator_name: ['element']\n", - " - id: [0]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\imputed_periodic_table_values.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: chemenv\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_kwargs.base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - " - generator_name: ['chemenv']\n", - " - id: [1]\n", - "Generator Kwargs:\n", - " - base_file: ['C:\\\\Users\\\\lllang\\\\Desktop\\\\Current_Projects\\\\MatGraphDB\\\\matgraphdb\\\\utils\\\\chem_utils\\\\resources\\\\coordination_geometries.parquet']\n", - "------------------------------------------------------------\n", - "• Generator: crystal_system\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['crystal_system']\n", - " - id: [2]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: magnetic_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['magnetic_state']\n", - " - id: [3]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: oxidation_state\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['oxidation_state']\n", - " - id: [4]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: space_group\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['space_group']\n", - " - id: [5]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: wyckoff\n", - "Generator Args:\n", - " - generator_func: []\n", - " - generator_name: ['wyckoff']\n", - " - id: [6]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_site\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_site']\n", - " - id: [7]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_lattice\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - generator_func: []\n", - " - generator_name: ['material_lattice']\n", - " - id: [8]\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "\n", - "############################################################\n", - "EDGE GENERATOR DETAILS\n", - "############################################################\n", - "Total edge generators: 10\n", - "------------------------------------------------------------\n", - "• Generator: element_oxiState_canOccur\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - oxiState_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\oxidation_state\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_chemenv_containsSite\n", - "Generator Args:\n", - " - chemenv_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_crystalSystem_has\n", - "Generator Args:\n", - " - crystal_system_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_element_neighborsByGroupPeriod\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_element_has\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_lattice_has\n", - "Generator Args:\n", - " - lattice_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material_lattice\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: material_spg_has\n", - "Generator Args:\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - " - spg_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_chemenv_canOccur\n", - "Generator Args:\n", - " - chemenv_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\chemenv\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: spg_crystalSystem_isApart\n", - "Generator Args:\n", - " - crystal_system_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\crystal_system\n", - " - spg_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\space_group\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "• Generator: element_element_bonds\n", - "Generator Args:\n", - " - element_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\element\n", - " - material_store: ..\\..\\data\\examples\\01\\MatGraphDB\\nodes\\material\n", - "Generator Kwargs:\n", - "------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "print(mdb)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "matgraphdb_dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.21" - }, - "nbsphinx": { - "execute": "never" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/notebooks/01 - Getting Started.ipynb b/examples/notebooks/01 - Getting Started.ipynb index 42468bb..bfae0c2 100644 --- a/examples/notebooks/01 - Getting Started.ipynb +++ b/examples/notebooks/01 - Getting Started.ipynb @@ -284,7 +284,7 @@ "metadata": {}, "outputs": [], "source": [ - "from matgraphdb.materials.nodes import (\n", + "from matgraphdb.core.nodes import (\n", " element, chemenv, crystal_system, magnetic_state, \n", " oxidation_state, space_group, wyckoff, material_site, material_lattice\n", ")\n", @@ -805,7 +805,7 @@ } ], "source": [ - "from matgraphdb.materials.edges import (\n", + "from matgraphdb.core.edges import (\n", " material_element_has,\n", " material_lattice_has,\n", " material_spg_has,\n", diff --git a/matgraphdb/__init__.py b/matgraphdb/__init__.py index 4d44a1d..efcf75e 100644 --- a/matgraphdb/__init__.py +++ b/matgraphdb/__init__.py @@ -1,11 +1,3 @@ from matgraphdb._version import __version__ -from matgraphdb.core import ( - EdgeStore, - GeneratorStore, - GraphDB, - NodeStore, - edge_generator, - node_generator, -) -from matgraphdb.materials import MaterialStore, MatGraphDB +from matgraphdb.core import MaterialStore, MatGraphDB from matgraphdb.utils.config import PKG_DIR, config diff --git a/matgraphdb/core/__init__.py b/matgraphdb/core/__init__.py index 66a11bd..cd96b05 100644 --- a/matgraphdb/core/__init__.py +++ b/matgraphdb/core/__init__.py @@ -1,4 +1,2 @@ -from matgraphdb.core.edges import EdgeStore, edge_generator -from matgraphdb.core.generator_store import GeneratorStore -from matgraphdb.core.graph_db import GraphDB -from matgraphdb.core.nodes import NodeStore, node_generator +from matgraphdb.core.matgraphdb import MatGraphDB +from matgraphdb.core.nodes import MaterialStore diff --git a/matgraphdb/core/datasets/__init__.py b/matgraphdb/core/datasets/__init__.py new file mode 100644 index 0000000..aa0e707 --- /dev/null +++ b/matgraphdb/core/datasets/__init__.py @@ -0,0 +1 @@ +from matgraphdb.core.datasets.mp_near_hull import MPNearHull diff --git a/matgraphdb/materials/datasets/mp_near_hull.py b/matgraphdb/core/datasets/mp_near_hull.py similarity index 97% rename from matgraphdb/materials/datasets/mp_near_hull.py rename to matgraphdb/core/datasets/mp_near_hull.py index 9ce909e..c857231 100644 --- a/matgraphdb/materials/datasets/mp_near_hull.py +++ b/matgraphdb/core/datasets/mp_near_hull.py @@ -3,9 +3,9 @@ from huggingface_hub import snapshot_download -from matgraphdb.materials import MatGraphDB -from matgraphdb.materials.edges import * -from matgraphdb.materials.nodes import * +from matgraphdb.core import MatGraphDB +from matgraphdb.core.edges import * +from matgraphdb.core.nodes import * from matgraphdb.utils.config import config MPNEARHULL_PATH = os.path.join(config.data_dir, "datasets", "MPNearHull") diff --git a/matgraphdb/core/edges.py b/matgraphdb/core/edges.py index 477c3be..e447842 100644 --- a/matgraphdb/core/edges.py +++ b/matgraphdb/core/edges.py @@ -1,441 +1,715 @@ import logging -import os -from functools import wraps -from typing import Dict, List, Union +import shutil +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc -from parquetdb import ParquetDB -from parquetdb.core.parquetdb import LoadConfig, NormalizeConfig +from parquetdb import EdgeStore, NodeStore, ParquetDB, edge_generator +from parquetdb.utils import pyarrow_utils -from matgraphdb.core.utils import get_dataframe_column_names +from matgraphdb.utils.chem_utils.periodic import get_group_period_edge_index logger = logging.getLogger(__name__) -REQUIRED_EDGE_COLUMNS_FIELDS = set( - ["source_id", "source_type", "target_id", "target_type", "edge_type"] -) - - -def validate_edge_dataframe(df): - column_names = get_dataframe_column_names(df) - fields = set(column_names) - missing_fields = REQUIRED_EDGE_COLUMNS_FIELDS - fields - if missing_fields: - raise ValueError( - f"Edge dataframe is missing required fields: {missing_fields}. Edge dataframe must contain the following columns: {REQUIRED_EDGE_COLUMNS_FIELDS}" - ) - return df - - -def edge_generator(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Perform pre-execution checks - logger.debug(f"Executing {func.__name__} with args: {args}, kwargs: {kwargs}") - df = func(*args, **kwargs) - validate_edge_dataframe(df) - return df - - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - return wrapper - - -class EdgeStore(ParquetDB): - """ - A wrapper around ParquetDB specifically for storing edge features - of a given edge type. - """ - - required_fields = REQUIRED_EDGE_COLUMNS_FIELDS - edge_metadata_keys = ["class", "class_module"] - - def __init__(self, storage_path: str, setup_kwargs: dict = None): - """ - Parameters - ---------- - storage_path : str - The path where ParquetDB files for this edge type are stored. - """ - - super().__init__( - db_path=storage_path, - initial_fields=[ - pa.field("source_id", pa.int64()), - pa.field("source_type", pa.string()), - pa.field("target_id", pa.int64()), - pa.field("target_type", pa.string()), - pa.field("edge_type", pa.string()), - ], - ) - - self._initialize_metadata() - self._initialize_field_metadata() - - logger.debug(f"Initialized EdgeStore at {storage_path}") - if self.is_empty(): - if setup_kwargs is None: - setup_kwargs = {} - self._setup(**setup_kwargs) - - def __repr__(self): - return self.summary(show_column_names=False) - - @property - def storage_path(self): - return self._db_path - - @storage_path.setter - def storage_path(self, value): - self._db_path = value - self.edge_type = os.path.basename(value) - - @property - def edge_type(self): - return os.path.basename(self.storage_path) - - @edge_type.setter - def edge_type(self, value): - self._edge_type = value - - @property - def n_edges(self): - return self.read_edges(columns=["id"]).num_rows - - @property - def n_features(self): - return len(self.get_schema().names) - - @property - def columns(self): - return self.get_schema().names - - def summary(self, show_column_names: bool = False): - fields_metadata = self.get_field_metadata() - metadata = self.get_metadata() - # Header section - tmp_str = f"{'=' * 60}\n" - tmp_str += f"EDGE STORE SUMMARY\n" - tmp_str += f"{'=' * 60}\n" - tmp_str += f"Edge type: {self.edge_type}\n" - tmp_str += f"• Number of edges: {self.n_edges}\n" - tmp_str += f"• Number of features: {self.n_features}\n" - tmp_str += f"Storage path: {os.path.relpath(self.storage_path)}\n\n" - - # Metadata section - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"METADATA\n" - tmp_str += f"{'#' * 60}\n" - for key, value in metadata.items(): - tmp_str += f"• {key}: {value}\n" - - # Node details - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"EDGE DETAILS\n" - tmp_str += f"{'#' * 60}\n" - if show_column_names: - tmp_str += f"• Columns:\n" - for col in self.columns: - tmp_str += f" - {col}\n" - - if fields_metadata[col]: - tmp_str += f" - Field metadata\n" - for key, value in fields_metadata[col].items(): - tmp_str += f" - {key}: {value}\n" - - return tmp_str - - def _setup(self, **kwargs): - data = self.setup(**kwargs) - if data is not None: - self.create_edges(data=data) - self.set_metadata(kwargs) - - def setup(self, **kwargs): - return None - - def _initialize_metadata(self, **kwargs): - metadata = self.get_metadata() - update_metadata = False - for key in self.edge_metadata_keys: - if key not in metadata: - update_metadata = update_metadata or key not in metadata - - if update_metadata: - self.set_metadata( - { - "class": f"{self.__class__.__name__}", - "class_module": f"{self.__class__.__module__}", - } - ) - - def _initialize_field_metadata(self, **kwargs): - pass - - def create_edges( - self, - data: Union[List[dict], dict, pd.DataFrame], - schema: pa.Schema = None, - metadata: dict = None, - fields_metadata: dict = None, - treat_fields_as_ragged: List[str] = None, - convert_to_fixed_shape: bool = True, - normalize_dataset: bool = False, - normalize_config: dict = NormalizeConfig(), - ): - """ - Adds new data to the database. - - Parameters - ---------- - data : dict, list of dict, or pandas.DataFrame - The data to be added to the database. - schema : pyarrow.Schema, optional - The schema for the incoming data. - metadata : dict, optional - Metadata to be attached to the table. - fields_metadata : dict, optional - A dictionary containing the metadata to be set for the fields. - normalize_dataset : bool, optional - If True, the dataset will be normalized after the data is added (default is True). - treat_fields_as_ragged : list of str, optional - A list of fields to treat as ragged arrays. - convert_to_fixed_shape : bool, optional - If True, the ragged arrays will be converted to fixed shape arrays. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - Examples - -------- - >>> db.create_nodes(data=my_data, schema=my_schema, metadata={'source': 'api'}, normalize_dataset=True) - """ - create_kwargs = dict( - data=data, - schema=schema, - metadata=metadata, - fields_metadata=fields_metadata, - treat_fields_as_ragged=treat_fields_as_ragged, - convert_to_fixed_shape=convert_to_fixed_shape, - normalize_dataset=normalize_dataset, - normalize_config=normalize_config, - ) - self.create(**create_kwargs) - - def create(self, **kwargs): - logger.debug(f"Creating edges") - - if not self.validate_edges(kwargs["data"]): - logger.error("Edge data validation failed - missing required fields") - raise ValueError( - "Edge data is missing required fields. Must include: " - + ", ".join(EdgeStore.required_fields) - ) - - super().create(**kwargs) - - logger.info(f"Successfully created edges") - - def read_edges( - self, - ids: List[int] = None, - columns: List[str] = None, - filters: List[pc.Expression] = None, - load_format: str = "table", - batch_size: int = None, - include_cols: bool = True, - rebuild_nested_struct: bool = False, - rebuild_nested_from_scratch: bool = False, - load_config: LoadConfig = LoadConfig(), - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Reads data from the database. - - Parameters - ---------- - - ids : list of int, optional - A list of IDs to read. If None, all data is read (default is None). - columns : list of str, optional - The columns to include in the output. If None, all columns are included (default is None). - filters : list of pyarrow.compute.Expression, optional - Filters to apply to the data (default is None). - load_format : str, optional - The format of the returned data: 'table' or 'batches' (default is 'table'). - batch_size : int, optional - The batch size to use for loading data in batches. If None, data is loaded as a whole (default is None). - include_cols : bool, optional - If True, includes only the specified columns. If False, excludes the specified columns (default is True). - rebuild_nested_struct : bool, optional - If True, rebuilds the nested structure (default is False). - rebuild_nested_from_scratch : bool, optional - If True, rebuilds the nested structure from scratch (default is False). - load_config : LoadConfig, optional - Configuration for loading data, optimizing performance by managing memory usage. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Returns - ------- - pa.Table, generator, or dataset - The data read from the database. The output can be in table format or as a batch generator. - - Examples - -------- - >>> data = db.read_edges(ids=[1, 2, 3], columns=['name', 'age'], filters=[pc.field('age') > 18]) - """ - logger.debug(f"Reading edges with ids: {ids}, columns: {columns}") - - read_kwargs = dict( - ids=ids, - columns=columns, - filters=filters, - load_format=load_format, - batch_size=batch_size, - include_cols=include_cols, - rebuild_nested_struct=rebuild_nested_struct, - rebuild_nested_from_scratch=rebuild_nested_from_scratch, - load_config=load_config, - normalize_config=normalize_config, - ) - return self.read(**read_kwargs) - - def update(self, **kwargs): - logger.debug(f"Updating edges") - # if not self.validate_edges(kwargs["data"]): - # logger.error("Edge data validation failed - missing required fields") - # raise ValueError( - # "Edge data is missing required fields. Must include: " - # + ", ".join(EdgeStore.required_fields) - # ) - - super().update(**kwargs) - logger.info("Successfully updated edges") - - def update_edges( - self, - data: Union[List[dict], dict, pd.DataFrame], - schema: pa.Schema = None, - metadata: dict = None, - fields_metadata: dict = None, - update_keys: Union[List[str], str] = "id", - treat_fields_as_ragged=None, - convert_to_fixed_shape: bool = True, - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Updates existing records in the database. - - Parameters - ---------- - data : dict, list of dicts, or pandas.DataFrame - The data to be updated in the database. Each record must contain an 'id' key - corresponding to the record to be updated. - schema : pyarrow.Schema, optional - The schema for the data being added. If not provided, it will be inferred. - metadata : dict, optional - Additional metadata to store alongside the data. - fields_metadata : dict, optional - A dictionary containing the metadata to be set for the fields. - update_keys : list of str or str, optional - The keys to use for updating the data. If a list, the data must contain a value for each key. - treat_fields_as_ragged : list of str, optional - A list of fields to treat as ragged arrays. - convert_to_fixed_shape : bool, optional - If True, the ragged arrays will be converted to fixed shape arrays. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Examples - -------- - >>> db.update(data=[{'id': 1, 'name': 'John', 'age': 30}, {'id': 2, 'name': 'Jane', 'age': 25}]) - """ - logger.debug(f"Updating edges") - - # if not self.validate_edges(data): - # logger.error("Edge data validation failed - missing required fields") - # raise ValueError( - # "Edge data is missing required fields. Must include: " - # + ", ".join(EdgeStore.required_fields) - # ) - - update_kwargs = dict( - data=data, - schema=schema, - metadata=metadata, - update_keys=update_keys, - treat_fields_as_ragged=treat_fields_as_ragged, - convert_to_fixed_shape=convert_to_fixed_shape, - normalize_config=normalize_config, - ) - self.update(**update_kwargs) - logger.info("Successfully updated edges") - - def delete_edges( - self, - ids: List[int] = None, - columns: List[str] = None, - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Deletes records from the database. - - Parameters - ---------- - ids : list of int - A list of record IDs to delete from the database. - columns : list of str, optional - A list of column names to delete from the dataset. If not provided, it will be inferred from the existing data (default: None). - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Returns - ------- - None - - Examples - -------- - >>> db.delete(ids=[1, 2, 3]) - """ - logger.debug(f"Deleting edges with ids: {ids}, columns: {columns}") - self.delete(ids=ids, columns=columns, normalize_config=normalize_config) - logger.info(f"Successfully deleted edges") - - def normalize_edges(self, normalize_config: NormalizeConfig = NormalizeConfig()): - """ - Triggers file restructuring and compaction to optimize edge storage. - """ - logger.info("Starting edge store normalization") - self.normalize(normalize_config=normalize_config) - logger.info("Completed edge store normalization") - - def validate_edges( - self, data: Union[List[dict], dict, pd.DataFrame, pa.Table, pa.RecordBatch] - ): - """ - Validates the edges to ensure they contain the required fields. - """ - logger.debug("Validating edge data") - - data = ParquetDB.construct_table(data) - - if isinstance(data, pa.Table) or isinstance(data, pa.RecordBatch): - fields = data.schema.names - else: - logger.error(f"Invalid data type for edge validation: {type(data)}") - raise ValueError("Invalid data type for edge validation") - - is_valid = True - missing_fields = [] - for required_field in EdgeStore.required_fields: - if required_field not in fields: - is_valid = False - missing_fields.append(required_field) - - if not is_valid: - logger.warning(f"Edge validation failed. Missing fields: {missing_fields}") - else: - logger.debug("Edge validation successful") - - return is_valid + +@edge_generator +def element_element_neighborsByGroupPeriod(element_store): + + try: + connection_name = "neighborsByGroupPeriod" + table = element_store.read_nodes( + columns=["atomic_number", "extended_group", "period", "symbol"] + ) + element_df = table.to_pandas(split_blocks=True, self_destruct=True) + + # Getting group-period edge index + edge_index = get_group_period_edge_index(element_df) + + # Creating the relationships dataframe + df = pd.DataFrame(edge_index, columns=[f"source_id", f"target_id"]) + + # Dropping rows with NaN values and casting to int64 + df = df.dropna().astype(np.int64) + + # Add source and target type columns + df["source_type"] = element_store.node_type + df["target_type"] = element_store.node_type + df["edge_type"] = connection_name + df["weight"] = 1.0 + + table = ParquetDB.construct_table(df) + + reduced_table = element_store.read( + columns=["symbol", "id", "extended_group", "period"] + ) + reduced_source_table = reduced_table.rename_columns( + { + "symbol": "source_name", + "extended_group": "source_extended_group", + "period": "source_period", + } + ) + reduced_target_table = reduced_table.rename_columns( + { + "symbol": "target_name", + "extended_group": "target_extended_group", + "period": "target_period", + } + ) + + table = pyarrow_utils.join_tables( + table, + reduced_source_table, + left_keys=["source_id"], + right_keys=["id"], + join_type="left outer", + ) + + table = pyarrow_utils.join_tables( + table, + reduced_target_table, + left_keys=["target_id"], + right_keys=["id"], + join_type="left outer", + ) + + names = pc.binary_join_element_wise( + pc.cast(table["source_name"], pa.string()), + pc.cast(table["target_name"], pa.string()), + f"_{connection_name}_", + ) + + table = table.append_column("name", names) + + logger.debug( + f"Created element-group-period relationships. Shape: {table.shape}" + ) + except Exception as e: + logger.exception(f"Error creating element-group-period relationships: {e}") + raise e + + return table + + +@edge_generator +def element_element_bonds(element_store, material_store): + try: + connection_name = "canBondTo" + material_table = material_store.read_nodes( + columns=[ + "id", + "core.material_id", + "core.species", + "chemenv.coordination_environments_multi_weight", + "bonding.geometric_consistent.bond_connections", + ] + ) + + element_table = element_store.read_nodes(columns=["id", "symbol"]) + + element_table = element_table.rename_columns({"symbol": "name"}) + element_table = element_table.append_column( + "source_type", pa.array([element_store.node_type] * element_table.num_rows) + ) + + material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) + element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) + + element_target_id_map = { + row["name"]: row["id"] for _, row in element_df.iterrows() + } + + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + } + + for _, row in material_df.iterrows(): + bond_connections = row["bonding.geometric_consistent.bond_connections"] + + if bond_connections is None: + continue + + elements = row["core.species"] + element_graph = {} + for i, site_connections in enumerate(bond_connections): + site_element_name = elements[i] + for i_neighbor_element in site_connections: + i_neighbor_element = int(i_neighbor_element) + neighbor_element_name = elements[i_neighbor_element] + + source_id = element_target_id_map[site_element_name] + target_id = element_target_id_map[neighbor_element_name] + + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(element_store.node_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(element_store.node_type) + table_dict["edge_type"].append(connection_name) + + name = ( + f"{site_element_name}_{connection_name}_{neighbor_element_name}" + ) + table_dict["name"].append(name) + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created element-chemenv-canOccur relationships. Shape: {edge_table.shape}" + ) + + except Exception as e: + logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def element_oxiState_canOccur(element_store, oxiState_store): + try: + connection_name = "canOccur" + + element_table = element_store.read_nodes( + columns=["id", "experimental_oxidation_states", "symbol"] + ) + oxiState_table = oxiState_store.read_nodes( + columns=["id", "oxidation_state", "value"] + ) + + # element_table=element_table.rename_columns({'id':'source_id'}) + element_table = element_table.append_column( + "source_type", pa.array([element_store.node_type] * element_table.num_rows) + ) + + # oxiState_table=oxiState_table.rename_columns({'id':'target_id'}) + oxiState_table = oxiState_table.append_column( + "target_type", + pa.array([oxiState_store.node_type] * oxiState_table.num_rows), + ) + + element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) + oxiState_df = oxiState_table.to_pandas(split_blocks=True, self_destruct=True) + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + "weight": [], + } + + oxiState_id_map = {} + id_oxidationState_map = {} + for i, oxiState_row in oxiState_df.iterrows(): + oxiState_id_map[oxiState_row["value"]] = oxiState_row["id"] + id_oxidationState_map[oxiState_row["id"]] = oxiState_row["oxidation_state"] + + for i, element_row in element_df.iterrows(): + exp_oxidation_states = element_row["experimental_oxidation_states"] + source_id = element_row["id"] + source_type = element_store.node_type + symbol = element_row["symbol"] + for exp_oxidation_state in exp_oxidation_states: + target_id = oxiState_id_map[exp_oxidation_state] + target_type = oxiState_store.node_type + oxi_state_name = id_oxidationState_map[target_id] + + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(source_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(target_type) + table_dict["edge_type"].append(connection_name) + table_dict["weight"].append(1.0) + table_dict["name"].append( + f"{symbol}_{connection_name}_{oxi_state_name}" + ) + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created element-oxiState-canOccur relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception(f"Error creating element-oxiState-canOccur relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def material_chemenv_containsSite(material_store, chemenv_store): + try: + connection_name = "containsSite" + + material_table = material_store.read_nodes( + columns=[ + "id", + "core.material_id", + "chemenv.coordination_environments_multi_weight", + ] + ) + chemenv_table = chemenv_store.read_nodes(columns=["id", "mp_symbol"]) + + material_table = material_table.rename_columns( + {"id": "source_id", "core.material_id": "material_name"} + ) + material_table = material_table.append_column( + "source_type", + pa.array([material_store.node_type] * material_table.num_rows), + ) + + chemenv_table = chemenv_table.rename_columns( + {"id": "target_id", "mp_symbol": "chemenv_name"} + ) + chemenv_table = chemenv_table.append_column( + "target_type", pa.array([chemenv_store.node_type] * chemenv_table.num_rows) + ) + + material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) + chemenv_df = chemenv_table.to_pandas(split_blocks=True, self_destruct=True) + chemenv_target_id_map = { + row["chemenv_name"]: row["target_id"] for _, row in chemenv_df.iterrows() + } + + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + "weight": [], + } + + for _, row in material_df.iterrows(): + coord_envs = row["chemenv.coordination_environments_multi_weight"] + if coord_envs is None: + continue + + source_id = row["source_id"] + material_name = row["material_name"] + + for coord_env in coord_envs: + try: + chemenv_name = coord_env[0]["ce_symbol"] + target_id = chemenv_target_id_map[chemenv_name] + except: + continue + + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(material_store.node_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(chemenv_store.node_type) + table_dict["edge_type"].append(connection_name) + + name = f"{material_name}_{connection_name}_{chemenv_name}" + table_dict["name"].append(name) + table_dict["weight"].append(1.0) + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created material-chemenv-containsSite relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception( + f"Error creating material-chemenv-containsSite relationships: {e}" + ) + raise e + + return edge_table + + +@edge_generator +def material_crystalSystem_has(material_store, crystal_system_store): + try: + connection_name = "has" + + material_table = material_store.read_nodes( + columns=["id", "core.material_id", "symmetry.crystal_system"] + ) + crystal_system_table = crystal_system_store.read_nodes( + columns=["id", "crystal_system"] + ) + + material_table = material_table.rename_columns( + {"id": "source_id", "symmetry.crystal_system": "crystal_system"} + ) + material_table = material_table.append_column( + "source_type", + pa.array([material_store.node_type] * material_table.num_rows), + ) + + crystal_system_table = crystal_system_table.rename_columns({"id": "target_id"}) + crystal_system_table = crystal_system_table.append_column( + "target_type", + pa.array([crystal_system_store.node_type] * crystal_system_table.num_rows), + ) + + edge_table = pyarrow_utils.join_tables( + material_table, + crystal_system_table, + left_keys=["crystal_system"], + right_keys=["crystal_system"], + join_type="left outer", + ) + edge_table = edge_table.append_column( + "edge_type", pa.array([connection_name] * edge_table.num_rows) + ) + edge_table = edge_table.append_column( + "weight", pa.array([1.0] * edge_table.num_rows) + ) + + names = pc.binary_join_element_wise( + pc.cast(edge_table["core.material_id"], pa.string()), + pc.cast(edge_table["crystal_system"], pa.string()), + f"_{connection_name}_", + ) + + edge_table = edge_table.append_column("name", names) + + logger.debug( + f"Created material-crystalSystem-has relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception( + f"Error creating material-crystalSystem-has relationships: {e}" + ) + raise e + + return edge_table + + +@edge_generator +def material_element_has(material_store, element_store): + try: + connection_name = "has" + + material_table = material_store.read_nodes( + columns=["id", "core.material_id", "core.elements"] + ) + element_table = element_store.read_nodes(columns=["id", "symbol"]) + + material_table = material_table.rename_columns( + {"id": "source_id", "core.material_id": "material_name"} + ) + material_table = material_table.append_column( + "source_type", pa.array(["material"] * material_table.num_rows) + ) + + element_table = element_table.rename_columns({"id": "target_id"}) + element_table = element_table.append_column( + "target_type", pa.array(["elements"] * element_table.num_rows) + ) + + material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) + element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) + element_target_id_map = { + row["symbol"]: row["target_id"] for _, row in element_df.iterrows() + } + + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + "weight": [], + } + + for _, row in material_df.iterrows(): + elements = row["core.elements"] + source_id = row["source_id"] + material_name = row["material_name"] + if elements is None: + continue + + # Append the material name for each element in the species list + for element in elements: + + target_id = element_target_id_map[element] + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(material_store.node_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(element_store.node_type) + table_dict["edge_type"].append(connection_name) + + name = f"{material_name}_{connection_name}_{element}" + table_dict["name"].append(name) + table_dict["weight"].append(1.0) + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created material-element-has relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception(f"Error creating material-element-has relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def material_lattice_has(material_store, lattice_store): + try: + connection_name = "has" + + material_table = material_store.read_nodes(columns=["id", "core.material_id"]) + lattice_table = lattice_store.read_nodes(columns=["material_node_id"]) + + material_table = material_table.rename_columns( + {"id": "source_id", "core.material_id": "material_id"} + ) + material_table = material_table.append_column( + "source_type", + pa.array([material_store.node_type] * material_table.num_rows), + ) + + lattice_table = lattice_table.append_column( + "target_id", lattice_table["material_node_id"].combine_chunks() + ) + lattice_table = lattice_table.append_column( + "target_type", pa.array([lattice_store.node_type] * lattice_table.num_rows) + ) + + edge_table = pyarrow_utils.join_tables( + material_table, + lattice_table, + left_keys=["source_id"], + right_keys=["material_node_id"], + join_type="left outer", + ) + edge_table = edge_table.append_column( + "edge_type", pa.array([connection_name] * edge_table.num_rows) + ) + edge_table = edge_table.append_column( + "weight", pa.array([1.0] * edge_table.num_rows) + ) + + logger.debug( + f"Created material-lattice-has relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception(f"Error creating material-lattice-has relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def material_spg_has(material_store, spg_store): + try: + connection_name = "has" + + material_table = material_store.read_nodes( + columns=["id", "core.material_id", "symmetry.number"] + ) + spg_table = spg_store.read_nodes(columns=["id", "spg"]) + + material_table = material_table.rename_columns( + {"id": "source_id", "symmetry.number": "spg"} + ) + material_table = material_table.append_column( + "source_type", + pa.array([material_store.node_type] * material_table.num_rows), + ) + + spg_table = spg_table.rename_columns({"id": "target_id"}) + spg_table = spg_table.append_column( + "target_type", pa.array([spg_store.node_type] * spg_table.num_rows) + ) + + edge_table = pyarrow_utils.join_tables( + material_table, + spg_table, + left_keys=["spg"], + right_keys=["spg"], + join_type="left outer", + ) + + edge_table = edge_table.append_column( + "edge_type", pa.array([connection_name] * edge_table.num_rows) + ) + + edge_table = edge_table.append_column( + "weight", pa.array([1.0] * edge_table.num_rows) + ) + + names = pc.binary_join_element_wise( + pc.cast(edge_table["core.material_id"], pa.string()), + pc.cast(edge_table["spg"], pa.string()), + f"_{connection_name}_SpaceGroup", + ) + + edge_table = edge_table.append_column("name", names) + + logger.debug( + f"Created material-spg-has relationships. Shape: {edge_table.shape}" + ) + except Exception as e: + logger.exception(f"Error creating material-spg-has relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def element_chemenv_canOccur(element_store, chemenv_store, material_store): + try: + connection_name = "canOccur" + material_table = material_store.read_nodes( + columns=[ + "id", + "core.material_id", + "core.elements", + "chemenv.coordination_environments_multi_weight", + ] + ) + + chemenv_table = chemenv_store.read_nodes(columns=["id", "mp_symbol"]) + element_table = element_store.read_nodes(columns=["id", "symbol"]) + + chemenv_table = chemenv_table.rename_columns({"mp_symbol": "name"}) + chemenv_table = chemenv_table.append_column( + "target_type", pa.array([chemenv_store.node_type] * chemenv_table.num_rows) + ) + + element_table = element_table.rename_columns({"symbol": "name"}) + element_table = element_table.append_column( + "source_type", pa.array([element_store.node_type] * element_table.num_rows) + ) + + material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) + chemenv_df = chemenv_table.to_pandas(split_blocks=True, self_destruct=True) + element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) + + chemenv_target_id_map = { + row["name"]: row["id"] for _, row in chemenv_df.iterrows() + } + element_target_id_map = { + row["name"]: row["id"] for _, row in element_df.iterrows() + } + + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + } + + for _, row in material_df.iterrows(): + coord_envs = row["chemenv.coordination_environments_multi_weight"] + + if coord_envs is None: + continue + + elements = row["core.elements"] + + for i, coord_env in enumerate(coord_envs): + try: + chemenv_name = coord_env[0]["ce_symbol"] + element_name = elements[i] + + source_id = element_target_id_map[element_name] + target_id = chemenv_target_id_map[chemenv_name] + except: + continue + + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(element_store.node_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(chemenv_store.node_type) + table_dict["edge_type"].append(connection_name) + + name = f"{element_name}_{connection_name}_{chemenv_name}" + table_dict["name"].append(name) + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created element-chemenv-canOccur relationships. Shape: {edge_table.shape}" + ) + + except Exception as e: + logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") + raise e + + return edge_table + + +@edge_generator +def spg_crystalSystem_isApart(spg_store, crystal_system_store): + try: + connection_name = "isApart" + + except Exception as e: + logger.exception(f"Error creating spg-crystalSystem-isApart relationships: {e}") + raise e + + spg_table = spg_store.read_nodes(columns=["id", "spg"]) + crystal_system_table = crystal_system_store.read_nodes( + columns=["id", "crystal_system"] + ) + + spg_df = spg_table.to_pandas(split_blocks=True, self_destruct=True) + crystal_system_df = crystal_system_table.to_pandas( + split_blocks=True, self_destruct=True + ) + + spg_target_id_map = {row["spg"]: row["id"] for _, row in spg_df.iterrows()} + crystal_system_target_id_map = { + row["crystal_system"]: row["id"] for _, row in crystal_system_df.iterrows() + } + + crys_spg_map = { + "Triclinic": np.arange(1, 3), + "Monoclinic": np.arange(3, 16), + "Orthorhombic": np.arange(16, 75), + "Tetragonal": np.arange(75, 143), + "Trigonal": np.arange(143, 168), + "Hexagonal": np.arange(168, 195), + "Cubic": np.arange(195, 231), + } + table_dict = { + "source_id": [], + "source_type": [], + "target_id": [], + "target_type": [], + "edge_type": [], + "name": [], + } + try: + for crystal_system, spg_range in crys_spg_map.items(): + for spg in spg_range: + source_id = spg_target_id_map[spg] + target_id = crystal_system_target_id_map[crystal_system] + + table_dict["source_id"].append(source_id) + table_dict["source_type"].append(spg_store.node_type) + table_dict["target_id"].append(target_id) + table_dict["target_type"].append(crystal_system_store.node_type) + table_dict["edge_type"].append(connection_name) + table_dict["name"].append(f"{crystal_system}_{connection_name}_{spg}") + + edge_table = ParquetDB.construct_table(table_dict) + + logger.debug( + f"Created spg-crystalSystem-isApart relationships. Shape: {edge_table.shape}" + ) + + except Exception as e: + logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") + raise e + + return edge_table diff --git a/matgraphdb/core/generator_store.py b/matgraphdb/core/generator_store.py deleted file mode 100644 index fd19b71..0000000 --- a/matgraphdb/core/generator_store.py +++ /dev/null @@ -1,339 +0,0 @@ -import inspect -import logging -import os -import sys -from typing import Callable, Dict, List, Union - -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -from parquetdb import ParquetDB -from parquetdb.core import types -from parquetdb.core.parquetdb import NormalizeConfig -from parquetdb.utils import data_utils - -logger = logging.getLogger(__name__) - - -def validate_generator_inputs(args): - from matgraphdb.core.edges import EdgeStore - from matgraphdb.core.nodes import NodeStore - - ALLOWED_GENERATOR_INPUTS = (EdgeStore, NodeStore) - - for arg in args: - if not isinstance(arg, ALLOWED_GENERATOR_INPUTS): - raise ValueError( - f"Generator input must be a {ALLOWED_GENERATOR_INPUTS}, {arg} is {type(arg)}" - ) - - -class GeneratorStore(ParquetDB): - """ - A store for managing generator functions in a graph database. - This class handles serialization, storage, and loading of functions - that generate edges between nodes. - """ - - required_fields = ["generator_name", "generator_func"] - metadata_keys = ["class", "class_module"] - - def __init__(self, storage_path: str, initial_fields: List[pa.Field] = None): - """ - Initialize the EdgeGeneratorStore. - - Parameters - ---------- - storage_path : str - Path where the generator functions will be stored - - """ - if initial_fields is None: - initial_fields = [] - - initial_fields.extend( - [ - pa.field("generator_name", pa.string()), - pa.field("generator_func", types.PythonObjectArrowType()), - ] - ) - super().__init__(db_path=storage_path, initial_fields=initial_fields) - - self._initialize_metadata() - logger.debug(f"Initialized GeneratorStore at {storage_path}") - - def __repr__(self): - return self.summary(show_column_names=True) - - def _initialize_metadata(self): - """Initialize store metadata if not present.""" - metadata = self.get_metadata() - update_metadata = False - for key in self.metadata_keys: - if key not in metadata: - update_metadata = True - break - - if update_metadata: - self.set_metadata( - { - "class": f"{self.__class__.__name__}", - "class_module": f"{self.__class__.__module__}", - } - ) - - @property - def storage_path(self): - return self._db_path - - @property - def n_generators(self): - return self.read(columns=["generator_name"]).num_rows - - @property - def generator_names(self): - return ( - self.read(columns=["generator_name"]).to_pandas()["generator_name"].tolist() - ) - - def summary(self, show_column_names: bool = False): - fields_metadata = self.get_field_metadata() - metadata = self.get_metadata() - - # Header section - tmp_str = f"{'=' * 60}\n" - tmp_str += f"GENERATOR STORE SUMMARY\n" - tmp_str += f"{'=' * 60}\n" - tmp_str += f"• Number of generators: {self.n_generators}\n" - tmp_str += f"Storage path: {os.path.relpath(self.storage_path)}\n\n" - - # Metadata section - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"METADATA\n" - tmp_str += f"{'#' * 60}\n" - for key, value in metadata.items(): - tmp_str += f"• {key}: {value}\n" - - # Generator details - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"GENERATOR DETAILS\n" - tmp_str += f"{'#' * 60}\n" - if show_column_names: - tmp_str += f"• Columns:\n" - for col in self.get_schema().names: - tmp_str += f" - {col}\n" - - if fields_metadata[col]: - tmp_str += f" - Field metadata\n" - for key, value in fields_metadata[col].items(): - tmp_str += f" - {key}: {value}\n" - - # Show generator names - tmp_str += f"\n• Generator names:\n" - for name in self.generator_names: - tmp_str += f" - {name}\n" - - return tmp_str - - def store_generator( - self, - generator_func: Callable, - generator_name: str, - generator_args: Dict = None, - generator_kwargs: Dict = None, - create_kwargs: Dict = None, - ) -> None: - """ - Store an edge generator function. - - Parameters - ---------- - generator_func : Callable - The function that generates edges - generator_name : str - Name to identify the generator function - generator_args : Dict - Arguments to pass to the generator function - generator_kwargs : Dict - Keyword arguments to pass to the generator function - create_kwargs : Dict - Keyword arguments to pass to the create method - """ - if create_kwargs is None: - create_kwargs = {} - if generator_args is None: - generator_args = {} - if generator_kwargs is None: - generator_kwargs = get_function_kwargs(generator_func) - try: - df = self.read(columns=["generator_name"]).to_pandas() - - if generator_name in df["generator_name"].values: - logger.warning(f"Generator '{generator_name}' already exists") - return None - - # Serialize the function using dill - - # Create data record - extra_fields = {} - for key, value in generator_args.items(): - extra_fields[f"generator_args.{key}"] = value - - for key, value in generator_kwargs.items(): - extra_fields[f"generator_kwargs.{key}"] = value - - data = [ - { - "generator_name": generator_name, - "generator_func": generator_func, - **extra_fields, - } - ] - # Store the function data - self.create(data=data, **create_kwargs) - logger.info(f"Successfully stored generator '{generator_name}'") - - except Exception as e: - logger.error(f"Failed to store generator '{generator_name}': {str(e)}") - raise - - def load_generator_data(self, generator_name: str) -> pd.DataFrame: - filters = [pc.field("generator_name") == generator_name] - table = self.read(filters=filters) - - for column_name in table.column_names: - logger.debug(f"Loading generator data for column: {column_name}") - col_array = table[column_name].drop_null() - if len(col_array) == 0: - table = table.drop(column_name) - - if len(table) == 0: - raise ValueError(f"No generator found with name '{generator_name}'") - return table.to_pandas() - - def is_in(self, generator_name: str) -> bool: - filters = [pc.field("generator_name") == generator_name] - table = self.read(filters=filters) - return len(table) > 0 - - def load_generator(self, generator_name: str) -> Callable: - """ - Load an edge generator function by name. - - Parameters - ---------- - generator_name : str - Name of the generator function to load - - Returns - ------- - Callable - The loaded generator function - """ - try: - df = self.load_generator_data(generator_name) - generator_func = df["generator_func"].iloc[0] - return generator_func - - except Exception as e: - logger.error(f"Failed to load generator '{generator_name}': {str(e)}") - raise - - def list_generators(self) -> List[Dict]: - """ - List all stored edge generators. - - Returns - ------- - List[Dict] - List of dictionaries containing generator information - """ - try: - result = self.read(columns=["generator_name"]) - return result.to_pylist() - except Exception as e: - logger.error(f"Failed to list generators: {str(e)}") - raise - - def delete_generator(self, generator_name: str) -> None: - """ - Delete a generator by name. - - Parameters - ---------- - generator_name : str - Name of the generator to delete - """ - try: - filters = [pc.field("generator_name") == generator_name] - self.delete(filters=filters) - logger.info(f"Successfully deleted generator '{generator_name}'") - except Exception as e: - logger.error(f"Failed to delete generator '{generator_name}': {str(e)}") - raise - - def run_generator( - self, - generator_name: str, - generator_args: Dict = None, - generator_kwargs: Dict = None, - ) -> None: - """ - Run a generator function by name. - """ - - if generator_args is None: - generator_args = {} - if generator_kwargs is None: - generator_kwargs = {} - - df = self.load_generator_data(generator_name) - for column_name in df.columns: - value = df[column_name].iloc[0] - if "generator_args" in column_name: - arg_name = column_name.split(".")[-1] - - # Do not overwrite user-provided args - if arg_name not in generator_args: - - generator_args[arg_name] = value - elif "generator_kwargs" in column_name and value is not None: - kwarg_name = column_name.split(".")[-1] - - # Do not overwrite user-provided kwargs - if kwarg_name not in generator_kwargs: - - generator_kwargs[kwarg_name] = value - - generator_func = df["generator_func"].iloc[0] - - logger.debug(f"Generator func: {generator_func}") - logger.debug(f"Generator args: {generator_args}") - logger.debug(f"Generator kwargs: {generator_kwargs}") - - arg_names = get_function_arg_names(generator_func) - generator_args = [generator_args[k] for k in arg_names] - - logger.debug(f"Running {generator_func.__name__} with args: {generator_args}") - logger.debug( - f"Running {generator_func.__name__} with kwargs: {generator_kwargs}" - ) - return generator_func(*generator_args, **generator_kwargs) - - -def get_function_arg_names(func): - sig = inspect.signature(func) - return [ - name - for name, param in sig.parameters.items() - if param.default == inspect.Parameter.empty - ] - - -def get_function_kwargs(func): - sig = inspect.signature(func) - return { - name: param.default - for name, param in sig.parameters.items() - if param.default != inspect.Parameter.empty - } diff --git a/matgraphdb/core/graph_db.py b/matgraphdb/core/graph_db.py deleted file mode 100644 index 2fe0861..0000000 --- a/matgraphdb/core/graph_db.py +++ /dev/null @@ -1,944 +0,0 @@ -import importlib -import json -import logging -import os -import shutil -import time -from glob import glob -from typing import Callable, Dict, List, Union - -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -from parquetdb import ParquetDB -from parquetdb.utils import pyarrow_utils -from pyarrow import parquet as pq - -from matgraphdb.core.edges import EdgeStore -from matgraphdb.core.generator_store import GeneratorStore -from matgraphdb.core.nodes import NodeStore - -logger = logging.getLogger(__name__) - - -class GraphDB: - """ - A manager for a graph storing multiple node types and edge types. - Each node type and edge type is backed by a separate ParquetDB instance - (wrapped by NodeStore or EdgeStore). - """ - - def __init__(self, storage_path: str, load_custom_stores: bool = True): - """ - Parameters - ---------- - storage_path : str - The root path for this graph, e.g. '/path/to/my_graph'. - Subdirectories 'nodes/' and 'edges/' will be used. - """ - logger.info(f"Initializing GraphDB at root path: {storage_path}") - self.storage_path = os.path.abspath(storage_path) - - self.nodes_path = os.path.join(self.storage_path, "nodes") - self.edges_path = os.path.join(self.storage_path, "edges") - self.edge_generators_path = os.path.join(self.storage_path, "edge_generators") - self.node_generators_path = os.path.join(self.storage_path, "node_generators") - self.graph_path = os.path.join(self.storage_path, "graph") - self.generator_dependency_json = os.path.join( - self.storage_path, "generator_dependency.json" - ) - - self.graph_name = os.path.basename(self.storage_path) - - # Create directories if they don't exist - os.makedirs(self.nodes_path, exist_ok=True) - os.makedirs(self.edges_path, exist_ok=True) - os.makedirs(self.edge_generators_path, exist_ok=True) - os.makedirs(self.graph_path, exist_ok=True) - - logger.debug(f"Node directory: {self.nodes_path}") - logger.debug(f"Edge directory: {self.edges_path}") - logger.debug(f"Graph directory: {self.graph_path}") - - # Initialize empty dictionaries for stores, load existing stores - self.node_stores = self._load_existing_node_stores(load_custom_stores) - self.edge_stores = self._load_existing_edge_stores(load_custom_stores) - - self.edge_generator_store = GeneratorStore( - storage_path=self.edge_generators_path - ) - self.node_generator_store = GeneratorStore( - storage_path=self.node_generators_path - ) - - # This is here to make sure the node and edges paths - # listed in the generator stores align where the GraphDB is, - # this allows user to easily move the directory and the generators will still work - self._load_generator_dependency_graph() - self.generator_consistency_check() - - def __repr__(self): - return self.summary(show_column_names=False) - - # def __getitem__(self, *args: QueryType) -> Any: - # # `data[*]` => Link to either `_global_store`, _node_store_dict` or - # # `_edge_store_dict`. - # # If neither is present, we create a new `Storage` object for the given - # # node/edge-type. - # key = self._to_canonical(*args) - - # out = self._global_store.get(key, None) - # if out is not None: - # return out - - # if isinstance(key, tuple): - # return self.get_edge_store(*key) - # else: - # return self.get_node_store(key) - - # def __setitem__(self, key: str, value: Any): - # if key in self.node_types: - # raise AttributeError(f"'{key}' is already present as a node type") - # elif key in self.edge_types: - # raise AttributeError(f"'{key}' is already present as an edge type") - # self._global_store[key] = value - - # def __delitem__(self, *args: QueryType): - # # `del data[*]` => Link to `_node_store_dict` or `_edge_store_dict`. - # key = self._to_canonical(*args) - # if key in self.edge_types: - # del self._edge_store_dict[key] - # elif key in self.node_types: - # del self._node_store_dict[key] - # else: - # del self._global_store[key] - - @property - def n_node_types(self): - return len(self.node_stores) - - @property - def n_edge_types(self): - return len(self.edge_stores) - - @property - def n_nodes_per_type(self): - return { - node_type: node_store.n_nodes - for node_type, node_store in self.node_stores.items() - } - - @property - def n_edges_per_type(self): - return { - edge_type: edge_store.n_edges - for edge_type, edge_store in self.edge_stores.items() - } - - def generator_consistency_check(self): - logger.info("Checking directory consistency") - self._generator_check(self.node_generator_store) - self._generator_check(self.edge_generator_store) - - def _generator_check(self, generator_store): - df = generator_store.read().to_pandas() - for i, row in df.iterrows(): - generator_name = row["generator_name"] - for col_name in df.columns: - if col_name.startswith("generator_kwargs.") or col_name.startswith( - "generator_args." - ): - col_value = row[col_name] - if isinstance(col_value, (EdgeStore, NodeStore)): - store = col_value - - if hasattr(store, "node_type"): - current_path = self.get_node_store( - store.node_type - ).storage_path - generator_store_path = store.storage_path - if current_path != generator_store_path: - df.at[i, col_name] = current_path - - elif hasattr(store, "edge_type"): - current_path = self.get_edge_store( - store.edge_type - ).storage_path - generator_store_path = store.storage_path - if current_path != generator_store_path: - df.at[i, col_name] = current_path - - generator_store.update(data=df) - - def _load_existing_node_stores(self, load_custom_stores: bool = True): - logger.info(f"Loading existing node stores") - return self._load_existing_stores( - self.nodes_path, - default_store_class=NodeStore, - load_custom_stores=load_custom_stores, - ) - - def _load_existing_edge_stores(self, load_custom_stores: bool = True): - logger.info(f"Loading existing edge stores") - return self._load_existing_stores( - self.edges_path, - default_store_class=EdgeStore, - load_custom_stores=load_custom_stores, - ) - - def _load_existing_stores( - self, - stores_path, - default_store_class: Union[NodeStore, EdgeStore] = None, - load_custom_stores: bool = True, - ): - - if load_custom_stores: - default_store_class = None - - logger.debug(f"Load custom stores: {load_custom_stores}") - - store_dict = {} - store_types = os.listdir(stores_path) - logger.info(f"Found {len(store_types)} store types") - for store_type in store_types: - logger.debug(f"Attempting to load store: {store_type}") - - store_path = os.path.join(stores_path, store_type) - if os.path.isdir(store_path): - store_dict[store_type] = load_store(store_path, default_store_class) - else: - raise ValueError( - f"Store path {store_path} is not a directory. Likely does not exist." - ) - - return store_dict - - def _load_generator_dependency_graph(self): - if os.path.exists(self.generator_dependency_json): - with open(self.generator_dependency_json, "r") as f: - self.generator_dependency_graph = json.load(f) - for key, value in self.generator_dependency_graph["nodes"].items(): - self.generator_dependency_graph["nodes"][key] = set(value) - for key, value in self.generator_dependency_graph["edges"].items(): - self.generator_dependency_graph["edges"][key] = set(value) - else: - self.generator_dependency_graph = {"nodes": {}, "edges": {}} - - def summary(self, show_column_names: bool = False): - # Header section - tmp_str = f"{'=' * 60}\n" - tmp_str += f"GRAPH DATABASE SUMMARY\n" - tmp_str += f"{'=' * 60}\n" - tmp_str += f"Name: {self.graph_name}\n" - tmp_str += f"Storage path: {os.path.relpath(self.storage_path)}\n" - tmp_str += "└── Repository structure:\n" - tmp_str += ( - f" ├── nodes/ ({os.path.relpath(self.nodes_path)})\n" - ) - tmp_str += ( - f" ├── edges/ ({os.path.relpath(self.edges_path)})\n" - ) - tmp_str += f" ├── edge_generators/ ({os.path.relpath(self.edge_generators_path)})\n" - tmp_str += f" ├── node_generators/ ({os.path.relpath(self.node_generators_path)})\n" - tmp_str += ( - f" └── graph/ ({os.path.relpath(self.graph_path)})\n\n" - ) - - # Node section header - tmp_str += f"{'#' * 60}\n" - tmp_str += f"NODE DETAILS\n" - tmp_str += f"{'#' * 60}\n" - tmp_str += f"Total node types: {len(self.node_stores)}\n" - tmp_str += f"{'-' * 60}\n" - - # Node details - for node_type, node_store in self.node_stores.items(): - tmp_str += f"• Node type: {node_type}\n" - tmp_str += f" - Number of nodes: {node_store.n_nodes}\n" - tmp_str += f" - Number of features: {node_store.n_features}\n" - if show_column_names: - tmp_str += f" - Columns:\n" - for col in node_store.columns: - tmp_str += f" - {col}\n" - tmp_str += f" - db_path: {os.path.relpath(node_store.storage_path)}\n" - tmp_str += f"{'-' * 60}\n" - - # Edge section header - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"EDGE DETAILS\n" - tmp_str += f"{'#' * 60}\n" - tmp_str += f"Total edge types: {len(self.edge_stores)}\n" - tmp_str += f"{'-' * 60}\n" - - # Edge details - for edge_type, edge_store in self.edge_stores.items(): - tmp_str += f"• Edge type: {edge_type}\n" - tmp_str += f" - Number of edges: {edge_store.n_edges}\n" - tmp_str += f" - Number of features: {edge_store.n_features}\n" - if show_column_names: - tmp_str += f" - Columns:\n" - for col in edge_store.columns: - tmp_str += f" - {col}\n" - tmp_str += f" - db_path: {os.path.relpath(edge_store.storage_path)}\n" - tmp_str += f"{'-' * 60}\n" - - # Node generator header - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"NODE GENERATOR DETAILS\n" - tmp_str += f"{'#' * 60}\n" - tmp_str += f"Total node generators: {self.node_generator_store.n_generators}\n" - tmp_str += f"{'-' * 60}\n" - - # Node generator details - for generator_name in self.node_generator_store.generator_names: - df = self.node_generator_store.load_generator_data(generator_name) - tmp_str += f"• Generator: {generator_name}\n" - tmp_str += f"Generator Args:\n" - for col in df.columns: - col_name = col.replace("generator_args.", "") - if isinstance(df[col].iloc[0], (NodeStore, EdgeStore)): - tmp_str += f" - {col_name}: {os.path.relpath(df[col].iloc[0].storage_path)}\n" - else: - tmp_str += f" - {col_name}: {df[col].tolist()}\n" - tmp_str += f"Generator Kwargs:\n" - for col in df.columns: - col_name = col.replace("generator_kwargs.", "") - if col.startswith("generator_kwargs."): - tmp_str += f" - {col_name}: {df[col].tolist()}\n" - tmp_str += f"{'-' * 60}\n" - - # Edge generator header - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"EDGE GENERATOR DETAILS\n" - tmp_str += f"{'#' * 60}\n" - tmp_str += f"Total edge generators: {self.edge_generator_store.n_generators}\n" - tmp_str += f"{'-' * 60}\n" - - # Edge generator details - for generator_name in self.edge_generator_store.generator_names: - df = self.edge_generator_store.load_generator_data(generator_name) - tmp_str += f"• Generator: {generator_name}\n" - tmp_str += f"Generator Args:\n" - for col in df.columns: - col_name = col.replace("generator_args.", "") - if col.startswith("generator_args."): - if isinstance(df[col].iloc[0], (NodeStore, EdgeStore)): - tmp_str += f" - {col_name}: {os.path.relpath(df[col].iloc[0].storage_path)}\n" - else: - tmp_str += f" - {col_name}: {df[col].tolist()}\n" - tmp_str += f"Generator Kwargs:\n" - for col in df.columns: - col_name = col.replace("generator_kwargs.", "") - if col.startswith("generator_kwargs."): - tmp_str += f" - {col_name}: {df[col].tolist()}\n" - tmp_str += f"{'-' * 60}\n" - - return tmp_str - - # ------------------ - # Node-level methods - # ------------------ - def add_nodes(self, node_type: str, data, **kwargs): - logger.info(f"Creating nodes of type '{node_type}'") - store = self.add_node_type(node_type) - store.create_nodes(data, **kwargs) - - self._run_dependent_generators(node_type) - logger.debug(f"Successfully created nodes of type '{node_type}'") - - def add_node_type(self, node_type: str) -> NodeStore: - """ - Create (or load) a NodeStore for the specified node_type. - """ - if node_type in self.node_stores: - logger.debug(f"Returning existing NodeStore for type: {node_type}") - return self.node_stores[node_type] - - logger.info(f"Creating new NodeStore for type: {node_type}") - storage_path = os.path.join(self.nodes_path, node_type) - self.node_stores[node_type] = NodeStore(storage_path=storage_path) - return self.node_stores[node_type] - - def add_node_store( - self, - node_store: NodeStore, - overwrite: bool = False, - remove_original: bool = False, - ): - logger.info(f"Adding node store of type {node_store.node_type}") - - # Check if node store already exists - if node_store.node_type in self.node_stores: - if overwrite: - logger.warning( - f"Node store of type {node_store.node_type} already exists, overwriting" - ) - self.remove_node_store(node_store.node_type) - else: - raise ValueError( - f"Node store of type {node_store.node_type} already exists, and overwrite is False" - ) - - # Move node store to the nodes directory - new_path = os.path.join(self.nodes_path, node_store.node_type) - if node_store.storage_path != new_path: - logger.debug( - f"Moving node store from {node_store.storage_path} to {new_path}" - ) - shutil.copytree(node_store.storage_path, new_path) - - if remove_original: - shutil.rmtree(node_store.storage_path) - node_store.storage_path = new_path - self.node_stores[node_store.node_type] = node_store - - self._run_dependent_generators(node_store.node_type) - - def get_nodes( - self, node_type: str, ids: List[int] = None, columns: List[str] = None, **kwargs - ): - logger.info(f"Reading nodes of type '{node_type}'") - if ids: - logger.debug(f"Filtering by {len(ids)} node IDs") - if columns: - logger.debug(f"Selecting columns: {columns}") - store = self.get_node_store(node_type) - return store.read_nodes(ids=ids, columns=columns, **kwargs) - - def read_nodes( - self, node_type: str, ids: List[int] = None, columns: List[str] = None, **kwargs - ): - store = self.get_node_store(node_type) - return store.read_nodes(ids=ids, columns=columns, **kwargs) - - def get_node_store(self, node_type: str): - # if node_type not in self.node_stores: - node_store = self.node_stores.get(node_type, None) - if node_store is None: - raise ValueError(f"Node store of type {node_type} does not exist") - return node_store - - def update_nodes(self, node_type: str, data, **kwargs): - store = self.get_node_store(node_type) - store.update_nodes(data, **kwargs) - - self._run_dependent_generators(node_type) - - def delete_nodes( - self, node_type: str, ids: List[int] = None, columns: List[str] = None - ): - store = self.get_node_store(node_type) - store.delete_nodes(ids=ids, columns=columns) - - self._run_dependent_generators(node_type) - - def remove_node_store(self, node_type: str): - logger.info(f"Removing node store of type {node_type}") - store = self.get_node_store(node_type) - shutil.rmtree(store.storage_path) - self.node_stores.pop(node_type) - - self._run_dependent_generators(node_type) - - def remove_node_type(self, node_type: str): - self.remove_node_store(node_type) - - def normalize_nodes(self, node_type: str, normalize_kwargs: Dict = None): - store = self.add_node_type(node_type) - store.normalize_nodes(**normalize_kwargs) - - def normalize_all_nodes(self, normalize_kwargs): - for node_type in self.node_stores: - self.normalize_nodes(node_type, normalize_kwargs) - - def list_node_types(self): - return list(self.node_stores.keys()) - - def node_exists(self, node_type: str): - logger.debug(f"Node type: {node_type}") - logger.debug(f"Node stores: {self.node_stores}") - - return node_type in self.node_stores - - def node_is_empty(self, node_type: str): - store = self.get_node_store(node_type) - return store.is_empty() - - def add_node_generator( - self, - generator_func: Callable, - generator_args: Dict = None, - generator_kwargs: Dict = None, - create_kwargs: Dict = None, - run_immediately: bool = True, - run_generator_kwargs: Dict = None, - depends_on: List[str] = None, - add_dependency: bool = True, - ) -> None: - generator_name = generator_func.__name__ - self.node_generator_store.store_generator( - generator_func=generator_func, - generator_name=generator_name, - generator_args=generator_args, - generator_kwargs=generator_kwargs, - create_kwargs=create_kwargs, - ) - self.generator_consistency_check() - - if run_immediately: - if run_generator_kwargs is None: - run_generator_kwargs = dict(generator_name=generator_name) - else: - run_generator_kwargs["generator_name"] = generator_name - - self.run_node_generator(**run_generator_kwargs) - - if add_dependency and depends_on: - self.add_generator_dependency(generator_name, depends_on, node_type="nodes") - elif add_dependency and not depends_on: - self.add_generator_dependency(generator_name) - - def run_node_generator( - self, - generator_name: str, - generator_args: Dict = None, - generator_kwargs: Dict = None, - create_kwargs: Dict = None, - ) -> None: - """ - Execute a previously registered custom node-generation function by name. - Parameters - ---------- - generator_name : str - The unique name used when registering the function. - generator_args : Dict - Additional arguments passed to the generator function. - generator_kwargs : Dict - Additional keyword arguments passed to the generator function. - - Raises - ------ - ValueError - If there is no generator function with the given name. - """ - if create_kwargs is None: - create_kwargs = {} - - if generator_args is None: - generator_args = {} - - if generator_kwargs is None: - generator_kwargs = {} - - table = self.node_generator_store.run_generator( - generator_name, - generator_args=generator_args, - generator_kwargs=generator_kwargs, - ) - - storage_path = os.path.join(self.nodes_path, generator_name) - if os.path.exists(storage_path): - logger.info(f"Removing existing node store: {generator_name}") - self.remove_node_store(generator_name) - - self.add_nodes(node_type=generator_name, data=table, **create_kwargs) - return table - - # ------------------ - # Edge-level methods - # ------------------ - def add_edge_type(self, edge_type: str) -> EdgeStore: - """ - Create (or load) an EdgeStore for the specified edge_type. - """ - if edge_type in self.edge_stores: - logger.debug(f"Returning existing EdgeStore for type: {edge_type}") - return self.edge_stores[edge_type] - - logger.info(f"Creating new EdgeStore for type: {edge_type}") - storage_path = os.path.join(self.edges_path, edge_type) - self.edge_stores[edge_type] = EdgeStore(storage_path=storage_path) - return self.edge_stores[edge_type] - - def add_edges(self, edge_type: str, data, **kwargs): - logger.info(f"Creating edges of type '{edge_type}'") - incoming_table = ParquetDB.construct_table(data) - self._validate_edge_references(incoming_table) - store = self.add_edge_type(edge_type) - store.create_edges(incoming_table, **kwargs) - self._run_dependent_generators(edge_type) - logger.debug(f"Successfully created edges of type '{edge_type}'") - - def add_edge_store(self, edge_store: EdgeStore): - logger.info(f"Adding edge store of type {edge_store.edge_type}") - - # Move edge store to the edges directory - new_path = os.path.join(self.edges_path, edge_store.edge_type) - if edge_store.storage_path != new_path: - logger.debug( - f"Moving edge store from {edge_store.storage_path} to {new_path}" - ) - os.makedirs(new_path, exist_ok=True) - for file in glob(os.path.join(edge_store.storage_path, "*")): - new_file = os.path.join(new_path, os.path.basename(file)) - os.rename(file, new_file) - edge_store.storage_path = new_path - self.edge_stores[edge_store.edge_type] = edge_store - - self._run_dependent_generators(edge_store.edge_type) - - def read_edges( - self, edge_type: str, ids: List[int] = None, columns: List[str] = None, **kwargs - ): - store = self.add_edge_type(edge_type) - return store.read_edges(ids=ids, columns=columns, **kwargs) - - def update_edges(self, edge_type: str, data, **kwargs): - store = self.add_edge_type(edge_type) - store.update_edges(data, **kwargs) - - self._run_dependent_generators(edge_type) - - def delete_edges( - self, edge_type: str, ids: List[int] = None, columns: List[str] = None - ): - store = self.add_edge_type(edge_type) - store.delete_edges(ids=ids, columns=columns) - - self._run_dependent_generators(edge_type) - - def remove_edge_store(self, edge_type: str): - logger.info(f"Removing edge store of type {edge_type}") - store = self.get_edge_store(edge_type) - shutil.rmtree(store.storage_path) - self.edge_stores.pop(edge_type) - - self._run_dependent_generators(edge_type) - - def remove_edge_type(self, edge_type: str): - self.remove_edge_store(edge_type) - - def normalize_edges(self, edge_type: str): - store = self.add_edge_type(edge_type) - store.normalize_edges() - - def normalize_all_edges(self, normalize_kwargs): - for edge_type in self.edge_stores: - self.normalize_edges(edge_type, normalize_kwargs) - - def get_edge_store(self, edge_type: str): - edge_store = self.edge_stores.get(edge_type, None) - if edge_store is None: - raise ValueError(f"Edge store of type {edge_type} does not exist") - return edge_store - - def list_edge_types(self): - return list(self.edge_stores.keys()) - - def edge_exists(self, edge_type: str): - return edge_type in self.edge_stores - - def edge_is_empty(self, edge_type: str): - store = self.get_edge_store(edge_type) - return store.is_empty() - - def _validate_edge_references(self, table: pa.Table) -> None: - """ - Checks whether source_id and target_id in each edge record exist - in the corresponding node stores. - - Parameters - ---------- - table : pa.Table - A table containing 'source_id' and 'target_id' columns. - source_node_type : str - The node type for the source nodes (e.g., 'user'). - target_node_type : str - The node type for the target nodes (e.g., 'item'). - - Raises - ------ - ValueError - If any source_id/target_id is not found in the corresponding node store. - """ - # logger.debug(f"Validating edge references: {source_node_type} -> {target_node_type}") - edge_table = table - # 1. Retrieve the NodeStores - names = edge_table.column_names - logger.debug(f"Column names: {names}") - - assert "source_type" in names, "source_type column not found in table" - assert "target_type" in names, "target_type column not found in table" - assert "source_id" in names, "source_id column not found in table" - assert "target_id" in names, "target_id column not found in table" - assert "edge_type" in names, "edge_type column not found in table" - - node_types = pc.unique(table["source_type"]).to_pylist() - - for node_type in node_types: - store = self.node_stores.get(node_type, None) - if store is None: - logger.error(f"No node store found for node_type='{node_type}'") - raise ValueError(f"No node store found for node_type='{node_type}'.") - - # Read all existing source IDs from store_1 - node_table = store.read_nodes(columns=["id"]) - - # Filter all source_ids and target_ids that are of the same type as store_1 - source_id_array = edge_table.filter( - pc.field("source_type") == store.node_type - )["source_id"].combine_chunks() - target_id_array = edge_table.filter( - pc.field("target_type") == store.node_type - )["target_id"].combine_chunks() - - all_source_type_ids = pa.concat_arrays([source_id_array, target_id_array]) - - # Check if all source_ids and target_ids are in the node_store - is_source_ids_in_source_store = pc.index_in( - all_source_type_ids, node_table["id"] - ) - invalid_source_ids = is_source_ids_in_source_store.filter( - pc.is_null(is_source_ids_in_source_store) - ) - - if len(invalid_source_ids) > 0: - raise ValueError( - f"Source IDs not found in source_store of type {store.node_type}: {invalid_source_ids}" - ) - - logger.debug("Edge reference validation completed successfully") - - def construct_table(self, data, schema=None, metadata=None, fields_metadata=None): - logger.info("Validating data") - return ParquetDB.construct_table( - data, schema=schema, metadata=metadata, fields_metadata=fields_metadata - ) - - def add_edge_generator( - self, - generator_func: Callable, - generator_args: Dict = None, - generator_kwargs: Dict = None, - create_kwargs: Dict = None, - run_immediately: bool = True, - run_generator_create_kwargs: Dict = None, - depends_on: List[str] = None, - add_dependency: bool = True, - ) -> None: - """ - Register a user-defined callable that can read from node stores, - and then create or update edges as it sees fit. - - Parameters - ---------- - name : str - A unique identifier for this generator function. - generator_func : Callable - """ - generator_name = generator_func.__name__ - self.edge_generator_store.store_generator( - generator_func=generator_func, - generator_name=generator_name, - generator_args=generator_args, - generator_kwargs=generator_kwargs, - create_kwargs=create_kwargs, - ) - logger.info(f"Added new edge generator: {generator_name}") - - if run_immediately: - if run_generator_create_kwargs is None: - run_generator_create_kwargs = {} - self.run_edge_generator( - generator_name=generator_name, create_kwargs=run_generator_create_kwargs - ) - - if add_dependency and depends_on: - self.add_generator_dependency(generator_name, depends_on, node_type="edges") - elif add_dependency and not depends_on: - self.add_generator_dependency(generator_name) - - def run_edge_generator( - self, - generator_name: str, - generator_args: Dict = None, - generator_kwargs: Dict = None, - create_kwargs: Dict = None, - ) -> None: - if create_kwargs is None: - create_kwargs = {} - - table = self.edge_generator_store.run_generator( - generator_name, - generator_args=generator_args, - generator_kwargs=generator_kwargs, - ) - - storage_path = os.path.join(self.edges_path, generator_name) - if os.path.exists(storage_path): - logger.info(f"Removing existing edge store: {generator_name}") - self.remove_edge_store(generator_name) - - self.add_edges(edge_type=generator_name, data=table, **create_kwargs) - return table - - def get_generator_type(self, generator_name: str): - if self.edge_generator_store.is_in(generator_name): - return "edges" - elif self.node_generator_store.is_in(generator_name): - return "nodes" - else: - raise ValueError(f"Generator {generator_name} not in node or edge store") - - def get_generator_dependency_graph(self, generator_name: str): - - generator_type = self.get_generator_type(generator_name) - generator_store = ( - self.edge_generator_store - if generator_type == "edges" - else self.node_generator_store - ) - dependency_graph = {generator_type: {}} - df = generator_store.load_generator_data(generator_name=generator_name) - logger.debug(f"Generator data: {df.columns}") - - for i, row in df.iterrows(): - if generator_name not in dependency_graph[generator_type]: - dependency_graph[generator_type][generator_name] = set() - - for col_name in df.columns: - if col_name.startswith("generator_kwargs.") or col_name.startswith( - "generator_args." - ): - col_value = row[col_name] - if isinstance(col_value, (EdgeStore, NodeStore)): - store = col_value - - if hasattr(store, "node_type"): - current_path = self.get_node_store( - store.node_type - ).storage_path - dependency_graph[generator_type][generator_name].add( - store.node_type - ) - - elif hasattr(store, "edge_type"): - current_path = self.get_edge_store( - store.edge_type - ).storage_path - - dependency_graph[generator_type][generator_name].add( - store.edge_type - ) - return dependency_graph - - def add_generator_dependency( - self, - generator_name: str, - depends_on: List[str] = None, - store_type: str = None, - ): - """ - Add dependencies for a generator. When any of the dependencies are updated, - the generator will automatically run. - - Parameters - ---------- - generator_name : str - Name of the generator that has dependencies - depends_on : List[str] - List of store names that this generator depends on - store_type : str - Either 'nodes' or 'edges', indicating the type of store the generator creates - """ - dependencies = None - if depends_on: - logger.info(f"Adding dependencies for {generator_name}: {depends_on}") - if store_type not in ["nodes", "edges"]: - raise ValueError("store_type must be either 'nodes' or 'edges'") - if generator_name not in self.generator_dependency_graph[store_type]: - self.generator_dependency_graph[store_type][generator_name] = set() - dependencies = depends_on - self.generator_dependency_graph[store_type][generator_name].update( - dependencies - ) - else: - logger.info(f"Adding all dependencies for {generator_name}") - dependencies = self.get_generator_dependency_graph(generator_name) - self.generator_dependency_graph.update(dependencies) - - with open(self.generator_dependency_json, "w") as f: - for key, value in self.generator_dependency_graph["nodes"].items(): - self.generator_dependency_graph["nodes"][key] = list(value) - for key, value in self.generator_dependency_graph["edges"].items(): - self.generator_dependency_graph["edges"][key] = list(value) - json.dump(self.generator_dependency_graph, f) - - if dependencies: - logger.info(f"Added dependencies for {generator_name}: {dependencies}") - - def _run_dependent_generators(self, store_name: str): - """ - Run all generators that depend on a specific store. - - Parameters - ---------- - store_type : str - Either 'nodes' or 'edges' - store_name : str - Name of the store that was updated - """ - logger.info(f"Running dependent generators: {store_name}") - - # Find all generators that depend on this store - dependent_generators = set() - for generator_name, dependencies in self.generator_dependency_graph[ - "nodes" - ].items(): - if store_name in dependencies: - dependent_generators.add(("nodes", generator_name)) - - for generator_name, dependencies in self.generator_dependency_graph[ - "edges" - ].items(): - if store_name in dependencies: - dependent_generators.add(("edges", generator_name)) - - logger.debug(f"Dependent generators: {dependent_generators}") - # Run each dependent generator - for dep_store_type, generator_name in dependent_generators: - logger.info(f"Running dependent generator: {generator_name}") - try: - if dep_store_type == "nodes": - self.run_node_generator(generator_name) - # Recursively run generators that depend on this result - self._run_dependent_generators(generator_name) - else: - self.run_edge_generator(generator_name=generator_name) - # Recursively run generators that depend on this result - self._run_dependent_generators(generator_name) - except Exception as e: - logger.error( - f"Failed to run dependent generator {generator_name}: {str(e)}" - ) - - -def load_store(store_path: str, default_store_class=None): - store_metadata = ParquetDB(store_path).get_metadata() - class_module = store_metadata.get("class_module", None) - class_name = store_metadata.get("class", None) - - logger.debug(f"Class module: {class_module}") - logger.debug(f"Class: {class_name}") - - if class_module and class_name and default_store_class is None: - logger.debug(f"Importing class from module: {class_module}") - module = importlib.import_module(class_module) - class_obj = getattr(module, class_name) - store = class_obj(storage_path=store_path) - else: - logger.debug(f"Using default store class: {default_store_class.__name__}") - store = default_store_class(storage_path=store_path) - - return store diff --git a/matgraphdb/materials/core.py b/matgraphdb/core/matgraphdb.py similarity index 91% rename from matgraphdb/materials/core.py rename to matgraphdb/core/matgraphdb.py index ebeea8f..8f43f7c 100644 --- a/matgraphdb/materials/core.py +++ b/matgraphdb/core/matgraphdb.py @@ -3,14 +3,14 @@ from typing import Dict, List, Union import pyarrow as pa +from parquetdb import ParquetGraphDB -from matgraphdb.core import GraphDB -from matgraphdb.materials.nodes import MaterialStore +from matgraphdb.core.nodes import MaterialStore logger = logging.getLogger(__name__) -class MatGraphDB(GraphDB): +class MatGraphDB(ParquetGraphDB): """ The main entry point for advanced material analysis and graph storage. @@ -24,6 +24,7 @@ def __init__( storage_path: str, materials_store: MaterialStore = None, load_custom_stores: bool = True, + **kwargs, ): """ Parameters @@ -37,7 +38,9 @@ def __init__( """ self.storage_path = os.path.abspath(storage_path) super().__init__( - storage_path=self.storage_path, load_custom_stores=load_custom_stores + storage_path=self.storage_path, + load_custom_stores=load_custom_stores, + **kwargs, ) logger.info(f"Initializing MatGraphDB at: {self.storage_path}") diff --git a/matgraphdb/core/nodes.py b/matgraphdb/core/nodes.py deleted file mode 100644 index ac8c976..0000000 --- a/matgraphdb/core/nodes.py +++ /dev/null @@ -1,427 +0,0 @@ -import logging -import os -from functools import wraps -from typing import List, Union - -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -from parquetdb import ParquetDB -from parquetdb.core.parquetdb import LoadConfig, NormalizeConfig - -from matgraphdb.core.generator_store import validate_generator_inputs -from matgraphdb.core.utils import get_dataframe_column_names - -logger = logging.getLogger(__name__) - - -REQUIRED_NODE_COLUMNS_FIELDS = set() - - -def validate_node_dataframe(df): - column_names = get_dataframe_column_names(df) - fields = set(column_names) - missing_fields = REQUIRED_NODE_COLUMNS_FIELDS - fields - if missing_fields: - raise ValueError( - f"Node dataframe is missing required fields: {missing_fields}. Node dataframe must contain the following columns: {REQUIRED_NODE_COLUMNS_FIELDS}" - ) - return df - - -def node_generator(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Perform pre-execution checks - logger.debug(f"Executing {func.__name__} with args: {args}, kwargs: {kwargs}") - validate_generator_inputs(args) - - df = func(*args, **kwargs) - validate_node_dataframe(df) - return df - - return wrapper - - -class NodeStore(ParquetDB): - """ - A wrapper around ParquetDB specifically for storing node features - of a given node type. - """ - - node_metadata_keys = ["class", "class_module", "node_type", "name_column"] - - def __init__(self, storage_path: str, initialize_kwargs: dict = None): - """ - Parameters - ---------- - storage_path : str - The path where ParquetDB files for this node type are stored. - """ - - self.node_type = os.path.basename(storage_path) - - if initialize_kwargs is None: - initialize_kwargs = {} - - super().__init__(db_path=storage_path) - - metadata = self.get_metadata() - - update_metadata = False - for key in self.node_metadata_keys: - if key not in metadata: - update_metadata = update_metadata or key not in metadata - if update_metadata: - self.set_metadata( - { - "class": f"{self.__class__.__name__}", - "class_module": f"{self.__class__.__module__}", - "node_type": self.node_type, - "name_column": "id", - } - ) - - if self.is_empty(): - self._initialize(**initialize_kwargs) - - logger.debug(f"Initialized NodeStore at {storage_path}") - - def __repr__(self): - return self.summary(show_column_names=False) - - @property - def storage_path(self): - return self._db_path - - @storage_path.setter - def storage_path(self, value): - self._db_path = value - self.node_type = os.path.basename(value) - - def _initialize(self, **kwargs): - data = self.initialize(**kwargs) - if data is not None: - self.create_nodes(data=data) - - def initialize(self, **kwargs): - return None - - @property - def name_column(self): - return self.get_metadata()["name_column"] - - @name_column.setter - def name_column(self, value): - self.set_metadata({"name_column": value}) - - @property - def n_nodes(self): - return self.read_nodes(columns=["id"]).num_rows - - @property - def n_features(self): - return len(self.get_schema().names) - - @property - def columns(self): - return self.get_schema().names - - def summary(self, show_column_names: bool = False): - fields_metadata = self.get_field_metadata() - metadata = self.get_metadata() - # Header section - tmp_str = f"{'=' * 60}\n" - tmp_str += f"NODE STORE SUMMARY\n" - tmp_str += f"{'=' * 60}\n" - tmp_str += f"Node type: {self.node_type}\n" - tmp_str += f"• Number of nodes: {self.n_nodes}\n" - tmp_str += f"• Number of features: {self.n_features}\n" - tmp_str += f"Storage path: {os.path.relpath(self.storage_path)}\n\n" - - # Metadata section - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"METADATA\n" - tmp_str += f"{'#' * 60}\n" - for key, value in metadata.items(): - tmp_str += f"• {key}: {value}\n" - - # Node details - tmp_str += f"\n{'#' * 60}\n" - tmp_str += f"NODE DETAILS\n" - tmp_str += f"{'#' * 60}\n" - if show_column_names: - tmp_str += f"• Columns:\n" - for col in self.columns: - tmp_str += f" - {col}\n" - - if fields_metadata[col]: - tmp_str += f" - Field metadata\n" - for key, value in fields_metadata[col].items(): - tmp_str += f" - {key}: {value}\n" - - return tmp_str - - def create_nodes( - self, - data: Union[List[dict], dict, pd.DataFrame, pa.Table], - schema: pa.Schema = None, - metadata: dict = None, - fields_metadata: dict = None, - treat_fields_as_ragged: List[str] = None, - convert_to_fixed_shape: bool = True, - normalize_dataset: bool = False, - normalize_config: dict = NormalizeConfig(), - ): - """ - Adds new data to the database. - - Parameters - ---------- - data : dict, list of dict, or pandas.DataFrame - The data to be added to the database. - schema : pyarrow.Schema, optional - The schema for the incoming data. - metadata : dict, optional - Metadata to be attached to the table. - fields_metadata : dict, optional - A dictionary containing the metadata to be set for the fields. - normalize_dataset : bool, optional - If True, the dataset will be normalized after the data is added (default is True). - treat_fields_as_ragged : list of str, optional - A list of fields to treat as ragged arrays. - convert_to_fixed_shape : bool, optional - If True, the ragged arrays will be converted to fixed shape arrays. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - Examples - -------- - >>> db.create_nodes(data=my_data, schema=my_schema, metadata={'source': 'api'}, normalize_dataset=True) - """ - create_kwargs = dict( - data=data, - schema=schema, - metadata=metadata, - fields_metadata=fields_metadata, - treat_fields_as_ragged=treat_fields_as_ragged, - convert_to_fixed_shape=convert_to_fixed_shape, - normalize_dataset=normalize_dataset, - normalize_config=normalize_config, - ) - num_records = len(data) if isinstance(data, (list, pd.DataFrame)) else 1 - logger.info(f"Creating {num_records} node records") - try: - self.create(**create_kwargs) - logger.debug("Node creation successful") - except Exception as e: - logger.error(f"Failed to create nodes: {str(e)}") - raise - - def read_nodes( - self, - ids: List[int] = None, - columns: List[str] = None, - filters: List[pc.Expression] = None, - load_format: str = "table", - batch_size: int = None, - include_cols: bool = True, - rebuild_nested_struct: bool = False, - rebuild_nested_from_scratch: bool = False, - load_config: LoadConfig = LoadConfig(), - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Reads data from the database. - - Parameters - ---------- - - ids : list of int, optional - A list of IDs to read. If None, all data is read (default is None). - columns : list of str, optional - The columns to include in the output. If None, all columns are included (default is None). - filters : list of pyarrow.compute.Expression, optional - Filters to apply to the data (default is None). - load_format : str, optional - The format of the returned data: 'table' or 'batches' (default is 'table'). - batch_size : int, optional - The batch size to use for loading data in batches. If None, data is loaded as a whole (default is None). - include_cols : bool, optional - If True, includes only the specified columns. If False, excludes the specified columns (default is True). - rebuild_nested_struct : bool, optional - If True, rebuilds the nested structure (default is False). - rebuild_nested_from_scratch : bool, optional - If True, rebuilds the nested structure from scratch (default is False). - load_config : LoadConfig, optional - Configuration for loading data, optimizing performance by managing memory usage. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Returns - ------- - pa.Table, generator, or dataset - The data read from the database. The output can be in table format or as a batch generator. - - Examples - -------- - >>> data = db.read_nodes(ids=[1, 2, 3], columns=['name', 'age'], filters=[pc.field('age') > 18]) - """ - id_msg = f"for IDs {ids[:5]}..." if ids else "for all nodes" - col_msg = f" columns: {columns}" if columns else "" - logger.debug(f"Reading nodes {id_msg}{col_msg}") - - read_kwargs = dict( - ids=ids, - columns=columns, - filters=filters, - load_format=load_format, - batch_size=batch_size, - include_cols=include_cols, - rebuild_nested_struct=rebuild_nested_struct, - rebuild_nested_from_scratch=rebuild_nested_from_scratch, - load_config=load_config, - normalize_config=normalize_config, - ) - try: - result = self.read(**read_kwargs) - logger.debug( - f"Successfully read {len(result) if hasattr(result, '__len__') else 'unknown'} records" - ) - return result - except Exception as e: - logger.error(f"Failed to read nodes: {str(e)}") - raise - - def update_nodes( - self, - data: Union[List[dict], dict, pd.DataFrame], - schema: pa.Schema = None, - metadata: dict = None, - fields_metadata: dict = None, - update_keys: Union[str, List[str]] = "id", - treat_fields_as_ragged=None, - convert_to_fixed_shape: bool = True, - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Updates existing records in the database. - - Parameters - ---------- - data : dict, list of dicts, or pandas.DataFrame - The data to be updated in the database. Each record must contain an 'id' key - corresponding to the record to be updated. - schema : pyarrow.Schema, optional - The schema for the data being added. If not provided, it will be inferred. - metadata : dict, optional - Additional metadata to store alongside the data. - fields_metadata : dict, optional - A dictionary containing the metadata to be set for the fields. - update_keys : str or list of str, optional - The keys to use for updating the data. If a list, the data must contain a row for each key. - treat_fields_as_ragged : list of str, optional - A list of fields to treat as ragged arrays. - convert_to_fixed_shape : bool, optional - If True, the ragged arrays will be converted to fixed shape arrays. - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Examples - -------- - >>> db.update_nodes(data=[{'id': 1, 'name': 'John', 'age': 30}, {'id': 2, 'name': 'Jane', 'age': 25}]) - """ - - num_records = len(data) if isinstance(data, (list, pd.DataFrame)) else 1 - logger.info(f"Updating {num_records} node records") - - update_kwargs = dict( - data=data, - update_keys=update_keys, - schema=schema, - metadata=metadata, - fields_metadata=fields_metadata, - normalize_config=normalize_config, - treat_fields_as_ragged=treat_fields_as_ragged, - convert_to_fixed_shape=convert_to_fixed_shape, - ) - try: - self.update(**update_kwargs) - logger.debug("Node update successful") - except Exception as e: - logger.error(f"Failed to update nodes: {str(e)}") - raise - - def delete_nodes( - self, - ids: List[int] = None, - columns: List[str] = None, - normalize_config: NormalizeConfig = NormalizeConfig(), - ): - """ - Deletes records from the database. - - Parameters - ---------- - ids : list of int - A list of record IDs to delete from the database. - columns : list of str, optional - A list of column names to delete from the dataset. If not provided, it will be inferred from the existing data (default: None). - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Returns - ------- - None - - Examples - -------- - >>> db.delete(ids=[1, 2, 3]) - """ - if ids: - logger.info(f"Deleting {len(ids)} nodes") - if columns: - logger.info(f"Deleting columns: {columns}") - try: - self.delete(ids=ids, columns=columns) - logger.debug("Node deletion successful") - except Exception as e: - logger.error(f"Failed to delete nodes: {str(e)}") - raise - - def normalize_nodes(self, normalize_config: NormalizeConfig = NormalizeConfig()): - """ - Normalize the dataset by restructuring files for consistent row distribution. - - This method optimizes performance by ensuring that files in the dataset directory have a consistent number of rows. - It first creates temporary files from the current dataset and rewrites them, ensuring that no file has significantly - fewer rows than others, which can degrade performance. This is particularly useful after a large data ingestion, - as it enhances the efficiency of create, read, update, and delete operations. - - Parameters - ---------- - normalize_config : NormalizeConfig, optional - Configuration for the normalization process, optimizing performance by managing row distribution and file structure. - - Returns - ------- - None - This function does not return anything but modifies the dataset directory in place. - - Examples - -------- - from parquetdb.core.parquetdb import NormalizeConfig - normalize_config=NormalizeConfig(load_format='batches', - max_rows_per_file=5000, - min_rows_per_group=500, - max_rows_per_group=5000, - existing_data_behavior='overwrite_or_ignore', - max_partitions=512) - >>> db.normalize_nodes(normalize_config=normalize_config) - """ - logger.info("Starting node store normalization") - try: - self.normalize(normalize_config=normalize_config) - logger.debug("Node store normalization completed") - except Exception as e: - logger.error(f"Failed to normalize node store: {str(e)}") - raise diff --git a/matgraphdb/materials/nodes/__init__.py b/matgraphdb/core/nodes/__init__.py similarity index 63% rename from matgraphdb/materials/nodes/__init__.py rename to matgraphdb/core/nodes/__init__.py index 5dafdb1..d80f699 100644 --- a/matgraphdb/materials/nodes/__init__.py +++ b/matgraphdb/core/nodes/__init__.py @@ -1,4 +1,4 @@ -from matgraphdb.materials.nodes.generators import ( +from matgraphdb.core.nodes.generators import ( chemenv, crystal_system, element, @@ -7,7 +7,7 @@ space_group, wyckoff, ) -from matgraphdb.materials.nodes.materials import ( +from matgraphdb.core.nodes.materials import ( MaterialStore, material_lattice, material_site, diff --git a/matgraphdb/materials/nodes/generators.py b/matgraphdb/core/nodes/generators.py similarity index 98% rename from matgraphdb/materials/nodes/generators.py rename to matgraphdb/core/nodes/generators.py index 29ac038..7f963c0 100644 --- a/matgraphdb/materials/nodes/generators.py +++ b/matgraphdb/core/nodes/generators.py @@ -7,11 +7,10 @@ import pandas as pd import pyarrow as pa import pyarrow.compute as pc -from parquetdb import ParquetDB +from parquetdb import ParquetDB, node_generator from parquetdb.utils import pyarrow_utils -from matgraphdb.core.nodes import node_generator -from matgraphdb.materials.nodes import * +from matgraphdb.core.nodes import * from matgraphdb.utils.config import PKG_DIR logger = logging.getLogger(__name__) diff --git a/matgraphdb/materials/nodes/materials.py b/matgraphdb/core/nodes/materials.py similarity index 97% rename from matgraphdb/materials/nodes/materials.py rename to matgraphdb/core/nodes/materials.py index db7e174..12a8581 100644 --- a/matgraphdb/materials/nodes/materials.py +++ b/matgraphdb/core/nodes/materials.py @@ -12,12 +12,11 @@ import pyarrow.dataset as ds import pyarrow.parquet as pq import spglib -from parquetdb import ParquetDB +from parquetdb import NodeStore, node_generator from parquetdb.core.parquetdb import LoadConfig, NormalizeConfig from pymatgen.core import Composition, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from matgraphdb.core.nodes import NodeStore, node_generator from matgraphdb.utils.general_utils import set_verbosity from matgraphdb.utils.mp_utils import multiprocess_task @@ -44,7 +43,6 @@ def create_material( convert_to_fixed_shape: bool = True, normalize_dataset: bool = False, normalize_config: NormalizeConfig = NormalizeConfig(), - verbose: int = 3, save_db: bool = True, **kwargs, ): @@ -83,8 +81,6 @@ def create_material( If True, normalizes the dataset. normalize_config : NormalizeConfig, optional The normalize configuration to be applied to the data. This is the NormalizeConfig object from Parquet - verbose : int, optional - The verbosity level for logging (default is 3). save_db : bool, optional If True, saves the material to the database. **kwargs @@ -95,7 +91,6 @@ def create_material( dict A dictionary containing the material's data, including calculated properties and additional information. """ - set_verbosity(verbose) # Generating entry data entry_data = {} @@ -237,7 +232,6 @@ def create_materials( convert_to_fixed_shape: bool = True, normalize_dataset: bool = False, normalize_config: NormalizeConfig = NormalizeConfig(), - verbose: int = 3, **kwargs, ): """ @@ -266,8 +260,6 @@ def create_materials( If True, normalizes the dataset. normalize_config : NormalizeConfig, optional The normalize configuration to be applied to the data. This is the NormalizeConfig object from Parquet - verbose : int, optional - The verbosity level for logging (default is 3). **kwargs Additional keyword arguments passed to the ParquetDB `create` method. @@ -276,7 +268,6 @@ def create_materials( None """ - set_verbosity(verbose) logger.info(f"Adding {len(materials)} materials to the database.") add_kwargs = dict( @@ -285,7 +276,6 @@ def create_materials( fields_metadata=fields_metadata, normalize_dataset=normalize_dataset, normalize_config=normalize_config, - verbose=verbose, treat_fields_as_ragged=treat_fields_as_ragged, convert_to_fixed_shape=convert_to_fixed_shape, ) @@ -600,7 +590,6 @@ def material_site(material_store: NodeStore): ) # table=material_nodes.read(columns=['structure.sites', *id_names])#, *lattice_names]) material_sites = table["structure.sites"].combine_chunks() - flatten_material_sites = pc.list_flatten(material_sites) material_sites_length_list = pc.list_value_length(material_sites).to_numpy() @@ -612,7 +601,11 @@ def material_site(material_store: NodeStore): table = None arrays = flatten_material_sites.flatten() - names = flatten_material_sites.type.names + + if hasattr(flatten_material_sites, "names"): + names = flatten_material_sites.type.names + else: + names= [val.name for val in flatten_material_sites.type] flatten_material_sites = None material_sites_length_list = None @@ -635,5 +628,5 @@ def material_site(material_store: NodeStore): except Exception as e: logger.error(f"Error creating site nodes: {e}") - return None + raise e return table diff --git a/matgraphdb/core/utils.py b/matgraphdb/core/utils.py deleted file mode 100644 index acda5dc..0000000 --- a/matgraphdb/core/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import pandas as pd -import pyarrow as pa - -ALLOWED_DATAFRAME_TYPES = [pd.DataFrame, pa.Table, pa.RecordBatch] - - -def get_dataframe_column_names(df): - if isinstance(df, pd.DataFrame): - column_names = df.columns - elif isinstance(df, pa.Table): - column_names = df.schema.names - elif isinstance(df, pa.RecordBatch): - column_names = df.schema.names - else: - raise ValueError( - f"Invalid data type for dataframe validation. Must be one of: {ALLOWED_DATAFRAME_TYPES}" - ) - return column_names diff --git a/matgraphdb/materials/__init__.py b/matgraphdb/materials/__init__.py deleted file mode 100644 index c5cd27b..0000000 --- a/matgraphdb/materials/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from matgraphdb.materials.core import MatGraphDB -from matgraphdb.materials.nodes import MaterialStore diff --git a/matgraphdb/materials/datasets/__init__.py b/matgraphdb/materials/datasets/__init__.py deleted file mode 100644 index 71fc383..0000000 --- a/matgraphdb/materials/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull diff --git a/matgraphdb/materials/edges.py b/matgraphdb/materials/edges.py deleted file mode 100644 index 43cc1e5..0000000 --- a/matgraphdb/materials/edges.py +++ /dev/null @@ -1,719 +0,0 @@ -import logging -import shutil - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -from parquetdb import ParquetDB -from parquetdb.utils import pyarrow_utils - -from matgraphdb.core.edges import EdgeStore, edge_generator -from matgraphdb.core.nodes import NodeStore - -# from matgraphdb.materials.nodes import * -from matgraphdb.utils.chem_utils.periodic import get_group_period_edge_index - -logger = logging.getLogger(__name__) - - -@edge_generator -def element_element_neighborsByGroupPeriod(element_store): - - try: - connection_name = "neighborsByGroupPeriod" - table = element_store.read_nodes( - columns=["atomic_number", "extended_group", "period", "symbol"] - ) - element_df = table.to_pandas(split_blocks=True, self_destruct=True) - - # Getting group-period edge index - edge_index = get_group_period_edge_index(element_df) - - # Creating the relationships dataframe - df = pd.DataFrame(edge_index, columns=[f"source_id", f"target_id"]) - - # Dropping rows with NaN values and casting to int64 - df = df.dropna().astype(np.int64) - - # Add source and target type columns - df["source_type"] = element_store.node_type - df["target_type"] = element_store.node_type - df["edge_type"] = connection_name - df["weight"] = 1.0 - - table = ParquetDB.construct_table(df) - - reduced_table = element_store.read( - columns=["symbol", "id", "extended_group", "period"] - ) - reduced_source_table = reduced_table.rename_columns( - { - "symbol": "source_name", - "extended_group": "source_extended_group", - "period": "source_period", - } - ) - reduced_target_table = reduced_table.rename_columns( - { - "symbol": "target_name", - "extended_group": "target_extended_group", - "period": "target_period", - } - ) - - table = pyarrow_utils.join_tables( - table, - reduced_source_table, - left_keys=["source_id"], - right_keys=["id"], - join_type="left outer", - ) - - table = pyarrow_utils.join_tables( - table, - reduced_target_table, - left_keys=["target_id"], - right_keys=["id"], - join_type="left outer", - ) - - names = pc.binary_join_element_wise( - pc.cast(table["source_name"], pa.string()), - pc.cast(table["target_name"], pa.string()), - f"_{connection_name}_", - ) - - table = table.append_column("name", names) - - logger.debug( - f"Created element-group-period relationships. Shape: {table.shape}" - ) - except Exception as e: - logger.exception(f"Error creating element-group-period relationships: {e}") - raise e - - return table - - -@edge_generator -def element_element_bonds(element_store, material_store): - try: - connection_name = "canBondTo" - material_table = material_store.read_nodes( - columns=[ - "id", - "core.material_id", - "core.species", - "chemenv.coordination_environments_multi_weight", - "bonding.geometric_consistent.bond_connections", - ] - ) - - element_table = element_store.read_nodes(columns=["id", "symbol"]) - - element_table = element_table.rename_columns({"symbol": "name"}) - element_table = element_table.append_column( - "source_type", pa.array([element_store.node_type] * element_table.num_rows) - ) - - material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) - element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) - - element_target_id_map = { - row["name"]: row["id"] for _, row in element_df.iterrows() - } - - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - } - - for _, row in material_df.iterrows(): - bond_connections = row["bonding.geometric_consistent.bond_connections"] - - if bond_connections is None: - continue - - elements = row["core.species"] - element_graph = {} - for i, site_connections in enumerate(bond_connections): - site_element_name = elements[i] - for i_neighbor_element in site_connections: - i_neighbor_element = int(i_neighbor_element) - neighbor_element_name = elements[i_neighbor_element] - - source_id = element_target_id_map[site_element_name] - target_id = element_target_id_map[neighbor_element_name] - - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(element_store.node_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(element_store.node_type) - table_dict["edge_type"].append(connection_name) - - name = ( - f"{site_element_name}_{connection_name}_{neighbor_element_name}" - ) - table_dict["name"].append(name) - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created element-chemenv-canOccur relationships. Shape: {edge_table.shape}" - ) - - except Exception as e: - logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def element_oxiState_canOccur(element_store, oxiState_store): - try: - connection_name = "canOccur" - - element_table = element_store.read_nodes( - columns=["id", "experimental_oxidation_states", "symbol"] - ) - oxiState_table = oxiState_store.read_nodes( - columns=["id", "oxidation_state", "value"] - ) - - # element_table=element_table.rename_columns({'id':'source_id'}) - element_table = element_table.append_column( - "source_type", pa.array([element_store.node_type] * element_table.num_rows) - ) - - # oxiState_table=oxiState_table.rename_columns({'id':'target_id'}) - oxiState_table = oxiState_table.append_column( - "target_type", - pa.array([oxiState_store.node_type] * oxiState_table.num_rows), - ) - - element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) - oxiState_df = oxiState_table.to_pandas(split_blocks=True, self_destruct=True) - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - "weight": [], - } - - oxiState_id_map = {} - id_oxidationState_map = {} - for i, oxiState_row in oxiState_df.iterrows(): - oxiState_id_map[oxiState_row["value"]] = oxiState_row["id"] - id_oxidationState_map[oxiState_row["id"]] = oxiState_row["oxidation_state"] - - for i, element_row in element_df.iterrows(): - exp_oxidation_states = element_row["experimental_oxidation_states"] - source_id = element_row["id"] - source_type = element_store.node_type - symbol = element_row["symbol"] - for exp_oxidation_state in exp_oxidation_states: - target_id = oxiState_id_map[exp_oxidation_state] - target_type = oxiState_store.node_type - oxi_state_name = id_oxidationState_map[target_id] - - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(source_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(target_type) - table_dict["edge_type"].append(connection_name) - table_dict["weight"].append(1.0) - table_dict["name"].append( - f"{symbol}_{connection_name}_{oxi_state_name}" - ) - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created element-oxiState-canOccur relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception(f"Error creating element-oxiState-canOccur relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def material_chemenv_containsSite(material_store, chemenv_store): - try: - connection_name = "containsSite" - - material_table = material_store.read_nodes( - columns=[ - "id", - "core.material_id", - "chemenv.coordination_environments_multi_weight", - ] - ) - chemenv_table = chemenv_store.read_nodes(columns=["id", "mp_symbol"]) - - material_table = material_table.rename_columns( - {"id": "source_id", "core.material_id": "material_name"} - ) - material_table = material_table.append_column( - "source_type", - pa.array([material_store.node_type] * material_table.num_rows), - ) - - chemenv_table = chemenv_table.rename_columns( - {"id": "target_id", "mp_symbol": "chemenv_name"} - ) - chemenv_table = chemenv_table.append_column( - "target_type", pa.array([chemenv_store.node_type] * chemenv_table.num_rows) - ) - - material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) - chemenv_df = chemenv_table.to_pandas(split_blocks=True, self_destruct=True) - chemenv_target_id_map = { - row["chemenv_name"]: row["target_id"] for _, row in chemenv_df.iterrows() - } - - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - "weight": [], - } - - for _, row in material_df.iterrows(): - coord_envs = row["chemenv.coordination_environments_multi_weight"] - if coord_envs is None: - continue - - source_id = row["source_id"] - material_name = row["material_name"] - - for coord_env in coord_envs: - try: - chemenv_name = coord_env[0]["ce_symbol"] - target_id = chemenv_target_id_map[chemenv_name] - except: - continue - - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(material_store.node_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(chemenv_store.node_type) - table_dict["edge_type"].append(connection_name) - - name = f"{material_name}_{connection_name}_{chemenv_name}" - table_dict["name"].append(name) - table_dict["weight"].append(1.0) - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created material-chemenv-containsSite relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception( - f"Error creating material-chemenv-containsSite relationships: {e}" - ) - raise e - - return edge_table - - -@edge_generator -def material_crystalSystem_has(material_store, crystal_system_store): - try: - connection_name = "has" - - material_table = material_store.read_nodes( - columns=["id", "core.material_id", "symmetry.crystal_system"] - ) - crystal_system_table = crystal_system_store.read_nodes( - columns=["id", "crystal_system"] - ) - - material_table = material_table.rename_columns( - {"id": "source_id", "symmetry.crystal_system": "crystal_system"} - ) - material_table = material_table.append_column( - "source_type", - pa.array([material_store.node_type] * material_table.num_rows), - ) - - crystal_system_table = crystal_system_table.rename_columns({"id": "target_id"}) - crystal_system_table = crystal_system_table.append_column( - "target_type", - pa.array([crystal_system_store.node_type] * crystal_system_table.num_rows), - ) - - edge_table = pyarrow_utils.join_tables( - material_table, - crystal_system_table, - left_keys=["crystal_system"], - right_keys=["crystal_system"], - join_type="left outer", - ) - edge_table = edge_table.append_column( - "edge_type", pa.array([connection_name] * edge_table.num_rows) - ) - edge_table = edge_table.append_column( - "weight", pa.array([1.0] * edge_table.num_rows) - ) - - names = pc.binary_join_element_wise( - pc.cast(edge_table["core.material_id"], pa.string()), - pc.cast(edge_table["crystal_system"], pa.string()), - f"_{connection_name}_", - ) - - edge_table = edge_table.append_column("name", names) - - logger.debug( - f"Created material-crystalSystem-has relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception( - f"Error creating material-crystalSystem-has relationships: {e}" - ) - raise e - - return edge_table - - -@edge_generator -def material_element_has(material_store, element_store): - try: - connection_name = "has" - - material_table = material_store.read_nodes( - columns=["id", "core.material_id", "core.elements"] - ) - element_table = element_store.read_nodes(columns=["id", "symbol"]) - - material_table = material_table.rename_columns( - {"id": "source_id", "core.material_id": "material_name"} - ) - material_table = material_table.append_column( - "source_type", pa.array(["material"] * material_table.num_rows) - ) - - element_table = element_table.rename_columns({"id": "target_id"}) - element_table = element_table.append_column( - "target_type", pa.array(["elements"] * element_table.num_rows) - ) - - material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) - element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) - element_target_id_map = { - row["symbol"]: row["target_id"] for _, row in element_df.iterrows() - } - - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - "weight": [], - } - - for _, row in material_df.iterrows(): - elements = row["core.elements"] - source_id = row["source_id"] - material_name = row["material_name"] - if elements is None: - continue - - # Append the material name for each element in the species list - for element in elements: - - target_id = element_target_id_map[element] - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(material_store.node_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(element_store.node_type) - table_dict["edge_type"].append(connection_name) - - name = f"{material_name}_{connection_name}_{element}" - table_dict["name"].append(name) - table_dict["weight"].append(1.0) - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created material-element-has relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception(f"Error creating material-element-has relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def material_lattice_has(material_store, lattice_store): - try: - connection_name = "has" - - material_table = material_store.read_nodes(columns=["id", "core.material_id"]) - lattice_table = lattice_store.read_nodes(columns=["material_node_id"]) - - material_table = material_table.rename_columns( - {"id": "source_id", "core.material_id": "material_id"} - ) - material_table = material_table.append_column( - "source_type", - pa.array([material_store.node_type] * material_table.num_rows), - ) - - lattice_table = lattice_table.append_column( - "target_id", lattice_table["material_node_id"].combine_chunks() - ) - lattice_table = lattice_table.append_column( - "target_type", pa.array([lattice_store.node_type] * lattice_table.num_rows) - ) - - edge_table = pyarrow_utils.join_tables( - material_table, - lattice_table, - left_keys=["source_id"], - right_keys=["material_node_id"], - join_type="left outer", - ) - edge_table = edge_table.append_column( - "edge_type", pa.array([connection_name] * edge_table.num_rows) - ) - edge_table = edge_table.append_column( - "weight", pa.array([1.0] * edge_table.num_rows) - ) - - logger.debug( - f"Created material-lattice-has relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception(f"Error creating material-lattice-has relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def material_spg_has(material_store, spg_store): - try: - connection_name = "has" - - material_table = material_store.read_nodes( - columns=["id", "core.material_id", "symmetry.number"] - ) - spg_table = spg_store.read_nodes(columns=["id", "spg"]) - - material_table = material_table.rename_columns( - {"id": "source_id", "symmetry.number": "spg"} - ) - material_table = material_table.append_column( - "source_type", - pa.array([material_store.node_type] * material_table.num_rows), - ) - - spg_table = spg_table.rename_columns({"id": "target_id"}) - spg_table = spg_table.append_column( - "target_type", pa.array([spg_store.node_type] * spg_table.num_rows) - ) - - edge_table = pyarrow_utils.join_tables( - material_table, - spg_table, - left_keys=["spg"], - right_keys=["spg"], - join_type="left outer", - ) - - edge_table = edge_table.append_column( - "edge_type", pa.array([connection_name] * edge_table.num_rows) - ) - - edge_table = edge_table.append_column( - "weight", pa.array([1.0] * edge_table.num_rows) - ) - - names = pc.binary_join_element_wise( - pc.cast(edge_table["core.material_id"], pa.string()), - pc.cast(edge_table["spg"], pa.string()), - f"_{connection_name}_SpaceGroup", - ) - - edge_table = edge_table.append_column("name", names) - - logger.debug( - f"Created material-spg-has relationships. Shape: {edge_table.shape}" - ) - except Exception as e: - logger.exception(f"Error creating material-spg-has relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def element_chemenv_canOccur(element_store, chemenv_store, material_store): - try: - connection_name = "canOccur" - material_table = material_store.read_nodes( - columns=[ - "id", - "core.material_id", - "core.elements", - "chemenv.coordination_environments_multi_weight", - ] - ) - - chemenv_table = chemenv_store.read_nodes(columns=["id", "mp_symbol"]) - element_table = element_store.read_nodes(columns=["id", "symbol"]) - - chemenv_table = chemenv_table.rename_columns({"mp_symbol": "name"}) - chemenv_table = chemenv_table.append_column( - "target_type", pa.array([chemenv_store.node_type] * chemenv_table.num_rows) - ) - - element_table = element_table.rename_columns({"symbol": "name"}) - element_table = element_table.append_column( - "source_type", pa.array([element_store.node_type] * element_table.num_rows) - ) - - material_df = material_table.to_pandas(split_blocks=True, self_destruct=True) - chemenv_df = chemenv_table.to_pandas(split_blocks=True, self_destruct=True) - element_df = element_table.to_pandas(split_blocks=True, self_destruct=True) - - chemenv_target_id_map = { - row["name"]: row["id"] for _, row in chemenv_df.iterrows() - } - element_target_id_map = { - row["name"]: row["id"] for _, row in element_df.iterrows() - } - - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - } - - for _, row in material_df.iterrows(): - coord_envs = row["chemenv.coordination_environments_multi_weight"] - - if coord_envs is None: - continue - - elements = row["core.elements"] - - for i, coord_env in enumerate(coord_envs): - try: - chemenv_name = coord_env[0]["ce_symbol"] - element_name = elements[i] - - source_id = element_target_id_map[element_name] - target_id = chemenv_target_id_map[chemenv_name] - except: - continue - - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(element_store.node_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(chemenv_store.node_type) - table_dict["edge_type"].append(connection_name) - - name = f"{element_name}_{connection_name}_{chemenv_name}" - table_dict["name"].append(name) - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created element-chemenv-canOccur relationships. Shape: {edge_table.shape}" - ) - - except Exception as e: - logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") - raise e - - return edge_table - - -@edge_generator -def spg_crystalSystem_isApart(spg_store, crystal_system_store): - try: - connection_name = "isApart" - - except Exception as e: - logger.exception(f"Error creating spg-crystalSystem-isApart relationships: {e}") - raise e - - spg_table = spg_store.read_nodes(columns=["id", "spg"]) - crystal_system_table = crystal_system_store.read_nodes( - columns=["id", "crystal_system"] - ) - - spg_df = spg_table.to_pandas(split_blocks=True, self_destruct=True) - crystal_system_df = crystal_system_table.to_pandas( - split_blocks=True, self_destruct=True - ) - - spg_target_id_map = {row["spg"]: row["id"] for _, row in spg_df.iterrows()} - crystal_system_target_id_map = { - row["crystal_system"]: row["id"] for _, row in crystal_system_df.iterrows() - } - - crys_spg_map = { - "Triclinic": np.arange(1, 3), - "Monoclinic": np.arange(3, 16), - "Orthorhombic": np.arange(16, 75), - "Tetragonal": np.arange(75, 143), - "Trigonal": np.arange(143, 168), - "Hexagonal": np.arange(168, 195), - "Cubic": np.arange(195, 231), - } - table_dict = { - "source_id": [], - "source_type": [], - "target_id": [], - "target_type": [], - "edge_type": [], - "name": [], - } - try: - for crystal_system, spg_range in crys_spg_map.items(): - for spg in spg_range: - source_id = spg_target_id_map[spg] - target_id = crystal_system_target_id_map[crystal_system] - - table_dict["source_id"].append(source_id) - table_dict["source_type"].append(spg_store.node_type) - table_dict["target_id"].append(target_id) - table_dict["target_type"].append(crystal_system_store.node_type) - table_dict["edge_type"].append(connection_name) - table_dict["name"].append(f"{crystal_system}_{connection_name}_{spg}") - - edge_table = ParquetDB.construct_table(table_dict) - - logger.debug( - f"Created spg-crystalSystem-isApart relationships. Shape: {edge_table.shape}" - ) - - except Exception as e: - logger.exception(f"Error creating element-chemenv-canOccur relationships: {e}") - raise e - - return edge_table diff --git a/matgraphdb/pyg/builders/__init__.py b/matgraphdb/pyg/builders/__init__.py new file mode 100644 index 0000000..2ea0a5b --- /dev/null +++ b/matgraphdb/pyg/builders/__init__.py @@ -0,0 +1,2 @@ +from matgraphdb.pyg.builders.crystal_graph import CrystalGraphBuilder +from matgraphdb.pyg.builders.hetero_graph import HeteroGraphBuilder diff --git a/matgraphdb/pyg/data/crystal_graph.py b/matgraphdb/pyg/builders/crystal_graph.py similarity index 99% rename from matgraphdb/pyg/data/crystal_graph.py rename to matgraphdb/pyg/builders/crystal_graph.py index bcc50f9..bb13420 100644 --- a/matgraphdb/pyg/data/crystal_graph.py +++ b/matgraphdb/pyg/builders/crystal_graph.py @@ -4,7 +4,7 @@ import torch from torch_geometric.data import Data -from matgraphdb.materials import MatGraphDB +from matgraphdb import MatGraphDB logger = logging.getLogger(__name__) diff --git a/matgraphdb/pyg/data/hetero_graph.py b/matgraphdb/pyg/builders/hetero_graph.py similarity index 96% rename from matgraphdb/pyg/data/hetero_graph.py rename to matgraphdb/pyg/builders/hetero_graph.py index f776a03..bfd42b5 100644 --- a/matgraphdb/pyg/data/hetero_graph.py +++ b/matgraphdb/pyg/builders/hetero_graph.py @@ -5,11 +5,10 @@ import pyarrow as pa import pyarrow.compute as pc import torch +from parquetdb import ParquetGraphDB from parquetdb.utils import pyarrow_utils from torch_geometric.data import HeteroData -from matgraphdb.core.graph_db import GraphDB - logger = logging.getLogger(__name__) @@ -21,7 +20,7 @@ class HeteroGraphBuilder: into PyTorch Geometric HeteroData objects for machine learning. """ - def __init__(self, graph_db: GraphDB): + def __init__(self, graph_db: ParquetGraphDB): """ Initialize the graph builder. @@ -78,7 +77,7 @@ def add_node_type( node_type: str, columns: Optional[List[str]] = None, filters: Optional[Dict] = None, - embedding_vectors:bool=False, + embedding_vectors: bool = False, label_column: Optional[str] = None, drop_null: bool = True, encoders: Optional[Dict] = None, @@ -101,9 +100,13 @@ def add_node_type( logger.info(f"Adding {node_type} nodes to graph") ids, torch_tensor, feature_names, labels = self._process_node_type( - node_type=node_type, columns=columns, filters=filters, - encoders=encoders, label_column=label_column, - read_kwargs=read_kwargs, drop_null=drop_null + node_type=node_type, + columns=columns, + filters=filters, + encoders=encoders, + label_column=label_column, + read_kwargs=read_kwargs, + drop_null=drop_null, ) logger.info(f"ids: {ids.shape}") @@ -111,7 +114,7 @@ def add_node_type( self.hetero_data[node_type].node_ids = torch.tensor(ids, dtype=torch.int64) if labels is not None: self.hetero_data[node_type].labels = labels - + if torch_tensor is not None: logger.info(f"torch_tensor: {torch_tensor.shape}") logger.info(f"feature_names: {feature_names}") @@ -119,7 +122,7 @@ def add_node_type( self.hetero_data[node_type].x = torch_tensor logger.info(f"x: {self.hetero_data[node_type].x.shape}") self.hetero_data[node_type].feature_names = feature_names - + if embedding_vectors: num_nodes = len(self.hetero_data[node_type].node_ids) self.hetero_data[node_type].x = torch.eye(num_nodes) @@ -140,12 +143,12 @@ def add_target_node_property( raise ValueError(f"Node type {node_type} has not been added to the graph") ids, torch_tensor, feature_names, labels = self._process_node_type( - node_type=node_type, - columns=columns, - filters=filters, + node_type=node_type, + columns=columns, + filters=filters, encoders=encoders, label_column=label_column, - read_kwargs=read_kwargs, + read_kwargs=read_kwargs, drop_null=drop_null, ) @@ -154,7 +157,7 @@ def add_target_node_property( # logger.info(f"feature_names: {feature_names}") target_feature_ids = torch.tensor(ids, dtype=torch.int64) - + all_feature_ids = self.hetero_data[node_type].node_ids.clone().detach() # all_feature_ids = torch.tensor( # self.hetero_data[node_type].node_id, dtype=torch.int64 @@ -225,7 +228,7 @@ def _process_node_type( labels = table[label_column].combine_chunks().to_pylist() else: labels = None - + torch_tensor = None feature_names = None if columns: @@ -487,8 +490,8 @@ def save(self, path: str): torch.save(self.hetero_data, path) @classmethod - def load(cls, graph_db: GraphDB, path: str): + def load(cls, graph_db: ParquetGraphDB, path: str): """Load a saved graph.""" builder = cls(graph_db) - builder.hetero_data = torch.load(path) + builder.hetero_data = torch.load(path, weights_only=False) return builder diff --git a/matgraphdb/pyg/data/__init__.py b/matgraphdb/pyg/data/__init__.py deleted file mode 100644 index 8fd75bb..0000000 --- a/matgraphdb/pyg/data/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from matgraphdb.pyg.data.crystal_graph import CrystalGraphBuilder -from matgraphdb.pyg.data.hetero_graph import HeteroGraphBuilder diff --git a/matgraphdb/pyg/models/cg_bond_order/train.py b/matgraphdb/pyg/models/cg_bond_order/train.py index ce1a637..c96f3c3 100644 --- a/matgraphdb/pyg/models/cg_bond_order/train.py +++ b/matgraphdb/pyg/models/cg_bond_order/train.py @@ -14,8 +14,8 @@ from torch_geometric.loader import DataLoader from torch_geometric.nn import CGConv, SAGEConv, global_mean_pool -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data.crystal_graph import CrystalGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders.crystal_graph import CrystalGraphBuilder from matgraphdb.pyg.models.cg_bond_order.model import BondOrderPredictor os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" diff --git a/matgraphdb/pyg/models/cg_target/run.py b/matgraphdb/pyg/models/cg_target/run.py index cd9c2e1..cc29e38 100644 --- a/matgraphdb/pyg/models/cg_target/run.py +++ b/matgraphdb/pyg/models/cg_target/run.py @@ -5,10 +5,10 @@ from torch_geometric.nn import CGConv, global_mean_pool from matgraphdb import config -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import CrystalGraphBuilder from matgraphdb.pyg.core import BaseTrainer from matgraphdb.pyg.core.experiment import run_experiment -from matgraphdb.pyg.data import CrystalGraphBuilder from matgraphdb.pyg.models.cg_target.model import CGConvModel print(torch.__version__) diff --git a/matgraphdb/pyg/models/cg_target/train.py b/matgraphdb/pyg/models/cg_target/train.py index 9dbc699..065c910 100644 --- a/matgraphdb/pyg/models/cg_target/train.py +++ b/matgraphdb/pyg/models/cg_target/train.py @@ -14,8 +14,8 @@ from torch_geometric.loader import DataLoader from torch_geometric.nn import CGConv, SAGEConv, global_mean_pool -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data.crystal_graph import CrystalGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders.crystal_graph import CrystalGraphBuilder from matgraphdb.pyg.models.cg_target.model import CGConvModel os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" diff --git a/matgraphdb/pyg/models/cgae/train.py b/matgraphdb/pyg/models/cgae/train.py index 6835007..a8719b3 100644 --- a/matgraphdb/pyg/models/cgae/train.py +++ b/matgraphdb/pyg/models/cgae/train.py @@ -21,8 +21,8 @@ unbatch_edge_index, ) -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data.crystal_graph import CrystalGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders.crystal_graph import CrystalGraphBuilder from .model import CGAE diff --git a/matgraphdb/pyg/models/grami/train.py b/matgraphdb/pyg/models/grami/train.py index 84ecba1..377b765 100644 --- a/matgraphdb/pyg/models/grami/train.py +++ b/matgraphdb/pyg/models/grami/train.py @@ -14,8 +14,8 @@ from omegaconf import OmegaConf from torch_geometric import nn as pyg_nn -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.grami.model import GraMI from matgraphdb.pyg.models.grami.trainer import ( Trainer, @@ -45,13 +45,15 @@ }, }, "model": { - "hidden_channels": 128, + "hidden_channels": 128, "out_channels": 1, "num_heads": 8, "num_layers": 3, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder" + ), "learning_rate": 0.001, "num_epochs": 40001, "eval_interval": 2000, @@ -64,7 +66,7 @@ "mlflow_experiment_name": "heterograph_encoder", "mlflow_tracking_uri": "${training.training_dir}/mlflow", "mlflow_record_system_metrics": True, - } + }, } ) @@ -233,7 +235,6 @@ def to_log(x): test_val_data = original_test_val_data - print(train_data) print(train_val_data) print(test_data) @@ -260,11 +261,13 @@ def to_log(x): #################################################################################################### # Model #################################################################################################### -model = HGT(hidden_channels=CONFIG.model.hidden_channels, - out_channels=CONFIG.model.out_channels, - num_heads=CONFIG.model.num_heads, - num_layers=CONFIG.model.num_layers, - data=train_data).to(device) +model = HGT( + hidden_channels=CONFIG.model.hidden_channels, + out_channels=CONFIG.model.out_channels, + num_heads=CONFIG.model.num_heads, + num_layers=CONFIG.model.num_layers, + data=train_data, +).to(device) print(model) @@ -283,8 +286,8 @@ def train(): model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) - mask = data['author'].train_mask - loss = F.cross_entropy(out[mask], data['author'].y[mask]) + mask = data["author"].train_mask + loss = F.cross_entropy(out[mask], data["author"].y[mask]) loss.backward() optimizer.step() return float(loss) @@ -296,9 +299,9 @@ def test(): pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) accs = [] - for split in ['train_mask', 'val_mask', 'test_mask']: - mask = data['author'][split] - acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum() + for split in ["train_mask", "val_mask", "test_mask"]: + mask = data["author"][split] + acc = (pred[mask] == data["author"].y[mask]).sum() / mask.sum() accs.append(float(acc)) return accs @@ -306,5 +309,7 @@ def test(): for epoch in range(1, 101): loss = train() train_acc, val_acc, test_acc = test() - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' - f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') \ No newline at end of file + print( + f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, " + f"Val: {val_acc:.4f}, Test: {test_acc:.4f}" + ) diff --git a/matgraphdb/pyg/models/han/train.py b/matgraphdb/pyg/models/han/train.py index 4e64750..37d2bc4 100644 --- a/matgraphdb/pyg/models/han/train.py +++ b/matgraphdb/pyg/models/han/train.py @@ -13,9 +13,10 @@ import torch_geometric.transforms as T from omegaconf import OmegaConf from torch_geometric import nn as pyg_nn +from torch_geometric.nn import MetaPath2Vec -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.han.model import HAN from matgraphdb.pyg.models.han.trainer import ( Trainer, @@ -24,9 +25,6 @@ roc_curve, ) -from torch_geometric.nn import MetaPath2Vec - - ######################################################################################################################## @@ -58,7 +56,9 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder" + ), "learning_rate": 0.001, "num_epochs": 40001, "eval_interval": 2000, @@ -71,7 +71,7 @@ "mlflow_experiment_name": "heterograph_encoder", "mlflow_tracking_uri": "${training.training_dir}/mlflow", "mlflow_record_system_metrics": True, - } + }, } ) @@ -179,15 +179,19 @@ def to_log(x): data = None - metapaths = [ - [('materials', 'has', 'elements'), ('elements', 'rev_has', 'materials')], - [('materials', 'has', 'elements'), ('elements', 'neighborsByGroupPeriod', 'elements'), ('elements', 'rev_has', 'materials')], + [("materials", "has", "elements"), ("elements", "rev_has", "materials")], + [ + ("materials", "has", "elements"), + ("elements", "neighborsByGroupPeriod", "elements"), + ("elements", "rev_has", "materials"), + ], # [('elements', 'neighborsByGroupPeriod', 'elements'), ('elements', 'neighborsByGroupPeriod', 'elements')], # [('elements', 'rev_has', 'materials'), ('materials', 'has', 'elements')], - ] -transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True, - drop_unconnected_node_types=True) +] +transform = T.AddMetaPaths( + metapaths=metapaths, drop_orig_edge_types=True, drop_unconnected_node_types=True +) parent_data = transform(parent_data) @@ -263,7 +267,6 @@ def to_log(x): test_val_data = original_test_val_data - print(train_data) print(train_val_data) print(test_data) @@ -290,7 +293,14 @@ def to_log(x): # #################################################################################################### # # Model # #################################################################################################### -model = HAN(in_channels=-1, out_channels=16, hidden_channels=128, heads=8, out_node_name='materials', data=parent_data).to(device) +model = HAN( + in_channels=-1, + out_channels=16, + hidden_channels=128, + heads=8, + out_node_name="materials", + data=parent_data, +).to(device) print(model) @@ -298,4 +308,3 @@ def to_log(x): # #################################################################################################### # # Training # #################################################################################################### - diff --git a/matgraphdb/pyg/models/hetero_crystal_cl/data.py b/matgraphdb/pyg/models/hetero_crystal_cl/data.py index 8ecc9cf..95db572 100644 --- a/matgraphdb/pyg/models/hetero_crystal_cl/data.py +++ b/matgraphdb/pyg/models/hetero_crystal_cl/data.py @@ -6,8 +6,8 @@ from torch import optim from torch_geometric import nn as pyg_nn -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.utils.config import DATA_DIR device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/matgraphdb/pyg/models/hetero_encoder/train.py b/matgraphdb/pyg/models/hetero_encoder/train.py index 6ec80eb..2c4d572 100644 --- a/matgraphdb/pyg/models/hetero_encoder/train.py +++ b/matgraphdb/pyg/models/hetero_encoder/train.py @@ -14,8 +14,8 @@ from omegaconf import OmegaConf from torch_geometric import nn as pyg_nn -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.hetero_encoder.model import HeteroEncoder from matgraphdb.pyg.models.hetero_encoder.trainer import ( Trainer, @@ -55,7 +55,9 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder" + ), "learning_rate": 0.001, "num_epochs": 40001, "eval_interval": 2000, @@ -68,7 +70,7 @@ "mlflow_experiment_name": "heterograph_encoder", "mlflow_tracking_uri": "${training.training_dir}/mlflow", "mlflow_record_system_metrics": True, - } + }, } ) @@ -246,10 +248,10 @@ def to_log(x): # print(type(random_link_split_args)) # for i,edge_type in enumerate(random_link_split_args['edge_types']): # random_link_split_args['edge_types'][i] = tuple(edge_type) - + # for i,edge_type in enumerate(random_link_split_args['rev_edge_types']): # random_link_split_args['rev_edge_types'][i] = tuple(edge_type) - + # print(random_link_split_args) # # Perform a link-level split into training, validation, and test edges: # train_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_train_data) @@ -267,7 +269,6 @@ def to_log(x): test_val_data = original_test_val_data - print(train_data) print(train_val_data) print(test_data) @@ -360,7 +361,10 @@ def to_log(x): if CONFIG.training.use_scheduler: scheduler = lr_scheduler.MultiStepLR( - optimizer, milestones=CONFIG.training.scheduler_milestones, gamma=0.1, verbose=False + optimizer, + milestones=CONFIG.training.scheduler_milestones, + gamma=0.1, + verbose=False, ) else: scheduler = None diff --git a/matgraphdb/pyg/models/heterograph_encoder/train.py b/matgraphdb/pyg/models/heterograph_encoder/train.py index 0f593fb..456a9be 100644 --- a/matgraphdb/pyg/models/heterograph_encoder/train.py +++ b/matgraphdb/pyg/models/heterograph_encoder/train.py @@ -14,8 +14,8 @@ from omegaconf import OmegaConf from torch_geometric import nn as pyg_nn -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.heterograph_encoder.model import MaterialEdgePredictor from matgraphdb.pyg.models.heterograph_encoder.trainer import ( Trainer, @@ -55,7 +55,9 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder" + ), "learning_rate": 0.001, "num_epochs": 10001, "eval_interval": 1000, @@ -67,7 +69,7 @@ "mlflow_experiment_name": "heterograph_encoder", "mlflow_tracking_uri": "${training.training_dir}/mlflow", "mlflow_record_system_metrics": True, - } + }, } ) @@ -241,14 +243,16 @@ def to_log(x): print(CONFIG.data.random_link_split_args) -random_link_split_args = OmegaConf.to_container(CONFIG.data.random_link_split_args, resolve=True) +random_link_split_args = OmegaConf.to_container( + CONFIG.data.random_link_split_args, resolve=True +) print(type(random_link_split_args)) -for i,edge_type in enumerate(random_link_split_args['edge_types']): - random_link_split_args['edge_types'][i] = tuple(edge_type) - -for i,edge_type in enumerate(random_link_split_args['rev_edge_types']): - random_link_split_args['rev_edge_types'][i] = tuple(edge_type) - +for i, edge_type in enumerate(random_link_split_args["edge_types"]): + random_link_split_args["edge_types"][i] = tuple(edge_type) + +for i, edge_type in enumerate(random_link_split_args["rev_edge_types"]): + random_link_split_args["rev_edge_types"][i] = tuple(edge_type) + print(random_link_split_args) # Perform a link-level split into training, validation, and test edges: train_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_train_data) @@ -261,7 +265,6 @@ def to_log(x): ) - print(train_data) print(train_val_data) print(test_data) @@ -399,12 +402,14 @@ def weighted_binary_cross_entropy(pred, target, weights=None): trainer.train(metrics_to_record=["loss", "accuracy", "precision", "recall"]) +out = model.encode( + test_val_data.x_dict, + test_val_data.edge_index_dict, + node_ids={ + "materials": test_val_data["materials"].node_ids, + "elements": test_val_data["elements"].node_ids, + }, +) -out = model.encode(test_val_data.x_dict, test_val_data.edge_index_dict, - node_ids={'materials':test_val_data['materials'].node_ids, - 'elements': test_val_data['elements'].node_ids}) - - - -print(out) \ No newline at end of file +print(out) diff --git a/matgraphdb/pyg/models/heterograph_encoder_general/train.py b/matgraphdb/pyg/models/heterograph_encoder_general/train.py index fea82f5..432d8d9 100644 --- a/matgraphdb/pyg/models/heterograph_encoder_general/train.py +++ b/matgraphdb/pyg/models/heterograph_encoder_general/train.py @@ -1,10 +1,11 @@ +import copy import json import os import time +from collections import defaultdict import matplotlib.pyplot as plt import numpy as np -import copy import pandas as pd import pyarrow.compute as pc import torch @@ -13,20 +14,7 @@ import torch_geometric as pyg import torch_geometric.transforms as T from omegaconf import OmegaConf -from torch_geometric import nn as pyg_nn - -from collections import defaultdict - -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder -from matgraphdb.pyg.models.heterograph_encoder_general.model import MaterialEdgePredictor -from matgraphdb.utils.colors import DEFAULT_COLORS -from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( - Trainer, - learning_curve, - pca_plots, - roc_curve, -) +from sklearn import linear_model from sklearn.metrics import ( mean_absolute_error, mean_squared_error, @@ -34,12 +22,25 @@ roc_auc_score, roc_curve, ) -from sklearn import linear_model +from torch_geometric import nn as pyg_nn + +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.heterograph_encoder_general.metrics import ( LearningCurve, ROCCurve, plot_pca, ) +from matgraphdb.pyg.models.heterograph_encoder_general.model import ( + MaterialEdgePredictor, +) +from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( + Trainer, + learning_curve, + pca_plots, + roc_curve, +) +from matgraphdb.utils.colors import DEFAULT_COLORS ######################################################################################################################## @@ -57,8 +58,16 @@ "num_test": 0.0, "neg_sampling_ratio": 1.0, "is_undirected": True, - "edge_types": [("materials", "has", "elements"), ("materials", "has", "space_groups"), ("materials", "has", "crystal_systems")], - "rev_edge_types": [("elements", "rev_has", "materials"), ("space_groups", "rev_has", "materials"), ("crystal_systems", "rev_has", "materials")], + "edge_types": [ + ("materials", "has", "elements"), + ("materials", "has", "space_groups"), + ("materials", "has", "crystal_systems"), + ], + "rev_edge_types": [ + ("elements", "rev_has", "materials"), + ("space_groups", "rev_has", "materials"), + ("crystal_systems", "rev_has", "materials"), + ], }, }, "model": { @@ -73,12 +82,14 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder_general"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder_general" + ), "learning_rate": 0.001, "num_epochs": 1001, "eval_interval": 100, "scheduler_milestones": [4000, 20000], - } + }, } ) @@ -177,6 +188,7 @@ def binning(x): builder.add_edge_type("material_spg_has") builder.add_edge_type("material_crystalSystem_has") + def to_log(x): return torch.tensor(np.log10(x), dtype=torch.float32) @@ -197,7 +209,7 @@ def to_log(x): # Set random feature vector for materials if CONFIG.data.create_random_features: n_materials = data["materials"].num_nodes - + data["materials"].x = torch.normal( mean=0.0, std=1.0, size=(n_materials, CONFIG.data.n_material_dim) ) @@ -244,7 +256,6 @@ def to_log(x): test_val_materials = total_test_materials[:test_val_size] - n_train = len(train_materials) # Print percentages of each split print("\nSplit percentages:") @@ -258,7 +269,6 @@ def to_log(x): ) - # Create subgraphs for each split train_dict = {"materials": train_materials} train_val_dict = {"materials": train_val_materials} @@ -266,13 +276,21 @@ def to_log(x): test_val_dict = {"materials": test_val_materials} original_train_data = parent_data.subgraph(train_dict) -original_train_data["materials"].node_ids = parent_data["materials"].node_ids[train_dict["materials"]] +original_train_data["materials"].node_ids = parent_data["materials"].node_ids[ + train_dict["materials"] +] original_train_val_data = parent_data.subgraph(train_val_dict) -original_train_val_data["materials"].node_ids = parent_data["materials"].node_ids[train_val_dict["materials"]] +original_train_val_data["materials"].node_ids = parent_data["materials"].node_ids[ + train_val_dict["materials"] +] original_test_data = parent_data.subgraph(test_dict) -original_test_data["materials"].node_ids = parent_data["materials"].node_ids[test_dict["materials"]] +original_test_data["materials"].node_ids = parent_data["materials"].node_ids[ + test_dict["materials"] +] original_test_val_data = parent_data.subgraph(test_val_dict) -original_test_val_data["materials"].node_ids = parent_data["materials"].node_ids[test_val_dict["materials"]] +original_test_val_data["materials"].node_ids = parent_data["materials"].node_ids[ + test_val_dict["materials"] +] print(original_train_data["materials"].node_ids) print(f"Train materials: {len(train_materials)}") @@ -281,11 +299,21 @@ def to_log(x): print(f"Test val materials: {len(test_val_materials)}") # Reduce the target values for each split. Also record the the y_node_ids and the index of the split. -y_id_map = {int(y_id): float(y) for y_id, y in zip(parent_data['materials'].y_index, parent_data['materials'].y)} -for i, data in enumerate([original_train_data, original_train_val_data, original_test_data, original_test_val_data]): - y_vals=[] - ids=[] - node_ids=[] +y_id_map = { + int(y_id): float(y) + for y_id, y in zip(parent_data["materials"].y_index, parent_data["materials"].y) +} +for i, data in enumerate( + [ + original_train_data, + original_train_val_data, + original_test_data, + original_test_val_data, + ] +): + y_vals = [] + ids = [] + node_ids = [] for i, node_id in enumerate(data["materials"].node_ids): if int(node_id) in y_id_map: y_vals.append(y_id_map[int(node_id)]) @@ -294,7 +322,7 @@ def to_log(x): data["materials"].y = torch.tensor(y_vals) data["materials"].y_node_ids = torch.tensor(node_ids) data["materials"].y_split_index = torch.tensor(ids) - + data = None builder = None @@ -303,20 +331,26 @@ def to_log(x): # omefga config cannot handle list of tuples. Must convert back to list of tuples # print(CONFIG.data.random_link_split_args) -random_link_split_args = OmegaConf.to_container(CONFIG.data.random_link_split_args, resolve=True) +random_link_split_args = OmegaConf.to_container( + CONFIG.data.random_link_split_args, resolve=True +) + +for i, edge_type in enumerate(random_link_split_args["edge_types"]): + random_link_split_args["edge_types"][i] = tuple(edge_type) + +for i, edge_type in enumerate(random_link_split_args["rev_edge_types"]): + random_link_split_args["rev_edge_types"][i] = tuple(edge_type) -for i,edge_type in enumerate(random_link_split_args['edge_types']): - random_link_split_args['edge_types'][i] = tuple(edge_type) - -for i,edge_type in enumerate(random_link_split_args['rev_edge_types']): - random_link_split_args['rev_edge_types'][i] = tuple(edge_type) - # Perform a link-level split into training, validation, and test edges: train_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_train_data) -train_val_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_train_val_data) +train_val_data, _, _ = T.RandomLinkSplit(**random_link_split_args)( + original_train_val_data +) test_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_test_data) -test_val_data, _, _ = T.RandomLinkSplit(**random_link_split_args)(original_test_val_data) +test_val_data, _, _ = T.RandomLinkSplit(**random_link_split_args)( + original_test_val_data +) # print("Train data:") # print(train_data) @@ -337,12 +371,24 @@ def to_log(x): # Random link split does not add edge labels and index to the reverse edges. Must add them manually. for split_label, data in split_data.items(): - data['elements', 'rev_has', 'materials'].edge_label_index = data['materials', 'has', 'elements'].edge_label_index[[1,0]] - data['elements', 'rev_has', 'materials'].edge_label = data['materials', 'has', 'elements'].edge_label - data['space_groups', 'rev_has', 'materials'].edge_label_index = data['materials', 'has', 'space_groups'].edge_label_index[[1,0]] - data['space_groups', 'rev_has', 'materials'].edge_label = data['materials', 'has', 'space_groups'].edge_label - data['crystal_systems', 'rev_has', 'materials'].edge_label_index = data['materials', 'has', 'crystal_systems'].edge_label_index[[1,0]] - data['crystal_systems', 'rev_has', 'materials'].edge_label = data['materials', 'has', 'crystal_systems'].edge_label + data["elements", "rev_has", "materials"].edge_label_index = data[ + "materials", "has", "elements" + ].edge_label_index[[1, 0]] + data["elements", "rev_has", "materials"].edge_label = data[ + "materials", "has", "elements" + ].edge_label + data["space_groups", "rev_has", "materials"].edge_label_index = data[ + "materials", "has", "space_groups" + ].edge_label_index[[1, 0]] + data["space_groups", "rev_has", "materials"].edge_label = data[ + "materials", "has", "space_groups" + ].edge_label + data["crystal_systems", "rev_has", "materials"].edge_label_index = data[ + "materials", "has", "crystal_systems" + ].edge_label_index[[1, 0]] + data["crystal_systems", "rev_has", "materials"].edge_label = data[ + "materials", "has", "crystal_systems" + ].edge_label train_data = train_data.to(device) train_val_data = train_val_data.to(device) test_data = test_data.to(device) @@ -401,17 +447,18 @@ def weighted_binary_cross_entropy(pred, target, weights=None): optimizer, milestones=CONFIG.training.scheduler_milestones, gamma=0.1, verbose=False ) results = { - "train": {"mae": [],"epochs": []}, - "train_val": {"mae": [],"epochs": []}, - "test": {"mae": [],"epochs": []}, - "test_val": {"mae": [],"epochs": []}, - } + "train": {"mae": [], "epochs": []}, + "train_val": {"mae": [], "epochs": []}, + "test": {"mae": [], "epochs": []}, + "test_val": {"mae": [], "epochs": []}, +} results_original = { - "train": {"mae": [],"epochs": []}, - "train_val": {"mae": [],"epochs": []}, - "test": {"mae": [],"epochs": []}, - "test_val": {"mae": [],"epochs": []}, - } + "train": {"mae": [], "epochs": []}, + "train_val": {"mae": [], "epochs": []}, + "test": {"mae": [], "epochs": []}, + "test_val": {"mae": [], "epochs": []}, +} + def train_step(data_batch): model.train() @@ -423,11 +470,11 @@ def train_step(data_batch): for edge_type, key in model.edge_types_to_decoder_keys.items(): src, rel, dst = edge_type pred = pred_edge_dict[key] - + if not hasattr(data_batch[src, dst], "edge_label"): continue - - target = data_batch[src,rel,dst].edge_label + + target = data_batch[src, rel, dst].edge_label loss = F.binary_cross_entropy(pred, target) total_loss += loss loss_dict[key] = loss @@ -440,7 +487,7 @@ def train_step(data_batch): @torch.no_grad() def validation_step(data_batch): model.eval() - + loss_dict = {} prediction_dict = {} total_loss = 0 @@ -449,31 +496,31 @@ def validation_step(data_batch): for edge_type, key in model.edge_types_to_decoder_keys.items(): src, rel, dst = edge_type pred = pred_edge_dict[key] - + if not hasattr(data_batch[src, dst], "edge_label"): continue - - target = data_batch[src,rel,dst].edge_label + + target = data_batch[src, rel, dst].edge_label loss = F.binary_cross_entropy(pred, target) total_loss += loss loss_dict[key] = float(loss.cpu()) if key not in prediction_dict: prediction_dict[key] = {} - prediction_dict[key]['predictions'] = pred - prediction_dict[key]['targets'] = target + prediction_dict[key]["predictions"] = pred + prediction_dict[key]["targets"] = target return float(total_loss.cpu()), loss_dict, prediction_dict + @torch.no_grad() def regression_eval(data_batch_per_split): model.eval() z_material_per_split = {} z_original_per_split = {} y_per_split = {} - tmp_str='' - node_type='materials' + tmp_str = "" + node_type = "materials" for split_name, data_batch in data_batch_per_split.items(): z_dict = model.encode(data_batch) - y_split_index = data_batch[node_type].y_split_index y = data_batch[node_type].y @@ -483,48 +530,48 @@ def regression_eval(data_batch_per_split): z_material_per_split[split_name] = z.cpu().numpy() y_per_split[split_name] = y.cpu().numpy() z_original_per_split[split_name] = z_original.cpu().numpy() - - tmp_str += f'|{split_name}: {len(z)}|' + + tmp_str += f"|{split_name}: {len(z)}|" print(tmp_str) - reg = linear_model.LinearRegression() - reg.fit(z_material_per_split['train'], y_per_split['train']) - + reg.fit(z_material_per_split["train"], y_per_split["train"]) + reg_original = linear_model.LinearRegression() - reg_original.fit(z_original_per_split['train'], y_per_split['train']) - - test_splits = ['train', 'train_val', 'test', 'test_val'] - tmp_str = '' + reg_original.fit(z_original_per_split["train"], y_per_split["train"]) + + test_splits = ["train", "train_val", "test", "test_val"] + tmp_str = "" for test_split_name in test_splits: y_pred = reg.predict(z_material_per_split[test_split_name]) y_real = y_per_split[test_split_name] - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) - + rmse = np.sqrt(np.mean((y_pred - y_real) ** 2)) mae = np.mean(np.abs(y_pred - y_real)) results[test_split_name]["mae"].append(mae) - tmp_str += f'|{test_split_name}: RMSE: {rmse:.4f}, MAE: {mae:.4f}|' + tmp_str += f"|{test_split_name}: RMSE: {rmse:.4f}, MAE: {mae:.4f}|" print(tmp_str) - - test_splits = ['train', 'train_val', 'test', 'test_val'] - tmp_str = '' + + test_splits = ["train", "train_val", "test", "test_val"] + tmp_str = "" for test_split_name in test_splits: y_pred = reg_original.predict(z_original_per_split[test_split_name]) y_real = y_per_split[test_split_name] - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) - + rmse = np.sqrt(np.mean((y_pred - y_real) ** 2)) mae = np.mean(np.abs(y_pred - y_real)) results_original[test_split_name]["mae"].append(mae) - - tmp_str += f'|{test_split_name}: RMSE: {rmse:.4f}, MAE: {mae:.4f}|' + + tmp_str += f"|{test_split_name}: RMSE: {rmse:.4f}, MAE: {mae:.4f}|" print(tmp_str) - + + def eval_metrics(preds, targets, **kwargs): # Calculate metrics pred_binary = (preds > 0.5).float() @@ -543,7 +590,9 @@ def eval_metrics(preds, targets, **kwargs): recall = true_positives / (actual_positives + 1e-10) f1 = 2 * (precision * recall) / (precision + recall + 1e-10) - auc_score = roc_auc_score(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) + auc_score = roc_auc_score( + targets.cpu().detach().numpy(), preds.cpu().detach().numpy() + ) results = { "accuracy": float(accuracy), @@ -561,11 +610,12 @@ def eval_metrics(preds, targets, **kwargs): return results + def calculate_embeddings(data_batch): model.eval() with torch.no_grad(): z_dict = model.encode(data_batch) - + for key, value in z_dict.items(): z_dict[key] = value.cpu().detach().numpy() return z_dict @@ -575,8 +625,8 @@ def roc_curve(metrics_per_split_dict, epoch_save_path=None, total_save_path=None roc_curve_plot = ROCCurve() for split_label, metrics_dict in metrics_per_split_dict.items(): - pred = metrics_dict['current_predictions'] - target = metrics_dict['current_targets'] + pred = metrics_dict["current_predictions"] + target = metrics_dict["current_targets"] # Add main model curve roc_curve_plot.add_curve( @@ -586,12 +636,15 @@ def roc_curve(metrics_per_split_dict, epoch_save_path=None, total_save_path=None roc_curve_plot.plot() if epoch_save_path is not None: roc_curve_plot.save(epoch_save_path) - + if total_save_path is not None: roc_curve_plot.save(total_save_path) roc_curve_plot.close() -def learning_curve(metrics_per_split, metric_name, epoch_save_path=None, total_save_path=None): + +def learning_curve( + metrics_per_split, metric_name, epoch_save_path=None, total_save_path=None +): learning_curve = LearningCurve() for split_label, metrics_dict in metrics_per_split.items(): @@ -612,27 +665,30 @@ def learning_curve(metrics_per_split, metric_name, epoch_save_path=None, total_s learning_curve.save(total_save_path) learning_curve.close() -def pca_plots(embeddings_per_node_type, - n_nodes_per_node_type, - node_labels_per_node_type, - pca_dir): + +def pca_plots( + embeddings_per_node_type, n_nodes_per_node_type, node_labels_per_node_type, pca_dir +): os.makedirs(pca_dir, exist_ok=True) # 3. Combine embeddings only for the selected node types - z_all = np.concatenate([embeddings for embeddings in embeddings_per_node_type.values()], axis=0) - + z_all = np.concatenate( + [embeddings for embeddings in embeddings_per_node_type.values()], axis=0 + ) + plot_pca( z_all, save_dir=pca_dir, - save_name=f'embeddings_pca_grid.png', + save_name=f"embeddings_pca_grid.png", n_nodes_per_type=n_nodes_per_node_type, node_labels_per_type=node_labels_per_node_type, n_components=2, figsize=(10, 8), - close=True + close=True, ) -def plot_learning_curves(results, save_path, measure='mae'): + +def plot_learning_curves(results, save_path, measure="mae"): """ Plots the learning curves for a specified measure from the results dictionary. @@ -643,23 +699,25 @@ def plot_learning_curves(results, save_path, measure='mae'): measure (str): The measure to plot (e.g., 'loss' or 'mae'). Default is 'loss'. """ plt.figure(figsize=(10, 6)) - + # Iterate over the splits in the results dictionary for idx, split in enumerate(results): split_data = results[split] - + # Check if the desired measure is available in this split's data if measure not in split_data: - print(f"Warning: Measure '{measure}' not found for split '{split}'. Skipping.") + print( + f"Warning: Measure '{measure}' not found for split '{split}'. Skipping." + ) continue # Use the provided 'epochs' list if available, otherwise create a range based on the measure length epochs = split_data.get("epochs", list(range(len(split_data[measure])))) values = split_data[measure] - + # Select a color for this plot color = DEFAULT_COLORS[idx % len(DEFAULT_COLORS)] - + # Plot the curve for this split plt.plot(epochs, values, label=split, color=color, linewidth=2) @@ -672,6 +730,7 @@ def plot_learning_curves(results, save_path, measure='mae'): plt.savefig(save_path) plt.close() + runs_dir = os.path.join(CONFIG.training.training_dir, "runs") os.makedirs(runs_dir, exist_ok=True) n_runs = len(os.listdir(runs_dir)) @@ -691,47 +750,45 @@ def plot_learning_curves(results, save_path, measure='mae'): os.makedirs(run_pca_dir, exist_ok=True) - loss, loss_dict, rel_prediction_dict = validation_step(test_val_data) -rel_names = [] +rel_names = [] split_names = list(split_data.keys()) for rel_name, prediction_dict in rel_prediction_dict.items(): - preds = prediction_dict['predictions'] - targets = prediction_dict['targets'] + preds = prediction_dict["predictions"] + targets = prediction_dict["targets"] metrics = eval_metrics(preds, targets) - metric_names=list(metrics.keys()) + metric_names = list(metrics.keys()) rel_names.append(rel_name) - -metric_names.append('loss') -metric_names.append('epochs') + +metric_names.append("loss") +metric_names.append("epochs") values_to_record = copy.deepcopy(metric_names) -values_to_record.append('current_predictions') -values_to_record.append('current_targets') -values_to_record.append('current_loss') +values_to_record.append("current_predictions") +values_to_record.append("current_targets") +values_to_record.append("current_loss") metrics_per_rel_per_split = defaultdict( - lambda: defaultdict( - lambda: {value_name: [] for value_name in values_to_record} - ) + lambda: defaultdict(lambda: {value_name: [] for value_name in values_to_record}) ) node_types = test_val_data.metadata()[0] n_nodes_per_split_per_node_type = defaultdict( - lambda: defaultdict( - lambda: {node_type: [] for node_type in node_types} - ) + lambda: defaultdict(lambda: {node_type: [] for node_type in node_types}) ) node_labels_per_split_per_node_type = defaultdict( lambda: defaultdict( - lambda: {node_type: [] for node_type in node_types if node_type != 'materials'} + lambda: {node_type: [] for node_type in node_types if node_type != "materials"} ) ) for split_name, data_batch in split_data.items(): for node_type in data_batch.metadata()[0]: - n_nodes_per_split_per_node_type[split_name][node_type] = data_batch[node_type].num_nodes - if node_type != 'materials': - node_labels_per_split_per_node_type[split_name][node_type] = data_batch[node_type].labels - + n_nodes_per_split_per_node_type[split_name][node_type] = data_batch[ + node_type + ].num_nodes + if node_type != "materials": + node_labels_per_split_per_node_type[split_name][node_type] = data_batch[ + node_type + ].labels for epoch in range(CONFIG.training.num_epochs): @@ -742,54 +799,57 @@ def plot_learning_curves(results, save_path, measure='mae'): current_epoch = epoch epoch_dir = os.path.join(epochs_dir, f"epoch_{epoch}") os.makedirs(epoch_dir, exist_ok=True) - + eval_str = f"Epoch: {epoch} :" loss_per_rel_str = "" - - - embeddings_per_split_per_node_type= {} + + embeddings_per_split_per_node_type = {} for split_name, data_batch in split_data.items(): loss, loss_dict, rel_prediction_dict = validation_step(data_batch) - + eval_str += f" |{split_name}_loss: {loss} " - + embeddings_dict = calculate_embeddings(data_batch) embeddings_per_split_per_node_type[split_name] = embeddings_dict - - + for rel_name in rel_names: prediction_dict = rel_prediction_dict[rel_name] - rel_loss=loss_dict[rel_name] - preds = prediction_dict['predictions'] - targets = prediction_dict['targets'] + rel_loss = loss_dict[rel_name] + preds = prediction_dict["predictions"] + targets = prediction_dict["targets"] metrics = eval_metrics(preds, targets) - metrics_per_rel_per_split[rel_name][split_name]['current_loss'] = rel_loss - metrics_per_rel_per_split[rel_name][split_name]['current_predictions'] = preds - metrics_per_rel_per_split[rel_name][split_name]['current_targets'] = targets - - metrics_per_rel_per_split[rel_name][split_name]['loss'].append(rel_loss) - metrics_per_rel_per_split[rel_name][split_name]['epochs'].append(epoch) + metrics_per_rel_per_split[rel_name][split_name][ + "current_loss" + ] = rel_loss + metrics_per_rel_per_split[rel_name][split_name][ + "current_predictions" + ] = preds + metrics_per_rel_per_split[rel_name][split_name][ + "current_targets" + ] = targets + + metrics_per_rel_per_split[rel_name][split_name]["loss"].append(rel_loss) + metrics_per_rel_per_split[rel_name][split_name]["epochs"].append(epoch) for metric_name, metric_value in metrics.items(): - metrics_per_rel_per_split[rel_name][split_name][metric_name].append(metric_value) - - + metrics_per_rel_per_split[rel_name][split_name][metric_name].append( + metric_value + ) + results[split_name]["epochs"].append(epoch) results_original[split_name]["epochs"].append(epoch) - - + eval_str += f"|" print(eval_str) - + for rel_name in rel_names: loss_list = [] for split_name in split_names: - loss = metrics_per_rel_per_split[rel_name][split_name]['current_loss'] + loss = metrics_per_rel_per_split[rel_name][split_name]["current_loss"] loss_list.append(str(round(loss, 4))) loss_per_rel_str = f"|{rel_name}: {':'.join(loss_list)}|" print(f" {loss_per_rel_str}") - - + regression_eval(split_data) # epoch_roc_curve_dir = os.path.join(epoch_dir, "roc_curves") @@ -799,7 +859,7 @@ def plot_learning_curves(results, save_path, measure='mae'): # epoch_save_path=os.path.join(epoch_roc_curve_dir, f"{rel_name}_roc_curve.png") # total_save_path=os.path.join(run_roc_curve_dir, f"{rel_name}_roc_curve.png") # roc_curve(metrics_per_split, epoch_save_path, total_save_path) - + # for metric_name in metric_names: # if metric_name == 'epochs': # continue @@ -817,26 +877,30 @@ def plot_learning_curves(results, save_path, measure='mae'): # n_nodes_per_node_type = n_nodes_per_split_per_node_type[split_name] # node_labels_per_type = node_labels_per_split_per_node_type[split_name] # pca_plots( - # embeddings_per_node_type, - # n_nodes_per_node_type, + # embeddings_per_node_type, + # n_nodes_per_node_type, # node_labels_per_type, # pca_dir=os.path.join(run_pca_dir, split_name)) - -plot_learning_curves(results, save_path=os.path.join(run_learning_curve_dir, "learning_curves.png")) -plot_learning_curves(results_original, save_path=os.path.join(run_learning_curve_dir, "learning_curves_original.png")) +plot_learning_curves( + results, save_path=os.path.join(run_learning_curve_dir, "learning_curves.png") +) +plot_learning_curves( + results_original, + save_path=os.path.join(run_learning_curve_dir, "learning_curves_original.png"), +) epoch_roc_curve_dir = os.path.join(epoch_dir, "roc_curves") os.makedirs(epoch_roc_curve_dir, exist_ok=True) for rel_name in rel_names: metrics_per_split = metrics_per_rel_per_split[rel_name] # epoch_save_path=os.path.join(epoch_roc_curve_dir, f"{rel_name}_roc_curve.png") - total_save_path=os.path.join(run_roc_curve_dir, f"{rel_name}_roc_curve.png") + total_save_path = os.path.join(run_roc_curve_dir, f"{rel_name}_roc_curve.png") roc_curve(metrics_per_split, total_save_path=total_save_path) for metric_name in metric_names: - if metric_name == 'epochs': + if metric_name == "epochs": continue for rel_name in rel_names: metrics_per_split = metrics_per_rel_per_split[rel_name] @@ -845,14 +909,17 @@ def plot_learning_curves(results, save_path, measure='mae'): os.makedirs(rel_run_learning_curve_dir, exist_ok=True) os.makedirs(rel_epoch_learning_curve_dir, exist_ok=True) # epoch_save_path=os.path.join(rel_epoch_learning_curve_dir, f"{metric_name}_learning_curve.png") - total_save_path=os.path.join(rel_run_learning_curve_dir, f"{metric_name}_learning_curve.png") + total_save_path = os.path.join( + rel_run_learning_curve_dir, f"{metric_name}_learning_curve.png" + ) learning_curve(metrics_per_split, metric_name, total_save_path=total_save_path) for split_name, embeddings_per_node_type in embeddings_per_split_per_node_type.items(): n_nodes_per_node_type = n_nodes_per_split_per_node_type[split_name] node_labels_per_type = node_labels_per_split_per_node_type[split_name] pca_plots( - embeddings_per_node_type, - n_nodes_per_node_type, + embeddings_per_node_type, + n_nodes_per_node_type, node_labels_per_type, - pca_dir=os.path.join(run_pca_dir, split_name)) \ No newline at end of file + pca_dir=os.path.join(run_pca_dir, split_name), + ) diff --git a/matgraphdb/pyg/models/heterograph_encoder_general/train_linear.py b/matgraphdb/pyg/models/heterograph_encoder_general/train_linear.py index 31e1f39..9ee5125 100644 --- a/matgraphdb/pyg/models/heterograph_encoder_general/train_linear.py +++ b/matgraphdb/pyg/models/heterograph_encoder_general/train_linear.py @@ -1,10 +1,11 @@ +import copy import json import os import time +from collections import defaultdict import matplotlib.pyplot as plt import numpy as np -import copy import pandas as pd import pyarrow.compute as pc import torch @@ -13,19 +14,7 @@ import torch_geometric as pyg import torch_geometric.transforms as T from omegaconf import OmegaConf -from torch_geometric import nn as pyg_nn - -from collections import defaultdict - -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder -from matgraphdb.pyg.models.heterograph_encoder_general.model import MaterialEdgePredictor -from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( - Trainer, - learning_curve, - pca_plots, - roc_curve, -) +from sklearn import linear_model from sklearn.metrics import ( mean_absolute_error, mean_squared_error, @@ -33,12 +22,24 @@ roc_auc_score, roc_curve, ) -from sklearn import linear_model +from torch_geometric import nn as pyg_nn + +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.heterograph_encoder_general.metrics import ( LearningCurve, ROCCurve, plot_pca, ) +from matgraphdb.pyg.models.heterograph_encoder_general.model import ( + MaterialEdgePredictor, +) +from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( + Trainer, + learning_curve, + pca_plots, + roc_curve, +) ######################################################################################################################## @@ -56,8 +57,16 @@ "num_test": 0.0, "neg_sampling_ratio": 1.0, "is_undirected": True, - "edge_types": [("materials", "has", "elements"), ("materials", "has", "space_groups"), ("materials", "has", "crystal_systems")], - "rev_edge_types": [("elements", "rev_has", "materials"), ("space_groups", "rev_has", "materials"), ("crystal_systems", "rev_has", "materials")], + "edge_types": [ + ("materials", "has", "elements"), + ("materials", "has", "space_groups"), + ("materials", "has", "crystal_systems"), + ], + "rev_edge_types": [ + ("elements", "rev_has", "materials"), + ("space_groups", "rev_has", "materials"), + ("crystal_systems", "rev_has", "materials"), + ], }, }, "model": { @@ -71,12 +80,14 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder_general"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder_general" + ), "learning_rate": 0.001, "num_epochs": 20001, "eval_interval": 1000, "scheduler_milestones": [4000, 20000], - } + }, } ) @@ -107,15 +118,21 @@ material_store = mdb.material_store -df = material_store.read(columns=["elasticity.g_vrh", "elasticity.k_vrh", - "core.volume", "core.density", - "core.density_atomic", "core.nelements", - "core.nsites"], - filters=[ +df = material_store.read( + columns=[ + "elasticity.g_vrh", + "elasticity.k_vrh", + "core.volume", + "core.density", + "core.density_atomic", + "core.nelements", + "core.nsites", + ], + filters=[ pc.field("elasticity.g_vrh") > 0, pc.field("elasticity.g_vrh") < 400, - ]).to_pandas() - + ], +).to_pandas() print("-" * 100) @@ -127,19 +144,24 @@ # y_index = parent_data['materials'].y_index -z = df[["core.volume", "core.density", "core.density_atomic", - "core.nelements", +z = df[ + [ + "core.volume", + "core.density", + "core.density_atomic", + "core.nelements", # "core.nsites" - ]] + ] +] y = df["elasticity.g_vrh"] z = torch.tensor(z.values, dtype=torch.float32) y = torch.tensor(y.values, dtype=torch.float32) perm = torch.randperm(z.size(0)) -train_perm = perm[:int(z.size(0) * CONFIG.data.train_ratio)] -test_perm = perm[int(z.size(0) * CONFIG.data.train_ratio):] -print(f'N train: {len(train_perm)}, N test: {len(test_perm)}') +train_perm = perm[: int(z.size(0) * CONFIG.data.train_ratio)] +test_perm = perm[int(z.size(0) * CONFIG.data.train_ratio) :] +print(f"N train: {len(train_perm)}, N test: {len(test_perm)}") reg = linear_model.LinearRegression() reg.fit(z[train_perm].cpu().numpy(), y[train_perm].cpu().numpy()) @@ -150,5 +172,5 @@ # y_real = np.array([10**value for value in y_real]) rmse = np.sqrt(np.mean((y_pred - y_real) ** 2)) mae = np.mean(np.abs(y_pred - y_real)) -tmp_str = f'RMSE: {rmse:.4f}, MAE: {mae:.4f}|' +tmp_str = f"RMSE: {rmse:.4f}, MAE: {mae:.4f}|" print(tmp_str) diff --git a/matgraphdb/pyg/models/hgt/train.py b/matgraphdb/pyg/models/hgt/train.py index ca6834e..5ade1c4 100644 --- a/matgraphdb/pyg/models/hgt/train.py +++ b/matgraphdb/pyg/models/hgt/train.py @@ -14,8 +14,8 @@ from omegaconf import OmegaConf from torch_geometric import nn as pyg_nn -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.hgt.model import HGT from matgraphdb.pyg.models.hgt.trainer import ( Trainer, @@ -45,13 +45,15 @@ }, }, "model": { - "hidden_channels": 128, + "hidden_channels": 128, "out_channels": 1, "num_heads": 8, "num_layers": 3, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder" + ), "learning_rate": 0.001, "num_epochs": 40001, "eval_interval": 2000, @@ -64,7 +66,7 @@ "mlflow_experiment_name": "heterograph_encoder", "mlflow_tracking_uri": "${training.training_dir}/mlflow", "mlflow_record_system_metrics": True, - } + }, } ) @@ -233,7 +235,6 @@ def to_log(x): test_val_data = original_test_val_data - print(train_data) print(train_val_data) print(test_data) @@ -260,11 +261,13 @@ def to_log(x): #################################################################################################### # Model #################################################################################################### -model = HGT(hidden_channels=CONFIG.model.hidden_channels, - out_channels=CONFIG.model.out_channels, - num_heads=CONFIG.model.num_heads, - num_layers=CONFIG.model.num_layers, - data=train_data).to(device) +model = HGT( + hidden_channels=CONFIG.model.hidden_channels, + out_channels=CONFIG.model.out_channels, + num_heads=CONFIG.model.num_heads, + num_layers=CONFIG.model.num_layers, + data=train_data, +).to(device) print(model) @@ -283,8 +286,8 @@ def train(): model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) - mask = data['author'].train_mask - loss = F.cross_entropy(out[mask], data['author'].y[mask]) + mask = data["author"].train_mask + loss = F.cross_entropy(out[mask], data["author"].y[mask]) loss.backward() optimizer.step() return float(loss) @@ -296,9 +299,9 @@ def test(): pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) accs = [] - for split in ['train_mask', 'val_mask', 'test_mask']: - mask = data['author'][split] - acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum() + for split in ["train_mask", "val_mask", "test_mask"]: + mask = data["author"][split] + acc = (pred[mask] == data["author"].y[mask]).sum() / mask.sum() accs.append(float(acc)) return accs @@ -306,5 +309,7 @@ def test(): for epoch in range(1, 101): loss = train() train_acc, val_acc, test_acc = test() - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' - f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') \ No newline at end of file + print( + f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, " + f"Val: {val_acc:.4f}, Test: {test_acc:.4f}" + ) diff --git a/matgraphdb/pyg/models/metapath2vec/train.py b/matgraphdb/pyg/models/metapath2vec/train.py index 1d19eac..f0da761 100644 --- a/matgraphdb/pyg/models/metapath2vec/train.py +++ b/matgraphdb/pyg/models/metapath2vec/train.py @@ -1,8 +1,9 @@ import json +import logging import os import time -import logging +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -12,19 +13,19 @@ import torch.optim.lr_scheduler as lr_scheduler import torch_geometric as pyg import torch_geometric.transforms as T + +######################################################################################################################## +import umap from omegaconf import OmegaConf +from sklearn import linear_model from torch_geometric import nn as pyg_nn +from torch_geometric.nn import MetaPath2Vec -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.metapath2vec.metrics import plot_pca -import matplotlib.patches as mpatches -from sklearn import linear_model -from matgraphdb.utils.colors import DEFAULT_COLORS, DEFAULT_CMAP +from matgraphdb.utils.colors import DEFAULT_CMAP, DEFAULT_COLORS from matgraphdb.utils.config import config -from torch_geometric.nn import MetaPath2Vec -######################################################################################################################## -import umap LOGGER = logging.getLogger(__name__) @@ -45,33 +46,37 @@ # LOGGER.addHandler(logging.StreamHandler()) - - - def to_log(x): return torch.tensor(np.log10(x), dtype=torch.float32) - + + DATA_CONFIG = OmegaConf.create( { - "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), - "nodes" : - {"materials": {"columns": ["core.density_atomic"], 'drop_null': True}, - "elements": {"columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], 'drop_null':True, 'label_column': 'symbol'}, - "space_groups": {'drop_null': True, 'label_column': 'spg'}, - "crystal_systems": {'drop_null': True, 'label_column': 'crystal_system'} + "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), + "nodes": { + "materials": {"columns": ["core.density_atomic"], "drop_null": True}, + "elements": { + "columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], + "drop_null": True, + "label_column": "symbol", + }, + "space_groups": {"drop_null": True, "label_column": "spg"}, + "crystal_systems": {"drop_null": True, "label_column": "crystal_system"}, }, - "edges" : - { - "element_element_neighborsByGroupPeriod": {}, - "material_element_has": {}, - "material_spg_has": {}, - "material_crystalSystem_has": {} + "edges": { + "element_element_neighborsByGroupPeriod": {}, + "material_element_has": {}, + "material_spg_has": {}, + "material_crystalSystem_has": {}, + }, + "target": { + "materials": { + "columns": ["elasticity.g_vrh"], + "drop_null": True, + "filters": "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", + "encoders": "{'elasticity.g_vrh': to_log}", + } }, - "target":{ - "materials": {"columns": ["elasticity.g_vrh"], 'drop_null': True, - 'filters': "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", - 'encoders': "{'elasticity.g_vrh': to_log}"} - } } ) @@ -89,7 +94,10 @@ def to_log(x): "sparse": True, # "metapath": [('materials', 'has', 'elements'), ('elements', 'rev_has', 'materials')] # "metapath": [('materials', 'has', 'crystal_systems'), ('crystal_systems', 'rev_has', 'materials')] - "metapath": [('materials', 'has', 'space_groups'), ('space_groups', 'rev_has', 'materials')] + "metapath": [ + ("materials", "has", "space_groups"), + ("space_groups", "rev_has", "materials"), + ], }, "training": { "train_dir": os.path.join("data", "training_runs", "metapath2vec"), @@ -102,23 +110,25 @@ def to_log(x): "log_steps": 100, "eval_steps": 2000, "test_train_ratio": 0.8, - "test_max_iter": 150 - } + "test_max_iter": 150, + }, } ) -MLP_CONFIG = OmegaConf.create({ - "data": dict(DATA_CONFIG), - "model": { - "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline - }, - "training": { - "learning_rate": 0.001, - "train_ratio": 0.8, - "val_ratio": 0.1, - "epochs": 2000, +MLP_CONFIG = OmegaConf.create( + { + "data": dict(DATA_CONFIG), + "model": { + "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline + }, + "training": { + "learning_rate": 0.001, + "train_ratio": 0.8, + "val_ratio": 0.1, + "epochs": 2000, + }, } -}) +) #################################################################################################### #################################################################################################### @@ -139,18 +149,18 @@ def to_log(x): #################################################################################################### def build_heterograph(): """Build the initial heterogeneous graph from the materials database. - + Returns: torch_geometric.data.HeteroData: The constructed heterogeneous graph """ mdb = MPNearHull(DATA_CONFIG.dataset_dir) builder = HeteroGraphBuilder(mdb) - + # Define the "materials" node type (only a subset of columns is used here) for node_type, node_config in DATA_CONFIG.nodes.items(): node_config = OmegaConf.to_container(node_config) builder.add_node_type(node_type, **node_config) - + for edge_type, edge_config in DATA_CONFIG.edges.items(): edge_config = OmegaConf.to_container(edge_config) builder.add_edge_type(edge_type, **edge_config) @@ -164,12 +174,16 @@ def build_heterograph(): if "encoders" in target_config: encoders = target_config.pop("encoders") encoders = eval(encoders) - - builder.add_target_node_property(target_type, filters=filters, encoders=encoders, **target_config) - + + builder.add_target_node_property( + target_type, filters=filters, encoders=encoders, **target_config + ) + heterodata = builder.hetero_data LOGGER.info(f"HeteroData: {heterodata}") - heterodata["materials"].original_x = heterodata["materials"].x # Save original features + heterodata["materials"].original_x = heterodata[ + "materials" + ].x # Save original features return heterodata @@ -177,23 +191,21 @@ def heterograph_preprocessing(): """ Build the heterograph, apply transformations, partition the graph, and split the 'materials' nodes into training/validation/test subgraphs. - + Args: config (OmegaConf): A configuration object with the keys: - data: data-related parameters (e.g., dataset_dir, create_random_features, n_material_dim, train_ratio, val_ratio) - model: model-related parameters (e.g., n_partitions) - training: training-related parameters - + Returns: split_data (dict): A dictionary with keys "train", "train_val", "test", "test_val", each containing a subgraph for the corresponding split. """ # 1. Build the heterogeneous graph from the materials database - - + original_heterograph = build_heterograph() - - + # 2. Apply transformation: make the graph undirected. source_data = T.ToUndirected()(original_heterograph) # Free up memory. @@ -205,38 +217,41 @@ def heterograph_preprocessing(): # # Model # #################################################################################################### + class MLPBaseline(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim=1): super(MLPBaseline, self).__init__() self.net = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), - torch.nn.Linear(hidden_dim, output_dim) + torch.nn.Linear(hidden_dim, output_dim), ) - + def forward(self, x): return self.net(x) def train_mlp_baseline(heterodata, metapath2vec_model): - z = metapath2vec_model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y.to(DEVICE).squeeze() - + z = metapath2vec_model( + "materials", batch=heterodata["materials"].y_index.to(DEVICE) + ) + y = heterodata["materials"].y.to(DEVICE).squeeze() + material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = MLP_CONFIG.training.train_ratio val_ratio = MLP_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -255,19 +270,21 @@ def train_mlp_baseline(heterodata, metapath2vec_model): input_dim = z.shape[1] hidden_dim = MLP_CONFIG.model.mlp_hidden_dim model = MLPBaseline(input_dim=input_dim, hidden_dim=hidden_dim).to(DEVICE) - optimizer = torch.optim.Adam(model.parameters(), lr=MLP_CONFIG.training.learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), lr=MLP_CONFIG.training.learning_rate + ) loss_fn = torch.nn.L1Loss() - + # Initialize results storage results = { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } def train_step(): - + model.train() optimizer.zero_grad() # Move this here, before the forward pass @@ -276,27 +293,25 @@ def train_step(): loss = loss_fn(y_pred, y[split_data["train"]]) loss.backward(retain_graph=True) optimizer.step() - + total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss)) @torch.no_grad() def test_step(): model.eval() - + for split_name, split_materials in split_data.items(): y_pred = model(z[split_materials]).squeeze().cpu().numpy() y_real = y[split_materials].cpu().numpy() - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) mae = np.mean(np.abs(y_pred - y_real)) results[split_name]["mae"].append(float(mae)) - - for epoch in range(MLP_CONFIG.training.epochs): train_step() test_step() @@ -304,42 +319,51 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - + return results + def train_metapath2vec(heterodata): - metapath=[] + metapath = [] for path in METAPATH2VEC_CONFIG.model.metapath: metapath.append(tuple(path)) - - num_nodes_dict = {node_type: heterodata[node_type].num_nodes for node_type in heterodata.node_types} - model = MetaPath2Vec(heterodata.edge_index_dict, - embedding_dim=METAPATH2VEC_CONFIG.model.embedding_dim, - metapath=metapath, - walk_length=METAPATH2VEC_CONFIG.model.walk_length, - context_size=METAPATH2VEC_CONFIG.model.context_size, - walks_per_node=METAPATH2VEC_CONFIG.model.walks_per_node, - num_negative_samples=METAPATH2VEC_CONFIG.model.num_negative_samples, - sparse=METAPATH2VEC_CONFIG.model.sparse, - num_nodes_dict=num_nodes_dict).to(DEVICE) - - loader = model.loader(batch_size=METAPATH2VEC_CONFIG.training.batch_size, - shuffle=True, - num_workers=METAPATH2VEC_CONFIG.training.num_workers) + + num_nodes_dict = { + node_type: heterodata[node_type].num_nodes + for node_type in heterodata.node_types + } + model = MetaPath2Vec( + heterodata.edge_index_dict, + embedding_dim=METAPATH2VEC_CONFIG.model.embedding_dim, + metapath=metapath, + walk_length=METAPATH2VEC_CONFIG.model.walk_length, + context_size=METAPATH2VEC_CONFIG.model.context_size, + walks_per_node=METAPATH2VEC_CONFIG.model.walks_per_node, + num_negative_samples=METAPATH2VEC_CONFIG.model.num_negative_samples, + sparse=METAPATH2VEC_CONFIG.model.sparse, + num_nodes_dict=num_nodes_dict, + ).to(DEVICE) + + loader = model.loader( + batch_size=METAPATH2VEC_CONFIG.training.batch_size, + shuffle=True, + num_workers=METAPATH2VEC_CONFIG.training.num_workers, + ) print(model) - optimizer = torch.optim.SparseAdam(list(model.parameters()), - lr=METAPATH2VEC_CONFIG.training.learning_rate) + optimizer = torch.optim.SparseAdam( + list(model.parameters()), lr=METAPATH2VEC_CONFIG.training.learning_rate + ) results = { - "train": {"mae": [], "loss": [], "epochs": []}, + "train": {"mae": [], "loss": [], "epochs": []}, "train_val": {"mae": [], "loss": [], "epochs": []}, - "test": {"mae": [], "loss": [], "epochs": []}, - "test_val": {"mae": [], "loss": [], "epochs": []}, + "test": {"mae": [], "loss": [], "epochs": []}, + "test_val": {"mae": [], "loss": [], "epochs": []}, } def train_step(): @@ -353,31 +377,31 @@ def train_step(): optimizer.step() total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss / len(loader))) @torch.no_grad() def test_step(): model.eval() - z = model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y + z = model("materials", batch=heterodata["materials"].y_index.to(DEVICE)) + y = heterodata["materials"].y material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = METAPATH2VEC_CONFIG.training.train_ratio val_ratio = METAPATH2VEC_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -392,18 +416,20 @@ def test_step(): "test": test_materials, "test_val": test_val_materials, } - + reg = linear_model.LinearRegression() - reg.fit(z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy()) - + reg.fit( + z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy() + ) + for split_name, split_materials in split_data.items(): y_pred = reg.predict(z[split_materials].cpu().numpy()) y_real = y[split_materials].cpu().numpy() - + if split_name != "train": loss = np.mean(np.abs(y_pred - y_real)) results[split_name]["loss"].append(float(loss)) - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) @@ -419,14 +445,13 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - - return model, results + return model, results def main(): @@ -434,9 +459,9 @@ def main(): model, linear_results = train_metapath2vec(heterodata) mlp_results = train_mlp_baseline(heterodata, model) - + training_dir = METAPATH2VEC_CONFIG.training.train_dir - + runs_dir = os.path.join(training_dir, "runs") os.makedirs(runs_dir, exist_ok=True) n_runs = len(os.listdir(runs_dir)) @@ -448,37 +473,39 @@ def main(): with open(os.path.join(results_dir, "mlp_config.json"), "w") as f: json.dump(OmegaConf.to_container(MLP_CONFIG), f) - + with open(os.path.join(results_dir, "linear_results.json"), "w") as f: json.dump(linear_results, f) - + with open(os.path.join(results_dir, "mlp_results.json"), "w") as f: json.dump(mlp_results, f) - - plot_learning_curves(linear_results, os.path.join(results_dir, "linear_learning_curves.png")) - plot_learning_curves(mlp_results, os.path.join(results_dir, "mlp_learning_curves.png")) - - - + + plot_learning_curves( + linear_results, os.path.join(results_dir, "linear_learning_curves.png") + ) + plot_learning_curves( + mlp_results, os.path.join(results_dir, "mlp_learning_curves.png") + ) + z_per_type = { - "materials": model('materials'), + "materials": model("materials"), # "elements": model('elements'), - "space_groups": model('space_groups'), + "space_groups": model("space_groups"), # "crystal_systems": model('crystal_systems'), } targets_per_type = { - "materials": 10 ** heterodata['materials'].y.cpu().numpy(), + "materials": 10 ** heterodata["materials"].y.cpu().numpy(), } targets_labels_per_type = { - "materials": heterodata['materials'].y_label_name[0], + "materials": heterodata["materials"].y_label_name[0], } targets_index_per_type = { - "materials": heterodata['materials'].y_index.cpu().numpy(), + "materials": heterodata["materials"].y_index.cpu().numpy(), } LOGGER.info(f"Targets index per type: {len(heterodata['elements'].labels)}") labels_per_type = { - "elements": heterodata['elements'].labels, - "space_groups": heterodata['space_groups'].labels, + "elements": heterodata["elements"].labels, + "space_groups": heterodata["space_groups"].labels, # "crystal_systems": heterodata['crystal_systems'].labels, } color_per_type = { @@ -486,25 +513,27 @@ def main(): "space_groups": "black", # "crystal_systems": "black", } - - create_umap_plot(z_per_type, - targets_per_type=targets_per_type, - targets_index_per_type=targets_index_per_type, - targets_labels_per_type=targets_labels_per_type, - labels_per_type=labels_per_type, - color_per_type=color_per_type, - save_path=os.path.join(results_dir, "umap.png"), - n_neighbors=30) - # create_umap_plot3d(z_per_type, + + create_umap_plot( + z_per_type, + targets_per_type=targets_per_type, + targets_index_per_type=targets_index_per_type, + targets_labels_per_type=targets_labels_per_type, + labels_per_type=labels_per_type, + color_per_type=color_per_type, + save_path=os.path.join(results_dir, "umap.png"), + n_neighbors=30, + ) + # create_umap_plot3d(z_per_type, # targets_per_type=targets_per_type, # targets_index_per_type=targets_index_per_type, # labels_per_type=labels_per_type, # color_per_type=color_per_type, # save_path=os.path.join(results_dir, "umap_materials_elements_3d.png"), # n_neighbors=30) - -def plot_learning_curves(results, save_path, measure='mae'): + +def plot_learning_curves(results, save_path, measure="mae"): """ Plots the learning curves for a specified measure from the results dictionary. @@ -515,23 +544,25 @@ def plot_learning_curves(results, save_path, measure='mae'): measure (str): The measure to plot (e.g., 'loss' or 'mae'). Default is 'loss'. """ plt.figure(figsize=(10, 6)) - + # Iterate over the splits in the results dictionary for idx, split in enumerate(results): split_data = results[split] - + # Check if the desired measure is available in this split's data if measure not in split_data: - print(f"Warning: Measure '{measure}' not found for split '{split}'. Skipping.") + print( + f"Warning: Measure '{measure}' not found for split '{split}'. Skipping." + ) continue # Use the provided 'epochs' list if available, otherwise create a range based on the measure length epochs = split_data.get("epochs", list(range(len(split_data[measure])))) values = split_data[measure] - + # Select a color for this plot color = DEFAULT_COLORS[idx % len(DEFAULT_COLORS)] - + # Plot the curve for this split plt.plot(epochs, values, label=split, color=color, linewidth=2) @@ -545,20 +576,20 @@ def plot_learning_curves(results, save_path, measure='mae'): plt.close() - - -def create_umap_plot(z_per_type, - targets_per_type:dict=None, - targets_index_per_type:dict=None, - targets_labels_per_type:dict=None, - filter_index_per_type:dict=None, - labels_per_type:dict=None, - color_per_type:dict=None, - save_path=".", - n_neighbors=50, - n_jobs=4): +def create_umap_plot( + z_per_type, + targets_per_type: dict = None, + targets_index_per_type: dict = None, + targets_labels_per_type: dict = None, + filter_index_per_type: dict = None, + labels_per_type: dict = None, + color_per_type: dict = None, + save_path=".", + n_neighbors=50, + n_jobs=4, +): node_types = list(z_per_type.keys()) - + if targets_per_type is None: targets_per_type = {} if targets_index_per_type is None: @@ -571,98 +602,100 @@ def create_umap_plot(z_per_type, color_per_type = {} if targets_labels_per_type is None: targets_labels_per_type = {} - - z_global_idx_per_type={} - z_local_idx_per_type={} - local_global_idx_mapping_per_type={} + + z_global_idx_per_type = {} + z_local_idx_per_type = {} + local_global_idx_mapping_per_type = {} z_node_type_mapping = {} - - z_all=None - total_n_nodes=0 + z_all = None + total_n_nodes = 0 for i, (node_type, z) in enumerate(z_per_type.items()): - z=z.detach().cpu().numpy() + z = z.detach().cpu().numpy() n_nodes = z.shape[0] - - + LOGGER.info(f"Node type: {node_type}, Number of nodes: {n_nodes}") - + z_node_type_mapping[node_type] = i - z_global_idx_per_type[node_type] = np.arange(total_n_nodes, total_n_nodes + n_nodes) + z_global_idx_per_type[node_type] = np.arange( + total_n_nodes, total_n_nodes + n_nodes + ) z_local_idx_per_type[node_type] = np.arange(n_nodes) - local_global_idx_mapping_per_type[node_type] = {i:j for i,j in zip(z_local_idx_per_type[node_type], z_global_idx_per_type[node_type])} + local_global_idx_mapping_per_type[node_type] = { + i: j + for i, j in zip( + z_local_idx_per_type[node_type], z_global_idx_per_type[node_type] + ) + } if z_all is None: z_all = z else: z_all = np.concatenate([z_all, z], axis=0) - - total_n_nodes+=n_nodes - - + + total_n_nodes += n_nodes + # Apply UMAP to reduce dimensions to 2. reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=2, n_jobs=n_jobs) embedding = reducer.fit_transform(z_all) # Create the scatter plot. plt.figure(figsize=(10, 8)) - - - handles=[] - scatter_handles=[] + + handles = [] + scatter_handles = [] for node_type in node_types: LOGGER.info(f"Plotting {node_type}") - + color = color_per_type.get(node_type, None) node_labels = labels_per_type.get(node_type, None) targets = targets_per_type.get(node_type, None) target_idx = targets_index_per_type.get(node_type, None) filter_idx = filter_index_per_type.get(node_type, None) - + node_idx = z_global_idx_per_type.get(node_type, None) LOGGER.info(f"Node index: {node_idx}") if target_idx is not None: LOGGER.info(f"Target index: {target_idx}") - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in target_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in target_idx + ] if node_labels is not None: - node_labels = node_labels[target_idx] # Needs to be local index + node_labels = node_labels[target_idx] # Needs to be local index if filter_idx is not None: - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in filter_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in filter_idx + ] if node_labels is not None: - node_labels = node_labels[filter_idx] # Needs to be local index - - + node_labels = node_labels[filter_idx] # Needs to be local index + if targets is not None: c = targets - cmap=DEFAULT_CMAP + cmap = DEFAULT_CMAP elif color is not None: - c=color - cmap=None + c = color + cmap = None handles.append(mpatches.Patch(color=color, label=node_type)) - - - x = embedding[node_idx, 0] # Needs to be global index - y = embedding[node_idx, 1] # Needs to be global index - scatter = plt.scatter(x, y, s=10, alpha=0.8, - c=c, - cmap=cmap) - c=None - + + x = embedding[node_idx, 0] # Needs to be global index + y = embedding[node_idx, 1] # Needs to be global index + scatter = plt.scatter(x, y, s=10, alpha=0.8, c=c, cmap=cmap) + c = None + if targets is not None: LOGGER.info(f"Plotting {node_type} targets") scatter_handles.append(scatter) - + if node_labels is not None: LOGGER.info(f"Plotting {node_type} labels, n_labels: {len(node_labels)}") for i, label in enumerate(node_labels): plt.annotate(label, (x[i], y[i]), fontsize=8, alpha=1) - if targets_per_type: - label="" + label = "" for node_type in node_types: - label+=targets_labels_per_type.get(node_type, "") + label += targets_labels_per_type.get(node_type, "") plt.colorbar(scatter_handles[0], label=label) - plt.legend(handles=handles) + plt.legend(handles=handles) plt.title("UMAP Projection of Node Embeddings") plt.xlabel("UMAP 1") plt.ylabel("UMAP 2") @@ -670,17 +703,17 @@ def create_umap_plot(z_per_type, plt.close() - - -def create_umap_plot3d(z_per_type, - targets_per_type: dict = None, - targets_index_per_type: dict = None, - filter_index_per_type: dict = None, - labels_per_type: dict = None, - color_per_type: dict = None, - save_path="umap_3d_plot.png", - n_neighbors=50, - n_jobs=4): +def create_umap_plot3d( + z_per_type, + targets_per_type: dict = None, + targets_index_per_type: dict = None, + filter_index_per_type: dict = None, + labels_per_type: dict = None, + color_per_type: dict = None, + save_path="umap_3d_plot.png", + n_neighbors=50, + n_jobs=4, +): """ Creates a 3D UMAP scatter plot from node embeddings for multiple node types. @@ -694,9 +727,9 @@ def create_umap_plot3d(z_per_type, save_path (str): Path (including filename) to save the plot. n_jobs (int): Number of parallel jobs to run in UMAP. """ - + node_types = list(z_per_type.keys()) - + # Set default dictionaries if None. if targets_per_type is None: targets_per_type = {} @@ -708,7 +741,7 @@ def create_umap_plot3d(z_per_type, labels_per_type = {} if color_per_type is None: color_per_type = {} - + z_global_idx_per_type = {} z_local_idx_per_type = {} local_global_idx_mapping_per_type = {} @@ -721,13 +754,17 @@ def create_umap_plot3d(z_per_type, z = z.detach().cpu().numpy() n_nodes = z.shape[0] LOGGER.info(f"Node type: {node_type}, Number of nodes: {n_nodes}") - + z_node_type_mapping[node_type] = i - z_global_idx_per_type[node_type] = np.arange(total_n_nodes, total_n_nodes + n_nodes) + z_global_idx_per_type[node_type] = np.arange( + total_n_nodes, total_n_nodes + n_nodes + ) z_local_idx_per_type[node_type] = np.arange(n_nodes) local_global_idx_mapping_per_type[node_type] = { - local_idx: global_idx - for local_idx, global_idx in zip(z_local_idx_per_type[node_type], z_global_idx_per_type[node_type]) + local_idx: global_idx + for local_idx, global_idx in zip( + z_local_idx_per_type[node_type], z_global_idx_per_type[node_type] + ) } z_all = z if z_all is None else np.concatenate([z_all, z], axis=0) total_n_nodes += n_nodes @@ -738,37 +775,43 @@ def create_umap_plot3d(z_per_type, # Create the 3D scatter plot. fig = plt.figure(figsize=(10, 8)) - ax = fig.add_subplot(111, projection='3d') - + ax = fig.add_subplot(111, projection="3d") + handles = [] scatter_handles = [] - + for node_type in node_types: LOGGER.info(f"Plotting {node_type}") - + color = color_per_type.get(node_type, None) node_labels = labels_per_type.get(node_type, None) targets = targets_per_type.get(node_type, None) target_idx = targets_index_per_type.get(node_type, None) filter_idx = filter_index_per_type.get(node_type, None) - + # Start with all global indices for this node type. node_idx = z_global_idx_per_type.get(node_type, None) LOGGER.info(f"Node index: {node_idx}") - + # If a target index is specified, map local indices to global indices. if target_idx is not None: LOGGER.info(f"Target index: {target_idx}") - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in target_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in target_idx + ] if node_labels is not None: - node_labels = node_labels[target_idx] # Select labels based on local indices. - + node_labels = node_labels[ + target_idx + ] # Select labels based on local indices. + # Apply additional filtering if provided. if filter_idx is not None: - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in filter_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in filter_idx + ] if node_labels is not None: node_labels = node_labels[filter_idx] - + # Determine the color and colormap. if targets is not None: c = targets @@ -785,12 +828,12 @@ def create_umap_plot3d(z_per_type, x = embedding[node_idx, 0] y = embedding[node_idx, 1] z_coord = embedding[node_idx, 2] - + scatter = ax.scatter(x, y, z_coord, s=10, alpha=0.8, c=c, cmap=cmap) if targets is not None: LOGGER.info(f"Plotting {node_type} targets") scatter_handles.append(scatter) - + # Annotate points with labels if provided. if node_labels is not None: LOGGER.info(f"Plotting {node_type} labels, n_labels: {len(node_labels)}") @@ -800,44 +843,35 @@ def create_umap_plot3d(z_per_type, # Add a colorbar if targets were used for coloring. if targets_per_type and scatter_handles: cbar = plt.colorbar(scatter_handles[0], ax=ax, pad=0.1) - cbar.set_label('Label') - - plt.legend(handles=handles) + cbar.set_label("Label") + + plt.legend(handles=handles) plt.title("3D UMAP Projection of Node Embeddings") ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") ax.set_zlabel("UMAP 3") - + plt.savefig(save_path) plt.close() - if __name__ == "__main__": - + logger = logging.getLogger("__main__") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) - - + logger = logging.getLogger("matgraphdb.pyg.data.hetero_graph") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) main() - - - - - - - - - - - diff --git a/matgraphdb/pyg/models/propinit/train.py b/matgraphdb/pyg/models/propinit/train.py index d593e19..2d614e8 100644 --- a/matgraphdb/pyg/models/propinit/train.py +++ b/matgraphdb/pyg/models/propinit/train.py @@ -1,10 +1,11 @@ +import copy import json import os import time +from collections import defaultdict import matplotlib.pyplot as plt import numpy as np -import copy import pandas as pd import pyarrow.compute as pc import torch @@ -15,19 +16,8 @@ import torch_geometric as pyg import torch_geometric.transforms as T from omegaconf import OmegaConf -from torch_geometric import nn as pyg_nn - -from collections import defaultdict - -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder -from matgraphdb.pyg.models.propinit.model import Model -from torch_sparse import SparseTensor from pyg_lib.partition import metis -from torch_geometric.index import index2ptr, ptr2index -from torch_geometric.loader import DataLoader - # from matgraphdb.pyg.models.heterograph_encoder_general.model import MaterialEdgePredictor # from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( # Trainer, @@ -43,21 +33,29 @@ # roc_curve, # ) from sklearn import linear_model +from torch_geometric import nn as pyg_nn +from torch_geometric.index import index2ptr, ptr2index +from torch_geometric.loader import DataLoader +from torch_sparse import SparseTensor + +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.heterograph_encoder_general.metrics import ( LearningCurve, ROCCurve, plot_pca, ) +from matgraphdb.pyg.models.propinit.model import Model ######################################################################################################################## DATA_CONFIG = OmegaConf.create( { - "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), - "create_random_features": False, - "n_material_dim": 4, - "train_ratio": 0.8, - "val_ratio": 0.1, + "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), + "create_random_features": False, + "n_material_dim": 4, + "train_ratio": 0.8, + "val_ratio": 0.1, } ) @@ -86,27 +84,31 @@ ) -MLP_CONFIG = OmegaConf.create({ - "data": dict(DATA_CONFIG), - "model": { - "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline - }, - "training": { - "learning_rate": 0.001, - "epochs": 4000, +MLP_CONFIG = OmegaConf.create( + { + "data": dict(DATA_CONFIG), + "model": { + "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline + }, + "training": { + "learning_rate": 0.001, + "epochs": 4000, + }, } -}) - -LINEAR_CONFIG = OmegaConf.create({ - "data": dict(DATA_CONFIG), - "model": { - "linear_hidden_dim": 32, # hidden dimension for the linear baseline - }, - "training": { - "learning_rate": 0.001, - "epochs": 4000, +) + +LINEAR_CONFIG = OmegaConf.create( + { + "data": dict(DATA_CONFIG), + "model": { + "linear_hidden_dim": 32, # hidden dimension for the linear baseline + }, + "training": { + "learning_rate": 0.001, + "epochs": 4000, + }, } -}) +) #################################################################################################### @@ -129,13 +131,13 @@ #################################################################################################### def build_heterograph(): """Build the initial heterogeneous graph from the materials database. - + Returns: torch_geometric.data.HeteroData: The constructed heterogeneous graph """ mdb = MPNearHull(DATA_CONFIG.dataset_dir) builder = HeteroGraphBuilder(mdb) - + # Define the "materials" node type (only a subset of columns is used here) builder.add_node_type( "materials", @@ -143,31 +145,33 @@ def build_heterograph(): "core.density_atomic", ], ) - + # Define additional node types. builder.add_node_type( "elements", columns=[ "atomic_mass", - "radius_covalent", + "radius_covalent", "radius_vanderwaals", ], drop_null=True, label_column="symbol", ) builder.add_node_type("space_groups", drop_null=True, label_column="spg") - builder.add_node_type("crystal_systems", drop_null=True, label_column="crystal_system") - + builder.add_node_type( + "crystal_systems", drop_null=True, label_column="crystal_system" + ) + # Define edge types. builder.add_edge_type("element_element_neighborsByGroupPeriod") builder.add_edge_type("material_element_has") builder.add_edge_type("material_spg_has") builder.add_edge_type("material_crystalSystem_has") - + # Define a helper function for target encoding. def to_log(x): return torch.tensor(np.log10(x), dtype=torch.float32) - + # Add a target property for the "materials" node type. builder.add_target_node_property( "materials", @@ -179,26 +183,27 @@ def to_log(x): encoders={"elasticity.g_vrh": to_log}, ) heterodata = builder.hetero_data - heterodata["materials"].original_x = heterodata["materials"].x # Save original features - + heterodata["materials"].original_x = heterodata[ + "materials" + ].x # Save original features + # Optionally, create random features if desired. if DATA_CONFIG.create_random_features: n_materials = heterodata["materials"].num_nodes heterodata["materials"].x = torch.normal( - mean=0.0, - std=1.0, - size=(n_materials, DATA_CONFIG.n_material_dim) + mean=0.0, std=1.0, size=(n_materials, DATA_CONFIG.n_material_dim) ) - + return heterodata + def partition_and_convert_graph(source_data): """Convert heterogeneous graph to homogeneous, partition it, and convert back. - + Args: source_data: The source heterogeneous graph data config: Configuration object containing model parameters - + Returns: heterodata: The processed heterogeneous graph with partitioning """ @@ -206,33 +211,32 @@ def partition_and_convert_graph(source_data): rowptr = index2ptr(homodata.edge_index[0]) col = homodata.edge_index[1] node_partitions = metis( - rowptr=rowptr, - col=col, - num_partitions=PROPINET_CONFIG.model.n_partitions + rowptr=rowptr, col=col, num_partitions=PROPINET_CONFIG.model.n_partitions ) homodata.partition = node_partitions homodata.node_type_id = homodata.node_type heterodata = homodata.to_heterogeneous() - + # Transfer target information and feature vectors. - heterodata['materials'].y = heterodata.y - heterodata['materials'].y_index = heterodata.y_index - heterodata['materials'].target_feature_mask = heterodata.target_feature_mask - + heterodata["materials"].y = heterodata.y + heterodata["materials"].y_index = heterodata.y_index + heterodata["materials"].target_feature_mask = heterodata.target_feature_mask + node_types = source_data.metadata()[0] for nt in node_types: if hasattr(source_data[nt], "x"): heterodata[nt].x = source_data[nt].x - + return heterodata + def split_by_material_nodes(parent_data): """Split material nodes into train/val/test sets and create corresponding subgraphs. - + Args: parent_data: The full heterograph containing all data config: Configuration object containing split ratios - + Returns: Dictionary containing the split subgraphs for train, validation and test sets """ @@ -240,60 +244,71 @@ def split_by_material_nodes(parent_data): n_materials = parent_data["materials"].num_nodes node_ids = parent_data["materials"].node_ids material_indices = torch.randperm(n_materials) - + train_ratio = DATA_CONFIG.train_ratio val_ratio = DATA_CONFIG.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + # Split train and test into their validation sets train_val_materials = total_train_materials[:train_val_size] train_materials = total_train_materials[train_val_size:] test_val_materials = total_test_materials[:test_val_size] test_materials = total_test_materials[test_val_size:] - + print("\nSplit percentages:") print(f"Total: {n_materials}") print(f"Train: {len(train_materials)/n_materials*100:.1f}%") print(f"Train val: {len(train_val_materials)/n_materials*100:.1f}%") print(f"Test: {len(test_materials)/n_materials*100:.1f}%") print(f"Test val: {len(test_val_materials)/n_materials*100:.1f}%") - total_pct = (len(train_materials) + len(train_val_materials) + len(test_materials) + len(test_val_materials)) / n_materials * 100 + total_pct = ( + ( + len(train_materials) + + len(train_val_materials) + + len(test_materials) + + len(test_val_materials) + ) + / n_materials + * 100 + ) print(f"Total: {total_pct:.1f}%\n") - + # Create subgraphs for each split split_dicts = { "train": {"materials": train_materials}, - "train_val": {"materials": train_val_materials}, + "train_val": {"materials": train_val_materials}, "test": {"materials": test_materials}, - "test_val": {"materials": test_val_materials} + "test_val": {"materials": test_val_materials}, } - + split_data = {} for split_name, split_dict in split_dicts.items(): data = parent_data.subgraph(split_dict) - data["materials"].node_ids = parent_data["materials"].node_ids[split_dict["materials"]] + data["materials"].node_ids = parent_data["materials"].node_ids[ + split_dict["materials"] + ] split_data[split_name] = data - + print(split_data["train"]["materials"].node_ids) print(f"Train materials: {len(train_materials)}") print(f"Train val materials: {len(train_val_materials)}") print(f"Test materials: {len(test_materials)}") print(f"Test val materials: {len(test_val_materials)}") - + # For each split, reduce the target values and record indices y_id_map = { int(y_id): float(y) - for y_id, y in zip(parent_data['materials'].y_index, parent_data['materials'].y) + for y_id, y in zip(parent_data["materials"].y_index, parent_data["materials"].y) } - + for data in split_data.values(): y_vals = [] ids = [] @@ -306,48 +321,45 @@ def split_by_material_nodes(parent_data): data["materials"].y = torch.tensor(y_vals) data["materials"].y_node_ids = torch.tensor(node_ids_list) data["materials"].y_split_index = torch.tensor(ids) - - return split_data + return split_data def heterograph_preprocessing(): """ Build the heterograph, apply transformations, partition the graph, and split the 'materials' nodes into training/validation/test subgraphs. - + Args: config (OmegaConf): A configuration object with the keys: - data: data-related parameters (e.g., dataset_dir, create_random_features, n_material_dim, train_ratio, val_ratio) - model: model-related parameters (e.g., n_partitions) - training: training-related parameters - + Returns: split_data (dict): A dictionary with keys "train", "train_val", "test", "test_val", each containing a subgraph for the corresponding split. """ # 1. Build the heterogeneous graph from the materials database - - + original_heterograph = build_heterograph() - - + # 2. Apply transformation: make the graph undirected. source_data = T.ToUndirected()(original_heterograph) # Free up memory. original_heterograph = None - + # (edge_types not used further here but available as metadata[1]) - + heterodata = partition_and_convert_graph(source_data) split_data = split_by_material_nodes(heterodata) - - return split_data, heterodata - + return split_data, heterodata -def learning_curve(metrics_per_split, metric_name, epoch_save_path=None, total_save_path=None): +def learning_curve( + metrics_per_split, metric_name, epoch_save_path=None, total_save_path=None +): learning_curve = LearningCurve() for split_label, metrics_dict in metrics_per_split.items(): @@ -369,7 +381,6 @@ def learning_curve(metrics_per_split, metric_name, epoch_save_path=None, total_s learning_curve.close() - ######################################## # 2. Model Definitions ######################################## @@ -377,56 +388,62 @@ class LinearBaseline(nn.Module): def __init__(self, input_dim, output_dim=1): super(LinearBaseline, self).__init__() self.linear = nn.Linear(input_dim, output_dim) - + def forward(self, x): return self.linear(x) + class MLPBaseline(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim=1): super(MLPBaseline, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), - nn.Linear(hidden_dim, output_dim) + nn.Linear(hidden_dim, output_dim), ) - + def forward(self, x): return self.net(x) + def train_linear_baseline(split_data): """Train a simple linear model using PyTorch.""" input_dim = split_data["train"]["materials"].original_x.shape[1] model = LinearBaseline(input_dim=input_dim) optimizer = optim.Adam(model.parameters(), lr=LINEAR_CONFIG.training.learning_rate) loss_fn = nn.L1Loss() - + # Initialize results storage baseline_results = { "linear": { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } } - + num_epochs = LINEAR_CONFIG.training.epochs for epoch in range(num_epochs): model.train() optimizer.zero_grad() # Training step on the training split - x = split_data["train"]["materials"].original_x[split_data["train"]["materials"].y_split_index] + x = split_data["train"]["materials"].original_x[ + split_data["train"]["materials"].y_split_index + ] y = split_data["train"]["materials"].y y_pred = model(x).squeeze() loss = loss_fn(y_pred, y) loss.backward() optimizer.step() - + # Evaluate on all splits with torch.no_grad(): for split, data_batch in split_data.items(): model.eval() - x_eval = data_batch["materials"].original_x[data_batch["materials"].y_split_index] + x_eval = data_batch["materials"].original_x[ + data_batch["materials"].y_split_index + ] y_eval = data_batch["materials"].y y_pred_eval = model(x_eval).squeeze() # Convert predictions and targets from log-scale to original scale (if applicable) @@ -437,11 +454,12 @@ def train_linear_baseline(split_data): baseline_results["linear"][split]["loss"].append(loss) baseline_results["linear"][split]["mae"].append(mae) baseline_results["linear"][split]["epochs"].append(epoch) - + print(f"[Linear] Epoch {epoch:3d} | Loss: {loss.item():.4f} | MAE: {mae}") - + return baseline_results["linear"] + def train_mlp_baseline(split_data): """Train an MLP baseline model using PyTorch.""" input_dim = split_data["train"]["materials"].original_x.shape[1] @@ -449,34 +467,38 @@ def train_mlp_baseline(split_data): model = MLPBaseline(input_dim=input_dim, hidden_dim=hidden_dim) optimizer = optim.Adam(model.parameters(), lr=MLP_CONFIG.training.learning_rate) loss_fn = nn.L1Loss() - + # Initialize results storage baseline_results = { "mlp": { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } } - + num_epochs = MLP_CONFIG.training.epochs for epoch in range(num_epochs): model.train() optimizer.zero_grad() # Training step on the training split - x = split_data["train"]["materials"].original_x[split_data["train"]["materials"].y_split_index] + x = split_data["train"]["materials"].original_x[ + split_data["train"]["materials"].y_split_index + ] y = split_data["train"]["materials"].y y_pred = model(x).squeeze() loss = loss_fn(y_pred, y) loss.backward() optimizer.step() - + # Evaluate on all splits with torch.no_grad(): for split, data_batch in split_data.items(): model.eval() - x_eval = data_batch["materials"].original_x[data_batch["materials"].y_split_index] + x_eval = data_batch["materials"].original_x[ + data_batch["materials"].y_split_index + ] y_eval = data_batch["materials"].y y_pred_eval = model(x_eval).squeeze() y_pred_orig = 10 ** y_pred_eval.cpu().numpy() @@ -488,19 +510,20 @@ def train_mlp_baseline(split_data): baseline_results["mlp"][split]["loss"].append(loss) baseline_results["mlp"][split]["mae"].append(mae) baseline_results["mlp"][split]["epochs"].append(epoch) - + print(f"[MLP] Epoch {epoch:3d} | Loss: {loss.item():.4f} | MAE: {mae}") - + return baseline_results["mlp"] + def train_propinet(split_data, heterodata): """ Train the Propinet model using the given heterograph and data splits. - + The function trains the model for a specified number of epochs, evaluates performance (MAE and RMSE) on each split (train_val, test, test_val), and records the training loss for the train split. - + Args: split_data (dict): Dictionary with keys "train", "train_val", "test", "test_val" containing the corresponding subgraphs. @@ -531,7 +554,9 @@ def train_propinet(split_data, heterodata): ) print(model) - optimizer = torch.optim.Adam(model.parameters(), lr=PROPINET_CONFIG.training.learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), lr=PROPINET_CONFIG.training.learning_rate + ) loss_fn = nn.L1Loss() def train_step(data_batch): @@ -540,12 +565,12 @@ def train_step(data_batch): """ model.train() optimizer.zero_grad() - + out = model(data_batch).squeeze() y_split_index = data_batch["materials"].y_split_index y_target = data_batch["materials"].y y_pred = out[y_split_index] - + loss = loss_fn(y_pred, y_target) loss.backward() optimizer.step() @@ -561,24 +586,24 @@ def evaluation_step(data_batch): y_split_index = data_batch["materials"].y_split_index y_target = data_batch["materials"].y y_pred = out[y_split_index] - loss=loss_fn(y_pred, y_target) - + loss = loss_fn(y_pred, y_target) + y_pred = y_pred.cpu().numpy() y_target = y_target.cpu().numpy() # Convert predictions from log-scale back to original scale. - y_pred_orig = 10 ** y_pred - y_target_orig = 10 ** y_target - + y_pred_orig = 10**y_pred + y_target_orig = 10**y_target + mae = np.mean(np.abs(y_pred_orig - y_target_orig)) - + return {"loss": loss, "mae": mae} # Initialize the results dictionary. results = { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } num_epochs = PROPINET_CONFIG.training.num_epochs @@ -587,16 +612,16 @@ def evaluation_step(data_batch): train_loss = train_step(split_data["train"]) results["train"]["loss"].append(train_loss) results["train"]["epochs"].append(epoch) - + # Evaluate model on the evaluation splits. for split in ["train", "train_val", "test", "test_val"]: metrics = evaluation_step(split_data[split]) results[split]["mae"].append(metrics["mae"]) results[split]["loss"].append(metrics["loss"]) results[split]["epochs"].append(epoch) - + # scheduler.step() - + # Print progress at the configured evaluation interval. if epoch % PROPINET_CONFIG.training.eval_interval == 0: print( @@ -609,56 +634,54 @@ def evaluation_step(data_batch): return results - def main(): # Create dummy splits: "train", "train_val", "test", "test_val" split_data, heterodata = heterograph_preprocessing() - - + # Run the separate training loops print("Training Linear Baseline...") linear_results = train_linear_baseline(split_data) - + print("Training MLP Baseline...") mlp_results = train_mlp_baseline(split_data) - + print("Training Propinet...") propinet_results = train_propinet(split_data, heterodata) - - - + runs_dir = os.path.join("data", "training_runs", "propinit", "runs") os.makedirs(runs_dir, exist_ok=True) - + n_runs = len(os.listdir(runs_dir)) run_dir = os.path.join(runs_dir, f"run_{n_runs+1}") os.makedirs(run_dir, exist_ok=True) - - + with open(os.path.join(run_dir, "results.json"), "w") as f: - json.dump({ - "linear": linear_results, - "mlp": mlp_results, - "propinet": propinet_results - }, f) - + json.dump( + { + "linear": linear_results, + "mlp": mlp_results, + "propinet": propinet_results, + }, + f, + ) + # Plot all learning curves using the LearningCurve class learning_curve_plot = LearningCurve() - for model_name, results in [("linear", linear_results), ("mlp", mlp_results), ("propinet", propinet_results)]: + for model_name, results in [ + ("linear", linear_results), + ("mlp", mlp_results), + ("propinet", propinet_results), + ]: for split, metrics in results.items(): label = f"{model_name}-{split}" learning_curve_plot.add_curve( - metrics["epochs"], - metrics["mae"], - label, - label, - is_baseline=True + metrics["epochs"], metrics["mae"], label, label, is_baseline=True ) learning_curve_plot.plot() learning_curve_plot.save(os.path.join(run_dir, "learning_curve.png")) learning_curve_plot.close() - - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/matgraphdb/pyg/models/propinit/train_linear.py b/matgraphdb/pyg/models/propinit/train_linear.py index 31e1f39..9ee5125 100644 --- a/matgraphdb/pyg/models/propinit/train_linear.py +++ b/matgraphdb/pyg/models/propinit/train_linear.py @@ -1,10 +1,11 @@ +import copy import json import os import time +from collections import defaultdict import matplotlib.pyplot as plt import numpy as np -import copy import pandas as pd import pyarrow.compute as pc import torch @@ -13,19 +14,7 @@ import torch_geometric as pyg import torch_geometric.transforms as T from omegaconf import OmegaConf -from torch_geometric import nn as pyg_nn - -from collections import defaultdict - -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder -from matgraphdb.pyg.models.heterograph_encoder_general.model import MaterialEdgePredictor -from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( - Trainer, - learning_curve, - pca_plots, - roc_curve, -) +from sklearn import linear_model from sklearn.metrics import ( mean_absolute_error, mean_squared_error, @@ -33,12 +22,24 @@ roc_auc_score, roc_curve, ) -from sklearn import linear_model +from torch_geometric import nn as pyg_nn + +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.heterograph_encoder_general.metrics import ( LearningCurve, ROCCurve, plot_pca, ) +from matgraphdb.pyg.models.heterograph_encoder_general.model import ( + MaterialEdgePredictor, +) +from matgraphdb.pyg.models.heterograph_encoder_general.trainer import ( + Trainer, + learning_curve, + pca_plots, + roc_curve, +) ######################################################################################################################## @@ -56,8 +57,16 @@ "num_test": 0.0, "neg_sampling_ratio": 1.0, "is_undirected": True, - "edge_types": [("materials", "has", "elements"), ("materials", "has", "space_groups"), ("materials", "has", "crystal_systems")], - "rev_edge_types": [("elements", "rev_has", "materials"), ("space_groups", "rev_has", "materials"), ("crystal_systems", "rev_has", "materials")], + "edge_types": [ + ("materials", "has", "elements"), + ("materials", "has", "space_groups"), + ("materials", "has", "crystal_systems"), + ], + "rev_edge_types": [ + ("elements", "rev_has", "materials"), + ("space_groups", "rev_has", "materials"), + ("crystal_systems", "rev_has", "materials"), + ], }, }, "model": { @@ -71,12 +80,14 @@ "use_shallow_embedding_for_materials": False, }, "training": { - "training_dir": os.path.join("data", "training_runs", "heterograph_encoder_general"), + "training_dir": os.path.join( + "data", "training_runs", "heterograph_encoder_general" + ), "learning_rate": 0.001, "num_epochs": 20001, "eval_interval": 1000, "scheduler_milestones": [4000, 20000], - } + }, } ) @@ -107,15 +118,21 @@ material_store = mdb.material_store -df = material_store.read(columns=["elasticity.g_vrh", "elasticity.k_vrh", - "core.volume", "core.density", - "core.density_atomic", "core.nelements", - "core.nsites"], - filters=[ +df = material_store.read( + columns=[ + "elasticity.g_vrh", + "elasticity.k_vrh", + "core.volume", + "core.density", + "core.density_atomic", + "core.nelements", + "core.nsites", + ], + filters=[ pc.field("elasticity.g_vrh") > 0, pc.field("elasticity.g_vrh") < 400, - ]).to_pandas() - + ], +).to_pandas() print("-" * 100) @@ -127,19 +144,24 @@ # y_index = parent_data['materials'].y_index -z = df[["core.volume", "core.density", "core.density_atomic", - "core.nelements", +z = df[ + [ + "core.volume", + "core.density", + "core.density_atomic", + "core.nelements", # "core.nsites" - ]] + ] +] y = df["elasticity.g_vrh"] z = torch.tensor(z.values, dtype=torch.float32) y = torch.tensor(y.values, dtype=torch.float32) perm = torch.randperm(z.size(0)) -train_perm = perm[:int(z.size(0) * CONFIG.data.train_ratio)] -test_perm = perm[int(z.size(0) * CONFIG.data.train_ratio):] -print(f'N train: {len(train_perm)}, N test: {len(test_perm)}') +train_perm = perm[: int(z.size(0) * CONFIG.data.train_ratio)] +test_perm = perm[int(z.size(0) * CONFIG.data.train_ratio) :] +print(f"N train: {len(train_perm)}, N test: {len(test_perm)}") reg = linear_model.LinearRegression() reg.fit(z[train_perm].cpu().numpy(), y[train_perm].cpu().numpy()) @@ -150,5 +172,5 @@ # y_real = np.array([10**value for value in y_real]) rmse = np.sqrt(np.mean((y_pred - y_real) ** 2)) mae = np.mean(np.abs(y_pred - y_real)) -tmp_str = f'RMSE: {rmse:.4f}, MAE: {mae:.4f}|' +tmp_str = f"RMSE: {rmse:.4f}, MAE: {mae:.4f}|" print(tmp_str) diff --git a/matgraphdb/pyg/models/rotate/train.py b/matgraphdb/pyg/models/rotate/train.py index e1c6a95..b107b69 100644 --- a/matgraphdb/pyg/models/rotate/train.py +++ b/matgraphdb/pyg/models/rotate/train.py @@ -1,8 +1,9 @@ import json +import logging import os import time -import logging +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -12,19 +13,19 @@ import torch.optim.lr_scheduler as lr_scheduler import torch_geometric as pyg import torch_geometric.transforms as T + +######################################################################################################################## +import umap from omegaconf import OmegaConf +from sklearn import linear_model from torch_geometric import nn as pyg_nn +from torch_geometric.nn import MetaPath2Vec -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.metapath2vec.metrics import plot_pca -import matplotlib.patches as mpatches -from sklearn import linear_model -from matgraphdb.utils.colors import DEFAULT_COLORS, DEFAULT_CMAP +from matgraphdb.utils.colors import DEFAULT_CMAP, DEFAULT_COLORS from matgraphdb.utils.config import config -from torch_geometric.nn import MetaPath2Vec -######################################################################################################################## -import umap LOGGER = logging.getLogger(__name__) @@ -45,33 +46,37 @@ # LOGGER.addHandler(logging.StreamHandler()) - - - def to_log(x): return torch.tensor(np.log10(x), dtype=torch.float32) - + + DATA_CONFIG = OmegaConf.create( { - "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), - "nodes" : - {"materials": {"columns": ["core.density_atomic"], 'drop_null': True}, - "elements": {"columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], 'drop_null':True, 'label_column': 'symbol'}, - "space_groups": {'drop_null': True, 'label_column': 'spg'}, - "crystal_systems": {'drop_null': True, 'label_column': 'crystal_system'} + "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), + "nodes": { + "materials": {"columns": ["core.density_atomic"], "drop_null": True}, + "elements": { + "columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], + "drop_null": True, + "label_column": "symbol", + }, + "space_groups": {"drop_null": True, "label_column": "spg"}, + "crystal_systems": {"drop_null": True, "label_column": "crystal_system"}, }, - "edges" : - { - "element_element_neighborsByGroupPeriod": {}, - "material_element_has": {}, - "material_spg_has": {}, - "material_crystalSystem_has": {} + "edges": { + "element_element_neighborsByGroupPeriod": {}, + "material_element_has": {}, + "material_spg_has": {}, + "material_crystalSystem_has": {}, + }, + "target": { + "materials": { + "columns": ["elasticity.g_vrh"], + "drop_null": True, + "filters": "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", + "encoders": "{'elasticity.g_vrh': to_log}", + } }, - "target":{ - "materials": {"columns": ["elasticity.g_vrh"], 'drop_null': True, - 'filters': "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", - 'encoders': "{'elasticity.g_vrh': to_log}"} - } } ) @@ -89,7 +94,10 @@ def to_log(x): "sparse": True, # "metapath": [('materials', 'has', 'elements'), ('elements', 'rev_has', 'materials')] # "metapath": [('materials', 'has', 'crystal_systems'), ('crystal_systems', 'rev_has', 'materials')] - "metapath": [('materials', 'has', 'space_groups'), ('space_groups', 'rev_has', 'materials')] + "metapath": [ + ("materials", "has", "space_groups"), + ("space_groups", "rev_has", "materials"), + ], }, "training": { "train_dir": os.path.join("data", "training_runs", "metapath2vec"), @@ -102,23 +110,25 @@ def to_log(x): "log_steps": 100, "eval_steps": 2000, "test_train_ratio": 0.8, - "test_max_iter": 150 - } + "test_max_iter": 150, + }, } ) -MLP_CONFIG = OmegaConf.create({ - "data": dict(DATA_CONFIG), - "model": { - "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline - }, - "training": { - "learning_rate": 0.001, - "train_ratio": 0.8, - "val_ratio": 0.1, - "epochs": 2000, +MLP_CONFIG = OmegaConf.create( + { + "data": dict(DATA_CONFIG), + "model": { + "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline + }, + "training": { + "learning_rate": 0.001, + "train_ratio": 0.8, + "val_ratio": 0.1, + "epochs": 2000, + }, } -}) +) #################################################################################################### #################################################################################################### @@ -139,18 +149,18 @@ def to_log(x): #################################################################################################### def build_heterograph(): """Build the initial heterogeneous graph from the materials database. - + Returns: torch_geometric.data.HeteroData: The constructed heterogeneous graph """ mdb = MPNearHull(DATA_CONFIG.dataset_dir) builder = HeteroGraphBuilder(mdb) - + # Define the "materials" node type (only a subset of columns is used here) for node_type, node_config in DATA_CONFIG.nodes.items(): node_config = OmegaConf.to_container(node_config) builder.add_node_type(node_type, **node_config) - + for edge_type, edge_config in DATA_CONFIG.edges.items(): edge_config = OmegaConf.to_container(edge_config) builder.add_edge_type(edge_type, **edge_config) @@ -164,12 +174,16 @@ def build_heterograph(): if "encoders" in target_config: encoders = target_config.pop("encoders") encoders = eval(encoders) - - builder.add_target_node_property(target_type, filters=filters, encoders=encoders, **target_config) - + + builder.add_target_node_property( + target_type, filters=filters, encoders=encoders, **target_config + ) + heterodata = builder.hetero_data LOGGER.info(f"HeteroData: {heterodata}") - heterodata["materials"].original_x = heterodata["materials"].x # Save original features + heterodata["materials"].original_x = heterodata[ + "materials" + ].x # Save original features return heterodata @@ -177,23 +191,21 @@ def heterograph_preprocessing(): """ Build the heterograph, apply transformations, partition the graph, and split the 'materials' nodes into training/validation/test subgraphs. - + Args: config (OmegaConf): A configuration object with the keys: - data: data-related parameters (e.g., dataset_dir, create_random_features, n_material_dim, train_ratio, val_ratio) - model: model-related parameters (e.g., n_partitions) - training: training-related parameters - + Returns: split_data (dict): A dictionary with keys "train", "train_val", "test", "test_val", each containing a subgraph for the corresponding split. """ # 1. Build the heterogeneous graph from the materials database - - + original_heterograph = build_heterograph() - - + # 2. Apply transformation: make the graph undirected. source_data = T.ToUndirected()(original_heterograph) # Free up memory. @@ -205,38 +217,41 @@ def heterograph_preprocessing(): # # Model # #################################################################################################### + class MLPBaseline(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim=1): super(MLPBaseline, self).__init__() self.net = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), - torch.nn.Linear(hidden_dim, output_dim) + torch.nn.Linear(hidden_dim, output_dim), ) - + def forward(self, x): return self.net(x) def train_mlp_baseline(heterodata, metapath2vec_model): - z = metapath2vec_model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y.to(DEVICE).squeeze() - + z = metapath2vec_model( + "materials", batch=heterodata["materials"].y_index.to(DEVICE) + ) + y = heterodata["materials"].y.to(DEVICE).squeeze() + material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = MLP_CONFIG.training.train_ratio val_ratio = MLP_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -255,19 +270,21 @@ def train_mlp_baseline(heterodata, metapath2vec_model): input_dim = z.shape[1] hidden_dim = MLP_CONFIG.model.mlp_hidden_dim model = MLPBaseline(input_dim=input_dim, hidden_dim=hidden_dim).to(DEVICE) - optimizer = torch.optim.Adam(model.parameters(), lr=MLP_CONFIG.training.learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), lr=MLP_CONFIG.training.learning_rate + ) loss_fn = torch.nn.L1Loss() - + # Initialize results storage results = { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } def train_step(): - + model.train() optimizer.zero_grad() # Move this here, before the forward pass @@ -276,27 +293,25 @@ def train_step(): loss = loss_fn(y_pred, y[split_data["train"]]) loss.backward(retain_graph=True) optimizer.step() - + total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss)) @torch.no_grad() def test_step(): model.eval() - + for split_name, split_materials in split_data.items(): y_pred = model(z[split_materials]).squeeze().cpu().numpy() y_real = y[split_materials].cpu().numpy() - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) mae = np.mean(np.abs(y_pred - y_real)) results[split_name]["mae"].append(float(mae)) - - for epoch in range(MLP_CONFIG.training.epochs): train_step() test_step() @@ -304,42 +319,51 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - + return results + def train_metapath2vec(heterodata): - metapath=[] + metapath = [] for path in METAPATH2VEC_CONFIG.model.metapath: metapath.append(tuple(path)) - - num_nodes_dict = {node_type: heterodata[node_type].num_nodes for node_type in heterodata.node_types} - model = MetaPath2Vec(heterodata.edge_index_dict, - embedding_dim=METAPATH2VEC_CONFIG.model.embedding_dim, - metapath=metapath, - walk_length=METAPATH2VEC_CONFIG.model.walk_length, - context_size=METAPATH2VEC_CONFIG.model.context_size, - walks_per_node=METAPATH2VEC_CONFIG.model.walks_per_node, - num_negative_samples=METAPATH2VEC_CONFIG.model.num_negative_samples, - sparse=METAPATH2VEC_CONFIG.model.sparse, - num_nodes_dict=num_nodes_dict).to(DEVICE) - - loader = model.loader(batch_size=METAPATH2VEC_CONFIG.training.batch_size, - shuffle=True, - num_workers=METAPATH2VEC_CONFIG.training.num_workers) + + num_nodes_dict = { + node_type: heterodata[node_type].num_nodes + for node_type in heterodata.node_types + } + model = MetaPath2Vec( + heterodata.edge_index_dict, + embedding_dim=METAPATH2VEC_CONFIG.model.embedding_dim, + metapath=metapath, + walk_length=METAPATH2VEC_CONFIG.model.walk_length, + context_size=METAPATH2VEC_CONFIG.model.context_size, + walks_per_node=METAPATH2VEC_CONFIG.model.walks_per_node, + num_negative_samples=METAPATH2VEC_CONFIG.model.num_negative_samples, + sparse=METAPATH2VEC_CONFIG.model.sparse, + num_nodes_dict=num_nodes_dict, + ).to(DEVICE) + + loader = model.loader( + batch_size=METAPATH2VEC_CONFIG.training.batch_size, + shuffle=True, + num_workers=METAPATH2VEC_CONFIG.training.num_workers, + ) print(model) - optimizer = torch.optim.SparseAdam(list(model.parameters()), - lr=METAPATH2VEC_CONFIG.training.learning_rate) + optimizer = torch.optim.SparseAdam( + list(model.parameters()), lr=METAPATH2VEC_CONFIG.training.learning_rate + ) results = { - "train": {"mae": [], "loss": [], "epochs": []}, + "train": {"mae": [], "loss": [], "epochs": []}, "train_val": {"mae": [], "loss": [], "epochs": []}, - "test": {"mae": [], "loss": [], "epochs": []}, - "test_val": {"mae": [], "loss": [], "epochs": []}, + "test": {"mae": [], "loss": [], "epochs": []}, + "test_val": {"mae": [], "loss": [], "epochs": []}, } def train_step(): @@ -353,31 +377,31 @@ def train_step(): optimizer.step() total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss / len(loader))) @torch.no_grad() def test_step(): model.eval() - z = model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y + z = model("materials", batch=heterodata["materials"].y_index.to(DEVICE)) + y = heterodata["materials"].y material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = METAPATH2VEC_CONFIG.training.train_ratio val_ratio = METAPATH2VEC_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -392,18 +416,20 @@ def test_step(): "test": test_materials, "test_val": test_val_materials, } - + reg = linear_model.LinearRegression() - reg.fit(z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy()) - + reg.fit( + z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy() + ) + for split_name, split_materials in split_data.items(): y_pred = reg.predict(z[split_materials].cpu().numpy()) y_real = y[split_materials].cpu().numpy() - + if split_name != "train": loss = np.mean(np.abs(y_pred - y_real)) results[split_name]["loss"].append(float(loss)) - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) @@ -419,14 +445,13 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - - return model, results + return model, results def main(): @@ -434,9 +459,9 @@ def main(): model, linear_results = train_metapath2vec(heterodata) mlp_results = train_mlp_baseline(heterodata, model) - + training_dir = METAPATH2VEC_CONFIG.training.train_dir - + runs_dir = os.path.join(training_dir, "runs") os.makedirs(runs_dir, exist_ok=True) n_runs = len(os.listdir(runs_dir)) @@ -448,37 +473,39 @@ def main(): with open(os.path.join(results_dir, "mlp_config.json"), "w") as f: json.dump(OmegaConf.to_container(MLP_CONFIG), f) - + with open(os.path.join(results_dir, "linear_results.json"), "w") as f: json.dump(linear_results, f) - + with open(os.path.join(results_dir, "mlp_results.json"), "w") as f: json.dump(mlp_results, f) - - plot_learning_curves(linear_results, os.path.join(results_dir, "linear_learning_curves.png")) - plot_learning_curves(mlp_results, os.path.join(results_dir, "mlp_learning_curves.png")) - - - + + plot_learning_curves( + linear_results, os.path.join(results_dir, "linear_learning_curves.png") + ) + plot_learning_curves( + mlp_results, os.path.join(results_dir, "mlp_learning_curves.png") + ) + z_per_type = { - "materials": model('materials'), + "materials": model("materials"), # "elements": model('elements'), - "space_groups": model('space_groups'), + "space_groups": model("space_groups"), # "crystal_systems": model('crystal_systems'), } targets_per_type = { - "materials": 10 ** heterodata['materials'].y.cpu().numpy(), + "materials": 10 ** heterodata["materials"].y.cpu().numpy(), } targets_labels_per_type = { - "materials": heterodata['materials'].y_label_name[0], + "materials": heterodata["materials"].y_label_name[0], } targets_index_per_type = { - "materials": heterodata['materials'].y_index.cpu().numpy(), + "materials": heterodata["materials"].y_index.cpu().numpy(), } LOGGER.info(f"Targets index per type: {len(heterodata['elements'].labels)}") labels_per_type = { - "elements": heterodata['elements'].labels, - "space_groups": heterodata['space_groups'].labels, + "elements": heterodata["elements"].labels, + "space_groups": heterodata["space_groups"].labels, # "crystal_systems": heterodata['crystal_systems'].labels, } color_per_type = { @@ -486,25 +513,27 @@ def main(): "space_groups": "black", # "crystal_systems": "black", } - - create_umap_plot(z_per_type, - targets_per_type=targets_per_type, - targets_index_per_type=targets_index_per_type, - targets_labels_per_type=targets_labels_per_type, - labels_per_type=labels_per_type, - color_per_type=color_per_type, - save_path=os.path.join(results_dir, "umap.png"), - n_neighbors=30) - # create_umap_plot3d(z_per_type, + + create_umap_plot( + z_per_type, + targets_per_type=targets_per_type, + targets_index_per_type=targets_index_per_type, + targets_labels_per_type=targets_labels_per_type, + labels_per_type=labels_per_type, + color_per_type=color_per_type, + save_path=os.path.join(results_dir, "umap.png"), + n_neighbors=30, + ) + # create_umap_plot3d(z_per_type, # targets_per_type=targets_per_type, # targets_index_per_type=targets_index_per_type, # labels_per_type=labels_per_type, # color_per_type=color_per_type, # save_path=os.path.join(results_dir, "umap_materials_elements_3d.png"), # n_neighbors=30) - -def plot_learning_curves(results, save_path, measure='mae'): + +def plot_learning_curves(results, save_path, measure="mae"): """ Plots the learning curves for a specified measure from the results dictionary. @@ -515,23 +544,25 @@ def plot_learning_curves(results, save_path, measure='mae'): measure (str): The measure to plot (e.g., 'loss' or 'mae'). Default is 'loss'. """ plt.figure(figsize=(10, 6)) - + # Iterate over the splits in the results dictionary for idx, split in enumerate(results): split_data = results[split] - + # Check if the desired measure is available in this split's data if measure not in split_data: - print(f"Warning: Measure '{measure}' not found for split '{split}'. Skipping.") + print( + f"Warning: Measure '{measure}' not found for split '{split}'. Skipping." + ) continue # Use the provided 'epochs' list if available, otherwise create a range based on the measure length epochs = split_data.get("epochs", list(range(len(split_data[measure])))) values = split_data[measure] - + # Select a color for this plot color = DEFAULT_COLORS[idx % len(DEFAULT_COLORS)] - + # Plot the curve for this split plt.plot(epochs, values, label=split, color=color, linewidth=2) @@ -545,20 +576,20 @@ def plot_learning_curves(results, save_path, measure='mae'): plt.close() - - -def create_umap_plot(z_per_type, - targets_per_type:dict=None, - targets_index_per_type:dict=None, - targets_labels_per_type:dict=None, - filter_index_per_type:dict=None, - labels_per_type:dict=None, - color_per_type:dict=None, - save_path=".", - n_neighbors=50, - n_jobs=4): +def create_umap_plot( + z_per_type, + targets_per_type: dict = None, + targets_index_per_type: dict = None, + targets_labels_per_type: dict = None, + filter_index_per_type: dict = None, + labels_per_type: dict = None, + color_per_type: dict = None, + save_path=".", + n_neighbors=50, + n_jobs=4, +): node_types = list(z_per_type.keys()) - + if targets_per_type is None: targets_per_type = {} if targets_index_per_type is None: @@ -571,98 +602,100 @@ def create_umap_plot(z_per_type, color_per_type = {} if targets_labels_per_type is None: targets_labels_per_type = {} - - z_global_idx_per_type={} - z_local_idx_per_type={} - local_global_idx_mapping_per_type={} + + z_global_idx_per_type = {} + z_local_idx_per_type = {} + local_global_idx_mapping_per_type = {} z_node_type_mapping = {} - - z_all=None - total_n_nodes=0 + z_all = None + total_n_nodes = 0 for i, (node_type, z) in enumerate(z_per_type.items()): - z=z.detach().cpu().numpy() + z = z.detach().cpu().numpy() n_nodes = z.shape[0] - - + LOGGER.info(f"Node type: {node_type}, Number of nodes: {n_nodes}") - + z_node_type_mapping[node_type] = i - z_global_idx_per_type[node_type] = np.arange(total_n_nodes, total_n_nodes + n_nodes) + z_global_idx_per_type[node_type] = np.arange( + total_n_nodes, total_n_nodes + n_nodes + ) z_local_idx_per_type[node_type] = np.arange(n_nodes) - local_global_idx_mapping_per_type[node_type] = {i:j for i,j in zip(z_local_idx_per_type[node_type], z_global_idx_per_type[node_type])} + local_global_idx_mapping_per_type[node_type] = { + i: j + for i, j in zip( + z_local_idx_per_type[node_type], z_global_idx_per_type[node_type] + ) + } if z_all is None: z_all = z else: z_all = np.concatenate([z_all, z], axis=0) - - total_n_nodes+=n_nodes - - + + total_n_nodes += n_nodes + # Apply UMAP to reduce dimensions to 2. reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=2, n_jobs=n_jobs) embedding = reducer.fit_transform(z_all) # Create the scatter plot. plt.figure(figsize=(10, 8)) - - - handles=[] - scatter_handles=[] + + handles = [] + scatter_handles = [] for node_type in node_types: LOGGER.info(f"Plotting {node_type}") - + color = color_per_type.get(node_type, None) node_labels = labels_per_type.get(node_type, None) targets = targets_per_type.get(node_type, None) target_idx = targets_index_per_type.get(node_type, None) filter_idx = filter_index_per_type.get(node_type, None) - + node_idx = z_global_idx_per_type.get(node_type, None) LOGGER.info(f"Node index: {node_idx}") if target_idx is not None: LOGGER.info(f"Target index: {target_idx}") - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in target_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in target_idx + ] if node_labels is not None: - node_labels = node_labels[target_idx] # Needs to be local index + node_labels = node_labels[target_idx] # Needs to be local index if filter_idx is not None: - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in filter_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in filter_idx + ] if node_labels is not None: - node_labels = node_labels[filter_idx] # Needs to be local index - - + node_labels = node_labels[filter_idx] # Needs to be local index + if targets is not None: c = targets - cmap=DEFAULT_CMAP + cmap = DEFAULT_CMAP elif color is not None: - c=color - cmap=None + c = color + cmap = None handles.append(mpatches.Patch(color=color, label=node_type)) - - - x = embedding[node_idx, 0] # Needs to be global index - y = embedding[node_idx, 1] # Needs to be global index - scatter = plt.scatter(x, y, s=10, alpha=0.8, - c=c, - cmap=cmap) - c=None - + + x = embedding[node_idx, 0] # Needs to be global index + y = embedding[node_idx, 1] # Needs to be global index + scatter = plt.scatter(x, y, s=10, alpha=0.8, c=c, cmap=cmap) + c = None + if targets is not None: LOGGER.info(f"Plotting {node_type} targets") scatter_handles.append(scatter) - + if node_labels is not None: LOGGER.info(f"Plotting {node_type} labels, n_labels: {len(node_labels)}") for i, label in enumerate(node_labels): plt.annotate(label, (x[i], y[i]), fontsize=8, alpha=1) - if targets_per_type: - label="" + label = "" for node_type in node_types: - label+=targets_labels_per_type.get(node_type, "") + label += targets_labels_per_type.get(node_type, "") plt.colorbar(scatter_handles[0], label=label) - plt.legend(handles=handles) + plt.legend(handles=handles) plt.title("UMAP Projection of Node Embeddings") plt.xlabel("UMAP 1") plt.ylabel("UMAP 2") @@ -670,17 +703,17 @@ def create_umap_plot(z_per_type, plt.close() - - -def create_umap_plot3d(z_per_type, - targets_per_type: dict = None, - targets_index_per_type: dict = None, - filter_index_per_type: dict = None, - labels_per_type: dict = None, - color_per_type: dict = None, - save_path="umap_3d_plot.png", - n_neighbors=50, - n_jobs=4): +def create_umap_plot3d( + z_per_type, + targets_per_type: dict = None, + targets_index_per_type: dict = None, + filter_index_per_type: dict = None, + labels_per_type: dict = None, + color_per_type: dict = None, + save_path="umap_3d_plot.png", + n_neighbors=50, + n_jobs=4, +): """ Creates a 3D UMAP scatter plot from node embeddings for multiple node types. @@ -694,9 +727,9 @@ def create_umap_plot3d(z_per_type, save_path (str): Path (including filename) to save the plot. n_jobs (int): Number of parallel jobs to run in UMAP. """ - + node_types = list(z_per_type.keys()) - + # Set default dictionaries if None. if targets_per_type is None: targets_per_type = {} @@ -708,7 +741,7 @@ def create_umap_plot3d(z_per_type, labels_per_type = {} if color_per_type is None: color_per_type = {} - + z_global_idx_per_type = {} z_local_idx_per_type = {} local_global_idx_mapping_per_type = {} @@ -721,13 +754,17 @@ def create_umap_plot3d(z_per_type, z = z.detach().cpu().numpy() n_nodes = z.shape[0] LOGGER.info(f"Node type: {node_type}, Number of nodes: {n_nodes}") - + z_node_type_mapping[node_type] = i - z_global_idx_per_type[node_type] = np.arange(total_n_nodes, total_n_nodes + n_nodes) + z_global_idx_per_type[node_type] = np.arange( + total_n_nodes, total_n_nodes + n_nodes + ) z_local_idx_per_type[node_type] = np.arange(n_nodes) local_global_idx_mapping_per_type[node_type] = { - local_idx: global_idx - for local_idx, global_idx in zip(z_local_idx_per_type[node_type], z_global_idx_per_type[node_type]) + local_idx: global_idx + for local_idx, global_idx in zip( + z_local_idx_per_type[node_type], z_global_idx_per_type[node_type] + ) } z_all = z if z_all is None else np.concatenate([z_all, z], axis=0) total_n_nodes += n_nodes @@ -738,37 +775,43 @@ def create_umap_plot3d(z_per_type, # Create the 3D scatter plot. fig = plt.figure(figsize=(10, 8)) - ax = fig.add_subplot(111, projection='3d') - + ax = fig.add_subplot(111, projection="3d") + handles = [] scatter_handles = [] - + for node_type in node_types: LOGGER.info(f"Plotting {node_type}") - + color = color_per_type.get(node_type, None) node_labels = labels_per_type.get(node_type, None) targets = targets_per_type.get(node_type, None) target_idx = targets_index_per_type.get(node_type, None) filter_idx = filter_index_per_type.get(node_type, None) - + # Start with all global indices for this node type. node_idx = z_global_idx_per_type.get(node_type, None) LOGGER.info(f"Node index: {node_idx}") - + # If a target index is specified, map local indices to global indices. if target_idx is not None: LOGGER.info(f"Target index: {target_idx}") - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in target_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in target_idx + ] if node_labels is not None: - node_labels = node_labels[target_idx] # Select labels based on local indices. - + node_labels = node_labels[ + target_idx + ] # Select labels based on local indices. + # Apply additional filtering if provided. if filter_idx is not None: - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in filter_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in filter_idx + ] if node_labels is not None: node_labels = node_labels[filter_idx] - + # Determine the color and colormap. if targets is not None: c = targets @@ -785,12 +828,12 @@ def create_umap_plot3d(z_per_type, x = embedding[node_idx, 0] y = embedding[node_idx, 1] z_coord = embedding[node_idx, 2] - + scatter = ax.scatter(x, y, z_coord, s=10, alpha=0.8, c=c, cmap=cmap) if targets is not None: LOGGER.info(f"Plotting {node_type} targets") scatter_handles.append(scatter) - + # Annotate points with labels if provided. if node_labels is not None: LOGGER.info(f"Plotting {node_type} labels, n_labels: {len(node_labels)}") @@ -800,44 +843,35 @@ def create_umap_plot3d(z_per_type, # Add a colorbar if targets were used for coloring. if targets_per_type and scatter_handles: cbar = plt.colorbar(scatter_handles[0], ax=ax, pad=0.1) - cbar.set_label('Label') - - plt.legend(handles=handles) + cbar.set_label("Label") + + plt.legend(handles=handles) plt.title("3D UMAP Projection of Node Embeddings") ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") ax.set_zlabel("UMAP 3") - + plt.savefig(save_path) plt.close() - if __name__ == "__main__": - + logger = logging.getLogger("__main__") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) - - + logger = logging.getLogger("matgraphdb.pyg.data.hetero_graph") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) main() - - - - - - - - - - - diff --git a/matgraphdb/pyg/models/transe/train.py b/matgraphdb/pyg/models/transe/train.py index 4dbf775..d836118 100644 --- a/matgraphdb/pyg/models/transe/train.py +++ b/matgraphdb/pyg/models/transe/train.py @@ -1,8 +1,9 @@ import json +import logging import os import time -import logging +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -12,19 +13,19 @@ import torch.optim.lr_scheduler as lr_scheduler import torch_geometric as pyg import torch_geometric.transforms as T + +######################################################################################################################## +import umap from omegaconf import OmegaConf +from sklearn import linear_model from torch_geometric import nn as pyg_nn +from torch_geometric.nn import MetaPath2Vec, TransE -from matgraphdb.materials.datasets.mp_near_hull import MPNearHull -from matgraphdb.pyg.data import HeteroGraphBuilder +from matgraphdb.core.datasets.mp_near_hull import MPNearHull +from matgraphdb.pyg.builders import HeteroGraphBuilder from matgraphdb.pyg.models.metapath2vec.metrics import plot_pca -import matplotlib.patches as mpatches -from sklearn import linear_model -from matgraphdb.utils.colors import DEFAULT_COLORS, DEFAULT_CMAP +from matgraphdb.utils.colors import DEFAULT_CMAP, DEFAULT_COLORS from matgraphdb.utils.config import config -from torch_geometric.nn import MetaPath2Vec, TransE -######################################################################################################################## -import umap LOGGER = logging.getLogger(__name__) @@ -45,33 +46,37 @@ # LOGGER.addHandler(logging.StreamHandler()) - - - def to_log(x): return torch.tensor(np.log10(x), dtype=torch.float32) - + + DATA_CONFIG = OmegaConf.create( { - "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), - "nodes" : - {"materials": {"columns": ["core.density_atomic"], 'drop_null': True}, - "elements": {"columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], 'drop_null':True, 'label_column': 'symbol'}, - "space_groups": {'drop_null': True, 'label_column': 'spg'}, - "crystal_systems": {'drop_null': True, 'label_column': 'crystal_system'} + "dataset_dir": os.path.join("data", "datasets", "MPNearHull"), + "nodes": { + "materials": {"columns": ["core.density_atomic"], "drop_null": True}, + "elements": { + "columns": ["atomic_mass", "radius_covalent", "radius_vanderwaals"], + "drop_null": True, + "label_column": "symbol", + }, + "space_groups": {"drop_null": True, "label_column": "spg"}, + "crystal_systems": {"drop_null": True, "label_column": "crystal_system"}, }, - "edges" : - { - "element_element_neighborsByGroupPeriod": {}, - "material_element_has": {}, - "material_spg_has": {}, - "material_crystalSystem_has": {} + "edges": { + "element_element_neighborsByGroupPeriod": {}, + "material_element_has": {}, + "material_spg_has": {}, + "material_crystalSystem_has": {}, + }, + "target": { + "materials": { + "columns": ["elasticity.g_vrh"], + "drop_null": True, + "filters": "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", + "encoders": "{'elasticity.g_vrh': to_log}", + } }, - "target":{ - "materials": {"columns": ["elasticity.g_vrh"], 'drop_null': True, - 'filters': "[pc.field('elasticity.g_vrh') > 0, pc.field('elasticity.g_vrh') < 400]", - 'encoders': "{'elasticity.g_vrh': to_log}"} - } } ) @@ -79,7 +84,7 @@ def to_log(x): { "data": dict(DATA_CONFIG), "model": { - "hidden_channels": 32, # "embedding_dim": 4, + "hidden_channels": 32, # "embedding_dim": 4, "margin": 1.0, "p_norm": 1.0, "sparse": False, @@ -95,23 +100,25 @@ def to_log(x): "log_steps": 100, "eval_steps": 2000, "test_train_ratio": 0.8, - "test_max_iter": 150 - } + "test_max_iter": 150, + }, } ) -MLP_CONFIG = OmegaConf.create({ - "data": dict(DATA_CONFIG), - "model": { - "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline - }, - "training": { - "learning_rate": 0.001, - "train_ratio": 0.8, - "val_ratio": 0.1, - "epochs": 2000, +MLP_CONFIG = OmegaConf.create( + { + "data": dict(DATA_CONFIG), + "model": { + "mlp_hidden_dim": 32, # hidden dimension for the MLP baseline + }, + "training": { + "learning_rate": 0.001, + "train_ratio": 0.8, + "val_ratio": 0.1, + "epochs": 2000, + }, } -}) +) #################################################################################################### #################################################################################################### @@ -132,18 +139,18 @@ def to_log(x): #################################################################################################### def build_heterograph(): """Build the initial heterogeneous graph from the materials database. - + Returns: torch_geometric.data.HeteroData: The constructed heterogeneous graph """ mdb = MPNearHull(DATA_CONFIG.dataset_dir) builder = HeteroGraphBuilder(mdb) - + # Define the "materials" node type (only a subset of columns is used here) for node_type, node_config in DATA_CONFIG.nodes.items(): node_config = OmegaConf.to_container(node_config) builder.add_node_type(node_type, **node_config) - + for edge_type, edge_config in DATA_CONFIG.edges.items(): edge_config = OmegaConf.to_container(edge_config) builder.add_edge_type(edge_type, **edge_config) @@ -157,21 +164,26 @@ def build_heterograph(): if "encoders" in target_config: encoders = target_config.pop("encoders") encoders = eval(encoders) - - builder.add_target_node_property(target_type, filters=filters, encoders=encoders, **target_config) - + + builder.add_target_node_property( + target_type, filters=filters, encoders=encoders, **target_config + ) + heterodata = builder.hetero_data LOGGER.info(f"HeteroData: {heterodata}") - heterodata["materials"].original_x = heterodata["materials"].x # Save original features + heterodata["materials"].original_x = heterodata[ + "materials" + ].x # Save original features return heterodata + def split_by_material_nodes(parent_data): """Split material nodes into train/val/test sets and create corresponding subgraphs. - + Args: parent_data: The full heterograph containing all data config: Configuration object containing split ratios - + Returns: Dictionary containing the split subgraphs for train, validation and test sets """ @@ -179,60 +191,71 @@ def split_by_material_nodes(parent_data): n_materials = parent_data["materials"].num_nodes node_ids = parent_data["materials"].node_ids material_indices = torch.randperm(n_materials) - + train_ratio = TRANSE_CONFIG.training.train_ratio val_ratio = TRANSE_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + # Split train and test into their validation sets train_val_materials = total_train_materials[:train_val_size] train_materials = total_train_materials[train_val_size:] test_val_materials = total_test_materials[:test_val_size] test_materials = total_test_materials[test_val_size:] - + print("\nSplit percentages:") print(f"Total: {n_materials}") print(f"Train: {len(train_materials)/n_materials*100:.1f}%") print(f"Train val: {len(train_val_materials)/n_materials*100:.1f}%") print(f"Test: {len(test_materials)/n_materials*100:.1f}%") print(f"Test val: {len(test_val_materials)/n_materials*100:.1f}%") - total_pct = (len(train_materials) + len(train_val_materials) + len(test_materials) + len(test_val_materials)) / n_materials * 100 + total_pct = ( + ( + len(train_materials) + + len(train_val_materials) + + len(test_materials) + + len(test_val_materials) + ) + / n_materials + * 100 + ) print(f"Total: {total_pct:.1f}%\n") - + # Create subgraphs for each split split_dicts = { "train": {"materials": train_materials}, - "train_val": {"materials": train_val_materials}, + "train_val": {"materials": train_val_materials}, "test": {"materials": test_materials}, - "test_val": {"materials": test_val_materials} + "test_val": {"materials": test_val_materials}, } - + split_data = {} for split_name, split_dict in split_dicts.items(): data = parent_data.subgraph(split_dict) - data["materials"].node_ids = parent_data["materials"].node_ids[split_dict["materials"]] + data["materials"].node_ids = parent_data["materials"].node_ids[ + split_dict["materials"] + ] split_data[split_name] = data - + print(split_data["train"]["materials"].node_ids) print(f"Train materials: {len(train_materials)}") print(f"Train val materials: {len(train_val_materials)}") print(f"Test materials: {len(test_materials)}") print(f"Test val materials: {len(test_val_materials)}") - + # For each split, reduce the target values and record indices y_id_map = { int(y_id): float(y) - for y_id, y in zip(parent_data['materials'].y_index, parent_data['materials'].y) + for y_id, y in zip(parent_data["materials"].y_index, parent_data["materials"].y) } - + for data in split_data.values(): y_vals = [] ids = [] @@ -245,42 +268,41 @@ def split_by_material_nodes(parent_data): data["materials"].y = torch.tensor(y_vals) data["materials"].y_node_ids = torch.tensor(node_ids_list) data["materials"].y_split_index = torch.tensor(ids) - + return split_data + def heterograph_preprocessing(): """ Build the heterograph, apply transformations, partition the graph, and split the 'materials' nodes into training/validation/test subgraphs. - + Args: config (OmegaConf): A configuration object with the keys: - data: data-related parameters (e.g., dataset_dir, create_random_features, n_material_dim, train_ratio, val_ratio) - model: model-related parameters (e.g., n_partitions) - training: training-related parameters - + Returns: split_data (dict): A dictionary with keys "train", "train_val", "test", "test_val", each containing a subgraph for the corresponding split. """ # 1. Build the heterogeneous graph from the materials database - - + original_heterograph = build_heterograph() - - + # 2. Apply transformation: make the graph undirected. - + source_data = T.ToUndirected()(original_heterograph) # Free up memory. original_heterograph = None # source_data = source_data.to_homogeneous() - + split_data = split_by_material_nodes(source_data) - + for split_name, split_data in split_data.items(): split_data = split_data.to_homogeneous() - + return split_data @@ -288,38 +310,41 @@ def heterograph_preprocessing(): # # Model # #################################################################################################### + class MLPBaseline(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim=1): super(MLPBaseline, self).__init__() self.net = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), - torch.nn.Linear(hidden_dim, output_dim) + torch.nn.Linear(hidden_dim, output_dim), ) - + def forward(self, x): return self.net(x) def train_mlp_baseline(heterodata, metapath2vec_model): - z = metapath2vec_model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y.to(DEVICE).squeeze() - + z = metapath2vec_model( + "materials", batch=heterodata["materials"].y_index.to(DEVICE) + ) + y = heterodata["materials"].y.to(DEVICE).squeeze() + material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = MLP_CONFIG.training.train_ratio val_ratio = MLP_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -338,19 +363,21 @@ def train_mlp_baseline(heterodata, metapath2vec_model): input_dim = z.shape[1] hidden_dim = MLP_CONFIG.model.mlp_hidden_dim model = MLPBaseline(input_dim=input_dim, hidden_dim=hidden_dim).to(DEVICE) - optimizer = torch.optim.Adam(model.parameters(), lr=MLP_CONFIG.training.learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), lr=MLP_CONFIG.training.learning_rate + ) loss_fn = torch.nn.L1Loss() - + # Initialize results storage results = { - "train": {"loss": [], "mae": [], "epochs": []}, + "train": {"loss": [], "mae": [], "epochs": []}, "train_val": {"loss": [], "mae": [], "epochs": []}, - "test": {"loss": [], "mae": [], "epochs": []}, - "test_val": {"loss": [], "mae": [], "epochs": []}, + "test": {"loss": [], "mae": [], "epochs": []}, + "test_val": {"loss": [], "mae": [], "epochs": []}, } def train_step(): - + model.train() optimizer.zero_grad() # Move this here, before the forward pass @@ -359,27 +386,25 @@ def train_step(): loss = loss_fn(y_pred, y[split_data["train"]]) loss.backward(retain_graph=True) optimizer.step() - + total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss)) @torch.no_grad() def test_step(): model.eval() - + for split_name, split_materials in split_data.items(): y_pred = model(z[split_materials]).squeeze().cpu().numpy() y_real = y[split_materials].cpu().numpy() - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) mae = np.mean(np.abs(y_pred - y_real)) results[split_name]["mae"].append(float(mae)) - - for epoch in range(MLP_CONFIG.training.epochs): train_step() test_step() @@ -387,40 +412,44 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - + return results + def train_transe(split_data): - + num_nodes = sum(num_nodes_dict.values()) num_relations = sum(num_relations_dict.values()) - model = TransE(num_nodes=num_nodes, - num_relations=num_relations, - hidden_channels=TRANSE_CONFIG.model.hidden_channels, - margin=TRANSE_CONFIG.model.margin, - p_norm=TRANSE_CONFIG.model.p_norm, - sparse=TRANSE_CONFIG.model.sparse).to(DEVICE) - + model = TransE( + num_nodes=num_nodes, + num_relations=num_relations, + hidden_channels=TRANSE_CONFIG.model.hidden_channels, + margin=TRANSE_CONFIG.model.margin, + p_norm=TRANSE_CONFIG.model.p_norm, + sparse=TRANSE_CONFIG.model.sparse, + ).to(DEVICE) print(model) - optimizer = torch.optim.Adam(model.parameters(), lr=TRANSE_CONFIG.training.learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), lr=TRANSE_CONFIG.training.learning_rate + ) loader = model.loader( - head_index=train_data.edge_index[0], - rel_type=train_data.edge_type, - tail_index=train_data.edge_index[1], - batch_size=1000, - shuffle=True, - ) + head_index=train_data.edge_index[0], + rel_type=train_data.edge_type, + tail_index=train_data.edge_index[1], + batch_size=1000, + shuffle=True, + ) results = { - "train": {"mae": [], "loss": [], "epochs": []}, + "train": {"mae": [], "loss": [], "epochs": []}, "train_val": {"mae": [], "loss": [], "epochs": []}, - "test": {"mae": [], "loss": [], "epochs": []}, - "test_val": {"mae": [], "loss": [], "epochs": []}, + "test": {"mae": [], "loss": [], "epochs": []}, + "test_val": {"mae": [], "loss": [], "epochs": []}, } def train_step(): @@ -434,31 +463,31 @@ def train_step(): optimizer.step() total_loss += loss.item() - + results["train"]["loss"].append(float(total_loss / len(loader))) @torch.no_grad() def test_step(): model.eval() - z = model('materials', batch=heterodata['materials'].y_index.to(DEVICE)) - y = heterodata['materials'].y + z = model("materials", batch=heterodata["materials"].y_index.to(DEVICE)) + y = heterodata["materials"].y material_indices = torch.randperm(z.size(0)) - + n_materials = z.size(0) train_ratio = METAPATH2VEC_CONFIG.training.train_ratio val_ratio = METAPATH2VEC_CONFIG.training.val_ratio test_ratio = 1 - train_ratio - + train_size = int(train_ratio * n_materials) test_size = int(test_ratio * n_materials) train_val_size = int(val_ratio * train_size) test_val_size = int(val_ratio * test_size) - + total_train_materials = material_indices[:train_size] total_test_materials = material_indices[train_size:] - + train_materials = total_train_materials[:train_val_size] test_materials = total_test_materials[:test_val_size] @@ -473,18 +502,20 @@ def test_step(): "test": test_materials, "test_val": test_val_materials, } - + reg = linear_model.LinearRegression() - reg.fit(z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy()) - + reg.fit( + z[split_data["train"]].cpu().numpy(), y[split_data["train"]].cpu().numpy() + ) + for split_name, split_materials in split_data.items(): y_pred = reg.predict(z[split_materials].cpu().numpy()) y_real = y[split_materials].cpu().numpy() - + if split_name != "train": loss = np.mean(np.abs(y_pred - y_real)) results[split_name]["loss"].append(float(loss)) - + y_pred = np.array([10**value for value in y_pred]) y_real = np.array([10**value for value in y_real]) @@ -500,12 +531,12 @@ def test_step(): results["train_val"]["epochs"].append(epoch) results["test"]["epochs"].append(epoch) results["test_val"]["epochs"].append(epoch) - + loss_str = f"Epoch: {epoch}," for split_name, split_results in results.items(): loss_str += f"{split_name}: {split_results['mae'][-1]:.4f} " print(loss_str) - + return model, results @@ -514,9 +545,9 @@ def main(): model, linear_results = train_metapath2vec(heterodata) mlp_results = train_mlp_baseline(heterodata, model) - + training_dir = METAPATH2VEC_CONFIG.training.train_dir - + runs_dir = os.path.join(training_dir, "runs") os.makedirs(runs_dir, exist_ok=True) n_runs = len(os.listdir(runs_dir)) @@ -528,37 +559,39 @@ def main(): with open(os.path.join(results_dir, "mlp_config.json"), "w") as f: json.dump(OmegaConf.to_container(MLP_CONFIG), f) - + with open(os.path.join(results_dir, "linear_results.json"), "w") as f: json.dump(linear_results, f) - + with open(os.path.join(results_dir, "mlp_results.json"), "w") as f: json.dump(mlp_results, f) - - plot_learning_curves(linear_results, os.path.join(results_dir, "linear_learning_curves.png")) - plot_learning_curves(mlp_results, os.path.join(results_dir, "mlp_learning_curves.png")) - - - + + plot_learning_curves( + linear_results, os.path.join(results_dir, "linear_learning_curves.png") + ) + plot_learning_curves( + mlp_results, os.path.join(results_dir, "mlp_learning_curves.png") + ) + z_per_type = { - "materials": model('materials'), + "materials": model("materials"), # "elements": model('elements'), - "space_groups": model('space_groups'), + "space_groups": model("space_groups"), # "crystal_systems": model('crystal_systems'), } targets_per_type = { - "materials": 10 ** heterodata['materials'].y.cpu().numpy(), + "materials": 10 ** heterodata["materials"].y.cpu().numpy(), } targets_labels_per_type = { - "materials": heterodata['materials'].y_label_name[0], + "materials": heterodata["materials"].y_label_name[0], } targets_index_per_type = { - "materials": heterodata['materials'].y_index.cpu().numpy(), + "materials": heterodata["materials"].y_index.cpu().numpy(), } LOGGER.info(f"Targets index per type: {len(heterodata['elements'].labels)}") labels_per_type = { - "elements": heterodata['elements'].labels, - "space_groups": heterodata['space_groups'].labels, + "elements": heterodata["elements"].labels, + "space_groups": heterodata["space_groups"].labels, # "crystal_systems": heterodata['crystal_systems'].labels, } color_per_type = { @@ -566,25 +599,27 @@ def main(): "space_groups": "black", # "crystal_systems": "black", } - - create_umap_plot(z_per_type, - targets_per_type=targets_per_type, - targets_index_per_type=targets_index_per_type, - targets_labels_per_type=targets_labels_per_type, - labels_per_type=labels_per_type, - color_per_type=color_per_type, - save_path=os.path.join(results_dir, "umap.png"), - n_neighbors=30) - # create_umap_plot3d(z_per_type, + + create_umap_plot( + z_per_type, + targets_per_type=targets_per_type, + targets_index_per_type=targets_index_per_type, + targets_labels_per_type=targets_labels_per_type, + labels_per_type=labels_per_type, + color_per_type=color_per_type, + save_path=os.path.join(results_dir, "umap.png"), + n_neighbors=30, + ) + # create_umap_plot3d(z_per_type, # targets_per_type=targets_per_type, # targets_index_per_type=targets_index_per_type, # labels_per_type=labels_per_type, # color_per_type=color_per_type, # save_path=os.path.join(results_dir, "umap_materials_elements_3d.png"), # n_neighbors=30) - -def plot_learning_curves(results, save_path, measure='mae'): + +def plot_learning_curves(results, save_path, measure="mae"): """ Plots the learning curves for a specified measure from the results dictionary. @@ -595,23 +630,25 @@ def plot_learning_curves(results, save_path, measure='mae'): measure (str): The measure to plot (e.g., 'loss' or 'mae'). Default is 'loss'. """ plt.figure(figsize=(10, 6)) - + # Iterate over the splits in the results dictionary for idx, split in enumerate(results): split_data = results[split] - + # Check if the desired measure is available in this split's data if measure not in split_data: - print(f"Warning: Measure '{measure}' not found for split '{split}'. Skipping.") + print( + f"Warning: Measure '{measure}' not found for split '{split}'. Skipping." + ) continue # Use the provided 'epochs' list if available, otherwise create a range based on the measure length epochs = split_data.get("epochs", list(range(len(split_data[measure])))) values = split_data[measure] - + # Select a color for this plot color = DEFAULT_COLORS[idx % len(DEFAULT_COLORS)] - + # Plot the curve for this split plt.plot(epochs, values, label=split, color=color, linewidth=2) @@ -625,20 +662,20 @@ def plot_learning_curves(results, save_path, measure='mae'): plt.close() - - -def create_umap_plot(z_per_type, - targets_per_type:dict=None, - targets_index_per_type:dict=None, - targets_labels_per_type:dict=None, - filter_index_per_type:dict=None, - labels_per_type:dict=None, - color_per_type:dict=None, - save_path=".", - n_neighbors=50, - n_jobs=4): +def create_umap_plot( + z_per_type, + targets_per_type: dict = None, + targets_index_per_type: dict = None, + targets_labels_per_type: dict = None, + filter_index_per_type: dict = None, + labels_per_type: dict = None, + color_per_type: dict = None, + save_path=".", + n_neighbors=50, + n_jobs=4, +): node_types = list(z_per_type.keys()) - + if targets_per_type is None: targets_per_type = {} if targets_index_per_type is None: @@ -651,98 +688,100 @@ def create_umap_plot(z_per_type, color_per_type = {} if targets_labels_per_type is None: targets_labels_per_type = {} - - z_global_idx_per_type={} - z_local_idx_per_type={} - local_global_idx_mapping_per_type={} + + z_global_idx_per_type = {} + z_local_idx_per_type = {} + local_global_idx_mapping_per_type = {} z_node_type_mapping = {} - - z_all=None - total_n_nodes=0 + z_all = None + total_n_nodes = 0 for i, (node_type, z) in enumerate(z_per_type.items()): - z=z.detach().cpu().numpy() + z = z.detach().cpu().numpy() n_nodes = z.shape[0] - - + LOGGER.info(f"Node type: {node_type}, Number of nodes: {n_nodes}") - + z_node_type_mapping[node_type] = i - z_global_idx_per_type[node_type] = np.arange(total_n_nodes, total_n_nodes + n_nodes) + z_global_idx_per_type[node_type] = np.arange( + total_n_nodes, total_n_nodes + n_nodes + ) z_local_idx_per_type[node_type] = np.arange(n_nodes) - local_global_idx_mapping_per_type[node_type] = {i:j for i,j in zip(z_local_idx_per_type[node_type], z_global_idx_per_type[node_type])} + local_global_idx_mapping_per_type[node_type] = { + i: j + for i, j in zip( + z_local_idx_per_type[node_type], z_global_idx_per_type[node_type] + ) + } if z_all is None: z_all = z else: z_all = np.concatenate([z_all, z], axis=0) - - total_n_nodes+=n_nodes - - + + total_n_nodes += n_nodes + # Apply UMAP to reduce dimensions to 2. reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=2, n_jobs=n_jobs) embedding = reducer.fit_transform(z_all) # Create the scatter plot. plt.figure(figsize=(10, 8)) - - - handles=[] - scatter_handles=[] + + handles = [] + scatter_handles = [] for node_type in node_types: LOGGER.info(f"Plotting {node_type}") - + color = color_per_type.get(node_type, None) node_labels = labels_per_type.get(node_type, None) targets = targets_per_type.get(node_type, None) target_idx = targets_index_per_type.get(node_type, None) filter_idx = filter_index_per_type.get(node_type, None) - + node_idx = z_global_idx_per_type.get(node_type, None) LOGGER.info(f"Node index: {node_idx}") if target_idx is not None: LOGGER.info(f"Target index: {target_idx}") - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in target_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in target_idx + ] if node_labels is not None: - node_labels = node_labels[target_idx] # Needs to be local index + node_labels = node_labels[target_idx] # Needs to be local index if filter_idx is not None: - node_idx = [local_global_idx_mapping_per_type[node_type][i] for i in filter_idx] + node_idx = [ + local_global_idx_mapping_per_type[node_type][i] for i in filter_idx + ] if node_labels is not None: - node_labels = node_labels[filter_idx] # Needs to be local index - - + node_labels = node_labels[filter_idx] # Needs to be local index + if targets is not None: c = targets - cmap=DEFAULT_CMAP + cmap = DEFAULT_CMAP elif color is not None: - c=color - cmap=None + c = color + cmap = None handles.append(mpatches.Patch(color=color, label=node_type)) - - - x = embedding[node_idx, 0] # Needs to be global index - y = embedding[node_idx, 1] # Needs to be global index - scatter = plt.scatter(x, y, s=10, alpha=0.8, - c=c, - cmap=cmap) - c=None - + + x = embedding[node_idx, 0] # Needs to be global index + y = embedding[node_idx, 1] # Needs to be global index + scatter = plt.scatter(x, y, s=10, alpha=0.8, c=c, cmap=cmap) + c = None + if targets is not None: LOGGER.info(f"Plotting {node_type} targets") scatter_handles.append(scatter) - + if node_labels is not None: LOGGER.info(f"Plotting {node_type} labels, n_labels: {len(node_labels)}") for i, label in enumerate(node_labels): plt.annotate(label, (x[i], y[i]), fontsize=8, alpha=1) - if targets_per_type: - label="" + label = "" for node_type in node_types: - label+=targets_labels_per_type.get(node_type, "") + label += targets_labels_per_type.get(node_type, "") plt.colorbar(scatter_handles[0], label=label) - plt.legend(handles=handles) + plt.legend(handles=handles) plt.title("UMAP Projection of Node Embeddings") plt.xlabel("UMAP 1") plt.ylabel("UMAP 2") @@ -750,34 +789,23 @@ def create_umap_plot(z_per_type, plt.close() - - - if __name__ == "__main__": - + logger = logging.getLogger("__main__") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) - - + logger = logging.getLogger("matgraphdb.pyg.data.hetero_graph") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) handler.setLevel(logging.DEBUG) logger.addHandler(handler) main() - - - - - - - - - - - diff --git a/pyproject.toml b/pyproject.toml index 0709133..0644aa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,27 +35,24 @@ dependencies = [ "pymatgen", "parquetdb", "variconfig", - "huggingface_hub[cli,torch]", + "huggingface_hub[cli]", ] [project.optional-dependencies] -tests = [ - "pytest", - "pytest-cov", - "pymatgen" -] -torch = [ + +ml = [ "torch", "torchvision", "torchaudio", - "pyg_lib ; platform_system == 'Linux'" , "torch_geometric", - "torch_scatter", - "torch_sparse", - "torch_cluster", - "torch_spline_conv" +] + +tests = [ + "pytest", + "pytest-cov", + "matgraphdb[ml]" ] graph-tool = [ diff --git a/scripts/run_pytest_continuously.py b/scripts/run_pytest_continuously.py new file mode 100644 index 0000000..f1e41bd --- /dev/null +++ b/scripts/run_pytest_continuously.py @@ -0,0 +1,74 @@ +import datetime +import os +import signal +import subprocess +import sys +import time +from pathlib import Path + + +def run_pytest_continuously(): + # Create a directory for storing failure logs if it doesn't exist + + root_dir = Path(__file__).parent.parent + log_dir = root_dir / "logs" + pytest_failure_logs_dir = log_dir / "pytest_failure_logs" + pytest_failure_logs_dir.mkdir(parents=True, exist_ok=True) + + iteration = 1 + running = True + + def signal_handler(signum, frame): + nonlocal running + print("\nStopping test execution gracefully...") + running = False + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + while running: + print(f"\nIteration {iteration}") + timestamp = datetime.datetime.now() + + # Run pytest and capture output + result = subprocess.run( + ["pytest", "tests/test_material_nodes.py", "-v"], + capture_output=True, + text=True, + ) + + # If the test failed + if result.returncode != 0: + # Create a failure log filename with timestamp + failure_time = timestamp.strftime("%Y%m%d_%H%M%S") + log_filename = f"{pytest_failure_logs_dir}/failure_{failure_time}_iteration_{iteration}.log" + + # Write the failure details to the log file + with open(log_filename, "w") as f: + f.write(f"Failure occurred at: {timestamp}\n") + f.write(f"Iteration: {iteration}\n") + f.write("\nSTDOUT:\n") + f.write(result.stdout) + f.write("\nSTDERR:\n") + f.write(result.stderr) + + print(f"Test failed! Details written to {log_filename}") + + # You might want to add additional failure handling here + # For example, you could break the loop or add a delay + + # Print a simple status update + print( + f"Iteration {iteration} completed with {'SUCCESS' if result.returncode == 0 else 'FAILURE'}" + ) + + iteration += 1 + + # Optional: Add a small delay between iterations to prevent overwhelming the system + time.sleep(0.1) + + +if __name__ == "__main__": + print("Starting continuous pytest execution. Press Ctrl+C to stop.") + run_pytest_continuously() diff --git a/tests/test_data/materials/materials_0.parquet b/tests/test_data/material/material_0.parquet similarity index 98% rename from tests/test_data/materials/materials_0.parquet rename to tests/test_data/material/material_0.parquet index debe442..47d3de0 100644 Binary files a/tests/test_data/materials/materials_0.parquet and b/tests/test_data/material/material_0.parquet differ diff --git a/tests/test_edge_store.py b/tests/test_edge_store.py deleted file mode 100644 index d9ffb3c..0000000 --- a/tests/test_edge_store.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import shutil - -import pandas as pd -import pyarrow as pa -import pytest - -from matgraphdb.core import EdgeStore, NodeStore -from matgraphdb.materials.nodes import elements - - -@pytest.fixture -def temp_storage(tmp_path): - """Fixture to create and cleanup a temporary storage directory""" - storage_dir = tmp_path / "test_edge_store" - yield str(storage_dir) - if os.path.exists(storage_dir): - shutil.rmtree(storage_dir) - - -@pytest.fixture -def tmp_dir(tmp_path): - """Fixture for temporary directory.""" - tmp_dir = str(tmp_path) - yield tmp_dir - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) - - -@pytest.fixture -def edge_store(tmp_dir): - """Fixture to create an EdgeStore instance.""" - edge_store = EdgeStore(storage_path=os.path.join(tmp_dir, "edges")) - return edge_store - - -@pytest.fixture -def element_store(tmp_dir): - """Fixture to create an ElementNodes instance.""" - element_store = NodeStore(storage_path=os.path.join(tmp_dir, "elements")) - element_store.create_nodes(elements()) - return element_store - - -@pytest.fixture -def sample_edge_data(): - """Fixture providing sample edge data with required fields""" - return { - "source_id": [1, 2], - "target_id": [3, 4], - "source_type": ["node_a", "node_a"], - "target_type": ["node_b", "node_b"], - "edge_type": ["has", "has"], - "weight": [0.5, 0.7], - } - - -def test_edge_store_initialization(temp_storage): - """Test that EdgeStore initializes correctly and creates the storage directory""" - store = EdgeStore(temp_storage) - assert os.path.exists(temp_storage) - assert store is not None - - metadata = store.get_metadata() - assert metadata["class"] == "EdgeStore" - assert metadata["class_module"] == "matgraphdb.core.edges" - - -def test_create_edges_from_dict(edge_store, sample_edge_data): - """Test creating edges from a dictionary""" - edge_store.create_edges(sample_edge_data) - - # Read back and verify - result_table = edge_store.read_edges() - result_df = result_table.to_pandas() - assert len(result_df) == 2 - assert all(field in result_df.columns for field in EdgeStore.required_fields) - assert list(result_df["source_id"]) == [1, 2] - assert list(result_df["target_id"]) == [3, 4] - - -def test_create_edges_from_dataframe(edge_store, sample_edge_data): - """Test creating edges from a pandas DataFrame""" - df = pd.DataFrame(sample_edge_data) - edge_store.create_edges(df) - - result_table = edge_store.read_edges() - result_df = result_table.to_pandas() - assert len(result_df) == 2 - assert all(result_df["source_id"] == df["source_id"]) - assert all(result_df["target_id"] == df["target_id"]) - - -def test_read_edges_with_filters(edge_store, sample_edge_data): - """Test reading edges with specific filters""" - edge_store.create_edges(sample_edge_data) - - # Read with column filter - result_table = edge_store.read_edges(columns=["id", "source_id", "target_id"]) - result_df = result_table.to_pandas() - assert list(result_df.columns) == ["id", "source_id", "target_id"] - - # Read with ID filter - first_result_table = edge_store.read_edges() - first_result_df = first_result_table.to_pandas() - first_id = first_result_df["id"].iloc[0] - filtered_result_table = edge_store.read_edges(ids=[first_id]) - filtered_result_df = filtered_result_table.to_pandas() - - assert len(filtered_result_df) == 1 - assert filtered_result_df["id"].iloc[0] == first_id - - -def test_update_edges(edge_store, sample_edge_data): - """Test updating existing edges""" - edge_store.create_edges(sample_edge_data) - - # Get the IDs - existing_edges_table = edge_store.read_edges() - existing_edges_df = existing_edges_table.to_pandas() - first_id = existing_edges_df["id"].iloc[0] - - assert ( - existing_edges_df[existing_edges_df["id"] == first_id]["weight"].iloc[0] == 0.5 - ) - - # Update the first edge - update_data = { - "source_id": [1], - "target_id": [3], - "source_type": ["node_a"], - "target_type": ["node_b"], - "edge_type": ["has"], - "weight": [0.9], - } - - update_keys = ["source_id", "target_id"] - edge_store.update_edges(update_data, update_keys=update_keys) - - # Verify update - updated_edges_table = edge_store.read_edges() - updated_edges_df = updated_edges_table.to_pandas() - assert updated_edges_df[updated_edges_df["id"] == first_id]["weight"].iloc[0] == 0.9 - - -def test_delete_edges(edge_store, sample_edge_data): - """Test deleting edges""" - edge_store.create_edges(sample_edge_data) - - # Get the IDs - existing_edges_table = edge_store.read_edges() - existing_edges_df = existing_edges_table.to_pandas() - first_id = existing_edges_df["id"].iloc[0] - - # Delete one edge - edge_store.delete_edges(ids=[first_id]) - - # Verify deletion - remaining_edges_table = edge_store.read_edges() - remaining_edges_df = remaining_edges_table.to_pandas() - assert len(remaining_edges_df) == 1 - assert first_id not in remaining_edges_df["id"].values - - -def test_delete_columns(edge_store, sample_edge_data): - """Test deleting specific columns""" - edge_store.create_edges(sample_edge_data) - - # Delete the weight column - edge_store.delete_edges(columns=["weight"]) - - # Verify column deletion - result_table = edge_store.read_edges() - result_df = result_table.to_pandas() - assert "weight" not in result_df.columns - assert all(field in result_df.columns for field in EdgeStore.required_fields) diff --git a/tests/test_graphdb.py b/tests/test_graphdb.py deleted file mode 100644 index e06ac22..0000000 --- a/tests/test_graphdb.py +++ /dev/null @@ -1,617 +0,0 @@ -import os -import shutil - -import pandas as pd -import pyarrow as pa -import pytest - -from matgraphdb.core import EdgeStore, GraphDB, NodeStore -from matgraphdb.materials.edges import element_element_neighborsByGroupPeriod -from matgraphdb.materials.nodes import elements -from matgraphdb.utils.config import PKG_DIR, config - -config.logging_config.loggers.matgraphdb.level = "DEBUG" - -config.apply() - - -@pytest.fixture -def tmp_dir(tmp_path): - """Fixture for temporary directory.""" - tmp_dir = str(tmp_path) - yield tmp_dir - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) - - -@pytest.fixture -def graphdb(tmp_dir): - """Fixture to create a GraphDB instance.""" - return GraphDB(storage_path=tmp_dir) - - -@pytest.fixture -def element_store(tmp_dir): - """Fixture to create an ElementNodes instance.""" - element_store = NodeStore(storage_path=os.path.join(tmp_dir, "elements")) - element_store.create_nodes(elements()) - return element_store - - -@pytest.fixture -def test_data(): - nodes_1_data = [{"name": "Source1"}, {"name": "Source2"}] - nodes_2_data = [{"name": "Target1"}, {"name": "Target2"}] - node_1_type = "user" - node_2_type = "item" - edge_data = [ - { - "source_id": 0, - "source_type": node_1_type, - "target_id": 0, - "target_type": node_2_type, - "weight": 0.5, - }, - { - "source_id": 1, - "source_type": node_2_type, - "target_id": 1, - "target_type": node_1_type, - "weight": 0.7, - }, - { - "source_id": 0, - "source_type": node_1_type, - "target_id": 1, - "target_type": node_2_type, - "weight": 0.5, - }, - { - "source_id": 1, - "source_type": node_2_type, - "target_id": 0, - "target_type": node_1_type, - "weight": 0.7, - }, - ] - edge_type = "test_edge" - return nodes_1_data, node_1_type, nodes_2_data, node_2_type, edge_data, edge_type - - -@pytest.fixture -def wyckoff_generator(): - """Fixture to provide a sample node generator function.""" - - def generate_wyckoff_nodes(): - """Generate basic Wyckoff position nodes.""" - data = [ - {"symbol": "1a", "multiplicity": 1, "letter": "a", "site_symmetry": "1"}, - {"symbol": "2b", "multiplicity": 2, "letter": "b", "site_symmetry": "2"}, - ] - return pa.Table.from_pylist(data) - - return generate_wyckoff_nodes - - -@pytest.fixture -def node_generator_data(wyckoff_generator): - """Fixture providing test data for node generators.""" - generator_name = "test_wyckoff_generator" - generator_func = wyckoff_generator - generator_args = {} - generator_kwargs = {} - return generator_name, generator_func, generator_args, generator_kwargs - - - - - -def test_initialize_graphdb(graphdb): - """Test if GraphDB initializes with the correct directories.""" - assert os.path.exists(graphdb.nodes_path), "Nodes directory not created." - assert os.path.exists(graphdb.edges_path), "Edges directory not created." - assert os.path.exists( - graphdb.edge_generators_path - ), "Edge generators directory not created." - assert os.path.exists(graphdb.graph_path), "Graph directory not created." - - -def test_add_node_type(graphdb): - """Test adding a node type.""" - node_type = "test_node" - graphdb.add_node_type(node_type) - node_store_path = os.path.join(graphdb.nodes_path, node_type) - assert os.path.exists(node_store_path), "Node type directory not created." - assert node_type in graphdb.node_stores, "Node type not registered in node_stores." - - store = graphdb.node_stores[node_type] - assert isinstance(store, NodeStore), "Node store is not of type NodeStore." - table = store.read_nodes() - assert table.num_rows == 0, "Node store is not empty." - - store = graphdb.get_node_store(node_type) - assert ( - store == graphdb.node_stores[node_type] - ), "Node store not retrieved correctly." - - -def test_add_nodes(graphdb): - """Test adding nodes to a node type.""" - node_type_1 = "test_node_1" - node_type_2 = "test_node_2" - data_1 = [{"name": "Node1"}, {"name": "Node2"}] - data_2 = [{"name": "Node3"}, {"name": "Node4"}, {"name": "Node5"}] - graphdb.add_nodes(node_type_1, data_1) - graphdb.add_nodes(node_type_2, data_2) - - assert ( - node_type_1 in graphdb.node_stores - ), "Node type 1 not registered in node_stores." - assert ( - node_type_2 in graphdb.node_stores - ), "Node type 2 not registered in node_stores." - - node_store_1 = graphdb.get_node_store(node_type_1) - node_store_2 = graphdb.get_node_store(node_type_2) - - assert len(node_store_1.read_nodes()) == len( - data_1 - ), "Incorrect number of nodes added." - assert len(node_store_2.read_nodes()) == len( - data_2 - ), "Incorrect number of nodes added." - - -def test_add_node_store(tmp_dir): - """Test adding a node store to the graph database.""" - # Create a temporary node store - temp_store_path = os.path.join(tmp_dir, "temp_store") - temp_store = NodeStore(storage_path=temp_store_path) - - # Add some test data to the temp store - test_data = [{"name": "Node1"}, {"name": "Node2"}] - temp_store.create_nodes(test_data) - - # Create a graph database - graph = GraphDB(storage_path=os.path.join(tmp_dir, "graph")) - - # Add the node store to the graph - graph.add_node_store(temp_store) - - # Verify the store was added correctly - assert temp_store.node_type in graph.node_stores - assert len(graph.get_nodes(temp_store.node_type)) == 2 - - # Test overwrite=False behavior - new_store = NodeStore(storage_path=os.path.join(tmp_dir, "new_store")) - new_store.node_type = temp_store.node_type - - with pytest.raises( - ValueError, match=f"Node store of type {temp_store.node_type} already exists" - ): - graph.add_node_store(new_store, overwrite=False) - - # Test overwrite=True behavior - new_data = [{"name": "Node3"}] - new_store.create_nodes(new_data) - graph.add_node_store(new_store, overwrite=True) - - # Verify the store was overwritten - assert len(graph.get_nodes(new_store.node_type)) == 1 - nodes = graph.get_nodes(new_store.node_type) - assert nodes.to_pydict()["name"] == ["Node3"] - - -def test_add_node_store_with_remove_original(tmp_dir): - """Test adding a node store with remove_original=True option.""" - # Create a temporary node store - node_type = "test_node" - temp_store_path = os.path.join(tmp_dir, node_type) - temp_store = NodeStore(storage_path=temp_store_path) - - # Add test data and ensure it's written to disk - test_data = [{"name": "Node1"}] - temp_store.create_nodes(test_data) - temp_store.normalize() # Add this line to ensure data is written - - # Create a graph database - graphdb = GraphDB(storage_path=os.path.join(tmp_dir, "graph")) - - # Add the node store with remove_original=True - graphdb.add_node_store(temp_store, remove_original=True) - - # Verify original store directory was removed - assert not os.path.exists(temp_store_path) - - # Verify data was transferred correctly - assert temp_store.node_type in graphdb.node_stores - - # Ensure the new store location exists before trying to read - new_store_path = os.path.join(graphdb.nodes_path, node_type) - assert os.path.exists(new_store_path), "New store location doesn't exist" - - nodes = graphdb.get_nodes(node_type) - assert nodes.to_pydict()["name"] == ["Node1"] - - -def test_nodes_persist_after_reload(tmp_dir): - """Test that nodes persist and can be loaded after recreating the GraphDB instance.""" - # Create initial graph instance and add nodes - graph = GraphDB(storage_path=tmp_dir, load_custom_stores=False) - node_type = "test_node" - test_data = [{"name": "Node1", "value": 10}, {"name": "Node2", "value": 20}] - graph.add_nodes(node_type, test_data) - - # Verify initial data - initial_nodes = graph.get_nodes(node_type) - assert len(initial_nodes) == 2, "Incorrect number of nodes added." - assert initial_nodes.to_pydict()["name"] == [ - "Node1", - "Node2", - ], "Incorrect node names." - - # Create new graph instance (simulating program restart) - new_graph = GraphDB(storage_path=tmp_dir, load_custom_stores=False) - - # Verify data persisted - loaded_nodes = new_graph.get_nodes(node_type) - assert len(loaded_nodes) == 2, "Incorrect number of nodes loaded." - loaded_dict = loaded_nodes.to_pydict() - assert loaded_dict["name"] == ["Node1", "Node2"], "Incorrect node names." - assert loaded_dict["value"] == [10, 20], "Incorrect node values." - assert ( - node_type in new_graph.node_stores - ), "Node type not registered in node_stores." - - -def test_node_exists(graphdb): - """Test checking if a node type exists.""" - node_type = "test_node" - - # Initially node type should not exist - assert not graphdb.node_exists(node_type), "Node type should not exist." - - # Add nodes and verify node type exists - test_data = [{"name": "Node1"}] - graphdb.add_nodes(node_type, test_data) - assert graphdb.node_exists(node_type), "Node type should exist." - - # Non-existent node type should return False - assert not graphdb.node_exists( - "nonexistent_type" - ), "Non-existent node type should not exist." - - -def test_node_is_empty(graphdb): - """Test checking if a node type is empty.""" - node_type = "test_node" - - assert not graphdb.node_exists(node_type), "Node type should not exist." - - # Add nodes and verify node type is not empty - test_data = [{"name": "Node1"}] - graphdb.add_nodes(node_type, test_data) - assert not graphdb.node_is_empty(node_type), "Node type should not be empty." - - -def test_add_edge_type(graphdb): - """Test adding an edge type.""" - edge_type = "test_edge" - graphdb.add_edge_type(edge_type) - edge_store_path = os.path.join(graphdb.edges_path, edge_type) - assert os.path.exists(edge_store_path), "Edge type directory not created." - assert edge_type in graphdb.edge_stores, "Edge type not registered in edge_stores." - - store = graphdb.edge_stores[edge_type] - assert isinstance(store, EdgeStore), "Edge store is not of type EdgeStore." - table = store.read_edges() - assert table.num_rows == 0, "Edge store is not empty." - - store = graphdb.get_edge_store(edge_type) - assert ( - store == graphdb.edge_stores[edge_type] - ), "Edge store not retrieved correctly." - - -def test_add_edges(graphdb, test_data): - """Test adding edges.""" - # First create some nodes that the edges will connect - nodes_1_data, node_1_type, nodes_2_data, node_2_type, edge_data, edge_type = ( - test_data - ) - graphdb.add_nodes(node_1_type, nodes_1_data) - graphdb.add_nodes(node_2_type, nodes_2_data) - - assert ( - node_1_type in graphdb.node_stores - ), "Node type 1 not registered in node_stores." - assert ( - node_2_type in graphdb.node_stores - ), "Node type 2 not registered in node_stores." - - graphdb.add_edges(edge_type, edge_data) - - assert edge_type in graphdb.edge_stores, "Edge type not registered in edge_stores." - - edges = graphdb.read_edges(edge_type) - assert len(edges) == 4, "Incorrect number of edges added." - - -def test_edge_exists_and_is_empty(graphdb, test_data): - """Test checking if an edge type exists and is empty.""" - nodes_1_data, node_1_type, nodes_2_data, node_2_type, edge_data, edge_type = ( - test_data - ) - graphdb.add_nodes(node_1_type, nodes_1_data) - graphdb.add_nodes(node_2_type, nodes_2_data) - - # Initially edge type should not exist - assert not graphdb.edge_exists(edge_type), "Edge type should not exist." - - # Add edge type and verify it exists but is empty - graphdb.add_edge_type(edge_type) - assert graphdb.edge_exists(edge_type), "Edge type should exist." - assert graphdb.edge_is_empty(edge_type), "Edge type should be empty." - - # Add some edges and verify it's no longer empty - graphdb.add_edges(edge_type, edge_data) - assert not graphdb.edge_is_empty(edge_type), "Edge type should not be empty." - - -def test_remove_edge_store(graphdb, test_data): - """Test removing an edge store.""" - nodes_1_data, node_1_type, nodes_2_data, node_2_type, edge_data, edge_type = ( - test_data - ) - graphdb.add_nodes(node_1_type, nodes_1_data) - graphdb.add_nodes(node_2_type, nodes_2_data) - - # Add an edge type - graphdb.add_edge_type(edge_type) - assert edge_type in graphdb.edge_stores, "Edge type not added correctly." - - # Remove the edge store - graphdb.remove_edge_store(edge_type) - assert ( - edge_type not in graphdb.edge_stores - ), "Edge type not removed from edge_stores." - assert not os.path.exists( - os.path.join(graphdb.edges_path, edge_type) - ), "Edge store directory not removed." - - # Verify that trying to get the removed store raises an error - with pytest.raises( - ValueError, match=f"Edge store of type {edge_type} does not exist" - ): - graphdb.get_edge_store(edge_type) - - -def test_edges_persist_after_reload(tmp_dir, test_data): - """Test that edges persist and can be loaded after recreating the GraphDB instance.""" - # Create initial graph instance and add edges - graph = GraphDB(storage_path=tmp_dir) - nodes_1_data, node_1_type, nodes_2_data, node_2_type, edge_data, edge_type = ( - test_data - ) - graph.add_nodes(node_1_type, nodes_1_data) - graph.add_nodes(node_2_type, nodes_2_data) - graph.add_edges(edge_type, edge_data) - - # Create new graph instance (simulating program restart) - new_graph = GraphDB(storage_path=tmp_dir) - - # Verify edges persisted - assert edge_type in new_graph.edge_stores, "Edge type not loaded." - edges = new_graph.read_edges(edge_type) - assert len(edges) == 4, "Incorrect number of edges loaded." - edge_dict = edges.to_pydict() - assert edge_dict["weight"] == [0.5, 0.7, 0.5, 0.7], "Incorrect edge weights." - - -def test_add_edge_generator(graphdb, element_store): - """Test adding an edge generator to the GraphDB.""" - # Get the generator name from the function - - generator_name = element_element_neighborsByGroupPeriod.__name__ - - graphdb.add_node_store(element_store) - # Add the generator - graphdb.add_edge_generator( - element_element_neighborsByGroupPeriod, - generator_args={"element_store": element_store}, - ) - - # Verify the generator was added - assert graphdb.edge_generator_store.is_in(generator_name) - - -def test_run_edge_generator(graphdb, element_store): - """Test running an edge generator and verify its output.""" - - graphdb.add_node_store(element_store) - - generator_name = element_element_neighborsByGroupPeriod.__name__ - - # Add and run the generator - graphdb.add_edge_generator( - element_element_neighborsByGroupPeriod, - generator_args={"element_store": element_store}, - ) - - table = graphdb.run_edge_generator(generator_name) - - # Verify the output table has the expected structure - assert isinstance(table, pa.Table) - expected_columns = { - "source_id", - "target_id", - "source_type", - "target_type", - "weight", - "source_name", - "target_name", - "name", - "source_extended_group", - "source_period", - "target_extended_group", - "target_period", - } - assert set(table.column_names) == expected_columns - - # Convert to pandas for easier verification - df = table.to_pandas() - - # Basic validation checks - assert not df.empty, "Generator produced no edges" - assert all(df["source_type"] == "elements"), "Incorrect source_type" - assert all(df["target_type"] == "elements"), "Incorrect target_type" - assert all(df["weight"] == 1.0), "Incorrect weight values" - - # Verify edge names are properly formatted - assert all( - df["name"].str.contains("_neighborsByGroupPeriod_") - ), "Edge names not properly formatted" - - -def test_edge_generator_persistence(tmp_dir, element_store): - """Test that edge generators persist when reloading the GraphDB.""" - generator_name = element_element_neighborsByGroupPeriod.__name__ - - # Create initial graph instance and add generator - graph = GraphDB(storage_path=tmp_dir) - graph.add_node_store(element_store) - graph.add_edge_generator( - element_element_neighborsByGroupPeriod, - generator_args={"element_store": element_store}, - ) - - # Create new graph instance (simulating program restart) - new_graph = GraphDB(storage_path=tmp_dir) - - # Verify generator was loaded - assert new_graph.edge_generator_store.is_in(generator_name) - - # Verify generator still works - table = new_graph.run_edge_generator(generator_name) - assert isinstance(table, pa.Table), "Generator output is not a pyarrow Table" - assert len(table) > 0, "Generator produced no edges" - - assert table.shape == ( - 391, - 12, - ), "Generator for element_element_neighborsByGroupPeriod edge output shape is incorrect" - - -def test_invalid_generator_args(graphdb, element_store): - """Test that invalid generator arguments raise appropriate errors.""" - - # element_store = ElementNodes(storage_path=os.path.join(tmp_dir, 'elements')) - generator_name = element_element_neighborsByGroupPeriod.__name__ - graphdb.add_node_store(element_store) - - # Test missing required argument - with pytest.raises(Exception): - graphdb.add_edge_generator( - element_element_neighborsByGroupPeriod, - generator_args={}, # Missing element_store - run_immediately=False, - ) - graphdb.run_edge_generator(generator_name) - - # Test invalid element_store argument - with pytest.raises(Exception): - graphdb.add_edge_generator( - element_element_neighborsByGroupPeriod, - generator_args={"element_store": "invalid_store"}, - ) - - -def test_add_node_generator(graphdb, wyckoff_generator): - """Test adding a node generator to the GraphDB.""" - - generator_name = wyckoff_generator.__name__ - # Add the generator - graphdb.add_node_generator(wyckoff_generator) - - # Verify the generator was added - assert graphdb.node_generator_store.is_in(generator_name) - - wyckoff_node_store = graphdb.get_node_store(generator_name) - - nodes = wyckoff_node_store.read_nodes() - nodes_df = nodes.to_pandas() - - assert nodes_df.shape == (2, 5) - - -def test_run_node_generator(graphdb, wyckoff_generator): - """Test running a node generator and verify its output.""" - generator_name = wyckoff_generator.__name__ - - # Add and run the generator - graphdb.add_node_generator( - wyckoff_generator, - generator_args=None, - generator_kwargs=None, - run_immediately=False, - ) - - graphdb.run_node_generator(generator_name) - - wyckoff_node_store = graphdb.get_node_store(generator_name) - - nodes = wyckoff_node_store.read_nodes() - nodes_df = nodes.to_pandas() - - # Verify the output table has the expected structure - assert isinstance(nodes_df, pd.DataFrame) - expected_columns = {"symbol", "multiplicity", "letter", "site_symmetry", "id"} - assert set(nodes_df.columns) == expected_columns - - # Basic validation checks - assert not nodes_df.empty, "Generator produced no nodes" - assert nodes_df.shape == (2, 5), "Generator has the wrong shape" - - -def test_node_generator_persistence(tmp_dir, wyckoff_generator): - """Test that node generators persist when reloading the GraphDB.""" - generator_name = wyckoff_generator.__name__ - # Create initial graph instance and add generator - graph = GraphDB(storage_path=tmp_dir) - graph.add_node_generator( - wyckoff_generator, - run_immediately=False, - ) - - # Create new graph instance (simulating program restart) - new_graph = GraphDB(storage_path=tmp_dir) - - # Verify generator was loaded - assert new_graph.node_generator_store.is_in(generator_name) - - # Verify generator still works - new_graph.run_node_generator(generator_name) - - wyckoff_node_store = new_graph.get_node_store(generator_name) - - nodes = wyckoff_node_store.read_nodes() - nodes_df = nodes.to_pandas() - - assert isinstance( - nodes_df, pd.DataFrame - ), "Generator output is not a pandas DataFrame" - assert len(nodes_df) > 0, "Generator produced no nodes" - assert nodes_df.shape == (2, 5), "Generator output shape is incorrect" - - -def test_invalid_node_generator_args(graphdb): - """Test that invalid node generator arguments raise appropriate errors.""" - - def bad_generator(nonexistent_arg): - return pa.Table.from_pylist([{"test": 1}]) - - with pytest.raises(Exception): - graphdb.add_node_generator( - bad_generator, - ) - graphdb.run_node_generator("bad_generator") diff --git a/tests/test_material_nodes.py b/tests/test_material_nodes.py index d70e484..dff2094 100644 --- a/tests/test_material_nodes.py +++ b/tests/test_material_nodes.py @@ -1,22 +1,25 @@ import os import shutil import tempfile +from pathlib import Path import numpy as np import pytest -from matgraphdb.materials.nodes.materials import MaterialStore +from matgraphdb.core.nodes.materials import MaterialStore + +TEMP_DIR = Path(tempfile.mkdtemp()) @pytest.fixture def material_store(): """Fixture that creates a temporary MaterialStore instance.""" - temp_dir = tempfile.mkdtemp() - store = MaterialStore(storage_path=temp_dir) + material_store_path = TEMP_DIR / "material" + store = MaterialStore(storage_path=material_store_path, verbose=3) yield store # Cleanup after test - if os.path.exists(temp_dir): - shutil.rmtree(temp_dir) + if material_store_path.exists(): + shutil.rmtree(material_store_path) @pytest.fixture diff --git a/tests/test_matgraphdb.py b/tests/test_matgraphdb.py index dcc948a..8fdcb1f 100644 --- a/tests/test_matgraphdb.py +++ b/tests/test_matgraphdb.py @@ -1,71 +1,74 @@ import os import shutil +import tempfile +from pathlib import Path import pyarrow as pa import pytest -from matgraphdb.materials.core import MatGraphDB -from matgraphdb.materials.edges import * -from matgraphdb.materials.nodes import * -from matgraphdb.utils.config import DATA_DIR, PKG_DIR, config +from matgraphdb.core.edges import * +from matgraphdb.core.matgraphdb import MatGraphDB +from matgraphdb.core.nodes import * -config.logging_config.loggers.matgraphdb.level = "DEBUG" -config.apply() +current_dir = Path(__file__).parent +TEST_DATA_DIR = current_dir / "test_data" -current_dir = os.path.dirname(os.path.abspath(__file__)) -TEST_DATA_DIR = os.path.join(current_dir, "test_data") +VERBOSE = 1 - -@pytest.fixture -def tmp_dir(tmp_path): - """Fixture for temporary directory.""" - tmp_dir = str(tmp_path) - yield tmp_dir - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) +TMP_DIR = Path(tempfile.mkdtemp()) @pytest.fixture -def matgraphdb(tmp_dir, material_store): +def matgraphdb(material_store): """Fixture to create a MatGraphDB instance.""" - return MatGraphDB(storage_path=tmp_dir, materials_store=material_store) + matgraphdb_path = TMP_DIR / "MatGraphDB" + if matgraphdb_path.exists(): + shutil.rmtree(matgraphdb_path) + return MatGraphDB( + storage_path=matgraphdb_path, materials_store=material_store, verbose=VERBOSE + ) @pytest.fixture -def empty_matgraphdb(tmp_dir): +def empty_matgraphdb(): """Fixture to create a MatGraphDB instance.""" - return MatGraphDB(storage_path=tmp_dir) + matgraphdb_path = TMP_DIR / "MatGraphDB" + if matgraphdb_path.exists(): + shutil.rmtree(matgraphdb_path) + return MatGraphDB(storage_path=matgraphdb_path, verbose=VERBOSE) @pytest.fixture -def empty_material_store(tmp_dir): - materials_path = os.path.join(tmp_dir, "materials") - return MaterialStore(storage_path=materials_path) +def empty_material_store(): + materials_path = TMP_DIR / "material" + if materials_path.exists(): + shutil.rmtree(materials_path) + return MaterialStore(storage_path=materials_path, verbose=VERBOSE) @pytest.fixture def material_store(): - materials_path = os.path.join(TEST_DATA_DIR, "materials") - return MaterialStore(storage_path=materials_path) + materials_path = TEST_DATA_DIR / "material" + return MaterialStore(storage_path=materials_path, verbose=VERBOSE) @pytest.fixture def node_generator_data(matgraphdb): node_generators = [ - {"generator_func": elements}, - {"generator_func": chemenvs}, - {"generator_func": crystal_systems}, - {"generator_func": magnetic_states}, - {"generator_func": oxidation_states}, - {"generator_func": space_groups}, - {"generator_func": wyckoffs}, + {"generator_func": element}, + {"generator_func": chemenv}, + {"generator_func": crystal_system}, + {"generator_func": magnetic_state}, + {"generator_func": oxidation_state}, + {"generator_func": space_group}, + {"generator_func": wyckoff}, { - "generator_func": material_sites, - "generator_args": {"material_store": matgraphdb.node_stores["materials"]}, + "generator_func": material_site, + "generator_args": {"material_store": matgraphdb.node_stores["material"]}, }, { - "generator_func": material_lattices, - "generator_args": {"material_store": matgraphdb.node_stores["materials"]}, + "generator_func": material_lattice, + "generator_args": {"material_store": matgraphdb.node_stores["material"]}, }, ] @@ -87,56 +90,56 @@ def edge_generator_data(node_generator_data): edge_generators = [ { "generator_func": element_element_neighborsByGroupPeriod, - "generator_args": {"element_store": matgraphdb.node_stores["elements"]}, + "generator_args": {"element_store": matgraphdb.node_stores["element"]}, }, { "generator_func": element_oxiState_canOccur, "generator_args": { - "element_store": matgraphdb.node_stores["elements"], - "oxiState_store": matgraphdb.node_stores["oxidation_states"], + "element_store": matgraphdb.node_stores["element"], + "oxiState_store": matgraphdb.node_stores["oxidation_state"], }, }, { "generator_func": material_chemenv_containsSite, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "chemenv_store": matgraphdb.node_stores["chemenvs"], + "material_store": matgraphdb.node_stores["material"], + "chemenv_store": matgraphdb.node_stores["chemenv"], }, }, { "generator_func": material_crystalSystem_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "crystal_system_store": matgraphdb.node_stores["crystal_systems"], + "material_store": matgraphdb.node_stores["material"], + "crystal_system_store": matgraphdb.node_stores["crystal_system"], }, }, { "generator_func": material_element_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "element_store": matgraphdb.node_stores["elements"], + "material_store": matgraphdb.node_stores["material"], + "element_store": matgraphdb.node_stores["element"], }, }, { "generator_func": material_lattice_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "lattice_store": matgraphdb.node_stores["material_lattices"], + "material_store": matgraphdb.node_stores["material"], + "lattice_store": matgraphdb.node_stores["material_lattice"], }, }, { "generator_func": material_spg_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "spg_store": matgraphdb.node_stores["space_groups"], + "material_store": matgraphdb.node_stores["material"], + "spg_store": matgraphdb.node_stores["space_group"], }, }, { "generator_func": element_chemenv_canOccur, "generator_args": { - "element_store": matgraphdb.node_stores["elements"], - "chemenv_store": matgraphdb.node_stores["chemenvs"], - "material_store": matgraphdb.node_stores["materials"], + "element_store": matgraphdb.node_stores["element"], + "chemenv_store": matgraphdb.node_stores["chemenv"], + "material_store": matgraphdb.node_stores["material"], }, }, ] @@ -173,14 +176,14 @@ def test_initialize_matgraphdb(empty_matgraphdb): # Check if materials node store exists matgraphdb = empty_matgraphdb assert ( - "materials" in matgraphdb.node_stores + "material" in matgraphdb.node_stores ), f"Materials node store not found in node_stores: {matgraphdb.node_stores}" assert isinstance( matgraphdb.material_store, MaterialStore ), f"MaterialStore instance not found in matgraphdb: {matgraphdb.material_store}" # Check if materials directory was created - materials_path = os.path.join(matgraphdb.nodes_path, "materials") + materials_path = os.path.join(matgraphdb.nodes_path, "material") assert os.path.exists( materials_path ), f"Materials directory not created: {materials_path}" @@ -212,6 +215,8 @@ def test_read_materials_with_filters(empty_matgraphdb, test_material_data): # Read specific columns columns = ["material_id", "elements"] + + print(matgraphdb.summary(show_column_names=True)) materials = matgraphdb.read_materials(columns=columns) assert isinstance(materials, pa.Table), f"Materials not read: {materials}" @@ -272,14 +277,18 @@ def test_delete_materials(empty_matgraphdb, test_material_data): ], f"Materials not deleted: {materials.to_pydict()['material_id']}" -def test_persistence(tmp_dir, test_material_data): +def test_persistence(test_material_data): """Test that materials persist when recreating the MatGraphDB instance.""" # Create initial graph instance and add materials - db = MatGraphDB(storage_path=tmp_dir) + matgraphdb_path = TMP_DIR / "MatGraphDB" + if matgraphdb_path.exists(): + shutil.rmtree(matgraphdb_path) + + db = MatGraphDB(storage_path=matgraphdb_path, verbose=VERBOSE) db.create_materials(test_material_data) # Create new graph instance (simulating program restart) - new_db = MatGraphDB(storage_path=tmp_dir) + new_db = MatGraphDB(storage_path=matgraphdb_path, verbose=VERBOSE) # Verify materials persisted materials = new_db.read_materials() @@ -296,6 +305,7 @@ def test_add_node_generators(node_generator_data): matgraphdb, node_generators_list = node_generator_data material_store = matgraphdb.material_store + for generator in node_generators_list[:]: generator_func = generator.get("generator_func") generator_args = generator.get("generator_args", None) @@ -312,70 +322,70 @@ def test_add_node_generators(node_generator_data): ), f"Node generators not found in node_stores: {generator_names}" -def test_moving_matgraphdb(tmp_dir, edge_generator_data): +def test_moving_matgraphdb(edge_generator_data): """Test adding an edge generator.""" matgraphdb, edge_generators = edge_generator_data - current_dir = matgraphdb.storage_path + current_dir = Path(matgraphdb.storage_path) - parent_dir = os.path.dirname(current_dir) - new_dir = os.path.join(parent_dir, "new_dir") + parent_dir = current_dir.parent + new_dir = parent_dir / "new_dir" shutil.move(current_dir, new_dir) matgraphdb = MatGraphDB(storage_path=new_dir) edge_generators = [ { "generator_func": element_element_neighborsByGroupPeriod, - "generator_args": {"element_store": matgraphdb.node_stores["elements"]}, + "generator_args": {"element_store": matgraphdb.node_stores["element"]}, }, { "generator_func": element_oxiState_canOccur, "generator_args": { - "element_store": matgraphdb.node_stores["elements"], - "oxiState_store": matgraphdb.node_stores["oxidation_states"], + "element_store": matgraphdb.node_stores["element"], + "oxiState_store": matgraphdb.node_stores["oxidation_state"], }, }, { "generator_func": material_chemenv_containsSite, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "chemenv_store": matgraphdb.node_stores["chemenvs"], + "material_store": matgraphdb.node_stores["material"], + "chemenv_store": matgraphdb.node_stores["chemenv"], }, }, { "generator_func": material_crystalSystem_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "crystal_system_store": matgraphdb.node_stores["crystal_systems"], + "material_store": matgraphdb.node_stores["material"], + "crystal_system_store": matgraphdb.node_stores["crystal_system"], }, }, { "generator_func": material_element_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "element_store": matgraphdb.node_stores["elements"], + "material_store": matgraphdb.node_stores["material"], + "element_store": matgraphdb.node_stores["element"], }, }, { "generator_func": material_lattice_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "lattice_store": matgraphdb.node_stores["material_lattices"], + "material_store": matgraphdb.node_stores["material"], + "lattice_store": matgraphdb.node_stores["material_lattice"], }, }, { "generator_func": material_spg_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "spg_store": matgraphdb.node_stores["space_groups"], + "material_store": matgraphdb.node_stores["material"], + "spg_store": matgraphdb.node_stores["space_group"], }, }, { "generator_func": element_chemenv_canOccur, "generator_args": { - "element_store": matgraphdb.node_stores["elements"], - "chemenv_store": matgraphdb.node_stores["chemenvs"], - "material_store": matgraphdb.node_stores["materials"], + "element_store": matgraphdb.node_stores["element"], + "chemenv_store": matgraphdb.node_stores["chemenv"], + "material_store": matgraphdb.node_stores["material"], }, }, ] @@ -413,8 +423,8 @@ def test_dependency_updates(matgraphdb, node_generator_data): { "generator_func": material_crystalSystem_has, "generator_args": { - "material_store": matgraphdb.node_stores["materials"], - "crystal_system_store": matgraphdb.node_stores["crystal_systems"], + "material_store": matgraphdb.node_stores["material"], + "crystal_system_store": matgraphdb.node_stores["crystal_system"], }, }, ] @@ -440,9 +450,9 @@ def test_dependency_updates(matgraphdb, node_generator_data): generator_name="material_crystalSystem_has", ) # Adding nodes - matgraphdb.add_nodes(node_type="materials", data=data) + matgraphdb.add_nodes(node_type="material", data=data) df = matgraphdb.read_nodes( - "materials", + "material", columns=["id", "symmetry.crystal_system"], filters=[pc.field("id") == 1000], ).to_pandas() @@ -454,7 +464,7 @@ def test_dependency_updates(matgraphdb, node_generator_data): df = df[df["source_id"] == 1000] assert df.iloc[0]["target_id"] == 6 # Cubic id - df = matgraphdb.read_nodes("materials", columns=["core.material_id"]).to_pandas() + df = matgraphdb.read_nodes("material", columns=["core.material_id"]).to_pandas() # Updating nodes data = pd.DataFrame( @@ -465,10 +475,10 @@ def test_dependency_updates(matgraphdb, node_generator_data): "symmetry.crystal_system": ["Hexagonal"], } ) - matgraphdb.update_nodes("materials", data) + matgraphdb.update_nodes("material", data) df = matgraphdb.read_nodes( - "materials", + "material", columns=["id", "symmetry.crystal_system"], filters=[pc.field("id") == 1000], ).to_pandas() @@ -480,10 +490,10 @@ def test_dependency_updates(matgraphdb, node_generator_data): df = df[df["source_id"] == 1000] assert df.iloc[0]["target_id"] == 5 # Hexagonal id - matgraphdb.delete_nodes("materials", ids=[1000]) + matgraphdb.delete_nodes("material", ids=[1000]) df = matgraphdb.read_nodes( - "materials", + "material", columns=["id", "symmetry.crystal_system"], filters=[pc.field("id") == 1000], ).to_pandas() diff --git a/tests/test_node_store.py b/tests/test_node_store.py deleted file mode 100644 index 4e9724f..0000000 --- a/tests/test_node_store.py +++ /dev/null @@ -1,184 +0,0 @@ -import os -import shutil - -import pandas as pd -import pyarrow as pa -import pytest - -from matgraphdb.core import NodeStore - - -@pytest.fixture -def temp_storage(tmp_path): - """Fixture to create and cleanup a temporary storage directory""" - storage_dir = tmp_path / "test_node_store" - yield str(storage_dir) - if os.path.exists(storage_dir): - shutil.rmtree(storage_dir) - -@pytest.fixture -def node_store(temp_storage): - """Fixture to create a NodeStore instance""" - return NodeStore(temp_storage) - -def test_node_store_initialization(temp_storage): - """Test that NodeStore initializes correctly and creates the storage directory""" - store = NodeStore(temp_storage) - assert os.path.exists(temp_storage) - assert store is not None - -def test_create_nodes_from_dict(node_store): - """Test creating nodes from a dictionary""" - test_data = { - 'name': ['node1', 'node2'], - 'value': [1.0, 2.0] - } - node_store.create_nodes(test_data) - - # Read back and verify - result_table = node_store.read_nodes() - result_df = result_table.to_pandas() - assert len(result_df) == 2 - assert 'name' in result_df.columns - assert 'value' in result_df.columns - assert list(result_df['name']) == ['node1', 'node2'] - assert list(result_df['value']) == [1.0, 2.0] - -def test_create_nodes_from_dataframe(node_store): - """Test creating nodes from a pandas DataFrame""" - df = pd.DataFrame({ - 'name': ['node1', 'node2'], - 'value': [1.0, 2.0] - }) - node_store.create_nodes(df) - - result_table = node_store.read_nodes() - result_df = result_table.to_pandas() - assert len(result_df) == 2 - assert all(result_df['name'] == df['name']) - assert all(result_df['value'] == df['value']) - -def test_read_nodes_with_filters(node_store): - """Test reading nodes with specific filters""" - test_data = { - 'name': ['node1', 'node2', 'node3'], - 'value': [1.0, 2.0, 3.0] - } - node_store.create_nodes(test_data) - - # Read with column filter - result_table = node_store.read_nodes(columns=['id','name']) - result_df = result_table.to_pandas() - assert list(result_df.columns) == ['id','name'] - - # Read with ID filter - first_result_table = node_store.read_nodes() - first_result_df = first_result_table.to_pandas() - first_id = first_result_df['id'].iloc[0] - filtered_result_table = node_store.read_nodes(ids=[first_id]) - filtered_result_df = filtered_result_table.to_pandas() - - assert len(filtered_result_df) == 1 - assert filtered_result_df['id'].iloc[0] == first_id - -def test_update_nodes(node_store): - """Test updating existing nodes""" - # Create initial data - initial_data = { - 'name': ['node1', 'node2'], - 'value': [1.0, 2.0] - } - node_store.create_nodes(initial_data) - - # Get the IDs - existing_nodes_table = node_store.read_nodes() - existing_nodes_df = existing_nodes_table.to_pandas() - first_id = existing_nodes_df['id'].iloc[0] - - # Update the first node - update_data = { - 'id': [first_id], - 'value': [10.0] - } - node_store.update_nodes(update_data) - - # Verify update - updated_nodes_table = node_store.read_nodes() - updated_nodes_df = updated_nodes_table.to_pandas() - assert updated_nodes_df[updated_nodes_df['id'] == first_id]['value'].iloc[0] == 10.0 - -def test_delete_nodes(node_store): - """Test deleting nodes""" - # Create initial data - initial_data = { - 'name': ['node1', 'node2', 'node3'], - 'value': [1.0, 2.0, 3.0] - } - node_store.create_nodes(initial_data) - - # Get the IDs - existing_nodes_table = node_store.read_nodes() - existing_nodes_df = existing_nodes_table.to_pandas() - first_id = existing_nodes_df['id'].iloc[0] - - # Delete one node - node_store.delete_nodes(ids=[first_id]) - - # Verify deletion - remaining_nodes_table = node_store.read_nodes() - remaining_nodes_df = remaining_nodes_table.to_pandas() - assert len(remaining_nodes_df) == 2 - assert first_id not in remaining_nodes_df['id'].values - -def test_delete_columns(node_store): - """Test deleting specific columns""" - test_data = { - 'name': ['node1', 'node2'], - 'value1': [1.0, 2.0], - 'value2': [3.0, 4.0] - } - node_store.create_nodes(test_data) - - # Delete one column - node_store.delete_nodes(columns=['value2']) - - # Verify column deletion - result_table = node_store.read_nodes() - result_df = result_table.to_pandas() - assert 'value2' not in result_df.columns - assert 'value1' in result_df.columns - assert 'name' in result_df.columns - -def test_create_nodes_with_schema(node_store): - """Test creating nodes with a specific schema""" - schema = pa.schema([ - ('name', pa.string()), - ('value', pa.float64()) - ]) - - test_data = { - 'name': ['node1', 'node2'], - 'value': [1.0, 2.0] - } - - node_store.create_nodes(test_data, schema=schema) - result_table = node_store.read_nodes() - result_df = result_table.to_pandas() - assert len(result_df) == 2 - assert result_df['value'].dtype == 'float64' - -def test_normalize_nodes(node_store): - """Test the normalize operation""" - test_data = { - 'name': ['node1', 'node2'], - 'value': [1.0, 2.0] - } - node_store.create_nodes(test_data) - - # This should not raise any errors - node_store.normalize_nodes() - - # Verify data is still accessible after normalization - result_table = node_store.read_nodes() - result_df = result_table.to_pandas() - assert len(result_df) == 2 diff --git a/tests/test_pyg_builder.py b/tests/test_pyg_builder.py index 2a1ffe1..93a5c8a 100644 --- a/tests/test_pyg_builder.py +++ b/tests/test_pyg_builder.py @@ -5,10 +5,10 @@ import pyarrow as pa import pytest import torch +from parquetdb import ParquetGraphDB from torch_geometric.data import HeteroData -from matgraphdb.core.graph_db import GraphDB -from matgraphdb.pyg.data.hetero_graph import GraphBuilder +from matgraphdb.pyg.builders import HeteroGraphBuilder @pytest.fixture @@ -23,7 +23,7 @@ def tmp_dir(tmp_path): @pytest.fixture def graph_db(tmp_dir): """Fixture to create a GraphDB instance with test data.""" - db = GraphDB(storage_path=tmp_dir) + db = ParquetGraphDB(storage_path=tmp_dir) # Add test materials materials_data = pd.DataFrame( @@ -58,7 +58,7 @@ def graph_db(tmp_dir): @pytest.fixture def graph_builder(graph_db): """Fixture to create a GraphBuilder instance.""" - return GraphBuilder(graph_db) + return HeteroGraphBuilder(graph_db) def test_init(graph_builder): @@ -139,7 +139,7 @@ def test_save_load(graph_builder, tmp_dir, graph_db): graph_builder.save(save_path) # Load the graph - loaded_builder = GraphBuilder.load(graph_db, save_path) + loaded_builder = HeteroGraphBuilder.load(graph_db, save_path) # Check if loaded graph matches original assert len(loaded_builder.node_types) == len(graph_builder.node_types)