diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 08a2d5fc19..4f7a4fa948 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -38,18 +38,34 @@ To support this, we need to know: #### Dataset -By default, NeMo RL has support for [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py) and [DeepScaler](../../nemo_rl/data/datasets/response_datasets/deepscaler.py) datasets. Both of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. +By default, NeMo RL has some built-in supported datasets (e.g., [OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [Squad](../../nemo_rl/data/datasets/response_datasets/squad.py), etc.), you can see the full list [here](../../nemo_rl/data/datasets/response_datasets/__init__.py). +All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. We provide a [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py) class that is compatible with JSONL-formatted response datasets for loading datasets from local path or Hugging Face. You can use `input_key`, `output_key` to specify which fields in your data correspond to the question and answer respectively. Here's an example configuration: ```yaml data: - dataset_name: ResponseDataset - train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - val_data_path: - input_key: , default is "input" - output_key: , default is "output" - train_split: , default is None # used for HuggingFace datasets - val_split: , default is None # used for HuggingFace datasets + train: + dataset_name: ResponseDataset + data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + input_key: , default is "input" + output_key: , default is "output" + split: , default is None # used for HuggingFace datasets + split_validation_size: 0.05 # use 5% of the training data as validation data + validation: + dataset_name: ResponseDataset + data_path: + input_key: , default is "input" + output_key: , default is "output" + split: , default is None # used for HuggingFace datasets +``` + +We support using a single dataset for both train and validation by using `split_validation_size` to set the ratio of validation. +[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature. +If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py). +```python +# `self.val_dataset` is used (not None) only when current dataset is used for both training and validation +self.val_dataset = None +self.split_train_validation(split_validation_size, seed) ``` #### Common Data Format @@ -99,21 +115,15 @@ We have an example of this as `math_data_processor` in [processors.py](../../nem Example (simplified): ```python +# task_spec default_task_spec = TaskDataSpec( task_name="math_default", prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], ) -task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = defaultdict( - lambda: (default_task_spec, math_hf_data_processor) -) - -# Resolve task_name from dataset or spec -task_spec = data.task_spec -task_name = data.task_name -assert hasattr(data, "processor"), "Dataset must have a processor attribute" -task_data_processors[task_name] = (task_spec, data.processor) +# task_data_processors +task_data_processors = {data.task_name: (data.task_spec, data.processor)} ``` #### Putting It All Together @@ -139,39 +149,34 @@ default_task_spec = TaskDataSpec( system_prompt_file=data_config["system_prompt_file"], ) -# 3) Define default processor mapping -task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = defaultdict( - lambda: (default_task_spec, math_hf_data_processor) -) +# 3) Load dataset using the helper (built-ins or local/HF datasets) +data = load_response_dataset(data_config["train"], seed) -# 4) Load dataset using the helper (built-ins or local/HF datasets) -data = load_response_dataset(data_config, seed) +# 4) Build task_data_processors mapping +task_data_processors = {data.task_name: (data.task_spec, data.processor)} -# 5) Resolve task spec/name and ensure dataset provides a processor -task_spec = data.task_spec -task_name = data.task_name -assert hasattr(data, "processor"), "Dataset must have a processor attribute" -task_data_processors[task_name] = (task_spec, data.processor) - -# 6) Construct processed datasets (train and optional validation) +# 5) Construct processed dataset dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + data.dataset, tokenizer, default_task_spec, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) -val_dataset = ( - AllTaskProcessedDataset( - data.formatted_ds["validation"], + +# 6) Do the same thing for validation dataset if it exists +if data_config["validation"] is not None: + val_data = load_response_dataset(data_config["validation"], seed) + + val_task_data_processors = {val_data.task_name: (val_data.task_spec, val_data.processor)} + + val_dataset = AllTaskProcessedDataset( + val_data.dataset, tokenizer, default_task_spec, - task_data_processors, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - if data.formatted_ds["validation"] - else None -) ``` Ensure you provide a mapping of tasks to their processors so the dataset knows which processor to use when handling samples. diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 726ab45933..bd59657b39 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -37,7 +37,7 @@ SFT datasets in NeMo RL are encapsulated using classes. Each SFT data class is e SFT datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design-docs/chat-datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [response_datasets/squad.py](../../nemo_rl/data/datasets/response_datasets/squad.py) has an example: ```python -def format_squad(data): +def format_data(self, data: dict[str, Any]) -> dict[str, Any]: return { "messages": [ { @@ -71,18 +71,34 @@ NeMo RL SFT uses HuggingFace chat templates to format the individual examples. T custom_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" ``` -By default, NeMo RL has support for [OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [Squad](../../nemo_rl/data/datasets/response_datasets/squad.py) and [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py) datasets. All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. +By default, NeMo RL has some built-in supported datasets (e.g., [OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [Squad](../../nemo_rl/data/datasets/response_datasets/squad.py), etc.), you can see the full list [here](../../nemo_rl/data/datasets/response_datasets/__init__.py). +All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. We provide a [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py) class that is compatible with jsonl-formatted response datasets for loading datasets from local path or HuggingFace. You can use `input_key`, `output_key` to specify which fields in your data correspond to the question and answer respectively. Here's an example configuration: ```yaml data: - dataset_name: ResponseDataset - train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - val_data_path: - input_key: , default is "input" - output_key: , default is "output" - train_split: , default is None # used for HuggingFace datasets - val_split: , default is None # used for HuggingFace datasets + train: + dataset_name: ResponseDataset + data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + input_key: , default is "input" + output_key: , default is "output" + split: , default is None # used for HuggingFace datasets + split_validation_size: 0.05 # use 5% of the training data as validation data + validation: + dataset_name: ResponseDataset + data_path: + input_key: , default is "input" + output_key: , default is "output" + split: , default is None # used for HuggingFace datasets +``` + +We support using a single dataset for both train and validation by using `split_validation_size` to set the ratio of validation. +[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature. +If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py). +```python +# `self.val_dataset` is used (not None) only when current dataset is used for both training and validation +self.val_dataset = None +self.split_train_validation(split_validation_size, seed) ``` ### OpenAI Format Datasets (with Tool Calling Support) @@ -95,14 +111,16 @@ To use an OpenAI format dataset, configure your YAML as follows: ```yaml data: - dataset_name: openai_format - train_data_path: "/path/to/train.jsonl" # Path to training data - val_data_path: "/path/to/val.jsonl" # Path to validation data - chat_key: "messages" # Key for messages in the data (default: "messages") - system_key: null # Key for system message in the data (optional) - system_prompt: null # Default system prompt if not in data (optional) - tool_key: "tools" # Key for tools in the data (default: "tools") - use_preserving_dataset: false # Set to true for heterogeneous tool schemas (see below) + train: + dataset_name: openai_format + data_path: # Path to training data + chat_key: "messages" # Key for messages in the data (default: "messages") + system_key: null # Key for system message in the data (optional) + system_prompt: null # Default system prompt if not in data (optional) + tool_key: "tools" # Key for tools in the data (default: "tools") + use_preserving_dataset: false # Set to true for heterogeneous tool schemas (see below) + validation: + ... ``` #### Data Format diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index 62937754f1..d8b0610b27 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -206,11 +206,20 @@ teacher: data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: null - dataset_name: "DeepScaler" shuffle: true + # dataset + train: + dataset_name: DeepScaler + validation: + dataset_name: AIME2024 + repeat: 16 + # default settings for all datasets + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + env_name: "math" + env: math: num_workers: 8 @@ -225,12 +234,12 @@ logger: monitor_gpus: true wandb: project: "nemo-distillation" - name: "distillation-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + name: "distillation-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" swanlab: project: "nemo-distillation" - name: "distillation-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + name: "distillation-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" tensorboard: - log_dir: "tb_logs-distillation-${data.dataset_name}" + log_dir: "tb_logs-distillation-${data.train.dataset_name}" mlflow: experiment_name: "distillation-dev" run_name: "distillation-math-cl-logger" diff --git a/examples/configs/distillation_math_megatron.yaml b/examples/configs/distillation_math_megatron.yaml index 644d240a7b..c720818a93 100644 --- a/examples/configs/distillation_math_megatron.yaml +++ b/examples/configs/distillation_math_megatron.yaml @@ -147,11 +147,11 @@ logger: wandb_enabled: true wandb: project: "nemo-distillation" - name: "distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + name: "distillation-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" tensorboard: - log_dir: "tb_logs-distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + log_dir: "tb_logs-distillation-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" mlflow: - run_name: "distillation-math-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + run_name: "distillation-math-megatron-${data.train.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" cluster: gpus_per_node: 8 diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 1dd9639472..970e56275d 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -246,22 +246,33 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: null shuffle: true num_workers: 1 - processor: "math_hf_data_processor" - env_name: "math" - dataset_name: "OpenMathInstruct-2" + + # dataset + train: + dataset_name: OpenMathInstruct-2 + split_validation_size: 0.05 # use 5% of the training data as validation data + validation: null + # default settings for all datasets + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + processor: "math_hf_data_processor" + env_name: "math" # You can use custom response datasets for training and validation. For example: - # data: + # train: + # dataset_name: ResponseDataset + # data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + # input_key: , default is "input" + # output_key: , default is "output" + # split: , default is None # used for HuggingFace datasets + # validation: # dataset_name: ResponseDataset - # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - # val_data_path: + # data_path: # input_key: , default is "input" # output_key: , default is "output" - # train_split: , default is None # used for HuggingFace datasets - # val_split: , default is None # used for HuggingFace datasets + # split: , default is None # used for HuggingFace datasets # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details. env: diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 95d85f74c7..fdeee8a4c6 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -157,13 +157,6 @@ policy: gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} -data: - max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: null - dataset_name: "OpenMathInstruct-2" - shuffle: true - env: math: num_workers: 8 diff --git a/examples/configs/grpo_multiple_datasets.yaml b/examples/configs/grpo_multiple_datasets.yaml new file mode 100644 index 0000000000..704cb1b18b --- /dev/null +++ b/examples/configs/grpo_multiple_datasets.yaml @@ -0,0 +1,26 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_1B.yaml" + +data: + _override_: true # override the data config instead of merging with it + + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + shuffle: true + num_workers: 1 + + # dataset + train: + - dataset_name: OpenMathInstruct-2 + split_validation_size: 0.05 + - dataset_name: DeepScaler + validation: + - dataset_name: AIME2024 + repeat: 16 + - dataset_name: DAPOMathAIME2024 + + # default settings for all datasets + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + processor: "math_hf_data_processor" + env_name: "math" diff --git a/examples/configs/grpo_rm_1B.yaml b/examples/configs/grpo_rm_1B.yaml index b0a709b253..61e6204b9a 100644 --- a/examples/configs/grpo_rm_1B.yaml +++ b/examples/configs/grpo_rm_1B.yaml @@ -2,7 +2,8 @@ defaults: "grpo_math_1B.yaml" data: - env_name: "reward_model" + default: + env_name: "reward_model" env: reward_model: diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 54e03ae524..edfc1096d1 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -77,4 +77,4 @@ logger: run_name: "grpo-dev-sliding_puzzle" gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) \ No newline at end of file + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) diff --git a/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml b/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml index 29ee217517..ec2705be5e 100644 --- a/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml +++ b/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml @@ -82,8 +82,12 @@ policy: enforce_eager: true data: max_input_seq_length: 2048 - prompt_file: null - dataset_name: DAPOMath17K + train: + dataset_name: DAPOMath17K + validation: + dataset_name: DAPOMathAIME2024 + default: + prompt_file: null env: math: num_workers: 16 diff --git a/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml b/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml index 6e00ecd37c..f9c54d76f1 100644 --- a/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml @@ -39,8 +39,12 @@ policy: async_engine: true tensor_parallel_size: 32 data: - prompt_file: null - dataset_name: DAPOMath17K + train: + dataset_name: DAPOMath17K + validation: + dataset_name: DAPOMathAIME2024 + default: + prompt_file: null logger: monitor_gpus: false wandb: diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml index 584b807663..ca29b07aac 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml @@ -28,7 +28,11 @@ policy: compilation_config: use_inductor: false data: - dataset_name: DeepScaler + train: + dataset_name: DeepScaler + validation: + dataset_name: AIME2024 + repeat: 16 env: math: num_workers: 16 diff --git a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml index d5525fc027..e98d7d4680 100644 --- a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml @@ -30,7 +30,11 @@ policy: vllm_cfg: enforce_eager: true data: - dataset_name: DeepScaler + train: + dataset_name: DeepScaler + validation: + dataset_name: AIME2024 + repeat: 16 env: math: num_workers: 16 diff --git a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.yaml.disabled b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.yaml.disabled index b1f65495fa..f442856807 100644 --- a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.yaml.disabled +++ b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.yaml.disabled @@ -44,11 +44,16 @@ policy: data: # Training with HelpSteer3 will lead to high logprob error. # ISSUE: https://github.com/NVIDIA-NeMo/RL/issues/1570 - prompt_file: null - dataset_name: HelpSteer3 - split: preference - env_name: "code_jaccard" - processor: helpsteer3_data_processor + train: + dataset_name: HelpSteer3 + split: train + validation: + dataset_name: HelpSteer3 + split: validation + default: + prompt_file: null + env_name: "code_jaccard" + processor: helpsteer3_data_processor env: code_jaccard: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml b/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml index 78b4597c2c..69ff4a4229 100644 --- a/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml @@ -37,8 +37,12 @@ policy: use_deep_gemm: true data: max_input_seq_length: 2048 - prompt_file: null - dataset_name: DAPOMath17K + train: + dataset_name: DAPOMath17K + validation: + dataset_name: DAPOMathAIME2024 + default: + prompt_file: null env: dapo: num_workers: 16 diff --git a/examples/configs/recipes/llm/sft-llama3.1-70b-8n8g-tp4pp2-long-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-70b-8n8g-tp4pp2-long-megatron.yaml index aa009da464..e46a920997 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-70b-8n8g-tp4pp2-long-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-70b-8n8g-tp4pp2-long-megatron.yaml @@ -43,12 +43,16 @@ policy: weight_decay: 0.01 eps: 1.0e-08 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: monitor_gpus: false wandb: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml index 88d446283d..90de698675 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml @@ -28,18 +28,22 @@ policy: weight_decay: 0.01 eps: 1.0e-08 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long wandb: project: nemo-rl name: sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long tensorboard: - log_dir: tb_logs-sft-dev-squad + log_dir: tb_logs-sft-dev-openmathinstruct2 cluster: gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml index 86db9da5e0..535f9c8bda 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml @@ -24,18 +24,22 @@ policy: weight_decay: 0.01 eps: 1.0e-08 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long wandb: project: nemo-rl name: sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long tensorboard: - log_dir: tb_logs-sft-dev-squad + log_dir: tb_logs-sft-dev-openmathinstruct2 cluster: gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.yaml index 784e4a02d5..4a67b3581d 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.yaml @@ -26,9 +26,12 @@ policy: weight_decay: 0.01 eps: 1.0e-08 data: - dataset_name: tulu3_sft_mixture add_generation_prompt: true - seed: 42 + train: + dataset_name: tulu3_sft_mixture + seed: 42 + split_validation_size: 0.05 + validation: null logger: log_dir: logs/sft-tmblog-llama3.1-8b tensorboard_enabled: false diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.yaml index 31b7538c1c..f2b9ca3ba3 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.yaml @@ -22,12 +22,16 @@ policy: weight_decay: 0.01 eps: 1.0e-08 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2 wandb: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.yaml index 3afca7ba02..8af89e6c46 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.yaml @@ -30,12 +30,16 @@ policy: scheduler: lr_warmup_init: 1.9999e-65 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.1-8b-1n8g-megatron wandb: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron.yaml index 2c08bef6f6..5b1d3166d9 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-megatron.yaml @@ -28,12 +28,16 @@ policy: scheduler: lr_warmup_init: 1.9999e-65 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.1-8b-1n8g-megatron wandb: diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index 77ff8aac89..69e3e6a7d5 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -9,12 +9,16 @@ policy: name: meta-llama/Llama-3.2-1B make_sequence_length_divisible_by: 1 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution - seed: 42 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + seed: 42 + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-llama3.2-1b-1n8g-fsdp2tp1 wandb: diff --git a/examples/configs/recipes/llm/sft-nemotron-super-49b-8n8g-fsdp2tp4cp8-tulu-v3.yaml.disabled b/examples/configs/recipes/llm/sft-nemotron-super-49b-8n8g-fsdp2tp4cp8-tulu-v3.yaml.disabled index d224a6d51f..9cc94d8574 100644 --- a/examples/configs/recipes/llm/sft-nemotron-super-49b-8n8g-fsdp2tp4cp8-tulu-v3.yaml.disabled +++ b/examples/configs/recipes/llm/sft-nemotron-super-49b-8n8g-fsdp2tp4cp8-tulu-v3.yaml.disabled @@ -44,9 +44,11 @@ policy: - milestones: - 10 data: - dataset_name: tulu3_sft_mixture num_workers: 20 - test_size: 0.05 + train: + dataset_name: tulu3_sft_mixture + split_validation_size: 0.05 + validation: null logger: tensorboard_enabled: false monitor_gpus: false diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml index c94683c61f..e6d6d184fd 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -15,11 +15,15 @@ policy: tensor_parallel_size: 8 make_sequence_length_divisible_by: 8 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: log_dir: logs/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt wandb: diff --git a/examples/configs/recipes/llm/sft-qwen2.5-math7b-2n8g-megatron.yaml b/examples/configs/recipes/llm/sft-qwen2.5-math7b-2n8g-megatron.yaml index 299e426084..95ff4375d9 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-math7b-2n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-math7b-2n8g-megatron.yaml @@ -33,12 +33,16 @@ policy: enabled: true make_sequence_length_divisible_by: 32 data: - dataset_name: openmathinstruct2 - prompt_file: examples/prompts/math.txt - split: train_1M add_generation_prompt: true - output_key: generated_solution num_workers: 8 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + split_validation_size: 0.05 + validation: null + default: + prompt_file: examples/prompts/math.txt logger: wandb: project: nemo-rl diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 4a0625895e..42b85bad80 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -165,24 +165,33 @@ data: shuffle: true num_workers: 1 - dataset_name: "squad" + # dataset + train: + dataset_name: "squad" + split: "train" + validation: + dataset_name: "squad" + split: "validation" + # default settings for all datasets + default: + prompt_file: null + system_prompt_file: null + processor: "sft_processor" # You can use custom response datasets for training and validation. For example: - # data: + # train: # dataset_name: ResponseDataset - # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - # val_data_path: + # data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) # input_key: , default is "input" # output_key: , default is "output" - # train_split: , default is None # used for HuggingFace datasets - # val_split: , default is None # used for HuggingFace datasets + # split: , default is None # used for HuggingFace datasets + # validation: + # dataset_name: ResponseDataset + # data_path: + # input_key: , default is "input" + # output_key: , default is "output" + # split: , default is None # used for HuggingFace datasets # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. - ## unused with squad dataset - prompt_file: null - split: null - output_key: null - seed: null - ## OpenAI format specific configs # train_data_path: "/path/to/train.jsonl" # Path to training data @@ -202,15 +211,15 @@ logger: monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: project: "sft-dev" - name: "sft-dev-${data.dataset_name}" + name: "sft-dev-${data.train.dataset_name}" swanlab: project: "sft-dev" - name: "sft-dev-${data.dataset_name}" + name: "sft-dev-${data.train.dataset_name}" tensorboard: - log_dir: "tb_logs-sft-dev-${data.dataset_name}" + log_dir: "tb_logs-sft-dev-${data.train.dataset_name}" mlflow: experiment_name: "sft-dev" - run_name: "sft-dev-${data.dataset_name}" + run_name: "sft-dev-${data.train.dataset_name}" gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 25368f7df5..1f35e62fdb 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -69,15 +69,22 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "openmathinstruct2" - prompt_file: examples/prompts/math.txt - split: "train_1M" add_bos: true add_eos: true add_generation_prompt: true - output_key: 'generated_solution' shuffle: true + # dataset + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + split_validation_size: 0.05 # use 5% of the training data as validation data + validation: null + # default settings for all datasets + default: + prompt_file: examples/prompts/math.txt + logger: log_dir: "logs" # Base directory for all logs wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running diff --git a/examples/configs/sft_openmathinstruct2_megatron.yaml b/examples/configs/sft_openmathinstruct2_megatron.yaml index b0f94fff6d..fc44396026 100644 --- a/examples/configs/sft_openmathinstruct2_megatron.yaml +++ b/examples/configs/sft_openmathinstruct2_megatron.yaml @@ -125,14 +125,6 @@ policy: optimizer: null data: - max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "openmathinstruct2" - prompt_file: examples/prompts/math.txt - split: "train_1M" - add_bos: true - add_eos: true - add_generation_prompt: true - output_key: 'generated_solution' num_workers: 1 logger: diff --git a/examples/configs/sft_vlm_3B.yaml b/examples/configs/sft_vlm_3B.yaml index 5615e2f99d..b67a0d2087 100644 --- a/examples/configs/sft_vlm_3B.yaml +++ b/examples/configs/sft_vlm_3B.yaml @@ -23,12 +23,20 @@ checkpointing: data: max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "clevr_cogent" add_bos: true add_eos: true add_generation_prompt: false - split: trainA - prompt_file: null + + # dataset + train: + dataset_name: clevr-cogent + split: train + validation: + dataset_name: clevr-cogent + split: valA + # default settings for all datasets + default: + prompt_file: null logger: log_dir: "logs" # Base directory for all logs @@ -37,9 +45,9 @@ logger: monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: project: "sft-dev" - name: "sft-dev-${data.dataset_name}" + name: "sft-dev-${data.train.dataset_name}" tensorboard: - log_dir: "tb_logs-sft-dev-${data.dataset_name}" + log_dir: "tb_logs-sft-dev-${data.train.dataset_name}" gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 47233d87db..774313e684 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -228,14 +228,23 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: "examples/prompts/clevr_cogent_cot.txt" - system_prompt_file: null - dataset_name: "clevr-cogent" - env_name: "clevr-cogent" - split: "trainA" shuffle: true num_workers: 1 + # dataset + train: + dataset_name: clevr-cogent + split: train + validation: + dataset_name: clevr-cogent + split: valA + # default settings for all datasets + default: + prompt_file: examples/prompts/clevr_cogent_cot.txt + system_prompt_file: null + processor: "vlm_hf_data_processor" + env_name: "clevr-cogent" + env: clevr-cogent: num_workers: 8 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 64f8ea158d..54cabe8103 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -180,13 +180,21 @@ policy: data_parallel_sharding_strategy: optim_grads_params data: max_input_seq_length: ${policy.max_total_sequence_length} - prompt_file: examples/prompts/clevr_cogent_cot.txt - system_prompt_file: null - dataset_name: clevr-cogent - env_name: "clevr-cogent" - split: trainA shuffle: true num_workers: 1 + # dataset + train: + dataset_name: clevr-cogent + split: train + validation: + dataset_name: clevr-cogent + split: valA + # default settings for all datasets + default: + prompt_file: examples/prompts/clevr_cogent_cot.txt + system_prompt_file: null + processor: "vlm_hf_data_processor" + env_name: "clevr-cogent" env: clevr-cogent: num_workers: 8 diff --git a/examples/nemo_gym/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml b/examples/nemo_gym/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml index d6d550a12c..88c56e4b42 100644 --- a/examples/nemo_gym/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml +++ b/examples/nemo_gym/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml @@ -65,6 +65,7 @@ policy: max_total_sequence_length: 32768 precision: "bfloat16" logprob_chunk_size: 1024 + offload_optimizer_for_logprob: false dtensor_cfg: _v2: false @@ -210,6 +211,7 @@ policy: num_first_layers_in_bf16: 0 expose_http_server: true skip_tokenizer_init: false + kv_cache_dtype: ${policy.precision} http_server_serving_chat_kwargs: # This is the tool parser for Qwen 3 4B Instruct. This needs to be changed for other models. enable_auto_tools: true @@ -232,10 +234,21 @@ policy: num_nodes: null # Decides number of nodes to be dedicated to generation data: - train_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/bytedtsinghua_dapo17k/train.jsonl - validation_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/bytedtsinghua_dapo17k/validation.jsonl + max_input_seq_length: ${policy.max_total_sequence_length} shuffle: true num_workers: 0 + train: + dataset_name: NemoGymDataset + data_path: 3rdparty/Gym-workspace/Gym/data/train.jsonl + repeat: 1 + validation: + dataset_name: NemoGymDataset + data_path: 3rdparty/Gym-workspace/Gym/data/validation.jsonl + default: + env_name: "nemo_gym" + prompt_file: null + system_prompt_file: null + processor: "nemo_gym_data_processor" env: should_use_nemo_gym: true @@ -243,10 +256,10 @@ env: nemo_gym: # This is passed into NeMo-Gym as the initial_global_config_dict config_paths: - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml # Required! And it must be *for_training - - resources_servers/library_judge_math/configs/library_judge_math.yaml - library_judge_math: + - resources_servers/math_with_judge/configs/math_with_judge.yaml + math_with_judge: resources_servers: - library_judge_math: + math_with_judge: judge_model_server: name: policy_model should_use_judge: false diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index c8d2c911e2..c2f47c13a8 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -13,11 +13,9 @@ # limitations under the License. import argparse -import json import os import pprint -from itertools import chain, repeat -from typing import Optional +from typing import Dict, Optional # Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB) import wandb.util @@ -25,6 +23,7 @@ wandb.util.VALUE_BYTES_LIMIT = 10_000_000 import ray +from datasets import concatenate_datasets from omegaconf import OmegaConf from wandb import Table @@ -42,18 +41,18 @@ setup, ) from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import DatumSpec -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.nemo_gym import ( - NemoGym, NemoGymConfig, - nemo_gym_example_to_nemo_rl_datum_spec, setup_nemo_gym_config, ) +from nemo_rl.environments.utils import create_env from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides @@ -75,38 +74,97 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: return args, overrides -def setup_single_nemo_gym_dataset( - jsonl_fpath: str, tokenizer, num_repeats: Optional[int] = None -): - with open(jsonl_fpath) as f: - nemo_gym_examples = list(map(json.loads, f)) - - print(f"Loaded data at {jsonl_fpath}. Found {len(nemo_gym_examples)} examples") - - if num_repeats: - previous_length = len(nemo_gym_examples) - nemo_gym_examples = list( - chain.from_iterable( - repeat(nemo_gym_example, num_repeats) - for nemo_gym_example in nemo_gym_examples - ) - ) - print( - f"Repeating examples (in a pattern of abc to aabbcc) for {jsonl_fpath} from {previous_length} to {len(nemo_gym_examples)}!" - ) - - nemo_rl_compatible_examples: list[DatumSpec] = [ - nemo_gym_example_to_nemo_rl_datum_spec(nemo_gym_example, idx) - for idx, nemo_gym_example in enumerate(nemo_gym_examples) - ] - - passthrough_task_processor = lambda datum_dict, *args, **kwargs: datum_dict - return AllTaskProcessedDataset( - nemo_rl_compatible_examples, +def setup_data( + tokenizer: TokenizerType, + data_config: Dict, + env_configs: Dict, + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) + for env_name in env_name_list + if env_name != "nemo_gym" + } + print("\n▶ Setting up data...") + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + for cfg in data_config["train"]: + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + # Skip binding nemo_gym env to task_to_env, nemo_gym env need to initialize policy first + if cfg["env_name"] != "nemo_gym": + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) + dataset = AllTaskProcessedDataset( + merged_data, tokenizer, None, - passthrough_task_processor, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], ) + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + if task_name in task_to_env: + val_task_to_env[task_name] = task_to_env[task_name] + + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + if cfg["env_name"] != "nemo_gym": + val_task_to_env[task_name] = envs[cfg["env_name"]] + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) + val_dataset = AllTaskProcessedDataset( + merged_val_data, + tokenizer, + None, + val_task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") + + return dataset, val_dataset, task_to_env, val_task_to_env # These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail. @@ -202,13 +260,11 @@ def main() -> None: assert _should_use_nemo_gym(config) print("\n▶ Setting up data...") - train_dataset = setup_single_nemo_gym_dataset( - jsonl_fpath=config["data"]["train_jsonl_fpath"], - tokenizer=tokenizer, - ) - val_dataset = setup_single_nemo_gym_dataset( - jsonl_fpath=config["data"]["validation_jsonl_fpath"], + train_dataset, val_dataset, task_to_env, val_task_to_env = setup_data( tokenizer=tokenizer, + data_config=config["data"], + env_configs=config["env"], + seed=config["grpo"]["seed"], ) # Validation dataset config setup. @@ -254,17 +310,12 @@ def main() -> None: base_urls=policy_generation.dp_openai_server_base_urls, initial_global_config_dict=config["env"]["nemo_gym"], ) - nemo_gym = NemoGym.options( - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.nemo_gym.NemoGym" - ), - } - ).remote(nemo_gym_config) + # Default nemo_gym env is used for trajectory collection + nemo_gym = create_env(env_name="nemo_gym", env_config=nemo_gym_config) # Blocking wait for NeMo-Gym to spin up ray.get(nemo_gym.health_check.remote()) - task_to_env = {"nemo_gym": nemo_gym} - val_task_to_env = task_to_env + task_to_env["nemo_gym"] = nemo_gym + val_task_to_env["nemo_gym"] = nemo_gym if is_trajectory_collection: collect_trajectories( diff --git a/examples/run_distillation_math.py b/examples/run_distillation_math.py index 51fc4b4283..237b5ccd3f 100644 --- a/examples/run_distillation_math.py +++ b/examples/run_distillation_math.py @@ -14,27 +14,24 @@ import argparse import os -from collections import defaultdict from typing import Any, Optional +from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.distillation import MasterConfig, distillation_train, setup from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import ( - TaskDataProcessFnCallable, - TaskDataSpec, -) -from nemo_rl.data.processors import math_hf_data_processor -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.environments.utils import create_env from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -74,57 +71,87 @@ def setup_data( dict[str, EnvironmentInterface], dict[str, EnvironmentInterface], ]: - print("\n▶ Setting up data...") - math_task_spec = TaskDataSpec( - task_name="math", - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - - # load dataset - data: Any = load_response_dataset(data_config, seed) - task_name = ( - data.task_name if hasattr(data, "task_name") else data.task_spec.task_name - ) - # data processor - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (math_task_spec, math_hf_data_processor)) - ) - task_data_processors[task_name] = (math_task_spec, math_hf_data_processor) - - # setup math environment - math_env = MathEnvironment.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.math_environment.MathEnvironment" - ), - "env_vars": dict(os.environ), # Pass thru all user environment variables - } - ).remote(env_configs["math"]) + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) + for env_name in env_name_list + } + print("\n▶ Setting up data...") + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + merged_data, tokenizer, - math_task_spec, + None, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + val_task_to_env[task_name] = task_to_env[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors and task_to_env + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + val_task_to_env[task_name] = envs[cfg["env_name"]] + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], + merged_val_data, tokenizer, - math_task_spec, - task_data_processors, + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - else: - val_dataset = None + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) - task_to_env[task_name] = math_env - return dataset, val_dataset, task_to_env, task_to_env + return dataset, val_dataset, task_to_env, val_task_to_env def main() -> None: diff --git a/examples/run_grpo.py b/examples/run_grpo.py index cd9d47f628..40a32fa484 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -15,21 +15,21 @@ import argparse import os import pprint -from collections import defaultdict from typing import Any, Optional +from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import ( - TaskDataProcessFnCallable, - TaskDataSpec, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) -from nemo_rl.data.processors import math_hf_data_processor from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.environments.utils import create_env @@ -71,50 +71,86 @@ def setup_data( dict[str, EnvironmentInterface], ]: print("\n▶ Setting up envs...") - env_name = data_config["env_name"] - env = create_env(env_name=env_name, env_configs=env_configs) + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) + for env_name in env_name_list + } print("\n▶ Setting up data...") - default_task_spec = TaskDataSpec( - task_name="math_default", - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - # define default task data processor - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (default_task_spec, math_hf_data_processor)) - ) - - # load dataset - data: Any = load_response_dataset(data_config, seed) - task_spec = data.task_spec - task_name = data.task_name - assert hasattr(data, "processor"), "Dataset must have a processor attribute" - task_data_processors[task_name] = (task_spec, data.processor) - + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + merged_data, tokenizer, - default_task_spec, # default task data spec to process any values not specified in the task-specific specs + None, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + val_task_to_env[task_name] = task_to_env[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors and task_to_env + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + val_task_to_env[task_name] = envs[cfg["env_name"]] + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], + merged_val_data, tokenizer, - default_task_spec, - task_data_processors, + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - else: - val_dataset = None + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: env) - task_to_env[task_name] = env - return dataset, val_dataset, task_to_env, task_to_env + return dataset, val_dataset, task_to_env, val_task_to_env def main() -> None: diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index bf790080d9..aee33aee48 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -15,27 +15,24 @@ import argparse import os import pprint -from collections import defaultdict from typing import Any, Optional +from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import ( - TaskDataProcessFnCallable, - TaskDataSpec, -) -from nemo_rl.data.processors import math_hf_data_processor -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.environments.utils import create_env from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -73,58 +70,87 @@ def setup_data( dict[str, EnvironmentInterface], dict[str, EnvironmentInterface], ]: - print("\n▶ Setting up data...") - math_task_spec = TaskDataSpec( - task_name="math", - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - - # load dataset - data: Any = load_response_dataset(data_config, seed) - task_name = ( - data.task_name if hasattr(data, "task_name") else data.task_spec.task_name - ) - - # data processor - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (math_task_spec, math_hf_data_processor)) - ) - task_data_processors[task_name] = (math_task_spec, math_hf_data_processor) - - # setup math environment - math_env = MathEnvironment.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.math_environment.MathEnvironment" - ), - "env_vars": dict(os.environ), # Pass thru all user environment variables - } - ).remote(env_configs["math"]) + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) + for env_name in env_name_list + } + print("\n▶ Setting up data...") + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + merged_data, tokenizer, - math_task_spec, + None, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + val_task_to_env[task_name] = task_to_env[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors and task_to_env + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + val_task_to_env[task_name] = envs[cfg["env_name"]] + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], + merged_val_data, tokenizer, - math_task_spec, - task_data_processors, + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - else: - val_dataset = None + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) - task_to_env[task_name] = math_env - return dataset, val_dataset, task_to_env, task_to_env + return dataset, val_dataset, task_to_env, val_task_to_env def main() -> None: diff --git a/examples/run_grpo_rm.py b/examples/run_grpo_rm.py index b36e34bf7e..21baf9252e 100644 --- a/examples/run_grpo_rm.py +++ b/examples/run_grpo_rm.py @@ -15,25 +15,24 @@ import argparse import os import pprint -from collections import defaultdict from typing import Any, Optional +from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import ( - TaskDataProcessFnCallable, - TaskDataSpec, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) -from nemo_rl.data.processors import math_hf_data_processor -from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.reward_model_environment import RewardModelEnvironment +from nemo_rl.environments.utils import create_env from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -77,56 +76,87 @@ def setup_data( dict[str, EnvironmentInterface], dict[str, EnvironmentInterface], ]: - print("\n▶ Setting up data...") - # load dataset - data: Any = load_response_dataset(data_config, seed) - task_name = ( - data.task_name if hasattr(data, "task_name") else data.task_spec.task_name - ) - - reward_model_task_spec = TaskDataSpec( - task_name=task_name, - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - # data processor - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (reward_model_task_spec, math_hf_data_processor)) - ) - task_data_processors[task_name] = (reward_model_task_spec, math_hf_data_processor) - - reward_model_env = RewardModelEnvironment.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.reward_model_environment.RewardModelEnvironment" - ), - "env_vars": dict(os.environ), # Pass thru all user environment variables - } - ).remote(env_configs["reward_model"]) + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) + for env_name in env_name_list + } + print("\n▶ Setting up data...") + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + merged_data, tokenizer, - reward_model_task_spec, + None, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + val_task_to_env[task_name] = task_to_env[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors and task_to_env + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + val_task_to_env[task_name] = envs[cfg["env_name"]] + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], + merged_val_data, tokenizer, - reward_model_task_spec, - task_data_processors, + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - else: - val_dataset = None + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: reward_model_env) - task_to_env[task_name] = reward_model_env - return dataset, val_dataset, task_to_env, task_to_env + return dataset, val_dataset, task_to_env, val_task_to_env def main() -> None: diff --git a/examples/run_sft.py b/examples/run_sft.py index 8f65262c73..cdd7ec50a9 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -16,17 +16,19 @@ import os import pprint from functools import partial -from typing import Any, Callable, Optional +from datasets import concatenate_datasets from omegaconf import OmegaConf from transformers import AutoTokenizer from nemo_rl.algorithms.sft import MasterConfig, setup, sft_train from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec -from nemo_rl.data.llm_message_utils import get_formatted_message_log +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + load_response_dataset, + update_single_dataset_config, +) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -51,104 +53,88 @@ def parse_args(): # ======================================================= # Data Processing # ======================================================= -def sft_preprocessor( - datum_dict: dict[str, Any], - task_data_spec: TaskDataSpec, - tokenizer, - max_seq_length: int, - idx: int, - add_bos: bool = True, - add_eos: bool = True, - add_generation_prompt: bool = False, - datum_preprocessor: Optional[Callable] = None, -) -> DatumSpec: - """Process a datum dictionary for SFT training.""" - # optional preprocessor - if datum_preprocessor is not None: - datum_dict = datum_preprocessor(datum_dict) - - message_log = get_formatted_message_log( - datum_dict["messages"], - tokenizer, - task_data_spec, - add_bos_token=add_bos, - add_eos_token=add_eos, - add_generation_prompt=add_generation_prompt, - tools=datum_dict.get("tools", None), # Pass tools from data if present - ) - - length = sum(len(m["token_ids"]) for m in message_log) - - loss_multiplier = 1.0 - if length > max_seq_length: - # make smaller and mask out - for message in message_log: - message["token_ids"] = message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] - loss_multiplier = 0.0 - - output = { - "message_log": message_log, - "length": length, - "extra_env_info": None, - "loss_multiplier": loss_multiplier, - "idx": idx, - } - return output def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): print("\n▶ Setting up data...") - - # load dataset - data = load_response_dataset(data_config, seed) - train_dataset = data.formatted_ds["train"] - val_dataset = data.formatted_ds["validation"] - sft_task_spec = data.task_spec - print( - f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset) if val_dataset else 0} samples, respectively." - ) - - # add preprocessor if needed - datum_preprocessor = None - if "dataset_name" in data_config and data_config["dataset_name"] == "clevr_cogent": - from nemo_rl.data.datasets.response_datasets.clevr import ( - format_clevr_cogent_dataset, - ) - - datum_preprocessor = partial(format_clevr_cogent_dataset, return_pil=True) - - train_dataset = AllTaskProcessedDataset( - train_dataset, - tokenizer, - sft_task_spec, - partial( - sft_preprocessor, + # setup train dataset + task_data_processors = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors + data_processor = partial( + data.processor, add_bos=data_config["add_bos"], add_eos=data_config["add_eos"], add_generation_prompt=data_config["add_generation_prompt"], - datum_preprocessor=datum_preprocessor, - ), + ) + task_data_processors[data.task_name] = (data.task_spec, data_processor) + + merged_data = concatenate_datasets([data.dataset for data in data_list]) + dataset = AllTaskProcessedDataset( + merged_data, + tokenizer, + None, + task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - - if val_dataset is not None: + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors + val_data_processor = partial( + val_data.processor, + add_bos=data_config["add_bos"], + add_eos=data_config["add_eos"], + add_generation_prompt=data_config["add_generation_prompt"], + ) + val_task_data_processors[val_data.task_name] = ( + val_data.task_spec, + val_data_processor, + ) + + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - val_dataset, + merged_val_data, tokenizer, - sft_task_spec, - partial( - sft_preprocessor, - add_bos=data_config.get("add_bos", True), - add_eos=data_config.get("add_eos", True), - add_generation_prompt=data_config["add_generation_prompt"], - datum_preprocessor=datum_preprocessor, - ), + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - return train_dataset, val_dataset, sft_task_spec + return dataset, val_dataset def main(is_vlm: bool = False): @@ -186,11 +172,7 @@ def main(is_vlm: bool = False): tokenizer = get_tokenizer(config["policy"]["tokenizer"], get_processor=is_vlm) # setup data - ( - dataset, - val_dataset, - sft_task_spec, - ) = setup_data(tokenizer, config["data"], config["sft"]["seed"]) + dataset, val_dataset = setup_data(tokenizer, config["data"], config["sft"]["seed"]) ( policy, @@ -212,7 +194,6 @@ def main(is_vlm: bool = False): loss_fn, master_config, logger, - sft_task_spec, checkpointer, sft_save_state, ) diff --git a/examples/run_vlm_grpo.py b/examples/run_vlm_grpo.py index 5e8cb1ef0c..29dcfdd627 100644 --- a/examples/run_vlm_grpo.py +++ b/examples/run_vlm_grpo.py @@ -13,42 +13,26 @@ # limitations under the License. import argparse -import base64 import os import pprint -from collections import defaultdict -from io import BytesIO from typing import Any, Optional -import requests +from datasets import concatenate_datasets from omegaconf import OmegaConf -from PIL import Image from transformers import AutoProcessor from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.datasets.response_datasets.clevr import format_clevr_cogent_dataset -from nemo_rl.data.datasets.response_datasets.geometry3k import format_geometry3k_dataset -from nemo_rl.data.datasets.response_datasets.refcoco import format_refcoco_dataset -from nemo_rl.data.interfaces import ( - DatumSpec, - LLMMessageLogType, - TaskDataProcessFnCallable, - TaskDataSpec, -) -from nemo_rl.data.multimodal_utils import ( - PackedTensor, - get_dim_to_pack_along, - get_multimodal_keys_from_processor, -) -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, +from nemo_rl.data.datasets import ( + AllTaskProcessedDataset, + extract_necessary_env_names, + load_response_dataset, + update_single_dataset_config, ) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.vlm_environment import VLMEnvironment +from nemo_rl.environments.utils import create_env from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -68,168 +52,8 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: # =============================================================================== -# VLM Data Processor +# Data Processor # =============================================================================== - - -def resolve_to_image(image_path_or_image: str | Image.Image) -> Image.Image: - """Resolve the image path to a PIL.Image object. - - image_path can be either: - - path to local file - - url to image - - base64 encoded image - """ - if isinstance(image_path_or_image, Image.Image): - return image_path_or_image - - if image_path_or_image.startswith(("http://", "https://")): - # Handle URL - response = requests.get(image_path_or_image) - response.raise_for_status() - return Image.open(BytesIO(response.content)).convert("RGB") - elif image_path_or_image.startswith("data:"): - # Handle base64 encoded image - # Format: ... - header, encoded = image_path_or_image.split(",", 1) - image_data = base64.b64decode(encoded) - return Image.open(BytesIO(image_data)).convert("RGB") - else: - # Handle local file path - return Image.open(image_path_or_image).convert("RGB") - - -def hf_data_processor( - datum_dict: dict[str, Any], - task_data_spec: TaskDataSpec, - processor: AutoProcessor, - max_seq_length: int, - idx: int, -) -> DatumSpec: - """Process a datum dictionary (directly loaded from response_datasets/.py) into a DatumSpec for the VLM Environment.""" - # depending on the task, format the data differently - if task_data_spec.task_name == "clevr-cogent": - datum_dict = format_clevr_cogent_dataset(datum_dict) - elif task_data_spec.task_name == "refcoco": - datum_dict = format_refcoco_dataset(datum_dict) - elif task_data_spec.task_name == "geometry3k": - datum_dict = format_geometry3k_dataset(datum_dict) - else: - raise ValueError(f"No data processor for task {task_data_spec.task_name}") - - user_message = datum_dict["messages"] - problem = user_message[0]["content"] - extra_env_info = {"ground_truth": user_message[1]["content"]} - - message_log: LLMMessageLogType = [] - ### only one round of interaction is assumed, this can easily be extended to a conversational setting - user_message = {"role": "user", "content": []} - # - images = [] - if isinstance(problem, list): - for content in problem: - # for image, video, just append it - # for text, format the prompt to the problem - if content["type"] != "text": - user_message["content"].append(content) - if content["type"] == "image": - images.append(content["image"]) - else: - raise ValueError(f"Unsupported content type: {content['type']}") - elif content["type"] == "text": - user_message["content"].append( - { - "type": "text", - "text": task_data_spec.prompt.format(content["text"]) - if task_data_spec.prompt - else content["text"], - } - ) - else: - # conversation consists of a text-only message - user_message["content"] = task_data_spec.prompt.format(problem) - - images = [resolve_to_image(image) for image in images] - - # get formatted user message - if hasattr(processor, "conversation_preprocessor"): - user_message_for_chat_template = processor.conversation_preprocessor( - user_message - ) - else: - user_message_for_chat_template = user_message - - # this is the string-tokenized conversation template for the generation policy (for vllm) - string_formatted_dialog = processor.apply_chat_template( - [user_message_for_chat_template], - tokenize=False, - add_generation_prompt=True, - ) - - # this is the id-tokenized and image processed conversation template for the policy - message: dict = processor.apply_chat_template( - [user_message], - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - ) - - # add this for backward compatibility - user_message["token_ids"] = message["input_ids"][0] - # add all keys and values to the user message, and the list of keys - multimodal_keys = get_multimodal_keys_from_processor(processor) - for key in multimodal_keys: - if key in message: - user_message[key] = PackedTensor( - message[key], dim_to_pack=get_dim_to_pack_along(processor, key) - ) - - # specifically for gemma, we need to add token_type_ids to the user message as a sequence-type value - if "token_type_ids" in message: - user_message["token_type_ids"] = message["token_type_ids"][0] - - ### append to user message - message_log.append(user_message) - - length = sum(len(m["token_ids"]) for m in message_log) - loss_multiplier = 1.0 - if length >= max_seq_length: - # Treat truncated messages as text only - vllm_kwargs = { - "vllm_content": None, - "vllm_images": [], - } - - # make smaller and mask out - for chat_message in message_log: - chat_message["token_ids"] = chat_message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] - for key, value in chat_message.items(): - if isinstance(value, PackedTensor): - chat_message[key] = PackedTensor.empty_like(value) - loss_multiplier = 0.0 - else: - # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation - # add images for vllm serving - vllm_kwargs = { - "vllm_content": string_formatted_dialog, - "vllm_images": images, - } - - output: DatumSpec = { - "message_log": message_log, - "length": length, - "extra_env_info": extra_env_info, - "loss_multiplier": loss_multiplier, - "idx": idx, - "task_name": task_data_spec.task_name, - **vllm_kwargs, - } - return output - - def setup_data( processor: AutoProcessor, data_config: DataConfig, @@ -241,62 +65,87 @@ def setup_data( dict[str, EnvironmentInterface], dict[str, EnvironmentInterface], ]: - """This function will create a TaskSpec, DatumSpec, and connect the two. + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = { + env_name: create_env(env_name="vlm", env_config=env_configs[env_name]) + for env_name in env_name_list + } - task_spec contains the task name as well as prompt and system prompt modifiers that can be used by data processor - """ print("\n▶ Setting up data...") - - # load dataset - # TODO @yukih: currently seed is not used for vlm datasets - data: Any = load_response_dataset(data_config, seed) - - task_name = data.task_name - vlm_task_spec = TaskDataSpec( - task_name=task_name, - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - - # add data processor for different tasks - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (vlm_task_spec, hf_data_processor)) - ) - task_data_processors[task_name] = (vlm_task_spec, hf_data_processor) - - env_name = data_config["env_name"] - vlm_env = VLMEnvironment.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.vlm_environment.VLMEnvironment" - ), - "env_vars": dict(os.environ), # Pass thru all user environment variables - } - ).remote(env_configs[env_name]) - + # setup train dataset + task_data_processors = {} + task_to_env = {} + data_list = [] + + if isinstance(data_config["train"], dict): + data_config["train"] = [data_config["train"]] + + for cfg in data_config["train"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + data = load_response_dataset(cfg, seed) + data_list.append(data) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + task_data_processors[task_name] = (data.task_spec, data.processor) + task_to_env[task_name] = envs[cfg["env_name"]] + + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( - data.formatted_ds["train"], + merged_data, processor, - vlm_task_spec, + None, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) + print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + + # setup validation dataset + val_task_data_processors = {} + val_task_to_env = {} + val_data_list = [] + + # validation dataset from train dataset (when train dataset's split_validation_size > 0) + for data in data_list: + if hasattr(data, "val_dataset") and data.val_dataset is not None: + val_data_list.append(data.val_dataset) + # bind task_name to task_data_processors and task_to_env + task_name = data.task_name + val_task_data_processors[task_name] = task_data_processors[task_name] + val_task_to_env[task_name] = task_to_env[task_name] + + # validation dataset from config + if data_config["validation"] is not None: + if isinstance(data_config["validation"], dict): + data_config["validation"] = [data_config["validation"]] + + for cfg in data_config["validation"]: + # load dataset + update_single_dataset_config(cfg, data_config["default"]) + val_data = load_response_dataset(cfg, seed) + val_data_list.append(val_data.dataset) + # bind task_name to task_data_processors and task_to_env + task_name = val_data.task_name + val_task_data_processors[task_name] = ( + val_data.task_spec, + val_data.processor, + ) + val_task_to_env[task_name] = envs[cfg["env_name"]] - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: + val_dataset = None + if len(val_data_list) > 0: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], + merged_val_data, processor, - vlm_task_spec, - task_data_processors, + None, + val_task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - else: - val_dataset = None + print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: vlm_env) - task_to_env[task_name] = vlm_env - return dataset, val_dataset, task_to_env, task_to_env + return dataset, val_dataset, task_to_env, val_task_to_env def main() -> None: diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 09cbdf93c2..b5787fdb28 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -28,7 +28,6 @@ from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import TaskDataSpec from nemo_rl.data.llm_message_utils import ( add_loss_mask_to_message_log, batched_message_log_to_flat_message, @@ -238,7 +237,6 @@ def validate( loss_fn, step: int, master_config: MasterConfig, - sft_task_spec: TaskDataSpec, val_batches: int, val_batch_size: int, val_mbs: int, @@ -358,7 +356,6 @@ def sft_train( loss_fn, master_config, logger, - sft_task_spec, checkpointer, sft_save_state: SFTSaveState, ) -> None: @@ -400,7 +397,6 @@ def sft_train( loss_fn, step=0, master_config=master_config, - sft_task_spec=sft_task_spec, val_batches=sft_config["val_batches"], val_batch_size=sft_config["val_global_batch_size"], val_mbs=sft_config["val_micro_batch_size"], @@ -474,7 +470,6 @@ def sft_train( loss_fn, step=total_steps + 1, master_config=master_config, - sft_task_spec=sft_task_spec, val_batches=sft_config["val_batches"], val_batch_size=sft_config["val_global_batch_size"], val_mbs=sft_config["val_micro_batch_size"], diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 3e40c9d78c..ad7d10a99e 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -15,32 +15,52 @@ from typing import Literal, NotRequired, TypedDict -# TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc +class ResponseDatasetConfig(TypedDict): + dataset_name: str + data_path: NotRequired[str] + input_key: NotRequired[str] + output_key: NotRequired[str] + split: NotRequired[str] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + env_name: NotRequired[str] + processor: NotRequired[str] # remove once processor is refactored + download_dir: NotRequired[str] + split_validation_size: NotRequired[float] + + +# TODO: split this typed dict up so it can be PreferenceDatasetConfig | ResponseDatasetConfig | etc # so that we can type check the configs more rigorously as opposed to saying everything # is not required. class DataConfig(TypedDict): max_input_seq_length: int - prompt_file: NotRequired[str | None] - system_prompt_file: NotRequired[str | None] - dataset_name: str - val_dataset_name: NotRequired[str] add_bos: NotRequired[bool] add_eos: NotRequired[bool] - input_key: NotRequired[str] - output_key: NotRequired[str | None] add_generation_prompt: NotRequired[bool] add_system_prompt: NotRequired[bool] - split: NotRequired[str | None] shuffle: bool - seed: NotRequired[int | None] - download_dir: NotRequired[str] - train_data_path: NotRequired[str] - val_data_paths: NotRequired[dict[str, str]] # Number of data loader workers. # Set to 8 or 10 for large batches to improve loading speed. # This saturates CPU threads without consuming too much memory # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] + # dataset configs + # TODO: remove NotRequired once preference dataset is refactored + train: NotRequired[ResponseDatasetConfig] + validation: NotRequired[ResponseDatasetConfig | None] + default: NotRequired[ResponseDatasetConfig | None] + # TODO: remove once preference dataset is refactored + dataset_name: NotRequired[str] + val_dataset_name: NotRequired[str] + input_key: NotRequired[str] + output_key: NotRequired[str | None] + split: NotRequired[str] + train_data_path: NotRequired[str] + val_data_paths: NotRequired[dict[str, str]] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + env_name: NotRequired[str] + processor: NotRequired[str] # remove once processor is refactored # =============================================================================== diff --git a/nemo_rl/data/datasets/__init__.py b/nemo_rl/data/datasets/__init__.py index f859705dba..a4747b7114 100644 --- a/nemo_rl/data/datasets/__init__.py +++ b/nemo_rl/data/datasets/__init__.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from nemo_rl.data.datasets.eval_datasets import load_eval_dataset from nemo_rl.data.datasets.preference_datasets import load_preference_dataset from nemo_rl.data.datasets.processed_dataset import AllTaskProcessedDataset from nemo_rl.data.datasets.response_datasets import load_response_dataset -from nemo_rl.data.datasets.utils import assert_no_double_bos +from nemo_rl.data.datasets.utils import ( + assert_no_double_bos, + extract_necessary_env_names, + update_single_dataset_config, +) __all__ = [ "AllTaskProcessedDataset", @@ -23,4 +28,6 @@ "load_preference_dataset", "load_response_dataset", "assert_no_double_bos", + "extract_necessary_env_names", + "update_single_dataset_config", ] diff --git a/nemo_rl/data/datasets/processed_dataset.py b/nemo_rl/data/datasets/processed_dataset.py index 906ab591fc..ea1cbf87d3 100644 --- a/nemo_rl/data/datasets/processed_dataset.py +++ b/nemo_rl/data/datasets/processed_dataset.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any, Optional, Union import torch @@ -55,17 +56,18 @@ def __init__( ): self.dataset = dataset self.tokenizer = tokenizer + # TODO: will be removed once preference dataset is refactored self.default_task_data_spec = default_task_data_spec self.task_data_processors = task_data_processors self.max_seq_length = max_seq_length self._bos_checked = False - if isinstance(task_data_processors, dict): + if ( + isinstance(task_data_processors, dict) + and default_task_data_spec is not None + ): # apply defaults to all task data specs - for task_name, ( - task_data_spec, - task_data_processor, - ) in task_data_processors.items(): + for _, (task_data_spec, _) in task_data_processors.items(): task_data_spec.copy_defaults(self.default_task_data_spec) def __len__(self) -> int: diff --git a/nemo_rl/data/datasets/raw_dataset.py b/nemo_rl/data/datasets/raw_dataset.py index e63217a469..c795480e49 100644 --- a/nemo_rl/data/datasets/raw_dataset.py +++ b/nemo_rl/data/datasets/raw_dataset.py @@ -12,18 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datasets import Dataset + +from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec from nemo_rl.data.processors import PROCESSOR_REGISTRY class RawDataset: - def __init__(self, data_config: dict, seed: int = 42): - self.data_config: dict = data_config - self.seed: int = seed - self.task_name: str | None = None - self.processor: TaskDataProcessFnCallable | None = None - self.task_spec: TaskDataSpec | None = None - raise NotImplementedError("__init__ is not implemented") + # change to ResponseDatasetConfig | PreferenceDatasetConfig once preference dataset is refactored + data_config: ResponseDatasetConfig + dataset: Dataset + # `val_dataset` is used only when current dataset is used for both training and validation + val_dataset: Dataset | None + processor: TaskDataProcessFnCallable + task_spec: TaskDataSpec + + def split_train_validation(self, test_size: float, seed: int): + if test_size > 0: + split_dataset = self.dataset.train_test_split( + test_size=test_size, seed=seed + ) + self.dataset = split_dataset["train"] + self.val_dataset = split_dataset["test"] def set_processor(self): processor_name = ( @@ -36,7 +47,7 @@ def set_processor(self): ) self.processor = PROCESSOR_REGISTRY[processor_name] - def set_task_spec(self, data_config: dict): + def set_task_spec(self, data_config: ResponseDatasetConfig): self.data_config = data_config system_prompt_file = self.data_config.get("system_prompt_file", None) prompt_file = self.data_config.get("prompt_file", None) diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index a259b8a152..761c6992d8 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -11,13 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any +from nemo_rl.data import ResponseDatasetConfig +from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset -from nemo_rl.data.datasets.response_datasets.dapo_math import DAPOMath17KDataset +from nemo_rl.data.datasets.response_datasets.dapo_math import ( + DAPOMath17KDataset, + DAPOMathAIME2024Dataset, +) from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset from nemo_rl.data.datasets.response_datasets.oai_format_dataset import ( OpenAIFormatDataset, ) @@ -29,102 +36,60 @@ from nemo_rl.data.datasets.response_datasets.response_dataset import ResponseDataset from nemo_rl.data.datasets.response_datasets.squad import SquadDataset from nemo_rl.data.datasets.response_datasets.tulu3 import Tulu3SftMixtureDataset -from nemo_rl.data.datasets.utils import get_extra_kwargs # TODO: refactor this to use the new processor interface and RawDataset interface. https://github.com/NVIDIA-NeMo/RL/issues/1552 -def load_response_dataset(data_config, seed: int = 42): +def load_response_dataset(data_config: ResponseDatasetConfig, seed: int = 42): """Loads response dataset.""" dataset_name = data_config["dataset_name"] - # TODO @yukih: remove duplicated dataset_name (openmathinstruct2, clevr_cogent) + if "data_path" in data_config: + print(f" • Loading {dataset_name} dataset from {data_config['data_path']}") + else: + print(f" • Loading {dataset_name} dataset") + # for sft training if dataset_name == "open_assistant": - base_dataset = OasstDataset( - output_dir="/tmp/open_assistant", - seed=seed, - ) + base_dataset: Any = OasstDataset(**data_config, seed=seed) elif dataset_name == "squad": - base_dataset = SquadDataset() - elif dataset_name == "openmathinstruct2": - base_dataset = OpenMathInstruct2Dataset( - split=data_config["split"], - output_key=data_config["output_key"], - prompt_file=data_config["prompt_file"], - seed=seed, - ) - elif dataset_name == "clevr_cogent": - base_dataset = CLEVRCoGenTDataset( - split=data_config["split"], - prompt_file=data_config["prompt_file"], - ) + base_dataset: Any = SquadDataset(**data_config) + elif dataset_name == "tulu3_sft_mixture": + base_dataset: Any = Tulu3SftMixtureDataset(**data_config, seed=seed) elif dataset_name == "openai_format": - base_dataset = OpenAIFormatDataset( - data_config["train_data_path"], - data_config["val_data_path"], - data_config["chat_key"], - data_config["system_key"], - data_config["system_prompt"], - data_config["tool_key"], - data_config["use_preserving_dataset"], + base_dataset: Any = OpenAIFormatDataset( + **data_config # pyrefly: ignore[missing-argument] `data_path` is required for this class ) # for rl training elif dataset_name == "OpenMathInstruct-2": - print("Loading nvidia/OpenMathInstruct2Dataset for training and validation") - base_dataset: Any = OpenMathInstruct2Dataset(seed=seed) + # TODO: also test after SFT updated + base_dataset: Any = OpenMathInstruct2Dataset(**data_config, seed=seed) elif dataset_name == "DeepScaler": - print( - "Loading agentica-org/DeepScaleR-Preview-Dataset for training and validation" - ) - base_dataset: Any = DeepScalerDataset(seed=seed) + base_dataset: Any = DeepScalerDataset(**data_config) elif dataset_name == "DAPOMath17K": - print( - "Loading BytedTsinghua-SIA/DAPO-Math-17k for training and AIME 2024 for validation" - ) - base_dataset: Any = DAPOMath17KDataset(seed=seed) - # for vlm rl training + base_dataset: Any = DAPOMath17KDataset(**data_config) + elif dataset_name == "HelpSteer3": + base_dataset: Any = HelpSteer3Dataset(**data_config) + elif dataset_name == "AIME2024": + base_dataset: Any = AIME2024Dataset(**data_config) + elif dataset_name == "DAPOMathAIME2024": + base_dataset: Any = DAPOMathAIME2024Dataset(**data_config) + # for vlm training + # TODO: test after GRPO-VLM updated elif dataset_name == "clevr-cogent": - base_dataset: Any = CLEVRCoGenTDataset( - split=data_config["split"], - ) + # TODO: also test after SFT updated + base_dataset: Any = CLEVRCoGenTDataset(**data_config) elif dataset_name == "refcoco": - base_dataset: Any = RefCOCODataset( - split=data_config["split"], - download_dir=data_config["download_dir"], - ) + base_dataset: Any = RefCOCODataset(**data_config) elif dataset_name == "geometry3k": - base_dataset: Any = Geometry3KDataset( - split=data_config["split"], - ) - elif dataset_name == "tulu3_sft_mixture": - base_dataset: Any = Tulu3SftMixtureDataset( - test_size=data_config.get("test_size", 0.05), - prompt_file=data_config.get("prompt_file", None), - max_samples=data_config.get("max_samples", None), - seed=seed, - ) - elif dataset_name == "HelpSteer3": - base_dataset: Any = HelpSteer3Dataset() + base_dataset: Any = Geometry3KDataset(**data_config) # fall back to load from JSON file elif dataset_name == "ResponseDataset": - if "train_data_path" not in data_config: - raise ValueError( - "train_data_path is required when dataset_name is not one of the built-ins." - ) - extra_kwargs = get_extra_kwargs( - data_config, - [ - "val_data_path", - "input_key", - "output_key", - "train_split", - "val_split", - ], - ) - base_dataset = ResponseDataset( - train_data_path=data_config["train_data_path"], - **extra_kwargs, + base_dataset: Any = ResponseDataset( + **data_config, # pyrefly: ignore[missing-argument] `data_path` is required for this class + seed=seed, ) + elif dataset_name == "NemoGymDataset": + base_dataset: Any = NemoGymDataset(**data_config) else: raise ValueError( f"Unsupported {dataset_name=}. " @@ -133,25 +98,17 @@ def load_response_dataset(data_config, seed: int = 42): ) base_dataset.set_task_spec(data_config) - # Skip sft datasets, the run_sft.py has not been refactored yet. - # TODO: refactor run_sft.py to use the new processor interface. https://github.com/NVIDIA-NeMo/RL/issues/1552 - if dataset_name not in [ - "open_assistant", - "squad", - "openmathinstruct2", - "clevr_cogent", - "openai_format", - "tulu3_sft_mixture", - ]: - base_dataset.set_processor() + base_dataset.set_processor() return base_dataset __all__ = [ + "AIME2024Dataset", "CLEVRCoGenTDataset", "DeepScalerDataset", "DAPOMath17KDataset", + "DAPOMathAIME2024Dataset", "Geometry3KDataset", "OpenAIFormatDataset", "OasstDataset", @@ -161,4 +118,5 @@ def load_response_dataset(data_config, seed: int = 42): "SquadDataset", "Tulu3SftMixtureDataset", "HelpSteer3Dataset", + "NemoGymDataset", ] diff --git a/nemo_rl/data/datasets/response_datasets/aime24.py b/nemo_rl/data/datasets/response_datasets/aime24.py new file mode 100644 index 0000000000..cb9c7b0395 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/aime24.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class AIME2024Dataset(RawDataset): + """Simple wrapper around the AIME2024 dataset with train split. + + Args: + repeat: Number of times to repeat the dataset, default is 16 + """ + + def __init__(self, repeat: int = 16, **kwargs) -> None: + self.task_name = "AIME2024" + + # load from huggingface + self.dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, + ) + + # repeat the dataset + self.dataset = self.dataset.repeat(repeat) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "user", "content": data["problem"]}, + {"role": "assistant", "content": data["answer"]}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/clevr.py b/nemo_rl/data/datasets/response_datasets/clevr.py index 30bf67b47f..775b67e8b2 100644 --- a/nemo_rl/data/datasets/response_datasets/clevr.py +++ b/nemo_rl/data/datasets/response_datasets/clevr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any from datasets import load_dataset @@ -52,68 +52,38 @@ def format_clevr_cogent_dataset( ret = { "messages": [ {"role": "user", "content": user_content}, - { - "role": "assistant", - "content": assistant_content, - }, + {"role": "assistant", "content": assistant_content}, ], - "task_name": "clevr-cogent", + "task_name": example["task_name"], } return ret -# contain different variants of the CLEVR dataset -def prepare_clevr_cogent_dataset( - split: str = "trainA", task_name: Optional[str] = None -): - if task_name is None: - task_name = "clevr-cogent" - - if split == "trainA": - tr_dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")[ - "train" - ] - val_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValA")["train"] - elif split == "trainB": - tr_dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")[ - "train" - ] - val_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValB")["train"] - elif split == "valA": - tr_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValA")["train"] - val_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValA")["train"] - elif split == "valB": - tr_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValB")["train"] - val_dataset = load_dataset("MMInstruction/Clevr_CoGenT_ValB")["train"] - - # format - disable features to avoid schema conflicts - tr_dataset = tr_dataset.add_column("task_name", [task_name] * len(tr_dataset)) - val_dataset = val_dataset.add_column("task_name", [task_name] * len(val_dataset)) - - return { - "train": tr_dataset, - "validation": val_dataset, - } - - class CLEVRCoGenTDataset(RawDataset): - def __init__( - self, - split: str = "trainA", - prompt_file: Optional[str] = None, - ): - """Simple wrapper around the CLEVR-CoGenT dataset. - - Args: - split: The split of the dataset to use. - prompt_file: The file containing the prompt for the dataset. - """ - if split not in ["trainA", "trainB", "valA", "valB"]: + """Simple wrapper around the CLEVR-CoGenT dataset. + + Args: + split: Split name for the dataset, default is "train" + """ + + def __init__(self, split: str = "train", **kwargs): + # train, valA, and valB are supported splits. + SPLIT_TO_HF_NAME = { + "train": "MMInstruction/Clevr_CoGenT_TrainA_70K_Complex", + "valA": "MMInstruction/Clevr_CoGenT_ValA", + "valB": "MMInstruction/Clevr_CoGenT_ValB", + } + if split not in SPLIT_TO_HF_NAME: raise ValueError( - f"Invalid split: {split}. Please use 'trainA', 'trainB', 'valA', or 'valB'." + f"Invalid split: {split}. Please use 'train', 'valA', or 'valB'." ) + self.task_name = "clevr-cogent" - self.formatted_ds = prepare_clevr_cogent_dataset( - split=split, task_name=self.task_name + # this dataset will process the image during training using `format_clevr_cogent_dataset` + self.dataset = load_dataset(SPLIT_TO_HF_NAME[split])["train"] + + # format - disable features to avoid schema conflicts + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) ) diff --git a/nemo_rl/data/datasets/response_datasets/dapo_math.py b/nemo_rl/data/datasets/response_datasets/dapo_math.py index 3a9988923b..096c6fe835 100644 --- a/nemo_rl/data/datasets/response_datasets/dapo_math.py +++ b/nemo_rl/data/datasets/response_datasets/dapo_math.py @@ -12,72 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any -from datasets import Dataset, load_dataset +from datasets import load_dataset from nemo_rl.data.datasets.raw_dataset import RawDataset -def format_dapo_math_17k( - data: dict[str, str | float | int], - task_name: str = "DAPOMath17K", -) -> dict[str, list[Any] | str]: - return { - "messages": [ - { - "role": "user", - "content": data["prompt"][0]["content"], - }, - { - "role": "assistant", - "content": data["reward_model"]["ground_truth"], - }, - ], - "task_name": task_name, - } - +class DAPOMath17KDataset(RawDataset): + """Simple wrapper around the DAPO Math 17K dataset with train split.""" -def prepare_dapo_math_17k_dataset( - seed: int = 42, task_name: str = "DAPOMath17K" -) -> dict[str, Dataset | None]: - """Load and split the DeepScaler dataset into train and test sets.""" - # Load the original dataset for training - train_ds = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") + def __init__(self, **kwargs) -> None: + self.task_name = "DAPOMath17K" - # Load hendrydong/aime24 dataset for validation - val_ds = load_dataset("BytedTsinghua-SIA/AIME-2024", split="train") + # load from huggingface + self.dataset = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") - # Shuffle the training dataset with the specified seed - train_ds = train_ds.shuffle(seed=seed) + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, + ) - # Format the examples, removing original columns - train_formatted = train_ds.map( - format_dapo_math_17k, - remove_columns=train_ds.column_names, - fn_kwargs={"task_name": task_name}, - ) - val_formatted = val_ds.map( - format_dapo_math_17k, - remove_columns=val_ds.column_names, - fn_kwargs={"task_name": task_name}, - ) + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + { + "role": "user", + "content": data["prompt"][0]["content"], + }, + { + "role": "assistant", + "content": data["reward_model"]["ground_truth"], + }, + ], + "task_name": self.task_name, + } - return { - "train": train_formatted, - "validation": val_formatted, - } +class DAPOMathAIME2024Dataset(DAPOMath17KDataset): + def __init__(self, **kwargs) -> None: + """Initialize the DAPO Math AIME 2024 dataset with train split.""" + self.task_name = "DAPOMathAIME2024" -class DAPOMath17KDataset(RawDataset): - def __init__(self, seed: int = 42) -> None: - """Initialize the DAPO Math 17K dataset with train split. + # load from huggingface + self.dataset = load_dataset("BytedTsinghua-SIA/AIME-2024", split="train") - Args: - seed: Random seed for reproducible splitting - """ - self.task_name = "DAPOMath17K" - self.formatted_ds = prepare_dapo_math_17k_dataset( - seed=seed, task_name=self.task_name + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, ) diff --git a/nemo_rl/data/datasets/response_datasets/deepscaler.py b/nemo_rl/data/datasets/response_datasets/deepscaler.py index 3465491225..7f6189281d 100644 --- a/nemo_rl/data/datasets/response_datasets/deepscaler.py +++ b/nemo_rl/data/datasets/response_datasets/deepscaler.py @@ -12,77 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any -from datasets import Dataset, load_dataset +from datasets import load_dataset from nemo_rl.data.datasets.raw_dataset import RawDataset -def format_math( - data: dict[str, str | float | int], task_name: str = "DeepScaler" -) -> dict[str, list[Any] | str]: - return { - "messages": [ - { - "role": "user", - "content": data["problem"], - }, - { - "role": "assistant", - "content": data["answer"], - }, - ], - "task_name": task_name, - } - - -def prepare_deepscaler_dataset( - seed: int = 42, task_name: str = "DeepScaler" -) -> dict[str, Dataset | None]: - """Load and split the DeepScaler dataset into train and test sets.""" - # Load the original dataset for training - train_ds = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train") - - # Load hendrydong/aime24 dataset for validation - val_ds = load_dataset("HuggingFaceH4/aime_2024", split="train") - - # Shuffle the training dataset with the specified seed - train_ds = train_ds.shuffle(seed=seed) - - # Format the examples, removing original columns - train_formatted = train_ds.map( - format_math, - remove_columns=train_ds.column_names, - fn_kwargs={"task_name": task_name}, - ) - val_formatted = val_ds.map( - format_math, - remove_columns=val_ds.column_names, - fn_kwargs={"task_name": task_name}, - ) - - # Compute accuracy 16 times per sample (matching the DeepScaleR evaluation setting) - val_repeated = [] - for _ in range(16): - val_repeated.extend(val_formatted) - val_formatted = val_formatted.from_list(val_repeated) - - return { - "train": train_formatted, - "validation": val_formatted, - } - - class DeepScalerDataset(RawDataset): - def __init__(self, seed: int = 42) -> None: - """Initialize the DeepScaler dataset with train/test split. + """Simple wrapper around the DeepScaler dataset with train split.""" - Args: - seed: Random seed for reproducible splitting - """ + def __init__(self, **kwargs) -> None: self.task_name = "DeepScaler" - self.formatted_ds = prepare_deepscaler_dataset( - seed=seed, task_name=self.task_name + + # load from huggingface + self.dataset = load_dataset( + "agentica-org/DeepScaleR-Preview-Dataset", split="train" ) + + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, + ) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "user", "content": data["problem"]}, + {"role": "assistant", "content": data["answer"]}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/geometry3k.py b/nemo_rl/data/datasets/response_datasets/geometry3k.py index d45fb15127..429decb522 100644 --- a/nemo_rl/data/datasets/response_datasets/geometry3k.py +++ b/nemo_rl/data/datasets/response_datasets/geometry3k.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional + +from typing import Any from datasets import load_dataset @@ -24,11 +25,8 @@ def format_geometry3k_dataset( ) -> dict[str, Any]: """Format the Geometry3K dataset into an OpenAI-API-like message log.""" # isolate single image - example["image"] = ( - example["images"][0] - if isinstance(example["images"], list) - else example["images"] - ) + if isinstance(example["images"], list): + example["image"] = example["images"][0] user_content = [ { @@ -48,50 +46,32 @@ def format_geometry3k_dataset( ret = { "messages": [ {"role": "user", "content": user_content}, - { - "role": "assistant", - "content": assistant_content, - }, + {"role": "assistant", "content": assistant_content}, ], - "task_name": "geometry3k", + "task_name": example["task_name"], } return ret -def prepare_geometry3k_dataset(split: str = "train", task_name: str = "geometry3k"): - if split == "train": - tr_dataset = load_dataset("hiyouga/geometry3k")["train"] - val_dataset = load_dataset("hiyouga/geometry3k")["validation"] - else: - tr_dataset = load_dataset("hiyouga/geometry3k")[split] - val_dataset = load_dataset("hiyouga/geometry3k")[split] - - # format - disable features to avoid schema conflicts - tr_dataset = tr_dataset.add_column("task_name", [task_name] * len(tr_dataset)) - val_dataset = val_dataset.add_column("task_name", [task_name] * len(val_dataset)) - return { - "train": tr_dataset, - "validation": val_dataset, - } - - class Geometry3KDataset(RawDataset): - def __init__( - self, - split: str = "train", - prompt_file: Optional[str] = None, - ): - """Simple wrapper around the Geometry3K dataset. + """Simple wrapper around the Geometry3K dataset. + + Args: + split: Split name for the dataset, default is "train" + """ - Args: - split: The split of the dataset to use. - prompt_file: The file containing the prompt for the dataset. - """ + def __init__(self, split: str = "train", **kwargs): + # train, validation, and test are supported splits. assert split in ["train", "validation", "test"], ( f"Invalid split: {split}. Please use 'train' or 'validation' or 'test'." ) + self.task_name = "geometry3k" - self.formatted_ds = prepare_geometry3k_dataset( - split=split, task_name=self.task_name + # this dataset will process the image during training using `format_geometry3k_dataset` + self.dataset = load_dataset("hiyouga/geometry3k")[split] + + # format - disable features to avoid schema conflicts + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) ) diff --git a/nemo_rl/data/datasets/response_datasets/helpsteer3.py b/nemo_rl/data/datasets/response_datasets/helpsteer3.py index 7d275634ef..af7e00be05 100644 --- a/nemo_rl/data/datasets/response_datasets/helpsteer3.py +++ b/nemo_rl/data/datasets/response_datasets/helpsteer3.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any from absl import logging @@ -19,44 +20,49 @@ from nemo_rl.data.datasets.raw_dataset import RawDataset -# Choose the chosen response as the response and the rejected response as the target -def to_response_data_format( - data: dict[str, Any], task_name: str = "HelpSteer3" -) -> dict: - response_1 = data["response1"] - response_2 = data["response2"] - overall_preference = data["overall_preference"] - - if overall_preference < 0: - chosen = response_1 - elif overall_preference == 0: - logging.log_every_n( - logging.WARNING, - "Preference is 0 for some examples! Setting chosen and rejected to response 1 since we don't know which response is better", - 1000, - ) - chosen = response_1 - else: - chosen = response_2 - - if isinstance(data["context"], str): - context = [{"role": "user", "content": data["context"]}] - else: - context = data["context"] +class HelpSteer3Dataset(RawDataset): + """Simple wrapper around the HelpSteer3 dataset with preference subset. - return { - "context": context, - "response": [{"role": "assistant", "content": chosen}], - "task_name": task_name, - } + Args: + split: Split name for the dataset, default is "train" + """ + def __init__(self, split: str = "train", **kwargs): + self.task_name = "HelpSteer3" -class HelpSteer3Dataset(RawDataset): - """HelpSteer3 preference dataset for DPO training.""" + # load from huggingface + self.dataset = load_dataset("nvidia/HelpSteer3", "preference")[split] - def __init__(self) -> None: - ds = load_dataset("nvidia/HelpSteer3", "preference") - self.task_name = "HelpSteer3" - self.formatted_ds = ds.map( - to_response_data_format, fn_kwargs={"task_name": self.task_name} + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, ) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + response_1 = data["response1"] + response_2 = data["response2"] + overall_preference = data["overall_preference"] + + if overall_preference < 0: + chosen = response_1 + elif overall_preference == 0: + logging.log_every_n( + logging.WARNING, + "Preference is 0 for some examples! Setting chosen and rejected to response 1 since we don't know which response is better", + 1000, + ) + chosen = response_1 + else: + chosen = response_2 + + if isinstance(data["context"], str): + context = [{"role": "user", "content": data["context"]}] + else: + context = data["context"] + + return { + "context": context, + "response": [{"role": "assistant", "content": chosen}], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py b/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py new file mode 100644 index 0000000000..5277484786 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from nemo_rl.data.datasets.raw_dataset import RawDataset +from nemo_rl.data.datasets.utils import load_dataset_from_path + + +class NemoGymDataset(RawDataset): + """Simple wrapper around the Nemo Gym dataset.""" + + def __init__(self, data_path: Optional[str] = None, **kwargs) -> None: + self.task_name = "NemoGymDataset" + + # load from jsonl + if data_path is None: + # Allow optional at type level for config validation; enforce at runtime for clarity + raise ValueError( + "NemoGymDataset requires `data_path` in data_config to load examples." + ) + self.dataset = load_dataset_from_path(data_path) + + # format the dataset + # HuggingFace Dataset 在 map/写入 Arrow 时不会持久化 torch.Tensor,会把它序列化成 Python 列表。因此下游在取样时读到的是 [](list),触发断言 + self.dataset = self.dataset.map( + self.format_data, + with_indices=True, + ) + if "repeat" in kwargs: + self.dataset = self.dataset.repeat(kwargs["repeat"]) + + def format_data(self, data: dict[str, Any], idx: int) -> dict[str, Any]: + return { + "message_log": [ + {"role": "user", "content": "", "token_ids": torch.tensor([])} + ], + "task_name": self.task_name, + "length": 0, + "extra_env_info": data, + "loss_multiplier": 1.0, # Fix to 1.0 to backprop on all examples + "idx": idx, + "stop_strings": None, + # Extra vars + "token_ids": [], # Just need this empty key to be compatible with the current NeMo RL GRPO impl + } diff --git a/nemo_rl/data/datasets/response_datasets/oai_format_dataset.py b/nemo_rl/data/datasets/response_datasets/oai_format_dataset.py index 2dfb44aada..674940e88e 100644 --- a/nemo_rl/data/datasets/response_datasets/oai_format_dataset.py +++ b/nemo_rl/data/datasets/response_datasets/oai_format_dataset.py @@ -97,8 +97,7 @@ class OpenAIFormatDataset(RawDataset): } Args: - train_ds_path: Path to the training dataset JSON file - val_ds_path: Path to the validation dataset JSON file + data_path: Path to the dataset JSON file chat_key: Key for the messages list in the dataset (default: "messages") system_key: Optional key for system prompt in the dataset system_prompt: Optional system prompt to add if not in the dataset @@ -121,36 +120,33 @@ class OpenAIFormatDataset(RawDataset): def __init__( self, - train_ds_path: str, - val_ds_path: str, + data_path: str, chat_key: str = "messages", system_key: str | None = None, system_prompt: str | None = None, tool_key: str | None = "tools", use_preserving_dataset: bool = False, + **kwargs, ): self.chat_key = chat_key self.system_key = system_key self.system_prompt = system_prompt self.tool_key = tool_key - self.task_name = "json_dataset" + self.task_name = data_path.split("/")[-1].split(".")[0] + if not use_preserving_dataset: # Use the standard HuggingFace approach (faster and more standard) - train_original_dataset = load_dataset("json", data_files=train_ds_path)[ - "train" - ] - val_original_dataset = load_dataset("json", data_files=val_ds_path)["train"] - - formatted_train_dataset = train_original_dataset.map(self.add_messages_key) - formatted_val_dataset = val_original_dataset.map(self.add_messages_key) + original_dataset = load_dataset("json", data_files=data_path)["train"] + # Format the dataset + self.dataset = original_dataset.map(self.format_data) print( - f"Loaded dataset using standard approach (train: {len(formatted_train_dataset)}, val: {len(formatted_val_dataset)})" + f"Loaded dataset using standard approach: {len(self.dataset)} samples." ) # Warn if tools are present in the dataset if self.tool_key and any( - self.tool_key in sample for sample in formatted_train_dataset + self.tool_key in sample for sample in self.dataset ): warnings.warn( "Tools detected in dataset. Set use_preserving_dataset=True to preserve heterogeneous tool schemas. " @@ -173,46 +169,28 @@ def __init__( ) # Load JSON files directly - with open(train_ds_path, "r") as f: - train_data = [json.loads(line) for line in f] - - with open(val_ds_path, "r") as f: - val_data = [json.loads(line) for line in f] - - # Apply transformations - formatted_train_data = [self.add_messages_key(item) for item in train_data] - formatted_val_data = [self.add_messages_key(item) for item in val_data] - + with open(data_path, "r") as f: + original_dataset = [json.loads(line) for line in f] + # Format the dataset + formatted_data = [self.format_data(item) for item in original_dataset] # Use PreservingDataset to maintain exact structure - formatted_train_dataset = PreservingDataset(formatted_train_data) - formatted_val_dataset = PreservingDataset(formatted_val_data) + self.dataset = PreservingDataset(formatted_data) print( - f"Loaded dataset using PreservingDataset (train: {len(formatted_train_dataset)}, val: {len(formatted_val_dataset)})" + f"Loaded dataset using PreservingDataset: {len(self.dataset)} samples." ) - self.formatted_ds = { - "train": formatted_train_dataset, - "validation": formatted_val_dataset, - } - self.task_name = "json_dataset" - - def add_messages_key( - self, - example: dict[str, Any], - ) -> dict[str, list[dict[str, Any]]]: - messages = [message for message in example[self.chat_key]] - if self.system_key is not None and self.system_key in example: - messages = [ - {"role": "system", "content": example[self.system_key]} - ] + messages + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + messages = [message for message in data[self.chat_key]] + if self.system_key is not None and self.system_key in data: + messages = [{"role": "system", "content": data[self.system_key]}] + messages elif self.system_prompt: messages = [{"role": "system", "content": self.system_prompt}] + messages assert messages[-1]["role"] == "assistant" # Preserve tools if they exist in the data - result = {"messages": messages} - if self.tool_key and self.tool_key in example: - result["tools"] = example[self.tool_key] + result = {"messages": messages, "task_name": self.task_name} + if self.tool_key and self.tool_key in data: + result["tools"] = data[self.tool_key] return result diff --git a/nemo_rl/data/datasets/response_datasets/oasst.py b/nemo_rl/data/datasets/response_datasets/oasst.py index 327bc52b8f..e76316e77e 100644 --- a/nemo_rl/data/datasets/response_datasets/oasst.py +++ b/nemo_rl/data/datasets/response_datasets/oasst.py @@ -15,10 +15,9 @@ import copy import gzip import json -import os -import random -import requests +from datasets import Dataset +from huggingface_hub import hf_hub_download from nemo_rl.data.datasets.raw_dataset import RawDataset @@ -67,7 +66,7 @@ def parse_conversations(tree_obj, first: bool = False): return all_conversations -def get_data_records(objs, task_name: str = "OASST"): +def get_data_records(objs, task_name: str = "oasst"): ## TODO: old format was multi-conversation per example, but ours is single conversation ## is this just because of the input data format? output = [] @@ -87,46 +86,31 @@ def get_data_records(objs, task_name: str = "OASST"): return output -def download_and_process_oasst( - output_directory: str = ".", - seed: int = 42, - task_name: str = "OASST", - split_ratio: float = 0.95, -) -> dict[str, list]: - os.makedirs(output_directory, exist_ok=True) - filename = f"{output_directory}/2023-04-12_oasst_all.trees.jsonl.gz" - - # only download if doesn't exist - if not os.path.isfile(filename): - url = "https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz" - response = requests.get(url) - with open(filename, mode="wb") as fw: - fw.write(response.content) - - with gzip.open(filename) as f: - file_content = f.readlines() - - all_objs = [json.loads(dp.decode("utf-8")) for dp in file_content] +class OasstDataset(RawDataset): + """Simple wrapper around the OASST dataset. - random.seed(seed) - random.shuffle(all_objs) - train_num = int(len(all_objs) * split_ratio) - train_objs = all_objs[:train_num] - val_objs = all_objs[train_num:] - train_records = get_data_records(train_objs, task_name=task_name) - val_records = get_data_records(val_objs, task_name=task_name) + Args: + split_validation_size: Size of the validation data, default is 0.05 + seed: Seed for train/validation split when split_validation_size > 0, default is 42 + """ - formatted_ds = { - "train": train_records, - "validation": val_records, - } + def __init__(self, split_validation_size: float = 0.05, seed: int = 42, **kwargs): + self.task_name = "oasst" - return formatted_ds + # load from huggingface + filename = hf_hub_download( + repo_id="OpenAssistant/oasst1", + filename="2023-04-12_oasst_all.trees.jsonl.gz", + repo_type="dataset", + ) + with gzip.open(filename) as f: + file_content = f.readlines() + # format the dataset + all_objs = [json.loads(dp.decode("utf-8")) for dp in file_content] + self.dataset = get_data_records(all_objs, task_name=self.task_name) + self.dataset = Dataset.from_list(self.dataset) -class OasstDataset(RawDataset): - def __init__(self, output_dir: str = ".", seed: int = 42) -> None: - self.task_name = "OASST" - self.formatted_ds = download_and_process_oasst( - output_dir, seed, task_name=self.task_name - ) + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) diff --git a/nemo_rl/data/datasets/response_datasets/openmathinstruct2.py b/nemo_rl/data/datasets/response_datasets/openmathinstruct2.py index f2bb228427..1b2c651997 100644 --- a/nemo_rl/data/datasets/response_datasets/openmathinstruct2.py +++ b/nemo_rl/data/datasets/response_datasets/openmathinstruct2.py @@ -12,96 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any -from typing import Any, Optional - -from datasets import Dataset, load_dataset +from datasets import load_dataset from nemo_rl.data.datasets.raw_dataset import RawDataset -def format_math( - data: dict[str, str | float | int], - output_key: str = "expected_answer", - task_name: str = "OpenMathInstruct-2", -) -> dict[str, list[Any] | str]: - return { - "messages": [ - { - "role": "user", - "content": data["problem"], - }, - { - "role": "assistant", - "content": data[output_key], - }, - ], - "task_name": task_name, - } - - -def prepare_openinstructmath2_dataset( - split: str = "train_1M", - seed: int = 42, - test_size: float = 0.05, - output_key: str = "expected_answer", - task_name: str = "OpenMathInstruct-2", -) -> dict[str, Dataset | None]: - """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split.""" - print( - "WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." - ) - - # Load the original dataset - original_ds = load_dataset("nvidia/OpenMathInstruct-2", split=split) - - # Split into train and validation sets using HF's train_test_split - split_ds = original_ds.train_test_split(test_size=test_size, seed=seed) - - # Format the examples, removing original columns - train_formatted = split_ds["train"].map( - format_math, - remove_columns=split_ds["train"].column_names, - fn_kwargs={"output_key": output_key, "task_name": task_name}, - ) - val_formatted = split_ds["test"].map( - format_math, - remove_columns=split_ds["test"].column_names, - fn_kwargs={"output_key": output_key, "task_name": task_name}, - ) - - return { - "train": train_formatted, - "validation": val_formatted, - } +class OpenMathInstruct2Dataset(RawDataset): + """Simple wrapper around the OpenMathInstruct2 dataset. + Args: + output_key: Key for the output text, default is "expected_answer" + split: Split name for the dataset, default is "train_1M" + split_validation_size: Size of the validation data, default is 0.05 + seed: Seed for train/validation split when split_validation_size > 0, default is 42 + """ -class OpenMathInstruct2Dataset(RawDataset): def __init__( self, + output_key: str = "expected_answer", split: str = "train_1M", + split_validation_size: float = 0.05, seed: int = 42, - test_size: float = 0.05, - output_key: str = "expected_answer", - prompt_file: Optional[str] = None, + **kwargs, ): - """Initialize the OpenMathInstruct2 dataset with train/validation split. - - Args: - seed: Random seed for reproducible splitting - test_size: Proportion of data to use for validation (0.0-1.0) - """ # train, train_1M, train_2M, and train_5M are supported splits. if split not in ["train", "train_1M", "train_2M", "train_5M"]: raise ValueError( f"Invalid split: {split}. Please use 'train', 'train_1M', 'train_2M', or 'train_5M'." ) + self.input_key = "problem" + self.output_key = output_key self.task_name = "OpenMathInstruct-2" - self.formatted_ds = prepare_openinstructmath2_dataset( - split=split, - seed=seed, - test_size=test_size, - output_key=output_key, - task_name=self.task_name, + + # load from huggingface + self.dataset = load_dataset("nvidia/OpenMathInstruct-2", split=split) + + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, ) + + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "user", "content": data[self.input_key]}, + {"role": "assistant", "content": data[self.output_key]}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/refcoco.py b/nemo_rl/data/datasets/response_datasets/refcoco.py index 9f32b1a12d..a8630e2c6b 100644 --- a/nemo_rl/data/datasets/response_datasets/refcoco.py +++ b/nemo_rl/data/datasets/response_datasets/refcoco.py @@ -15,8 +15,7 @@ import os import random import zipfile -from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import requests from datasets import load_dataset @@ -98,7 +97,6 @@ def format_refcoco_dataset( width: int = 256, height: int = 256, caption_type: str = "random", - prompt_file: Optional[str] = None, ) -> dict[str, Any]: """Format the RefCOCO dataset from huggingface. @@ -158,101 +156,56 @@ def format_refcoco_dataset( ret = { "messages": [ {"role": "user", "content": user_content}, - { - "role": "assistant", - "content": solution, - }, + {"role": "assistant", "content": solution}, ], - "task_name": "refcoco", + "task_name": example["task_name"], } return ret -# contain different variants of the CLEVR dataset -def prepare_refcoco_dataset( - split: str = "default", - task_name: Optional[str] = None, - path_to_coco_images: Optional[Union[str, Path]] = None, -): - if task_name is None: - task_name = "refcoco" - - tr_dataset = load_dataset("jxu124/refcoco")["train"] - val_dataset = load_dataset("jxu124/refcoco")["validation"] - - # format - disable features to avoid schema conflicts - tr_dataset = tr_dataset.add_column("task_name", [task_name] * len(tr_dataset)) - val_dataset = val_dataset.add_column("task_name", [task_name] * len(val_dataset)) - - if path_to_coco_images is None: - print("No path to coco images provided, downloading images to ./coco_images") - path_to_coco_images = Path("./coco_images") - os.makedirs(path_to_coco_images, exist_ok=True) - else: - path_to_coco_images = Path(path_to_coco_images) - - # check for images - if not os.path.exists(str(path_to_coco_images / "train2014")): - print(f"Downloading train2014 images to {path_to_coco_images}") - download_and_unzip( - "http://images.cocodataset.org/zips/train2014.zip", str(path_to_coco_images) - ) - if not os.path.exists(str(path_to_coco_images / "val2014")): - print(f"Downloading val2014 images to {path_to_coco_images}") - download_and_unzip( - "http://images.cocodataset.org/zips/val2014.zip", str(path_to_coco_images) - ) - - # add image column - tr_dataset = tr_dataset.map( - lambda example: { - **example, - "image_path": str(example["image_path"]).replace( - "coco/", str(path_to_coco_images) + "/" - ) - if "image_path" in example - else example.get("image_path"), - } - ) - val_dataset = val_dataset.map( - lambda example: { - **example, - "image_path": str(example["image_path"]).replace( - "coco/", str(path_to_coco_images) + "/" - ) - if "image_path" in example - else example.get("image_path"), - } - ) - - return { - "train": tr_dataset, - "validation": val_dataset, - } +class RefCOCODataset(RawDataset): + """Simple wrapper around the RefCOCO dataset. + Args: + split: Split name for the dataset, default is "train" + download_dir: Directory to download the dataset to, default is "./coco_images" + """ -class RefCOCODataset(RawDataset): def __init__( self, - split: str = "default", - prompt_file: Optional[str] = None, - download_dir: Optional[str] = None, + split: str = "train", + download_dir: str = "./coco_images", + **kwargs, ): - """Simple wrapper around the RefCOCO dataset. - - Args: - split: The split of the dataset to use (currently only 'default' is supported) - prompt_file: The file containing the prompt for the dataset. - """ - VALID_SPLITS = ["default"] - if split not in VALID_SPLITS: + # train and validation are supported splits. + SPLIT_TO_IMAGE_URL = { + "train": "http://images.cocodataset.org/zips/train2014.zip", + "validation": "http://images.cocodataset.org/zips/val2014.zip", + } + if split not in SPLIT_TO_IMAGE_URL: raise ValueError( - f"Invalid split: {split}. Please use one of {VALID_SPLITS}." + f"Invalid split: {split}. Please use 'train' or 'validation'." ) + + self.download_dir = download_dir self.task_name = "refcoco" - self.formatted_ds = prepare_refcoco_dataset( - split=split, - task_name=self.task_name, - path_to_coco_images=download_dir, - ) + # check for images + filename = SPLIT_TO_IMAGE_URL[split].split("/")[-1].split(".")[0] + if not os.path.exists(f"{download_dir}/{filename}"): + print(f"Downloading {filename} images to {download_dir}") + download_and_unzip(SPLIT_TO_IMAGE_URL[split], download_dir) + + # this dataset will process the image during training using `format_refcoco_dataset` + self.dataset = load_dataset("jxu124/refcoco")[split] + self.dataset = self.dataset.map(self.format_data) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + image_path = None + if "image_path" in data: + image_path = data["image_path"].replace("coco/", self.download_dir + "/") + + return { + "image_path": image_path, + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/response_dataset.py b/nemo_rl/data/datasets/response_datasets/response_dataset.py index 15af21206e..3fa6acfa7a 100644 --- a/nemo_rl/data/datasets/response_datasets/response_dataset.py +++ b/nemo_rl/data/datasets/response_datasets/response_dataset.py @@ -29,56 +29,51 @@ class ResponseDataset(RawDataset): } Args: - train_data_path: Path to the JSON file containing training data - val_data_path: Path to the JSON file containing validation data - input_key: Key for the input text - output_key: Key for the output text - train_split: Split name for the training data, used for HuggingFace datasets, default is None - val_split: Split name for the validation data, used for HuggingFace datasets, default is None + data_path: Path to the dataset JSON file + input_key: Key for the input text, default is "input" + output_key: Key for the output text, default is "output" + split: Optional split name for the dataset, used for HuggingFace datasets + split_validation_size: Size of the validation data, default is 0 + seed: Seed for train/validation split when split_validation_size > 0, default is 42 """ def __init__( self, - train_data_path: str, - val_data_path: Optional[str] = None, + data_path: str, input_key: str = "input", output_key: str = "output", - train_split: Optional[str] = None, - val_split: Optional[str] = None, + split: Optional[str] = None, + split_validation_size: float = 0, + seed: int = 42, + **kwargs, ): self.input_key = input_key self.output_key = output_key - self.task_name = "ResponseDataset" - # load from json file or huggingface - train_ds = load_dataset_from_path(train_data_path, train_split) - if val_data_path: - val_ds = load_dataset_from_path(val_data_path, val_split) - else: - val_ds = None + self.task_name = data_path.split("/")[-1].split(".")[0] + + # load from local or huggingface + self.dataset = load_dataset_from_path(data_path, split) - # Only apply add_messages_key if 'messages' column doesn't exist - if "messages" not in train_ds.column_names: - train_ds = train_ds.map( - self.add_messages_key, fn_kwargs={"task_name": self.task_name} + # format the dataset + if "messages" not in self.dataset.column_names: + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, ) - if val_ds is not None and "messages" not in val_ds.column_names: - val_ds = val_ds.map( - self.add_messages_key, fn_kwargs={"task_name": self.task_name} + else: + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) ) - # store the formatted dataset - self.formatted_ds = { - "train": train_ds, - "validation": val_ds, - } + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) - def add_messages_key( - self, example: dict[str, Any], task_name: str = "ResponseDataset" - ) -> dict[str, str | list[dict[str, Any]]]: + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: return { "messages": [ - {"role": "user", "content": example[self.input_key]}, - {"role": "assistant", "content": example[self.output_key]}, + {"role": "user", "content": data[self.input_key]}, + {"role": "assistant", "content": data[self.output_key]}, ], - "task_name": task_name, + "task_name": self.task_name, } diff --git a/nemo_rl/data/datasets/response_datasets/squad.py b/nemo_rl/data/datasets/response_datasets/squad.py index c4e1023424..dba0f7c243 100644 --- a/nemo_rl/data/datasets/response_datasets/squad.py +++ b/nemo_rl/data/datasets/response_datasets/squad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any from datasets import load_dataset @@ -20,27 +19,40 @@ from nemo_rl.data.datasets.raw_dataset import RawDataset -def format_squad(data: dict[str, Any]) -> dict[str, list[dict[str, str]]]: - return { - "messages": [ - { - "role": "system", - "content": data["context"], - }, - { - "role": "user", - "content": data["question"], - }, - { - "role": "assistant", - "content": data["answers"]["text"][0], - }, - ] - } - - class SquadDataset(RawDataset): - def __init__(self) -> None: - original_ds = load_dataset("rajpurkar/squad") - self.task_name = "SQuAD" - self.formatted_ds = original_ds.map(format_squad) + """Simple wrapper around the squad dataset. + + Args: + split: Split name for the dataset, default is "train" + """ + + def __init__(self, split: str = "train", **kwargs) -> None: + self.task_name = "squad" + + # load from huggingface + self.dataset = load_dataset("rajpurkar/squad")[split] + + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, + ) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + { + "role": "system", + "content": data["context"], + }, + { + "role": "user", + "content": data["question"], + }, + { + "role": "assistant", + "content": data["answers"]["text"][0], + }, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/tulu3.py b/nemo_rl/data/datasets/response_datasets/tulu3.py index 9dc29dd83f..1e27d25a2f 100644 --- a/nemo_rl/data/datasets/response_datasets/tulu3.py +++ b/nemo_rl/data/datasets/response_datasets/tulu3.py @@ -19,74 +19,54 @@ from nemo_rl.data.datasets.raw_dataset import RawDataset -def format_tulu3_sft_mixture( - data: dict[str, Any], task_name: str = "tulu3_sft_mixture" -) -> dict[str, str | dict[str, str]]: - """Format for Tulu3 SFT data.""" - messages = data["messages"] - - # Ensure last message is from assistant - if not messages or messages[-1]["role"] != "assistant": - raise ValueError(f"Expected last message to be from assistant, got: {messages}") - - return { - "messages": messages, - "task_name": task_name, - } - - class Tulu3SftMixtureDataset(RawDataset): - """Tulu3 SFT mixture dataset.""" + """Simple wrapper around the Tulu3 SFT mixture dataset with train split. + + Args: + split_validation_size: Size of the validation data, default is 0.05 + seed: Seed for train/validation split when split_validation_size > 0, default is 42 + max_samples: Optional maximum number of samples to use from the dataset + """ def __init__( self, + split_validation_size: float = 0.05, seed: int = 42, - test_size: float = 0.05, - prompt_file: str | None = None, max_samples: int | None = None, + **kwargs, ) -> None: - """Initialize the Tulu3 SFT mixture dataset. - - Args: - seed: Random seed for train/validation split - test_size: Proportion of data to use for validation (0.0-1.0) - prompt_file: Optional prompt file path to be applied via TaskDataSpec - max_samples: Optional maximum number of samples to use from the dataset - """ print( "WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." ) self.task_name = "tulu3_sft_mixture" - # Load the original dataset - original_ds = load_dataset( - path="allenai/tulu-3-sft-mixture", - trust_remote_code=True, - )["train"] # This dataset only has a train split + # load from huggingface + self.dataset = load_dataset("allenai/tulu-3-sft-mixture")["train"] # Optionally limit the number of samples if max_samples is not None and max_samples > 0: - original_ds = original_ds.shuffle(seed=seed).select( - range(min(max_samples, len(original_ds))) + self.dataset = self.dataset.shuffle(seed=seed).select( + range(min(max_samples, len(self.dataset))) ) - # Split into train and validation sets - split_ds = original_ds.train_test_split(test_size=test_size, seed=seed) - - # Format the examples without any reasoning processing - train_formatted = split_ds["train"].map( - format_tulu3_sft_mixture, - remove_columns=split_ds["train"].column_names, - fn_kwargs={"task_name": self.task_name}, - ) - val_formatted = split_ds["test"].map( - format_tulu3_sft_mixture, - remove_columns=split_ds["test"].column_names, - fn_kwargs={"task_name": self.task_name}, + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=["id", "source"], ) - self.formatted_ds = { - "train": train_formatted, - "validation": val_formatted, - } + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + messages = data["messages"] + + # Ensure last message is from assistant + if not messages or messages[-1]["role"] != "assistant": + raise ValueError( + f"Expected last message to be from assistant, got: {messages}" + ) + + return {"task_name": self.task_name} diff --git a/nemo_rl/data/datasets/utils.py b/nemo_rl/data/datasets/utils.py index eb78becc45..151c79d47d 100644 --- a/nemo_rl/data/datasets/utils.py +++ b/nemo_rl/data/datasets/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import base64 import io import os @@ -106,3 +107,34 @@ def get_extra_kwargs(data_config: dict, keys: list[str]) -> dict: if key in data_config: extra_kwargs[key] = data_config[key] return extra_kwargs + + +def update_single_dataset_config(data_config: dict, default_data_config: dict) -> None: + """Fill the single dataset config with default dataset config.""" + for key in default_data_config.keys(): + if key not in data_config: + data_config[key] = default_data_config[key] + + +def extract_necessary_env_names(data_config: dict) -> list[str]: + """Extract the necessary environment names from the data config. + + Some environments are set in env_configs but not used in the data config. + This function extracts the necessary environment names from the data config. + + Args: + data_config: The data config. + + Returns: + The necessary environment names. + """ + necessary_env_names = set() + keys = ["train", "validation", "default"] + for key in keys: + if ( + key in data_config + and data_config[key] is not None + and "env_name" in data_config[key] + ): + necessary_env_names.add(data_config[key]["env_name"]) + return list(necessary_env_names) diff --git a/nemo_rl/data/interfaces.py b/nemo_rl/data/interfaces.py index 05f10236c5..207b702bda 100644 --- a/nemo_rl/data/interfaces.py +++ b/nemo_rl/data/interfaces.py @@ -18,8 +18,11 @@ import torch from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from nemo_rl.data.multimodal_utils import PackedTensor + # OpenAI-API-like message log, but every messsage may contain associated tensors (i.e. tokenized strings and logprobs) in addition to the original "content" string LLMMessageLogType = list[dict[str, Union[str, torch.Tensor]]] +VLMMessageLogType = list[dict[str, Union[str, torch.Tensor, PackedTensor]]] # Flattened message log where all tensors and data are concatenated together for a conversation # Converts a conversation from list-of-turns format to key-value format with concatenated tensors @@ -30,9 +33,9 @@ class DatumSpec(TypedDict): - message_log: LLMMessageLogType + message_log: LLMMessageLogType | VLMMessageLogType length: int # total (concatenated) length of the message tensors - extra_env_info: dict[str, Any] + extra_env_info: Optional[dict[str, Any]] loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid) idx: int task_name: NotRequired[str] diff --git a/nemo_rl/data/multimodal_utils.py b/nemo_rl/data/multimodal_utils.py index 0da507acc7..918c589ad1 100644 --- a/nemo_rl/data/multimodal_utils.py +++ b/nemo_rl/data/multimodal_utils.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +from io import BytesIO from typing import Optional, Union +import requests import torch +from PIL import Image from transformers import PreTrainedTokenizerBase @@ -179,3 +183,30 @@ def get_dim_to_pack_along(processor, key: str) -> int: return 1 # return zero by default return 0 + + +def resolve_to_image(image_path_or_image: str | Image.Image) -> Image.Image: + """Resolve the image path to a PIL.Image object. + + image_path can be either: + - path to local file + - url to image + - base64 encoded image + """ + if isinstance(image_path_or_image, Image.Image): + return image_path_or_image + + if image_path_or_image.startswith(("http://", "https://")): + # Handle URL + response = requests.get(image_path_or_image) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + elif image_path_or_image.startswith("data:"): + # Handle base64 encoded image + # Format: ... + header, encoded = image_path_or_image.split(",", 1) + image_data = base64.b64decode(encoded) + return Image.open(BytesIO(image_data)).convert("RGB") + else: + # Handle local file path + return Image.open(image_path_or_image).convert("RGB") diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 235e77c225..e571db8a7b 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -17,14 +17,16 @@ from typing import Any, Dict, cast import torch -from transformers import PreTrainedTokenizerBase +from transformers import AutoProcessor, PreTrainedTokenizerBase from nemo_rl.data.interfaces import ( DatumSpec, LLMMessageLogType, TaskDataProcessFnCallable, TaskDataSpec, + VLMMessageLogType, ) +from nemo_rl.data.llm_message_utils import get_formatted_message_log TokenizerType = PreTrainedTokenizerBase @@ -132,6 +134,56 @@ def helpsteer3_data_processor( return output +def sft_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False, +) -> DatumSpec: + """Process a datum dictionary for SFT training.""" + # optional preprocessor + if datum_dict["task_name"] == "clevr-cogent": + from nemo_rl.data.datasets.response_datasets.clevr import ( + format_clevr_cogent_dataset, + ) + + datum_dict = format_clevr_cogent_dataset(datum_dict) + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + tools=datum_dict.get("tools", None), # Pass tools from data if present + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + # Example of a generic math data processor def math_data_processor( datum_dict: dict[str, Any], @@ -260,6 +312,151 @@ def math_hf_data_processor( return output +def vlm_hf_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + processor: AutoProcessor, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from response_datasets/.py) into a DatumSpec for the VLM Environment.""" + from nemo_rl.data.datasets.response_datasets.clevr import ( + format_clevr_cogent_dataset, + ) + from nemo_rl.data.datasets.response_datasets.geometry3k import ( + format_geometry3k_dataset, + ) + from nemo_rl.data.datasets.response_datasets.refcoco import format_refcoco_dataset + from nemo_rl.data.multimodal_utils import ( + PackedTensor, + get_dim_to_pack_along, + get_multimodal_keys_from_processor, + resolve_to_image, + ) + + # depending on the task, format the data differently + if datum_dict["task_name"] == "clevr-cogent": + datum_dict = format_clevr_cogent_dataset(datum_dict) + elif datum_dict["task_name"] == "refcoco": + datum_dict = format_refcoco_dataset(datum_dict) + elif datum_dict["task_name"] == "geometry3k": + datum_dict = format_geometry3k_dataset(datum_dict) + else: + raise ValueError(f"No data processor for task {datum_dict['task_name']}") + + user_message = datum_dict["messages"] + problem = user_message[0]["content"] + extra_env_info = {"ground_truth": user_message[1]["content"]} + + message_log: VLMMessageLogType = [] + ### only one round of interaction is assumed, this can easily be extended to a conversational setting + user_message: dict[str, Any] = {"role": "user", "content": []} + # + images = [] + if isinstance(problem, list): + for content in problem: + # for image, video, just append it + # for text, format the prompt to the problem + if content["type"] != "text": + user_message["content"].append(content) + if content["type"] == "image": + images.append(content["image"]) + else: + raise ValueError(f"Unsupported content type: {content['type']}") + elif content["type"] == "text": + user_message["content"].append( + { + "type": "text", + "text": task_data_spec.prompt.format(content["text"]) + if task_data_spec.prompt + else content["text"], + } + ) + else: + # conversation consists of a text-only message + user_message["content"] = task_data_spec.prompt.format(problem) + + images = [resolve_to_image(image) for image in images] + + # get formatted user message + if hasattr(processor, "conversation_preprocessor"): + user_message_for_chat_template = processor.conversation_preprocessor( + user_message + ) + else: + user_message_for_chat_template = user_message + + # this is the string-tokenized conversation template for the generation policy (for vllm) + string_formatted_dialog = processor.apply_chat_template( + [user_message_for_chat_template], + tokenize=False, + add_generation_prompt=True, + ) + + # this is the id-tokenized and image processed conversation template for the policy + message: dict = processor.apply_chat_template( + [user_message], + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + + # add this for backward compatibility + user_message["token_ids"] = message["input_ids"][0] + # add all keys and values to the user message, and the list of keys + multimodal_keys = get_multimodal_keys_from_processor(processor) + for key in multimodal_keys: + if key in message: + user_message[key] = PackedTensor( + message[key], dim_to_pack=get_dim_to_pack_along(processor, key) + ) + + # specifically for gemma, we need to add token_type_ids to the user message as a sequence-type value + if "token_type_ids" in message: + user_message["token_type_ids"] = message["token_type_ids"][0] + + ### append to user message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + loss_multiplier = 1.0 + if length >= max_seq_length: + # Treat truncated messages as text only + vllm_kwargs = { + "vllm_content": None, + "vllm_images": [], + } + + # make smaller and mask out + for chat_message in message_log: + chat_message["token_ids"] = chat_message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + for key, value in chat_message.items(): + if isinstance(value, PackedTensor): + chat_message[key] = PackedTensor.empty_like(value) + loss_multiplier = 0.0 + else: + # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation + # add images for vllm serving + vllm_kwargs = { + "vllm_content": string_formatted_dialog, + "vllm_images": images, + } + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict["task_name"], + **vllm_kwargs, # pyrefly: ignore[bad-unpacking] + } + return output + + def _construct_multichoice_prompt( prompt: str, question: str, options: dict[str, str] ) -> str: @@ -291,7 +488,7 @@ def multichoice_qa_processor( if "subject" in datum_dict: extra_env_info.update({"subject": datum_dict["subject"]}) - message_log = [] + message_log: LLMMessageLogType = [] # system prompt if task_data_spec.system_prompt: @@ -341,6 +538,26 @@ def multichoice_qa_processor( return output +def nemo_gym_data_processor( + datum_dict: dict[str, Any], + *args, + **kwargs, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym.""" + # Ensure message_log exists and contains tensor token_ids so downstream padding works + if "message_log" not in datum_dict or not datum_dict["message_log"]: + datum_dict["message_log"] = [ + {"role": "user", "content": "", "token_ids": torch.tensor([])} + ] + else: + for msg in datum_dict["message_log"]: + if "token_ids" not in msg: + msg["token_ids"] = torch.tensor([]) + elif not isinstance(msg["token_ids"], torch.Tensor): + msg["token_ids"] = torch.tensor(msg["token_ids"]) + return cast(DatumSpec, datum_dict) + + # Processor registry. Key is the processor name, value is the processor function. # Note: We cast the literal dict to Dict[str, TaskDataProcessFnCallable] because # type checkers see each concrete function's signature as a distinct callable type. @@ -351,10 +568,13 @@ def multichoice_qa_processor( Dict[str, TaskDataProcessFnCallable], { "default": math_hf_data_processor, + "helpsteer3_data_processor": helpsteer3_data_processor, + "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, "multichoice_qa_processor": multichoice_qa_processor, - "math_data_processor": math_data_processor, - "helpsteer3_data_processor": helpsteer3_data_processor, + "sft_processor": sft_processor, + "vlm_hf_data_processor": vlm_hf_data_processor, + "nemo_gym_data_processor": nemo_gym_data_processor, }, ) diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index a9e50c67e1..9b4f4d6279 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -43,6 +43,12 @@ class EnvRegistryEntry(TypedDict, total=False): "code_jaccard": { "actor_class_fqn": "nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment", }, + "vlm": { + "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", + }, + "nemo_gym": { + "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", + }, } @@ -93,7 +99,7 @@ def chunk_list_to_workers(to_chunk: list[Any], num_workers: int) -> list[list[An return chunks -def create_env(env_name: str, env_configs: dict) -> EnvironmentInterface: +def create_env(env_name: str, env_config: dict) -> EnvironmentInterface: assert env_name in ENV_REGISTRY, ( f"Env name {env_name} is not registered in ENV_REGISTRY. Please call register_env() to register the environment." ) @@ -104,7 +110,7 @@ def create_env(env_name: str, env_configs: dict) -> EnvironmentInterface: "py_executable": get_actor_python_env(actor_class_fqn), "env_vars": dict(os.environ), } - ).remote(env_configs[env_name]) + ).remote(env_config) return env diff --git a/nemo_rl/utils/config.py b/nemo_rl/utils/config.py index 690c8f164c..156a1b9b1c 100644 --- a/nemo_rl/utils/config.py +++ b/nemo_rl/utils/config.py @@ -27,6 +27,23 @@ def resolve_path(base_path: Path, path: str) -> Path: return base_path / path +def merge_with_override( + base_config: DictConfig, override_config: DictConfig +) -> DictConfig: + """Merge configs with support for _override_ marker to completely override sections.""" + for key in list(override_config.keys()): + if isinstance(override_config[key], DictConfig): + if override_config[key].get("_override_", False): + # remove the _override_ marker + override_config[key].pop("_override_") + # remove the key from base_config so it won't be merged + if key in base_config: + base_config.pop(key) + + merged_config = cast(DictConfig, OmegaConf.merge(base_config, override_config)) + return merged_config + + def load_config_with_inheritance( config_path: Union[str, Path], base_dir: Optional[Union[str, Path]] = None, @@ -63,10 +80,12 @@ def load_config_with_inheritance( for default in defaults: parent_path = resolve_path(base_dir, str(default)) parent_config = load_config_with_inheritance(parent_path, base_dir) - base_config = cast(DictConfig, OmegaConf.merge(base_config, parent_config)) + base_config = cast( + DictConfig, merge_with_override(base_config, parent_config) + ) # Merge with current config - config = cast(DictConfig, OmegaConf.merge(base_config, config)) + config = cast(DictConfig, merge_with_override(base_config, config)) return config diff --git a/pyrefly.toml b/pyrefly.toml index 95f8943e42..215b05ea0d 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -38,8 +38,8 @@ project-includes = [ "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", "nemo_rl/algorithms/interfaces.py", - "nemo_rl/algorithms/utils.py", "nemo_rl/algorithms/reward_functions.py", + "nemo_rl/algorithms/utils.py", "nemo_rl/data/__init__.py", "nemo_rl/data/chat_templates.py", "nemo_rl/data/collate_fn.py", @@ -59,13 +59,15 @@ project-includes = [ "nemo_rl/data/datasets/processed_dataset.py", "nemo_rl/data/datasets/raw_dataset.py", "nemo_rl/data/datasets/response_datasets/__init__.py", + "nemo_rl/data/datasets/response_datasets/aime24.py", "nemo_rl/data/datasets/response_datasets/clevr.py", + "nemo_rl/data/datasets/response_datasets/dapo_math.py", "nemo_rl/data/datasets/response_datasets/deepscaler.py", "nemo_rl/data/datasets/response_datasets/geometry3k.py", + "nemo_rl/data/datasets/response_datasets/helpsteer3.py", "nemo_rl/data/datasets/response_datasets/oai_format_dataset.py", "nemo_rl/data/datasets/response_datasets/oasst.py", "nemo_rl/data/datasets/response_datasets/openmathinstruct2.py", - "nemo_rl/data/datasets/response_datasets/helpsteer3.py", "nemo_rl/data/datasets/response_datasets/refcoco.py", "nemo_rl/data/datasets/response_datasets/response_dataset.py", "nemo_rl/data/datasets/response_datasets/squad.py", @@ -82,8 +84,8 @@ project-includes = [ "nemo_rl/distributed/virtual_cluster.py", "nemo_rl/distributed/worker_group_utils.py", "nemo_rl/environments/__init__.py", - "nemo_rl/environments/games/sliding_puzzle.py", "nemo_rl/environments/code_jaccard_environment.py", + "nemo_rl/environments/games/sliding_puzzle.py", "nemo_rl/environments/interfaces.py", "nemo_rl/environments/math_environment.py", "nemo_rl/environments/metrics.py", @@ -110,10 +112,10 @@ project-includes = [ "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", "nemo_rl/utils/__init__.py", + "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", "nemo_rl/utils/native_checkpoint.py", - "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", "nemo_rl/utils/packed_tensor.py", diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 8b26b5e5e1..1145716a2f 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -31,6 +31,7 @@ time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh +time uv run --no-sync bash ./tests/functional/grpo_multiple_datasets.sh time uv run --no-sync bash ./tests/functional/dpo.sh time uv run --no-sync bash ./tests/functional/rm.sh time uv run --no-sync bash ./tests/functional/eval.sh diff --git a/tests/functional/distillation.sh b/tests/functional/distillation.sh index 19cb71252c..195e3fc3a5 100644 --- a/tests/functional/distillation.sh +++ b/tests/functional/distillation.sh @@ -37,7 +37,9 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE distillation.max_val_samples=16 \ distillation.val_batch_size=8 \ distillation.val_period=3 \ - data.dataset_name=OpenMathInstruct-2 \ + data.train.dataset_name=OpenMathInstruct-2 \ + ++data.train.split_validation_size=0.05 \ + data.validation=null \ loss_fn.zero_outside_topk=true \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ diff --git a/tests/functional/distillation_megatron.sh b/tests/functional/distillation_megatron.sh index b56ea672fb..d40516d939 100644 --- a/tests/functional/distillation_megatron.sh +++ b/tests/functional/distillation_megatron.sh @@ -40,7 +40,9 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE distillation.max_val_samples=16 \ distillation.val_batch_size=8 \ distillation.val_period=3 \ - data.dataset_name=OpenMathInstruct-2 \ + data.train.dataset_name=OpenMathInstruct-2 \ + ++data.train.split_validation_size=0.05 \ + data.validation=null \ loss_fn.zero_outside_topk=false \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ diff --git a/tests/functional/grpo_multiple_datasets.sh b/tests/functional/grpo_multiple_datasets.sh new file mode 100755 index 0000000000..517fe637c0 --- /dev/null +++ b/tests/functional/grpo_multiple_datasets.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py \ + --config $PROJECT_ROOT/examples/configs/grpo_multiple_datasets.yaml \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.val_at_start=true \ + grpo.max_val_samples=4 \ + grpo.val_batch_size=4 \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/gen_kl_error"]) < 0.001' + diff --git a/tests/unit/data/datasets/test_oai_format_dataset.py b/tests/unit/data/datasets/test_oai_format_dataset.py index aad989ed15..ef7b000c59 100644 --- a/tests/unit/data/datasets/test_oai_format_dataset.py +++ b/tests/unit/data/datasets/test_oai_format_dataset.py @@ -16,9 +16,10 @@ import tempfile import pytest -from transformers import AutoTokenizer +from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_rl.data.datasets import load_response_dataset from nemo_rl.data.datasets.response_datasets import OpenAIFormatDataset @@ -27,74 +28,73 @@ def sample_data(request): chat_key = request.param[0] system_key = request.param[1] - train_data = { + data = { chat_key: [ {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}, ], } - val_data = { - chat_key: [ - {"role": "user", "content": "What is the capital of Germany?"}, - {"role": "assistant", "content": "The capital of Germany is Berlin."}, - ], - } if system_key is not None: - train_data[system_key] = "You are a helpful assistant." - if system_key is not None: - val_data[system_key] = "You are a helpful assistant." + data[system_key] = "You are a helpful assistant." # Create temporary files for train and validation data - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as train_file: - json.dump(train_data, train_file) - train_path = train_file.name + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + data_path = f.name + + return data_path - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as val_file: - json.dump(val_data, val_file) - val_path = val_file.name - return train_path, val_path +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + tokenizer = get_tokenizer({"name": "Qwen/Qwen3-0.6B"}) + return tokenizer @pytest.mark.parametrize("sample_data", [("messages", None)], indirect=True) def test_dataset_initialization(sample_data): - train_path, val_path = sample_data - dataset = OpenAIFormatDataset(train_path, val_path) + data_path = sample_data + data_config = { + "dataset_name": "openai_format", + "data_path": data_path, + } + dataset = load_response_dataset(data_config) assert dataset.chat_key == "messages" - assert "train" in dataset.formatted_ds - assert "validation" in dataset.formatted_ds + assert len(dataset.dataset) == 1 @pytest.mark.parametrize("sample_data", [("conversations", None)], indirect=True) def test_custom_keys(sample_data): - train_path, val_path = sample_data - dataset = OpenAIFormatDataset( - train_path, - val_path, - chat_key="conversations", - system_prompt="You are a helpful assistant.", - ) + data_path = sample_data + data_config = { + "dataset_name": "openai_format", + "data_path": data_path, + "chat_key": "conversations", + "system_prompt": "You are a helpful assistant.", + } + dataset = load_response_dataset(data_config) assert dataset.chat_key == "conversations" assert dataset.system_prompt == "You are a helpful assistant." -@pytest.mark.hf_gated @pytest.mark.parametrize("sample_data", [("messages", "system_key")], indirect=True) -def test_message_formatting(sample_data): - train_path, val_path = sample_data +def test_message_formatting(sample_data, tokenizer): + # load the dataset + data_path = sample_data dataset = OpenAIFormatDataset( - train_path, val_path, chat_key="messages", system_key="system_key" + data_path, + chat_key="messages", + system_key="system_key", ) - first_example = dataset.formatted_ds["train"][0] + # check the first example + first_example = dataset.dataset[0] + assert "task_name" in first_example assert first_example["messages"][0]["role"] == "system" assert first_example["messages"][0]["content"] == "You are a helpful assistant." assert first_example["messages"][1]["role"] == "user" @@ -102,9 +102,8 @@ def test_message_formatting(sample_data): assert first_example["messages"][2]["role"] == "assistant" assert first_example["messages"][2]["content"] == "The capital of France is Paris." + # check the combined message chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response - tokenizer = AutoTokenizer.from_pretrained("Meta-Llama/Meta-Llama-3-8B-Instruct") - combined_message = tokenizer.apply_chat_template( first_example["messages"], chat_template=chat_template, diff --git a/tests/unit/data/datasets/test_response_dataset.py b/tests/unit/data/datasets/test_response_dataset.py index 22bc7168fe..23c7923066 100644 --- a/tests/unit/data/datasets/test_response_dataset.py +++ b/tests/unit/data/datasets/test_response_dataset.py @@ -16,100 +16,155 @@ import tempfile import pytest -from transformers import AutoTokenizer +from datasets import Dataset -from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.datasets import load_response_dataset +from nemo_rl.data.datasets.response_datasets.clevr import format_clevr_cogent_dataset +from nemo_rl.data.datasets.response_datasets.geometry3k import format_geometry3k_dataset -@pytest.fixture -def sample_data(request): - input_key = request.param[0] - output_key = request.param[1] - - train_data = [ +def create_sample_data(input_key, output_key, is_save_to_disk=False): + data = [ {input_key: "Hello", output_key: "Hi there!"}, {input_key: "How are you?", output_key: "I'm good, thanks!"}, ] - val_data = [ - {input_key: "What's up?", output_key: "Not much!"}, - {input_key: "Bye", output_key: "Goodbye!"}, - ] - # Create temporary files for train and validation data - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as train_file: - json.dump(train_data, train_file) - train_path = train_file.name + # Create temporary dataset file + if is_save_to_disk: + data_path = tempfile.mktemp() + dataset = Dataset.from_list(data) + dataset.save_to_disk(data_path) + else: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + data_path = f.name - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as val_file: - json.dump(val_data, val_file) - val_path = val_file.name + return data_path - return train_path, val_path +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + tokenizer = get_tokenizer({"name": "Qwen/Qwen3-0.6B"}) + return tokenizer -@pytest.mark.parametrize("sample_data", [("input", "output")], indirect=True) -def test_dataset_initialization(sample_data): + +@pytest.mark.parametrize( + "input_key,output_key", [("input", "output"), ("question", "answer")] +) +@pytest.mark.parametrize("is_save_to_disk", [True, False]) +def test_response_dataset(input_key, output_key, is_save_to_disk, tokenizer): # load the dataset - train_path, val_path = sample_data + data_path = create_sample_data(input_key, output_key, is_save_to_disk) data_config = { "dataset_name": "ResponseDataset", - "train_data_path": train_path, - "val_data_path": val_path, + "data_path": data_path, + "input_key": input_key, + "output_key": output_key, } dataset = load_response_dataset(data_config) - assert dataset.input_key == "input" - assert dataset.output_key == "output" - assert "train" in dataset.formatted_ds - assert "validation" in dataset.formatted_ds + # check the input and output keys + assert dataset.input_key == input_key + assert dataset.output_key == output_key + + # check the first example + first_example = dataset.dataset[0] + + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example + + # check the combined message + chat_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" + combined_message = tokenizer.apply_chat_template( + first_example["messages"], + chat_template=chat_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + assert combined_message == " Question: Hello Answer: Hi there!" -@pytest.mark.parametrize("sample_data", [("question", "answer")], indirect=True) -def test_custom_keys(sample_data): +def test_helpsteer3_dataset(): # load the dataset - train_path, val_path = sample_data - data_config = { - "dataset_name": "ResponseDataset", - "train_data_path": train_path, - "val_data_path": val_path, - "input_key": "question", - "output_key": "answer", - } + data_config = {"dataset_name": "HelpSteer3"} dataset = load_response_dataset(data_config) - assert dataset.input_key == "question" - assert dataset.output_key == "answer" + # check the first example + first_example = dataset.dataset[0] + + # only contains messages and task_name + assert len(first_example.keys()) == 3 + assert "context" in first_example + assert "response" in first_example + assert "task_name" in first_example + + # check the content + assert len(first_example["context"]) == 7 + assert first_example["response"][0]["role"] == "assistant" + assert first_example["response"][0]["content"][:20] == "Yes, you are correct" -@pytest.mark.hf_gated -@pytest.mark.parametrize("sample_data", [("question", "answer")], indirect=True) -def test_message_formatting(sample_data): +def test_open_assistant_dataset(): # load the dataset - train_path, val_path = sample_data data_config = { - "dataset_name": "ResponseDataset", - "train_data_path": train_path, - "val_data_path": val_path, - "input_key": "question", - "output_key": "answer", + "dataset_name": "open_assistant", + "split_validation_size": 0.05, } dataset = load_response_dataset(data_config) - first_example = dataset.formatted_ds["train"][0] + # check the first example + first_example = dataset.dataset[0] + first_val_example = dataset.val_dataset[0] - assert first_example["messages"][0]["role"] == "user" - assert first_example["messages"][0]["content"] == "Hello" - assert first_example["messages"][1]["role"] == "assistant" - assert first_example["messages"][1]["content"] == "Hi there!" + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example - chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response - tokenizer = AutoTokenizer.from_pretrained("Meta-Llama/Meta-Llama-3-8B-Instruct") + # check the content + assert first_example["messages"][-1]["content"][:20] == "```\n def forward(" + assert len(first_example["messages"]) == 7 + assert first_val_example["messages"][-1]["content"][:20] == "The colors you shoul" + assert len(first_val_example["messages"]) == 5 + +@pytest.mark.parametrize( + "dataset_name", + ["DAPOMath17K", "DAPOMathAIME2024", "DeepScaler", "AIME2024", "squad"], +) +def test_build_in_dataset(dataset_name, tokenizer): + # load the dataset + data_config = {"dataset_name": dataset_name} + dataset = load_response_dataset(data_config) + + # check the first example + first_example = dataset.dataset[0] + + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example + + # check the content + if dataset_name == "DAPOMath17K": + assert first_example["messages"][1]["content"] == "34" + elif dataset_name == "DAPOMathAIME2024": + assert first_example["messages"][1]["content"] == "540" + elif dataset_name == "DeepScaler": + assert first_example["messages"][1]["content"] == "-\\frac{2}{3}" + elif dataset_name == "AIME2024": + assert first_example["messages"][1]["content"] == "204" + assert len(dataset.dataset) == 480 + elif dataset_name == "squad": + assert first_example["messages"][2]["content"] == "Saint Bernadette Soubirous" + + # check the combined message + chat_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" combined_message = tokenizer.apply_chat_template( first_example["messages"], chat_template=chat_template, @@ -118,122 +173,112 @@ def test_message_formatting(sample_data): add_special_tokens=False, ) - assert combined_message == "".join( - message["content"] for message in first_example["messages"] - ) + if dataset_name == "squad": + assert combined_message == ( + "Context: " + + first_example["messages"][0]["content"] + + " Question: " + + first_example["messages"][1]["content"] + + " Answer: " + + first_example["messages"][2]["content"] + ) + else: + assert combined_message == ( + " Question: " + + first_example["messages"][0]["content"] + + " Answer: " + + first_example["messages"][1]["content"] + ) -@pytest.mark.hf_gated -@pytest.mark.skip(reason="dataset download is flaky") -def test_squad_dataset(): +@pytest.mark.parametrize( + "dataset_name,output_key", + [ + ("OpenMathInstruct-2", "expected_answer"), + ("OpenMathInstruct-2", "generated_solution"), + ("tulu3_sft_mixture", None), + ], +) +def test_build_in_dataset_with_split_validation(dataset_name, output_key, tokenizer): # load the dataset data_config = { - "dataset_name": "squad", - "prompt_file": None, - "system_prompt_file": None, + "dataset_name": dataset_name, + "output_key": output_key, + "split_validation_size": 0.05, } - squad_dataset = load_response_dataset(data_config) - - # load the tokenizer - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + dataset = load_response_dataset(data_config) - # check that the dataset is formatted correctly - for example in squad_dataset.formatted_ds["train"].take(5): - assert "messages" in example - assert len(example["messages"]) == 3 + # check the first example + first_example = dataset.dataset[0] + first_val_example = dataset.val_dataset[0] + + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example + + # check the content + if dataset_name == "OpenMathInstruct-2": + if output_key == "expected_answer": + assert first_example["messages"][1]["content"] == "\\frac{8\\sqrt{3}}{3}" + elif output_key == "generated_solution": + assert ( + first_example["messages"][1]["content"][:20] == "Let's denote the poi" + ) + elif dataset_name == "tulu3_sft_mixture": + assert first_example["messages"][1]["content"][:20] == "I'm sorry, but I can" + + # check the combined message + messages = [first_example["messages"], first_val_example["messages"]] + chat_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" + combined_message = tokenizer.apply_chat_template( + messages, + chat_template=chat_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) - assert example["messages"][0]["role"] == "system" - assert example["messages"][1]["role"] == "user" - assert example["messages"][2]["role"] == "assistant" + for i in range(2): + assert combined_message[i] == ( + " Question: " + + messages[i][0]["content"] + + " Answer: " + + messages[i][1]["content"] + ) - template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" - ## check that applying chat template works as expected - default_templated = tokenizer.apply_chat_template( - example["messages"], - chat_template=template, - tokenize=False, - add_generation_prompt=False, - add_special_tokens=False, - ) +@pytest.mark.parametrize( + "dataset_name,format_func", + [ + ("clevr-cogent", format_clevr_cogent_dataset), + ("geometry3k", format_geometry3k_dataset), + # ("refcoco", format_refcoco_dataset), # this needs download 13.5G image + ], +) +def test_vlm_dataset(dataset_name, format_func): + # load the dataset + data_config = {"dataset_name": dataset_name} + dataset = load_response_dataset(data_config) - assert default_templated == ( - "Context: " - + example["messages"][0]["content"] - + " Question: " - + example["messages"][1]["content"] - + " Answer: " - + example["messages"][2]["content"] - ) + # check the first example + first_example = dataset.dataset[0] + first_example = format_func(first_example) + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example -def test_load_dataset_saved_with_save_to_disk(): - """Test loading a dataset that was saved using HuggingFace's save_to_disk(). - - This tests the fix for datasets that already have a 'messages' column, - which should be preserved without applying add_messages_key again. - """ - from datasets import Dataset - - # Create a dataset with 'messages' column already present - train_data = [ - { - "messages": [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - ] - }, - { - "messages": [ - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "Paris"}, - ] - }, - ] - val_data = [ - { - "messages": [ - {"role": "user", "content": "What is 3+3?"}, - {"role": "assistant", "content": "6"}, - ] - }, - ] + # check the content + assert first_example["messages"][0]["role"] == "user" + assert first_example["messages"][0]["content"][0]["type"] == "image" + assert first_example["messages"][0]["content"][1]["type"] == "text" + assert first_example["messages"][1]["role"] == "assistant" - with tempfile.TemporaryDirectory() as tmpdir: - # Create HF datasets and save using save_to_disk - train_dataset = Dataset.from_list(train_data) - val_dataset = Dataset.from_list(val_data) - - train_path = f"{tmpdir}/train" - val_path = f"{tmpdir}/val" - - train_dataset.save_to_disk(train_path) - val_dataset.save_to_disk(val_path) - - # Load using load_response_dataset - data_config = { - "dataset_name": "ResponseDataset", - "train_data_path": train_path, - "val_data_path": val_path, - } - dataset = load_response_dataset(data_config) - - # Verify the dataset loaded correctly - assert "train" in dataset.formatted_ds - assert "validation" in dataset.formatted_ds - assert len(dataset.formatted_ds["train"]) == 2 - assert len(dataset.formatted_ds["validation"]) == 1 - - # Verify messages are preserved correctly - first_train_example = dataset.formatted_ds["train"][0] - assert "messages" in first_train_example - assert len(first_train_example["messages"]) == 2 - assert first_train_example["messages"][0]["role"] == "user" - assert first_train_example["messages"][0]["content"] == "What is 2+2?" - assert first_train_example["messages"][1]["role"] == "assistant" - assert first_train_example["messages"][1]["content"] == "4" - - # Verify validation data - first_val_example = dataset.formatted_ds["validation"][0] - assert first_val_example["messages"][0]["content"] == "What is 3+3?" - assert first_val_example["messages"][1]["content"] == "6" + if dataset_name == "clevr-cogent": + assert first_example["messages"][1]["content"] == "3" + elif dataset_name == "geometry3k": + assert first_example["messages"][1]["content"] == "3" + elif dataset_name == "refcoco": + assert first_example["messages"][1]["content"] == "[243, 469, 558, 746]" diff --git a/tests/unit/data/test_data_processor.py b/tests/unit/data/test_data_processor.py index 7e2fa903f8..343bbe30bb 100644 --- a/tests/unit/data/test_data_processor.py +++ b/tests/unit/data/test_data_processor.py @@ -146,7 +146,7 @@ def test_math_hf_data_processor(tokenizer_name, dataset_cls): task_data_processors[task_name] = (math_task_spec, math_hf_data_processor) dataset = AllTaskProcessedDataset( - dataset=data.formatted_ds["train"], + dataset=data.dataset, tokenizer=tokenizer, default_task_data_spec=math_task_spec, task_data_processors=task_data_processors, diff --git a/tests/unit/data/test_data_shuffle_reproducity.py b/tests/unit/data/test_data_shuffle_reproducity.py index a918648dc6..4074e0d0fa 100644 --- a/tests/unit/data/test_data_shuffle_reproducity.py +++ b/tests/unit/data/test_data_shuffle_reproducity.py @@ -63,7 +63,7 @@ def create_dataloader( task_data_processors[task_name] = (math_task_spec, math_hf_data_processor) dataset = AllTaskProcessedDataset( - dataset=data.formatted_ds["train"].select(range(1000)), + dataset=data.dataset.select(range(1000)), tokenizer=tokenizer, default_task_data_spec=math_task_spec, task_data_processors=task_data_processors,