Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"cloudscraper>=1.2.71",
"BeautifulSoup4>=4.13.0",
"jsonpath-ng>=1.7.0",
"zarr>=2.18.2",
]

[tool.flit.module]
Expand Down
200 changes: 200 additions & 0 deletions tgrag/dataset/zarr_rni_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import logging
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional

import numpy as np
import pandas as pd
import torch
import zarr
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, InMemoryDataset

from tgrag.utils.dataset_loading import (
construct_zarr_rni_backend,
load_large_edge_csv,
)
from tgrag.utils.load_labels import get_full_dict
from tgrag.utils.target_generation import generate_exact_targets_csv

EMB_DIM = 64


class ZarrDataset(InMemoryDataset):
def __init__(
self,
root: str,
database_folder: str,
node_file: str = 'features.csv',
edge_file: str = 'edges.csv',
target_file: str = 'target.csv',
target_col: str = 'score',
target_index_name: str = 'nid',
target_index_col: int = 0,
edge_src_col: str = 'src',
edge_dst_col: str = 'dst',
index_col: int = 1,
index_name: str = 'node_id',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
seed: int = 42,
processed_dir: Optional[str] = None,
):
self.node_file = node_file
self.edge_file = edge_file
self.target_file = target_file
self.target_col = target_col
self.edge_src_col = edge_src_col
self.edge_dst_col = edge_dst_col
self.index_col = index_col
self.index_name = index_name
self.target_index_name = target_index_name
self.target_index_col = target_index_col
self.seed = seed
self._custome_processed_dir = processed_dir
self.database_folder = database_folder
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

@property
def raw_dir(self) -> str:
return os.path.join(self.root)

@property
def raw_file_names(self) -> List[str]:
return [self.node_file, self.edge_file]

@property
def processed_dir(self) -> str:
if self._custome_processed_dir is not None:
return self._custome_processed_dir
return super().processed_dir

@property
def processed_file_names(self) -> List[str]:
return ['data.pt']

def download(self) -> None:
pass

def process(self) -> None:
node_path = os.path.join(self.raw_dir, self.node_file)
edge_path = os.path.join(self.raw_dir, self.edge_file)
target_path = os.path.join(self.raw_dir, self.target_file)
if os.path.exists(target_path):
logging.info('Target file already exists.')
else:
logging.info('Generating target file.')
dqr = get_full_dict()
generate_exact_targets_csv(node_path, target_path, dqr)

# IF we use domains as our unique ID then we can construct the zarr database in a seperate script

zarr_path = Path(self.database_folder) / 'embeddings.zarr'
zarr_path.parent.mkdir(parents=True, exist_ok=True)
zarr_storage = zarr.DirectoryStore(str(zarr_path))

logging.info('***Constructing Mapping***')
mapping, full_index = construct_zarr_rni_backend(
path=node_path,
zarr_storage=zarr_storage,
index_col=0,
)

df_target = pd.read_csv(target_path)
logging.info(f'Size of target dataframe: {df_target.shape}')

mapping_index = [mapping[domain.strip()] for domain in df_target['domain']]
df_target.index = mapping_index
logging.info(f'Size of mapped target dataframe: {df_target.shape}')

missing_idx = full_index.difference(mapping_index)
filler = pd.DataFrame(
{col: np.nan for col in df_target.columns}, index=missing_idx
)
df_target = pd.concat([df_target, filler])
df_target.sort_index(inplace=True)
logging.info(f'Size of filled target dataframe: {df_target.shape}')
score = torch.tensor(
df_target[self.target_col].astype('float32').fillna(-1).values,
dtype=torch.float,
)
logging.info(f'Size of score vector: {score.size()}')

labeled_mask = score != -1.0

labeled_idx = torch.nonzero(torch.tensor(labeled_mask), as_tuple=True)[0]
labeled_scores = score[labeled_idx].squeeze().numpy()

if labeled_scores.size == 0:
raise ValueError(
f"No labeled nodes found in target column '{self.target_col}'"
)

logging.info('***Constructing Edge Matrix***')
edge_index, edge_attr = load_large_edge_csv(
path=edge_path,
src_index_col=self.edge_src_col,
dst_index_col=self.edge_dst_col,
mapping=mapping,
encoders=None,
)
logging.info('***Edge Matrix Constructed***')

data = Data(y=score, edge_index=edge_index, edge_attr=edge_attr)
data.num_nodes = len(mapping)

data.labeled_mask = labeled_mask.detach().clone().bool()

quantiles = np.quantile(labeled_scores, [1 / 3, 2 / 3])
quartile_labels = np.digitize(labeled_scores, bins=quantiles)

train_idx, temp_idx, _, quartile_labels_temp = train_test_split(
labeled_idx,
quartile_labels,
train_size=0.6,
stratify=quartile_labels,
random_state=self.seed,
)

valid_idx, test_idx = train_test_split(
temp_idx,
train_size=0.5,
stratify=quartile_labels_temp,
random_state=self.seed,
)

train_idx = torch.as_tensor(train_idx)
logging.info(f'Train size: {train_idx.size()}')
valid_idx = torch.as_tensor(valid_idx)
logging.info(f'Valid size: {valid_idx.size()}')
test_idx = torch.as_tensor(test_idx)
logging.info(f'Test size: {test_idx.size()}')

# Set global indices for our transductive nodes:
num_nodes = data.num_nodes
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[train_idx] = True
data.valid_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.valid_mask[valid_idx] = True
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask[test_idx] = True
data.idx_dict = {
'train': train_idx,
'valid': valid_idx,
'test': test_idx,
}

torch.save(mapping, self.processed_dir + '/mapping.pt')
torch.save(self.collate([data]), self.processed_paths[0])

def get_idx_split(self) -> Dict:
data = self[0]
if hasattr(data, 'idx_dict') and data.idx_dict is not None:
return data.idx_dict
raise TypeError('idx split is empty.')

def get_mapping(self) -> Dict:
if not hasattr(self, '_mapping'):
self._mapping = torch.load(self.processed_dir + '/mapping.pt')
return self._mapping
8 changes: 4 additions & 4 deletions tgrag/encoders/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
class TextEncoder(Encoder):
def __init__(
self,
model_name: str = 'Qwen/Qwen3-Embedding-0.6B',
model_name: str = 'Qwen/Qwen3-Embedding-0.6B', # TODO: Too small; 4B/8B> paramaters size
device: str | None = None,
max_seq_length: int = 256,
max_seq_length: int = 1024,
batch_size: int = 128,
use_fp16: bool = True,
):
self.device = device
self.model = SentenceTransformer(model_name, device=device)
self.model.max_seq_length = max_seq_length

if use_fp16 and self.device is not None and 'cuda' in self.device:
self.model = self.model.half()
# if use_fp16 and self.device is not None and "cuda" in self.device:
# self.model = self.model.half()

self.batch_size = batch_size

Expand Down
Loading