diff --git a/pyproject.toml b/pyproject.toml index 8a277b6a..63172be3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "cloudscraper>=1.2.71", "BeautifulSoup4>=4.13.0", "jsonpath-ng>=1.7.0", + "zarr>=2.18.2", ] [tool.flit.module] diff --git a/tgrag/dataset/zarr_rni_dataset.py b/tgrag/dataset/zarr_rni_dataset.py new file mode 100644 index 00000000..d209bae1 --- /dev/null +++ b/tgrag/dataset/zarr_rni_dataset.py @@ -0,0 +1,200 @@ +import logging +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional + +import numpy as np +import pandas as pd +import torch +import zarr +from sklearn.model_selection import train_test_split +from torch_geometric.data import Data, InMemoryDataset + +from tgrag.utils.dataset_loading import ( + construct_zarr_rni_backend, + load_large_edge_csv, +) +from tgrag.utils.load_labels import get_full_dict +from tgrag.utils.target_generation import generate_exact_targets_csv + +EMB_DIM = 64 + + +class ZarrDataset(InMemoryDataset): + def __init__( + self, + root: str, + database_folder: str, + node_file: str = 'features.csv', + edge_file: str = 'edges.csv', + target_file: str = 'target.csv', + target_col: str = 'score', + target_index_name: str = 'nid', + target_index_col: int = 0, + edge_src_col: str = 'src', + edge_dst_col: str = 'dst', + index_col: int = 1, + index_name: str = 'node_id', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + seed: int = 42, + processed_dir: Optional[str] = None, + ): + self.node_file = node_file + self.edge_file = edge_file + self.target_file = target_file + self.target_col = target_col + self.edge_src_col = edge_src_col + self.edge_dst_col = edge_dst_col + self.index_col = index_col + self.index_name = index_name + self.target_index_name = target_index_name + self.target_index_col = target_index_col + self.seed = seed + self._custome_processed_dir = processed_dir + self.database_folder = database_folder + super().__init__(root, transform, pre_transform) + self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False) + + @property + def raw_dir(self) -> str: + return os.path.join(self.root) + + @property + def raw_file_names(self) -> List[str]: + return [self.node_file, self.edge_file] + + @property + def processed_dir(self) -> str: + if self._custome_processed_dir is not None: + return self._custome_processed_dir + return super().processed_dir + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + pass + + def process(self) -> None: + node_path = os.path.join(self.raw_dir, self.node_file) + edge_path = os.path.join(self.raw_dir, self.edge_file) + target_path = os.path.join(self.raw_dir, self.target_file) + if os.path.exists(target_path): + logging.info('Target file already exists.') + else: + logging.info('Generating target file.') + dqr = get_full_dict() + generate_exact_targets_csv(node_path, target_path, dqr) + + # IF we use domains as our unique ID then we can construct the zarr database in a seperate script + + zarr_path = Path(self.database_folder) / 'embeddings.zarr' + zarr_path.parent.mkdir(parents=True, exist_ok=True) + zarr_storage = zarr.DirectoryStore(str(zarr_path)) + + logging.info('***Constructing Mapping***') + mapping, full_index = construct_zarr_rni_backend( + path=node_path, + zarr_storage=zarr_storage, + index_col=0, + ) + + df_target = pd.read_csv(target_path) + logging.info(f'Size of target dataframe: {df_target.shape}') + + mapping_index = [mapping[domain.strip()] for domain in df_target['domain']] + df_target.index = mapping_index + logging.info(f'Size of mapped target dataframe: {df_target.shape}') + + missing_idx = full_index.difference(mapping_index) + filler = pd.DataFrame( + {col: np.nan for col in df_target.columns}, index=missing_idx + ) + df_target = pd.concat([df_target, filler]) + df_target.sort_index(inplace=True) + logging.info(f'Size of filled target dataframe: {df_target.shape}') + score = torch.tensor( + df_target[self.target_col].astype('float32').fillna(-1).values, + dtype=torch.float, + ) + logging.info(f'Size of score vector: {score.size()}') + + labeled_mask = score != -1.0 + + labeled_idx = torch.nonzero(torch.tensor(labeled_mask), as_tuple=True)[0] + labeled_scores = score[labeled_idx].squeeze().numpy() + + if labeled_scores.size == 0: + raise ValueError( + f"No labeled nodes found in target column '{self.target_col}'" + ) + + logging.info('***Constructing Edge Matrix***') + edge_index, edge_attr = load_large_edge_csv( + path=edge_path, + src_index_col=self.edge_src_col, + dst_index_col=self.edge_dst_col, + mapping=mapping, + encoders=None, + ) + logging.info('***Edge Matrix Constructed***') + + data = Data(y=score, edge_index=edge_index, edge_attr=edge_attr) + data.num_nodes = len(mapping) + + data.labeled_mask = labeled_mask.detach().clone().bool() + + quantiles = np.quantile(labeled_scores, [1 / 3, 2 / 3]) + quartile_labels = np.digitize(labeled_scores, bins=quantiles) + + train_idx, temp_idx, _, quartile_labels_temp = train_test_split( + labeled_idx, + quartile_labels, + train_size=0.6, + stratify=quartile_labels, + random_state=self.seed, + ) + + valid_idx, test_idx = train_test_split( + temp_idx, + train_size=0.5, + stratify=quartile_labels_temp, + random_state=self.seed, + ) + + train_idx = torch.as_tensor(train_idx) + logging.info(f'Train size: {train_idx.size()}') + valid_idx = torch.as_tensor(valid_idx) + logging.info(f'Valid size: {valid_idx.size()}') + test_idx = torch.as_tensor(test_idx) + logging.info(f'Test size: {test_idx.size()}') + + # Set global indices for our transductive nodes: + num_nodes = data.num_nodes + data.train_mask = torch.zeros(num_nodes, dtype=torch.bool) + data.train_mask[train_idx] = True + data.valid_mask = torch.zeros(num_nodes, dtype=torch.bool) + data.valid_mask[valid_idx] = True + data.test_mask = torch.zeros(num_nodes, dtype=torch.bool) + data.test_mask[test_idx] = True + data.idx_dict = { + 'train': train_idx, + 'valid': valid_idx, + 'test': test_idx, + } + + torch.save(mapping, self.processed_dir + '/mapping.pt') + torch.save(self.collate([data]), self.processed_paths[0]) + + def get_idx_split(self) -> Dict: + data = self[0] + if hasattr(data, 'idx_dict') and data.idx_dict is not None: + return data.idx_dict + raise TypeError('idx split is empty.') + + def get_mapping(self) -> Dict: + if not hasattr(self, '_mapping'): + self._mapping = torch.load(self.processed_dir + '/mapping.pt') + return self._mapping diff --git a/tgrag/encoders/text_encoder.py b/tgrag/encoders/text_encoder.py index fb80b5ab..6ef4d739 100644 --- a/tgrag/encoders/text_encoder.py +++ b/tgrag/encoders/text_encoder.py @@ -12,9 +12,9 @@ class TextEncoder(Encoder): def __init__( self, - model_name: str = 'Qwen/Qwen3-Embedding-0.6B', + model_name: str = 'Qwen/Qwen3-Embedding-0.6B', # TODO: Too small; 4B/8B> paramaters size device: str | None = None, - max_seq_length: int = 256, + max_seq_length: int = 1024, batch_size: int = 128, use_fp16: bool = True, ): @@ -22,8 +22,8 @@ def __init__( self.model = SentenceTransformer(model_name, device=device) self.model.max_seq_length = max_seq_length - if use_fp16 and self.device is not None and 'cuda' in self.device: - self.model = self.model.half() + # if use_fp16 and self.device is not None and "cuda" in self.device: + # self.model = self.model.half() self.batch_size = batch_size diff --git a/tgrag/experiments/gnn_experiments/gnn_experiment_zarr_extension.py b/tgrag/experiments/gnn_experiments/gnn_experiment_zarr_extension.py new file mode 100644 index 00000000..eb6b7919 --- /dev/null +++ b/tgrag/experiments/gnn_experiments/gnn_experiment_zarr_extension.py @@ -0,0 +1,363 @@ +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn.functional as F +import zarr +from torch import Tensor +from torch_geometric.loader import NeighborLoader +from torcheval.metrics.functional import r2_score +from tqdm import tqdm + +from tgrag.dataset.zarr_rni_dataset import ZarrDataset +from tgrag.gnn.model import Model +from tgrag.utils.args import DataArguments, ModelArguments +from tgrag.utils.logger import Logger +from tgrag.utils.plot import ( + Scoring, + mean_across_lists, + plot_avg_loss, + plot_avg_loss_r2, + plot_pred_target_distributions_bin_list, +) +from tgrag.utils.prob import ragged_mean_by_index +from tgrag.utils.save import save_loss_results + + +def train( + model: torch.nn.Module, + train_loader: NeighborLoader, + optimizer: torch.optim.AdamW, + embeddings: zarr.Array, +) -> Tuple[float, float, Tensor, Tensor]: + model.train() + device = next(model.parameters()).device + + total_loss = 0 + total_batches = 0 + all_preds = [] + all_targets = [] + + total_zarr_time = 0.0 + total_to_tensor_time = 0.0 + total_forward_time = 0.0 + + for batch in tqdm(train_loader, desc='Batchs', leave=False): + optimizer.zero_grad() + batch = batch.to(device) + + t0 = time.perf_counter() + batch_indices = batch.n_id.to(torch.int64).cpu().numpy() + x_batch_numpy = embeddings[batch_indices] + t1 = time.perf_counter() + zarr_time = t1 - t0 + total_zarr_time += zarr_time + + t2 = time.perf_counter() + x_batch_tensor = torch.from_numpy(x_batch_numpy).to(device) + t3 = time.perf_counter() + to_tensor_time = t3 - t2 + total_to_tensor_time += to_tensor_time + + t4 = time.perf_counter() + preds = model(x_batch_tensor, batch.edge_index).squeeze() + t5 = time.perf_counter() + forward_time = t5 - t4 + total_forward_time += forward_time + + targets = batch.y + train_mask = batch.train_mask + + if train_mask.sum() == 0: + continue + + loss = F.l1_loss(preds[train_mask], targets[train_mask]) + loss.backward() + optimizer.step() + + total_loss += loss.item() + total_batches += 1 + all_preds.append(preds[train_mask]) + all_targets.append(targets[train_mask]) + + r2 = r2_score(torch.cat(all_preds), torch.cat(all_targets)).item() + avg_preds = ragged_mean_by_index(all_preds) + avg_targets = ragged_mean_by_index(all_targets) + mse = total_loss / total_batches + + return (mse, r2, avg_preds, avg_targets) + + +def train_( + model: torch.nn.Module, + train_loader: NeighborLoader, + optimizer: torch.optim.AdamW, + embeddings: zarr.Array, +) -> Tuple[float, float, List[float], List[float]]: + model.train() + device = next(model.parameters()).device + total_loss = 0 + total_batches = 0 + all_preds = [] + all_targets = [] + pred_scores = [] + target_scores = [] + total_zarr_time = 0.0 + total_to_tensor_time = 0.0 + total_forward_time = 0.0 + for batch in tqdm(train_loader, desc='Batchs', leave=False): + optimizer.zero_grad() + batch = batch.to(device) + + t0 = time.perf_counter() + batch_indices = batch.n_id.to(torch.int64).cpu().numpy() + x_batch_numpy = embeddings[batch_indices] + t1 = time.perf_counter() + zarr_time = t1 - t0 + total_zarr_time += zarr_time + + t2 = time.perf_counter() + x_batch_tensor = torch.from_numpy(x_batch_numpy).to(device) + t3 = time.perf_counter() + to_tensor_time = t3 - t2 + total_to_tensor_time += to_tensor_time + + t4 = time.perf_counter() + preds = model(x_batch_tensor, batch.edge_index).squeeze() + t5 = time.perf_counter() + forward_time = t5 - t4 + total_forward_time += forward_time + targets = batch.y + train_mask = batch.train_mask + if train_mask.sum() == 0: + continue + + loss = F.l1_loss(preds[train_mask], targets[train_mask]) + loss.backward() + optimizer.step() + total_loss += loss.item() + total_batches += 1 + all_preds.append(preds[train_mask]) + all_targets.append(targets[train_mask]) + for pred in preds[train_mask]: + pred_scores.append(pred.item()) + for targ in targets[train_mask]: + target_scores.append(targ.item()) + + r2 = r2_score(torch.cat(all_preds), torch.cat(all_targets)).item() + ragged_mean_by_index(all_preds) + ragged_mean_by_index(all_targets) + mse = total_loss / total_batches + return (mse, r2, pred_scores, target_scores) + + +@torch.no_grad() +def evaluate( + model: torch.nn.Module, + loader: NeighborLoader, + mask_name: str, + embeddings: zarr.Array, +) -> Tuple[float, float, float, float]: + model.eval() + device = next(model.parameters()).device + total_loss = 0 + total_mean_loss = 0 + total_random_loss = 0 + total_batches = 0 + all_preds = [] + all_targets = [] + for batch in loader: + batch = batch.to(device) + ##TODO: Integrate zarr read: + batch_indices = ( + batch.n_id.to(torch.int64).cpu().numpy() + ) # Will this slow down? GPU -> index to cpu? + x_batch_numpy = embeddings[batch_indices] # Zarr Read + x_batch_tensor = torch.from_numpy(x_batch_numpy).to( + device + ) # Slow? numpy -> tensor + preds = model(x_batch_tensor, batch.edge_index).squeeze() + targets = batch.y + mask = getattr(batch, mask_name) + if mask.sum() == 0: + continue + mean_preds = torch.full(batch.y[mask].size(), 0.546).to(device) + random_preds = torch.rand(batch.y[mask].size(0)).to(device) + loss = F.l1_loss(preds[mask], targets[mask]) + mean_loss = F.l1_loss(mean_preds, targets[mask]) + random_loss = F.l1_loss(random_preds, targets[mask]) + + total_loss += loss.item() + total_mean_loss += mean_loss.item() + total_random_loss += random_loss.item() + total_batches += 1 + + all_preds.append(preds[mask]) + all_targets.append(targets[mask]) + + r2 = r2_score(torch.cat(all_preds), torch.cat(all_targets)).item() + mse = total_loss / total_batches + mse_mean = total_mean_loss / total_batches + mse_random = total_random_loss / total_batches + return (mse, mse_mean, mse_random, r2) + + +def run_gnn_baseline_zarr_backend( + data_arguments: DataArguments, + model_arguments: ModelArguments, + weight_directory: Path, + dataset: ZarrDataset, + embeddings: zarr.Array, +) -> None: + data = dataset[0] + split_idx = dataset.get_idx_split() + logging.info( + 'Setting up training for task of: %s on model: %s', + data_arguments.task_name, + model_arguments.model, + ) + device = f'cuda:{model_arguments.device}' if torch.cuda.is_available() else 'cpu' + device = torch.device(device) + + logging.info(f'Device found: {device}') + + logging.info(f'Training set size: {split_idx["train"].size()}') + logging.info(f'Validation set size: {split_idx["valid"].size()}') + logging.info(f'Testing set size: {split_idx["test"].size()}') + + train_loader = NeighborLoader( + data, + input_nodes=split_idx['train'], + num_neighbors=model_arguments.num_neighbors, + batch_size=model_arguments.batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) + logging.info('Train loader created') + + val_loader = NeighborLoader( + data, + input_nodes=split_idx['valid'], + num_neighbors=model_arguments.num_neighbors, + batch_size=model_arguments.batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) + + logging.info('Valid loader created') + test_loader = NeighborLoader( + data, + input_nodes=split_idx['test'], + num_neighbors=model_arguments.num_neighbors, + batch_size=model_arguments.batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) + logging.info('Test loader created') + + logger = Logger(model_arguments.runs) + loss_tuple_run_mse: List[List[Tuple[float, float, float, float, float]]] = [] + loss_tuple_run_r2: List[List[Tuple[float, float, float]]] = [] + final_avg_preds: List[List[float]] = [] + final_avg_targets: List[List[float]] = [] + global_best_val_loss = float('inf') + best_state_dict = None + logging.info('*** Training ***') + for run in tqdm(range(model_arguments.runs), desc='Runs'): + model = Model( + model_name=model_arguments.model, + normalization=model_arguments.normalization, + in_channels=embeddings.shape[1], + hidden_channels=model_arguments.hidden_channels, + out_channels=model_arguments.embedding_dimension, + num_layers=model_arguments.num_layers, + dropout=model_arguments.dropout, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=model_arguments.lr) + loss_tuple_epoch_mse: List[Tuple[float, float, float, float, float]] = [] + loss_tuple_epoch_r2: List[Tuple[float, float, float]] = [] + epoch_avg_preds: List[List[float]] = [] + epoch_avg_targets: List[List[float]] = [] + for _ in tqdm(range(1, 1 + model_arguments.epochs), desc='Epochs'): + _, _, batch_preds, batch_targets = train_( + model, train_loader, optimizer, embeddings + ) + epoch_avg_preds.append(batch_preds) + epoch_avg_targets.append(batch_targets) + train_loss, _, _, train_r2 = evaluate( + model, train_loader, 'train_mask', embeddings + ) + valid_loss, valid_mean_baseline_loss, _, valid_r2 = evaluate( + model, val_loader, 'valid_mask', embeddings + ) + test_loss, test_mean_baseline_loss, test_random_baseline_loss, test_r2 = ( + evaluate(model, test_loader, 'test_mask', embeddings) + ) + result = ( + train_loss, + valid_loss, + test_loss, + test_mean_baseline_loss, + test_random_baseline_loss, + ) + result_r2 = (train_r2, valid_r2, test_r2) + loss_tuple_epoch_mse.append(result) + loss_tuple_epoch_r2.append(result_r2) + logger.add_result( + run, (train_loss, valid_loss, test_loss, valid_mean_baseline_loss) + ) + if valid_loss < global_best_val_loss: + global_best_val_loss = valid_loss + best_state_dict = model.state_dict() + + final_avg_preds.append(mean_across_lists(epoch_avg_preds)) + final_avg_targets.append(mean_across_lists(epoch_avg_targets)) + loss_tuple_run_mse.append(loss_tuple_epoch_mse) + loss_tuple_run_r2.append(loss_tuple_epoch_r2) + + best_model_dir = weight_directory / f'{model_arguments.model}' + best_model_dir.mkdir(parents=True, exist_ok=True) + best_model_path = best_model_dir / 'best_model.pt' + torch.save(best_state_dict, best_model_path) + logging.info(f'Model: {model_arguments} weights saved to: {best_model_path}') + logging.info('*** Statistics ***') + logging.info(logger.get_statistics()) + logging.info(logger.get_avg_statistics()) + logging.info( + logger.per_run_within_error( + preds=final_avg_preds, targets=final_avg_targets, percent=10 + ) + ) + logging.info( + logger.per_run_within_error( + preds=final_avg_preds, targets=final_avg_targets, percent=5 + ) + ) + logging.info( + logger.per_run_within_error( + preds=final_avg_preds, targets=final_avg_targets, percent=1 + ) + ) + logging.info('Constructing plots') + plot_pred_target_distributions_bin_list( + preds=final_avg_preds, + targets=final_avg_targets, + model_name=model_arguments.model, + bins=100, + ) + plot_avg_loss( + loss_tuple_run_mse, model_arguments.model, Scoring.mae, 'loss_plot.png' + ) + plot_avg_loss_r2( + loss_tuple_run_r2, model_arguments.model, Scoring.r2, 'r2_plot.png' + ) + logging.info('Saving pkl of results') + save_loss_results(loss_tuple_run_mse, model_arguments.model, 'TODO') diff --git a/tgrag/experiments/gnn_experiments/main_zarr.py b/tgrag/experiments/gnn_experiments/main_zarr.py new file mode 100644 index 00000000..e44fcf50 --- /dev/null +++ b/tgrag/experiments/gnn_experiments/main_zarr.py @@ -0,0 +1,84 @@ +import argparse +import logging +from typing import cast + +import zarr + +from tgrag.dataset.zarr_rni_dataset import ZarrDataset +from tgrag.experiments.gnn_experiments.gnn_experiment_zarr_extension import ( + run_gnn_baseline_zarr_backend, +) +from tgrag.utils.args import parse_args +from tgrag.utils.logger import setup_logging +from tgrag.utils.path import get_root_dir, get_scratch +from tgrag.utils.plot import ( + load_all_loss_tuples, + plot_metric_across_models, + plot_metric_per_encoder, + plot_model_per_encoder, +) +from tgrag.utils.seed import seed_everything + +parser = argparse.ArgumentParser( + description='GNN Experiments.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + '--config-file', + type=str, + default='configs/gnn/base.yaml', + help='Path to yaml configuration file to use', +) + + +def main() -> None: + root = get_root_dir() + scratch = get_scratch() + args = parser.parse_args() + config_file_path = root / args.config_file + meta_args, experiment_args = parse_args(config_file_path) + setup_logging(meta_args.log_file_path) + seed_everything(meta_args.global_seed) + + logging.info(f'Scratch Location: {scratch}') + + dataset = ZarrDataset( + root=f'{root}/data/', + node_file=cast(str, meta_args.node_file), + edge_file=cast(str, meta_args.edge_file), + target_file=cast(str, meta_args.target_file), + target_col=meta_args.target_col, + edge_src_col=meta_args.edge_src_col, + edge_dst_col=meta_args.edge_dst_col, + index_col=meta_args.index_col, + seed=meta_args.global_seed, + processed_dir=cast(str, meta_args.processed_location), + database_folder=cast(str, meta_args.database_folder), + ) + logging.info('In-Memory Zarr Dataset loaded.') + zarr_path = scratch / cast(str, meta_args.database_folder) / 'embeddings.zarr' + logging.info(f'Reading Zarr storage from: {zarr_path}') + embeddings = zarr.open_array(str(zarr_path)) + + for experiment, experiment_arg in experiment_args.exp_args.items(): + logging.info(f'\n**Running**: {experiment}') + run_gnn_baseline_zarr_backend( + data_arguments=experiment_arg.data_args, + model_arguments=experiment_arg.model_args, + weight_directory=root + / cast(str, meta_args.weights_directory) + / f'{meta_args.target_col}', + dataset=dataset, + embeddings=embeddings, + ) + results = load_all_loss_tuples() + logging.info('Constructing Plots, across models') + plot_metric_across_models(results) + logging.info('Constructing Plots, metric per-encoder') + plot_metric_per_encoder(results) + logging.info('Constructing Plots, model per-encoder') + plot_model_per_encoder(results) + + +if __name__ == '__main__': + main() diff --git a/tgrag/utils/dataset_loading.py b/tgrag/utils/dataset_loading.py index c2052180..7c762edc 100644 --- a/tgrag/utils/dataset_loading.py +++ b/tgrag/utils/dataset_loading.py @@ -1,9 +1,11 @@ import logging import pickle -from typing import Dict, Tuple +from typing import Dict, List, Tuple +import numpy as np import pandas as pd import torch +import zarr from torch import Tensor from tqdm import tqdm @@ -151,3 +153,75 @@ def get_seed_embeddings() -> Dict[str, torch.Tensor]: } return embeddings_lookup + + +def get_mapping( + path: str, index_col: int, chunk_size: int = 500_000 +) -> Tuple[Dict, pd.Index]: + dfs = [] + total_rows = sum(1 for _ in open(path)) - 1 + with pd.read_csv(path, index_col=index_col, chunksize=chunk_size) as reader: + for chunk in tqdm( + reader, total=total_rows // chunk_size + 1, desc='Reading node CSV' + ): + dfs.append(chunk) + + df = pd.concat(dfs, axis=0) + mapping = { + index: i for i, index in tqdm(enumerate(df.index.unique()), desc='Indexing') + } + + return mapping, pd.RangeIndex(len(mapping)) + + +def construct_zarr_rni_backend( + path: str, + zarr_storage: zarr.DirectoryStore, + index_col: int, + chunk_size: int = 500_000, +) -> Tuple[dict, pd.Index]: + total_rows = sum(1 for _ in open(path)) - 1 + + z_embeddings = zarr.open_array( + store=zarr_storage, + mode='w', + shape=(total_rows, 64), + chunks=(16384, 64), + dtype='float32', + ) + + mapping = {} + current_row = 0 + + reader = pd.read_csv(path, index_col=index_col, chunksize=chunk_size) + + for chunk in tqdm( + reader, total=total_rows // chunk_size + 1, desc='Populating Zarr storage' + ): + num_rows = len(chunk) + + domains = chunk.index.to_numpy() + start = current_row + end = current_row + num_rows + + mapping.update({domain: i for domain, i in zip(domains, range(start, end))}) + + random_vecs = np.random.normal(0.0, 1.0, size=(num_rows, 64)).astype('float32') + + z_embeddings[start:end, :] = random_vecs + + current_row += num_rows + + return mapping, pd.RangeIndex(len(mapping)) + + +def extract_text_from_row(txt_list: List[str], k: int = 5) -> List[str]: + if not isinstance(txt_list, list) or len(txt_list) == 0: + return [''] + + if len(txt_list) <= k: + chosen = txt_list + else: + chosen = np.random.choice(txt_list, size=k, replace=False).tolist() + + return chosen diff --git a/uv.lock b/uv.lock index 0552f312..e107857f 100644 --- a/uv.lock +++ b/uv.lock @@ -141,6 +141,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, ] +[[package]] +name = "asciitree" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/6a/885bc91484e1aa8f618f6f0228d76d0e67000b0fdd6090673b777e311913/asciitree-0.3.3.tar.gz", hash = "sha256:4aa4b9b649f85e3fcb343363d97564aa1fb62e249677f2e18a96765145cc0f6e", size = 3951 } + [[package]] name = "asttokens" version = "3.0.0" @@ -554,6 +560,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fasteners" +version = "0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/18/7881a99ba5244bfc82f06017316ffe93217dbbbcfa52b887caa1d4f2a6d3/fasteners-0.20.tar.gz", hash = "sha256:55dce8792a41b56f727ba6e123fcaee77fd87e638a6863cec00007bfea84c8d8", size = 25087 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/ac/e5d886f892666d2d1e5cb8c1a41146e1d79ae8896477b1153a21711d3b44/fasteners-0.20-py3-none-any.whl", hash = "sha256:9422c40d1e350e4259f509fb2e608d6bc43c0136f79a00db1b49046029d0b3b7", size = 18702 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -1434,6 +1449,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, ] +[[package]] +name = "numcodecs" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/1b/1f1d880e29e719c7c6205065d1afbc91114c0d91935ac419faa43e5e08b0/numcodecs-0.12.1.tar.gz", hash = "sha256:05d91a433733e7eef268d7e80ec226a0232da244289614a8f3826901aec1098e", size = 4091415 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/66/08744c9007f1d02476dd97f3c23032f3555dbb8e9a32b0f0ea4724e6b2a2/numcodecs-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d37f628fe92b3699e65831d5733feca74d2e33b50ef29118ffd41c13c677210e", size = 1696843 }, + { url = "https://files.pythonhosted.org/packages/b8/6f/a04a33c5edb8fa9ba63783d34ff5768ba6b562ebe11078c07848e283f4ad/numcodecs-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:941b7446b68cf79f089bcfe92edaa3b154533dcbcd82474f994b28f2eedb1c60", size = 1422578 }, + { url = "https://files.pythonhosted.org/packages/1e/b8/1040f299803eacc9c522fdc69a4dafc42ad0e8722bb48aa43d2310cf195b/numcodecs-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e79bf9d1d37199ac00a60ff3adb64757523291d19d03116832e600cac391c51", size = 7709402 }, + { url = "https://files.pythonhosted.org/packages/8c/fa/da0637e1a6db74361a2875425021957859749166c0174ddedbb629518970/numcodecs-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:82d7107f80f9307235cb7e74719292d101c7ea1e393fe628817f0d635b7384f5", size = 790204 }, + { url = "https://files.pythonhosted.org/packages/10/63/a50f4113a2bb1decfaedeffc448c5f8b26ded1c583247c893120fcd25e3e/numcodecs-0.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eeaf42768910f1c6eebf6c1bb00160728e62c9343df9e2e315dc9fe12e3f6071", size = 1696786 }, + { url = "https://files.pythonhosted.org/packages/92/77/0fde34bf3a8402d696218a565230097d904c9eebb62cd952923b1155b7f7/numcodecs-0.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:135b2d47563f7b9dc5ee6ce3d1b81b0f1397f69309e909f1a35bb0f7c553d45e", size = 1422330 }, + { url = "https://files.pythonhosted.org/packages/14/e6/8f9d4a498a06f11a06297f0b02af9968844d2e40ee79d372ccee33595285/numcodecs-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a191a8e347ecd016e5c357f2bf41fbcb026f6ffe78fff50c77ab12e96701d155", size = 7949787 }, + { url = "https://files.pythonhosted.org/packages/08/f3/44597198c2cfb0d808d68583445b60b0d0ae057f20f0caf2a1200405655e/numcodecs-0.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:21d8267bd4313f4d16f5b6287731d4c8ebdab236038f29ad1b0e93c9b2ca64ee", size = 790313 }, + { url = "https://files.pythonhosted.org/packages/d7/b2/7842675a798e79686d14a20baa554b165aab86feac28f32695266ab42b7e/numcodecs-0.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2f84df6b8693206365a5b37c005bfa9d1be486122bde683a7b6446af4b75d862", size = 1697725 }, + { url = "https://files.pythonhosted.org/packages/fc/1f/e3b033181a28ce153fd0c9acd3ed978ee9c424de7cc3d8e97fc60647eddf/numcodecs-0.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:760627780a8b6afdb7f942f2a0ddaf4e31d3d7eea1d8498cf0fd3204a33c4618", size = 1423927 }, + { url = "https://files.pythonhosted.org/packages/3b/88/fb3186f944b9586e9c4c54bd1d1899947b88465ad3ab1ff1111066871644/numcodecs-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c258bd1d3dfa75a9b708540d23b2da43d63607f9df76dfa0309a7597d1de3b73", size = 7944856 }, + { url = "https://files.pythonhosted.org/packages/f4/03/54e22e273d584e83100ffa60c47c29cae905015ecb1f693918072c3595b9/numcodecs-0.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e04649ea504aff858dbe294631f098fbfd671baf58bfc04fc48d746554c05d67", size = 787000 }, + { url = "https://files.pythonhosted.org/packages/dd/3c/950f816b837fc7714102b45491e2612b10757106f9a8e3785d7b3806acd4/numcodecs-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2fbb12a6a1abe95926f25c65e283762d63a9bf9e43c0de2c6a1a798347dfcb40", size = 1700073 }, + { url = "https://files.pythonhosted.org/packages/76/2f/19f4f012f253ff33948a024e0a814c758ea137e3ba86118daac83a8d9123/numcodecs-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f2207871868b2464dc11c513965fd99b958a9d7cde2629be7b2dc84fdaab013b", size = 1425835 }, + { url = "https://files.pythonhosted.org/packages/6d/0f/0442e80d707b5dd2e177a9490c25b89aa6a6c44579de8ec223e78a8884da/numcodecs-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abff3554a6892a89aacf7b642a044e4535499edf07aeae2f2e6e8fc08c9ba07f", size = 7722207 }, + { url = "https://files.pythonhosted.org/packages/77/b6/345f8648874a81232bc1a87e55a771430488a832c68f873aa6ed23a1dedf/numcodecs-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef964d4860d3e6b38df0633caf3e51dc850a6293fd8e93240473642681d95136", size = 792870 }, +] + [[package]] name = "numpy" version = "2.0.2" @@ -2254,6 +2296,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/1e/c6a28a142f14e735088534cc92951c3f48cccd77cdd4f3b10d7996be420f/pyqt6_sip-13.10.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3dde8024d055f496eba7d44061c5a1ba4eb72fc95e5a9d7a0dbc908317e0888b", size = 303833 }, { url = "https://files.pythonhosted.org/packages/89/63/e5adf350c1c3123d4865c013f164c5265512fa79f09ad464fb2fdf9f9e61/pyqt6_sip-13.10.2-cp313-cp313-win_amd64.whl", hash = "sha256:0b097eb58b4df936c4a2a88a2f367c8bb5c20ff049a45a7917ad75d698e3b277", size = 53527 }, { url = "https://files.pythonhosted.org/packages/58/74/2df4195306d050fbf4963fb5636108a66e5afa6dc05fd9e81e51ec96c384/pyqt6_sip-13.10.2-cp313-cp313-win_arm64.whl", hash = "sha256:cc6a1dfdf324efaac6e7b890a608385205e652845c62130de919fd73a6326244", size = 45373 }, + { url = "https://files.pythonhosted.org/packages/23/57/74b4eb7a51b9133958daa8409b55de95e44feb694d4e2e3eba81a070ca20/pyqt6_sip-13.10.2-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8a76a06a8e5c5b1f17a3f6f3c834ca324877e07b960b18b8b9bbfd9c536ec658", size = 112354 }, + { url = "https://files.pythonhosted.org/packages/f2/cb/fdef02e0d6ee8443a9683a43650d61c6474b634b6ae6e1c6f097da6310bf/pyqt6_sip-13.10.2-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9128d770a611200529468397d710bc972f1dcfe12bfcbb09a3ccddcd4d54fa5b", size = 323488 }, + { url = "https://files.pythonhosted.org/packages/8c/5b/8ede8d6234c3ea884cbd097d7d47ff9910fb114efe041af62b4453acd23b/pyqt6_sip-13.10.2-cp314-cp314-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d820a0fae7315932c08f27dc0a7e33e0f50fe351001601a8eb9cf6f22b04562e", size = 303881 }, + { url = "https://files.pythonhosted.org/packages/be/44/b5e78b072d1594643b0f1ff348f2bf54d4adb5a3f9b9f0989c54e33238d6/pyqt6_sip-13.10.2-cp314-cp314-win_amd64.whl", hash = "sha256:3213bb6e102d3842a3bb7e59d5f6e55f176c80880ff0b39d0dac0cfe58313fb3", size = 55098 }, + { url = "https://files.pythonhosted.org/packages/e2/91/357e9fcef5d830c3d50503d35e0357818aca3540f78748cc214dfa015d00/pyqt6_sip-13.10.2-cp314-cp314-win_arm64.whl", hash = "sha256:ce33ff1f94960ad4b08035e39fa0c3c9a67070bec39ffe3e435c792721504726", size = 46088 }, { url = "https://files.pythonhosted.org/packages/d1/39/4693dfad856ee9613fbf325916d980a76d5823f4da87fed76f00b48ee8ee/pyqt6_sip-13.10.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38b5823dca93377f8a4efac3cbfaa1d20229aa5b640c31cf6ebbe5c586333808", size = 110676 }, { url = "https://files.pythonhosted.org/packages/f0/42/6f7c2006871b20cf3e5073e3ffaa0bede0f8e2f8ccc2105c02e8d523c7d7/pyqt6_sip-13.10.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5506b9a795098df3b023cc7d0a37f93d3224a9c040c43804d4bc06e0b2b742b0", size = 303064 }, { url = "https://files.pythonhosted.org/packages/00/1c/38068f79d583fc9c2992553445634171e8b0bee6682be22cb8d4d18e7da6/pyqt6_sip-13.10.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e455a181d45a28ee8d18d42243d4f470d269e6ccdee60f2546e6e71218e05bb4", size = 281774 }, @@ -2855,6 +2902,7 @@ dependencies = [ { name = "torcheval" }, { name = "ujson" }, { name = "warcio" }, + { name = "zarr" }, ] [package.dev-dependencies] @@ -2895,6 +2943,7 @@ requires-dist = [ { name = "torcheval", specifier = ">=0.0.7" }, { name = "ujson", specifier = ">=5.10.0" }, { name = "warcio", specifier = ">=1.7.5" }, + { name = "zarr", specifier = ">=2.18.2" }, ] [package.metadata.requires-dev] @@ -3422,6 +3471,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542 }, ] +[[package]] +name = "zarr" +version = "2.18.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asciitree" }, + { name = "fasteners", marker = "sys_platform != 'emscripten'" }, + { name = "numcodecs" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/00/ac5c518ff1c1b1cc87a62f86ad9d19c647c19d969a91faa40d3b6342ccaa/zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47", size = 3603055 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/bd/8d881d8ca6d80fcb8da2b2f94f8855384daf649499ddfba78ffd1ee2caa3/zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38", size = 210228 }, +] + [[package]] name = "zipp" version = "3.23.0"