From 1af8664b0b27488a2b0afc1ccaaab227fd47f67e Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 18:09:48 -0500 Subject: [PATCH 01/15] preliminary get all embeddings for api --- .../gnn_experiments/get_all_embeddings.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 tgrag/experiments/gnn_experiments/get_all_embeddings.py diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py new file mode 100644 index 00000000..1ff69e4b --- /dev/null +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -0,0 +1,152 @@ +import argparse +import logging +import pickle +from pathlib import Path +from typing import Dict, cast + +import torch +from torch.utils.data import DataLoader, TensorDataset +from torch_geometric.loader import NeighborLoader +from tqdm import tqdm + +from tgrag.dataset.temporal_dataset import TemporalDataset +from tgrag.encoders.encoder import Encoder +from tgrag.encoders.rni_encoding import RNIEncoder +from tgrag.gnn.model import Model +from tgrag.utils.args import ModelArguments, parse_args +from tgrag.utils.logger import setup_logging +from tgrag.utils.path import get_root_dir, get_scratch +from tgrag.utils.seed import seed_everything + +parser = argparse.ArgumentParser( + description='Return all embeddings per snapshot.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + '--config-file', + type=str, + default='configs/gnn/base.yaml', + help='Path to yaml configuration file to use', +) + + +def get_embeddings( + model_arguments: ModelArguments, + dataset: TemporalDataset, + weight_directory: Path, +) -> None: + get_root_dir() + data = dataset[0] + device = f'cuda:{model_arguments.device}' if torch.cuda.is_available() else 'cpu' + device = torch.device(device) + logging.info(f'Device found: {device}') + weight_path = weight_directory / f'{model_arguments.model}' / 'best_model.pt' + domain_to_idx_mapping = dataset.get_mapping() + idx_to_domain_mapping = { + v: k for k, v in domain_to_idx_mapping.items() + } # One time, to speed up lookup + tensor_idx = torch.tensor(list(domain_to_idx_mapping.values())) + logging.info('Mapping returned.') + model = Model( + model_name=model_arguments.model, + normalization=model_arguments.normalization, + in_channels=data.num_features, + hidden_channels=model_arguments.hidden_channels, + out_channels=model_arguments.embedding_dimension, + num_layers=model_arguments.num_layers, + dropout=model_arguments.dropout, + ).to(device) + model.load_state_dict(torch.load(weight_path, map_location=device)) + logging.info('Model Loaded.') + model.eval() + + tensor_dataset = TensorDataset(tensor_idx) + + idx_loader = DataLoader( + dataset=tensor_dataset, batch_size=124, shuffle=True, num_workers=4 + ) + + num_nodes = data.num_nodes + all_preds_embeddings = torch.zeros(num_nodes, 256) + + all_domain_to_embeddings = {} + + for idx_batch in tqdm(idx_loader, desc='Batching Domain Index'): + loader = NeighborLoader( + data, + input_nodes=idx_batch, + num_neighbors=[30, 30, 30], + batch_size=1024, + shuffle=False, + ) + + with torch.no_grad(): + for batch in tqdm(loader, desc=f'batch'): + batch = batch.to(device) + preds = model.get_embeddings(batch.x, batch.edge_index) + seed_nodes = batch.n_id[: batch.batch_size] + all_preds_embeddings[seed_nodes] = preds[: batch.batch_size].cpu() + + domain_names = [idx_to_domain_mapping[idx.item()] for idx in idx_batch] + embeddings = all_preds_embeddings[idx_batch].tolist() + assert len(domain_names) == len(embeddings) + batch_dict = dict(zip(domain_names, embeddings)) + + all_domain_to_embeddings.update(batch_dict) + + save_path = weight_directory / 'rni_embeddings.pkl' + save_path.parent.mkdir(parents=True, exist_ok=True) + + with open(save_path, 'wb') as f: + pickle.dump(all_domain_to_embeddings, f) + + logging.info(f'Saved domain embeddings to {save_path}') + + +def main() -> None: + root = get_root_dir() + scratch = get_scratch() + args = parser.parse_args() + config_file_path = root / args.config_file + meta_args, experiment_args = parse_args(config_file_path) + setup_logging(meta_args.log_file_path) + seed_everything(meta_args.global_seed) + + encoder_classes: Dict[str, Encoder] = { + 'RNI': RNIEncoder(64), # TODO: Set this a paramater + } + + encoding_dict: Dict[str, Encoder] = {} + for index, value in meta_args.encoder_dict.items(): + encoder_class = encoder_classes[value] + encoding_dict[index] = encoder_class + + dataset = TemporalDataset( + root=f'{root}/data/', + node_file=cast(str, meta_args.node_file), + edge_file=cast(str, meta_args.edge_file), + target_file=cast(str, meta_args.target_file), + target_col=meta_args.target_col, + edge_src_col=meta_args.edge_src_col, + edge_dst_col=meta_args.edge_dst_col, + index_col=meta_args.index_col, + encoding=encoding_dict, + seed=meta_args.global_seed, + processed_dir=f'{scratch}/{meta_args.processed_location}', + ) # Map to .to_cpu() + logging.info('In-Memory Dataset loaded.') + weight_directory = ( + root / cast(str, meta_args.weights_directory) / f'{meta_args.target_col}' + ) + + for experiment, experiment_arg in experiment_args.exp_args.items(): + logging.info(f'\n**Running**: {experiment}') + get_embeddings( + experiment_arg.model_args, + dataset, + weight_directory, + ) + + +if __name__ == '__main__': + main() From 2fe742426d3004222c1dc2e482c4b59de1a5f0e1 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 18:47:18 -0500 Subject: [PATCH 02/15] logging to debug --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 1ff69e4b..43cbd25c 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -72,6 +72,7 @@ def get_embeddings( all_domain_to_embeddings = {} for idx_batch in tqdm(idx_loader, desc='Batching Domain Index'): + logging.info(f'idx_batch size: {idx_batch.size()}') loader = NeighborLoader( data, input_nodes=idx_batch, From 692424ed66c19c3ce61378d84cf6cc461840a4de Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 18:53:19 -0500 Subject: [PATCH 03/15] converting batch list to tensor --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 43cbd25c..cddb8d83 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -72,6 +72,7 @@ def get_embeddings( all_domain_to_embeddings = {} for idx_batch in tqdm(idx_loader, desc='Batching Domain Index'): + idx_batch = torch.Tensor(idx_batch) logging.info(f'idx_batch size: {idx_batch.size()}') loader = NeighborLoader( data, From 53406877c883042573ff5c80d19782388441fa05 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 18:59:19 -0500 Subject: [PATCH 04/15] debugging logs --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index cddb8d83..9e5eaea2 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -72,7 +72,8 @@ def get_embeddings( all_domain_to_embeddings = {} for idx_batch in tqdm(idx_loader, desc='Batching Domain Index'): - idx_batch = torch.Tensor(idx_batch) + logging.info(idx_batch) + idx_batch = torch.tensor(idx_batch) logging.info(f'idx_batch size: {idx_batch.size()}') loader = NeighborLoader( data, From 67591f21764a7a4d3d77ae80643168ebef5dc105 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 19:06:12 -0500 Subject: [PATCH 05/15] fix: list -> tensor, batch is a list of tensors, removing logs --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 9e5eaea2..73a4c67b 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -71,10 +71,8 @@ def get_embeddings( all_domain_to_embeddings = {} - for idx_batch in tqdm(idx_loader, desc='Batching Domain Index'): - logging.info(idx_batch) - idx_batch = torch.tensor(idx_batch) - logging.info(f'idx_batch size: {idx_batch.size()}') + for batch in tqdm(idx_loader, desc='Batching Domain Index'): + idx_batch = batch[0] loader = NeighborLoader( data, input_nodes=idx_batch, From 86247a129afe78edfcb120b910ed6d052f70e20d Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 19:23:07 -0500 Subject: [PATCH 06/15] optimizing for 50 million nodes --- .../gnn_experiments/get_all_embeddings.py | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 73a4c67b..dc766970 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -1,9 +1,9 @@ import argparse import logging -import pickle from pathlib import Path from typing import Dict, cast +import numpy as np import torch from torch.utils.data import DataLoader, TensorDataset from torch_geometric.loader import NeighborLoader @@ -69,37 +69,32 @@ def get_embeddings( num_nodes = data.num_nodes all_preds_embeddings = torch.zeros(num_nodes, 256) - all_domain_to_embeddings = {} - - for batch in tqdm(idx_loader, desc='Batching Domain Index'): - idx_batch = batch[0] - loader = NeighborLoader( - data, - input_nodes=idx_batch, - num_neighbors=[30, 30, 30], - batch_size=1024, - shuffle=False, - ) - - with torch.no_grad(): - for batch in tqdm(loader, desc=f'batch'): - batch = batch.to(device) - preds = model.get_embeddings(batch.x, batch.edge_index) - seed_nodes = batch.n_id[: batch.batch_size] - all_preds_embeddings[seed_nodes] = preds[: batch.batch_size].cpu() + loader = NeighborLoader( + data, + input_nodes=tensor_idx, + num_neighbors=[30, 30, 30], + batch_size=4096, + shuffle=False, + num_workers=8, + persistent_workers=True, + ) - domain_names = [idx_to_domain_mapping[idx.item()] for idx in idx_batch] - embeddings = all_preds_embeddings[idx_batch].tolist() - assert len(domain_names) == len(embeddings) - batch_dict = dict(zip(domain_names, embeddings)) + model.eval() + with torch.no_grad(): + for batch in tqdm(loader, desc=f'Inference'): + batch = batch.to(device) + preds = model.get_embeddings(batch.x, batch.edge_index) + seed_nodes = batch.n_id[: batch.batch_size] + all_preds_embeddings[seed_nodes] = preds[: batch.batch_size].cpu() - all_domain_to_embeddings.update(batch_dict) + np.save(weight_directory / 'embeddings.npy', all_preds_embeddings.numpy()) - save_path = weight_directory / 'rni_embeddings.pkl' + save_path = weight_directory / 'domain_names.txt' save_path.parent.mkdir(parents=True, exist_ok=True) - with open(save_path, 'wb') as f: - pickle.dump(all_domain_to_embeddings, f) + with open(save_path, 'w') as f: + for i in range(num_nodes): + f.write(f'{idx_to_domain_mapping[i]}\n') logging.info(f'Saved domain embeddings to {save_path}') From 6c61b59868d4be8dbe48d1f5e157390411506d39 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 20:38:06 -0500 Subject: [PATCH 07/15] sharding the pickle dump of domain, embeddings dict --- .../gnn_experiments/get_all_embeddings.py | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index dc766970..d2e881d1 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -1,11 +1,10 @@ import argparse import logging +import pickle from pathlib import Path from typing import Dict, cast -import numpy as np import torch -from torch.utils.data import DataLoader, TensorDataset from torch_geometric.loader import NeighborLoader from tqdm import tqdm @@ -60,12 +59,6 @@ def get_embeddings( logging.info('Model Loaded.') model.eval() - tensor_dataset = TensorDataset(tensor_idx) - - idx_loader = DataLoader( - dataset=tensor_dataset, batch_size=124, shuffle=True, num_workers=4 - ) - num_nodes = data.num_nodes all_preds_embeddings = torch.zeros(num_nodes, 256) @@ -79,6 +72,13 @@ def get_embeddings( persistent_workers=True, ) + shard_size = 1_000_000 + current_shard_dict = {} + shard_count = 0 + + save_dir = weight_directory / 'shards' + save_dir.parent.mkdir(parents=True, exist_ok=True) + model.eval() with torch.no_grad(): for batch in tqdm(loader, desc=f'Inference'): @@ -87,16 +87,29 @@ def get_embeddings( seed_nodes = batch.n_id[: batch.batch_size] all_preds_embeddings[seed_nodes] = preds[: batch.batch_size].cpu() - np.save(weight_directory / 'embeddings.npy', all_preds_embeddings.numpy()) + for i, node_idx in enumerate(seed_nodes): + name = idx_to_domain_mapping[node_idx.item()] + embedding = preds[i].cpu().tolist() + current_shard_dict[name] = embedding + + if len(current_shard_dict) >= shard_size: + shard_path = save_dir / f'shard_{shard_count}.pkl' + with open(shard_path, 'wb') as f: + pickle.dump(current_shard_dict, f) + + logging.info( + f'Saved shard {shard_count} with {len(current_shard_dict)}' + ) - save_path = weight_directory / 'domain_names.txt' - save_path.parent.mkdir(parents=True, exist_ok=True) + current_shard_dict = {} + shard_count += 1 - with open(save_path, 'w') as f: - for i in range(num_nodes): - f.write(f'{idx_to_domain_mapping[i]}\n') + if current_shard_dict: + shard_path = save_dir / f'shard_{shard_count}.pkl' + with open(shard_path, 'wb') as f: + pickle.dump(current_shard_dict, f) - logging.info(f'Saved domain embeddings to {save_path}') + logging.info(f'Saved domain embeddings to {save_dir}') def main() -> None: From cce95b93dc629a7eb712f293aa7710b76ebf5212 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 21:00:17 -0500 Subject: [PATCH 08/15] fix: remove parent dir creation --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index d2e881d1..08cdf0f0 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -77,7 +77,7 @@ def get_embeddings( shard_count = 0 save_dir = weight_directory / 'shards' - save_dir.parent.mkdir(parents=True, exist_ok=True) + save_dir.mkdir(parents=True, exist_ok=True) model.eval() with torch.no_grad(): From 24ef3a6378bf48aac0fd367fd9dde92ec5ce2661 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Mon, 5 Jan 2026 21:24:03 -0500 Subject: [PATCH 09/15] increasing shard size --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 08cdf0f0..ddb2fa25 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -72,7 +72,7 @@ def get_embeddings( persistent_workers=True, ) - shard_size = 1_000_000 + shard_size = 10_000_000 current_shard_dict = {} shard_count = 0 From 5f0cf55f3e828851e81601db9ee16edfe7c99768 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Tue, 6 Jan 2026 22:37:13 -0500 Subject: [PATCH 10/15] changing path to scratch, returning required imports --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index ddb2fa25..bc8f7bd8 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -33,6 +33,7 @@ def get_embeddings( model_arguments: ModelArguments, dataset: TemporalDataset, weight_directory: Path, + scratch: Path, ) -> None: get_root_dir() data = dataset[0] @@ -76,7 +77,7 @@ def get_embeddings( current_shard_dict = {} shard_count = 0 - save_dir = weight_directory / 'shards' + save_dir = scratch / weight_directory / 'shards' save_dir.mkdir(parents=True, exist_ok=True) model.eval() @@ -154,6 +155,7 @@ def main() -> None: experiment_arg.model_args, dataset, weight_directory, + scratch, ) From f0b6eda56d20e9fe48eb529016d82d7e230cb43b Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Tue, 6 Jan 2026 22:48:16 -0500 Subject: [PATCH 11/15] adding logs --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index bc8f7bd8..479789d6 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -78,6 +78,7 @@ def get_embeddings( shard_count = 0 save_dir = scratch / weight_directory / 'shards' + logging.info(f'save directory: {save_dir}') save_dir.mkdir(parents=True, exist_ok=True) model.eval() From f48a0af4050fdde14ea99bc7f80ca1c35e477563 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Wed, 7 Jan 2026 14:36:35 -0500 Subject: [PATCH 12/15] adding txt file for domains, reducing to 5gb per shard --- .../experiments/gnn_experiments/get_all_embeddings.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 479789d6..3952f666 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -73,11 +73,12 @@ def get_embeddings( persistent_workers=True, ) - shard_size = 10_000_000 + shard_size = 3_000_000 current_shard_dict = {} + current_shard_name = [] shard_count = 0 - save_dir = scratch / weight_directory / 'shards' + save_dir = scratch / 'shards' logging.info(f'save directory: {save_dir}') save_dir.mkdir(parents=True, exist_ok=True) @@ -93,17 +94,23 @@ def get_embeddings( name = idx_to_domain_mapping[node_idx.item()] embedding = preds[i].cpu().tolist() current_shard_dict[name] = embedding + current_shard_name.append(name) if len(current_shard_dict) >= shard_size: shard_path = save_dir / f'shard_{shard_count}.pkl' with open(shard_path, 'wb') as f: pickle.dump(current_shard_dict, f) + names_path = save_dir / f'domains_{shard_count}.txt' + + with open(names_path, 'w') as f: + f.write('\n'.join(current_shard_name)) logging.info( f'Saved shard {shard_count} with {len(current_shard_dict)}' ) current_shard_dict = {} + current_shard_name = [] shard_count += 1 if current_shard_dict: From 0e1f0b7840e036f834761cc5fa26a56ad0ace9e6 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Wed, 7 Jan 2026 16:39:51 -0500 Subject: [PATCH 13/15] changing save path --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 3952f666..37d872fe 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -163,7 +163,7 @@ def main() -> None: experiment_arg.model_args, dataset, weight_directory, - scratch, + scratch / cast(str, meta_args.node_file), ) From 5f93ef364da18ed52e80cc0e9a13ab553ee562b5 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Wed, 7 Jan 2026 17:41:54 -0500 Subject: [PATCH 14/15] changing path to directory folder --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 37d872fe..58aaa5d0 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -163,7 +163,7 @@ def main() -> None: experiment_arg.model_args, dataset, weight_directory, - scratch / cast(str, meta_args.node_file), + scratch / cast(str, meta_args.database_folder), ) From eeea06ac70f043dc77904134b15117e53e45ad84 Mon Sep 17 00:00:00 2001 From: sebastian9991 Date: Wed, 7 Jan 2026 18:56:07 -0500 Subject: [PATCH 15/15] fix: text file is created for last shard --- tgrag/experiments/gnn_experiments/get_all_embeddings.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tgrag/experiments/gnn_experiments/get_all_embeddings.py b/tgrag/experiments/gnn_experiments/get_all_embeddings.py index 58aaa5d0..2ff1da86 100644 --- a/tgrag/experiments/gnn_experiments/get_all_embeddings.py +++ b/tgrag/experiments/gnn_experiments/get_all_embeddings.py @@ -115,9 +115,13 @@ def get_embeddings( if current_shard_dict: shard_path = save_dir / f'shard_{shard_count}.pkl' + names_path = save_dir / f'domains_{shard_count}.txt' with open(shard_path, 'wb') as f: pickle.dump(current_shard_dict, f) + with open(names_path, 'w') as f: + f.write('\n'.join(current_shard_name)) + logging.info(f'Saved domain embeddings to {save_dir}')