diff --git a/configs/_base_/datasets/mmlu_fs.py b/configs/_base_/datasets/mmlu_fs.py new file mode 100644 index 000000000..2179deb96 --- /dev/null +++ b/configs/_base_/datasets/mmlu_fs.py @@ -0,0 +1,39 @@ +from datasets import load_dataset +from mmchat.datasets import process_hf_dataset +from mmengine.dataset import DefaultSampler + + +data_root = 'data/mmlu/' + +mmlu_fs_dataset = dict( + type=load_dataset, + path='json', + data_files=dict( + val=data_root + 'five_shot_mmlu_val.json', + test=data_root + 'five_shot_mmlu_test.json')) + +val_mmlu_fs = dict( + type=process_hf_dataset, + dataset=mmlu_fs_dataset, + mode='val') +val_dataloader = dict( + batch_size=1, + num_workers=1, + dataset=val_mmlu_fs, + sampler=dict(type=DefaultSampler, shuffle=False)) + +test_mmlu_fs = dict( + type=process_hf_dataset, + dataset=mmlu_fs_dataset, + mode='test') +test_dataloader = dict( + batch_size=1, + num_workers=1, + dataset=test_mmlu_fs, + sampler=dict(type=DefaultSampler, shuffle=False)) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_val') +test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_test') diff --git a/configs/_base_/datasets/mmlu_zs.py b/configs/_base_/datasets/mmlu_zs.py new file mode 100644 index 000000000..5be62de31 --- /dev/null +++ b/configs/_base_/datasets/mmlu_zs.py @@ -0,0 +1,39 @@ +from datasets import load_dataset +from mmchat.datasets import process_hf_dataset +from mmengine.dataset import DefaultSampler + + +data_root = 'data/mmlu/' + +mmlu_zs_dataset = dict( + type=load_dataset, + path='json', + data_files=dict( + val=data_root + 'zero_shot_mmlu_val.json', + test=data_root + 'zero_shot_mmlu_test.json')) + +val_mmlu_zs = dict( + type=process_hf_dataset, + dataset=mmlu_zs_dataset, + mode='val') +val_dataloader = dict( + batch_size=1, + num_workers=1, + dataset=val_mmlu_zs, + sampler=dict(type=DefaultSampler, shuffle=False)) + +test_mmlu_zs = dict( + type=process_hf_dataset, + dataset=mmlu_zs_dataset, + mode='test') +test_dataloader = dict( + batch_size=1, + num_workers=1, + dataset=test_mmlu_zs, + sampler=dict(type=DefaultSampler, shuffle=False)) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_val') +test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_test') diff --git a/configs/guanaco/gunaco_llama_7B.py b/configs/guanaco/gunaco_llama_7B.py index 1719eb95e..f8d299b29 100644 --- a/configs/guanaco/gunaco_llama_7B.py +++ b/configs/guanaco/gunaco_llama_7B.py @@ -7,6 +7,7 @@ import torch with read_base(): from .._base_.datasets.oasst1 import * + from .._base_.datasets.mmlu_fs import * from .._base_.schedules.guanaco import * from .._base_.default_runtime import * @@ -20,7 +21,7 @@ use_fast = False, padding_side="right", ), - source_max_len = 16, + source_max_len = 2048, target_max_len = 512, train_on_source = False, predict_with_generate = False, @@ -28,13 +29,13 @@ llm = dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path = '/nvme/share_data/llama-7b', - torch_dtype = torch.float32, + torch_dtype = torch.float16, quantization_config=dict( type = BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.float32, + bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type = 'nf4' ) @@ -50,3 +51,14 @@ ) +val_evaluator['tokenizer'] = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path='/nvme/share_data/llama-7b', + use_fast=False, + padding_side="right") + +test_evaluator['tokenizer'] = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path='/nvme/share_data/llama-7b', + use_fast=False, + padding_side="right") diff --git a/mmchat/datasets/huggingface.py b/mmchat/datasets/huggingface.py index 4a3466855..b7041b178 100644 --- a/mmchat/datasets/huggingface.py +++ b/mmchat/datasets/huggingface.py @@ -33,9 +33,10 @@ def _prompt_format(example): dataset = dataset.rename_column(old, new) # Remove unused columns. - dataset = dataset.remove_columns( - [col for col in dataset.column_names['train'] if col not in ['input', 'output']] - ) + if 'train' in dataset.column_names: + dataset = dataset.remove_columns( + [col for col in dataset.column_names['train'] if col not in ['input', 'output']] + ) return dataset[mode] diff --git a/mmchat/evaluation/__init__.py b/mmchat/evaluation/__init__.py index e69de29bb..1761d1aa7 100644 --- a/mmchat/evaluation/__init__.py +++ b/mmchat/evaluation/__init__.py @@ -0,0 +1 @@ +from .metrics import * diff --git a/mmchat/evaluation/metrics/__init__.py b/mmchat/evaluation/metrics/__init__.py new file mode 100644 index 000000000..6ba6ec97b --- /dev/null +++ b/mmchat/evaluation/metrics/__init__.py @@ -0,0 +1,3 @@ +from .mmlu_metric import MMLUMetric + +__all__ = ['MMLUMetric'] diff --git a/mmchat/evaluation/metrics/mmlu_metric.py b/mmchat/evaluation/metrics/mmlu_metric.py new file mode 100644 index 000000000..56092e491 --- /dev/null +++ b/mmchat/evaluation/metrics/mmlu_metric.py @@ -0,0 +1,202 @@ +from typing import Any, List, Optional, Sequence, Union +from rich.console import Console +from rich.table import Table + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmchat.registry import METRICS, TOKENIZER + + +@METRICS.register_module() +class MMLUMetric(BaseMetric): + METAINFO = { + 'subcategories': { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], + }, + 'categories': { + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], + "humanities": ["history", "philosophy", "law"], + "social sciences": ["politics", "culture", "economics", "geography", "psychology"], + "other (business, health, misc.)": ["other", "business", "health"], + }, + } + METAINFO['subcategories_list'] = list(set([subcat for subcats in METAINFO['subcategories'].values() + for subcat in subcats])) + + def __init__(self, tokenizer, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logger: MMLogger = MMLogger.get_current_instance() + tokenizer = TOKENIZER.build(tokenizer) + self.abcd_idx = [ + tokenizer("A", add_special_tokens=False).input_ids[0], + tokenizer("B", add_special_tokens=False).input_ids[0], + tokenizer("C", add_special_tokens=False).input_ids[0], + tokenizer("D", add_special_tokens=False).input_ids[0], + ] + + @staticmethod + def ABCD_to_0123(abcd): + return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd] + + @staticmethod + def accuracy(preds, gts): + """Computes the accuracy for preds and gts""" + correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)] + acc = np.mean(correct) * 100 + return acc + + def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Any): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + subjects = data_batch['subject'] + gts = [self.ABCD_to_0123(gt) for gt in data_batch['output']] + preds = [] + for sample, subject, gt in zip(data_samples, subjects, gts): + pred_logits = sample['logits'] + labels = sample['labels'] + labels_non_zero_id = (labels != -100).nonzero()[0][0] + pred_logtis_abcd = pred_logits[labels_non_zero_id-1, self.abcd_idx] + pred = torch.argmax(pred_logtis_abcd).item() + preds.append(pred) + self.results.append((subject, pred, gt)) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + subjects_results = {subject: {'preds': [], 'gts': []} for subject in self.METAINFO['subcategories'].keys()} + subcats_results = {subcat: {'preds': [], 'gts': []} for subcat in self.METAINFO['subcategories_list']} + cats_results = {cat: {'preds': [], 'gts': []} for cat in self.METAINFO['categories'].keys()} + for subject, pred, gt in results: + subjects_results[subject]['preds'].append(pred) + subjects_results[subject]['gts'].append(gt) + subcats = self.METAINFO['subcategories'][subject] + for subcat in subcats: + subcats_results[subcat]['preds'].append(pred) + subcats_results[subcat]['gts'].append(gt) + for cat, subcats in self.METAINFO['categories'].items(): + for subcat in subcats: + if subcat in subcats_results: + cats_results[cat]['preds'].extend(subcats_results[subcat]['preds']) + cats_results[cat]['gts'].extend(subcats_results[subcat]['gts']) + + subjects_metrics = dict() + subcats_metrics = dict() + cats_metrics = dict() + for subject in self.METAINFO['subcategories'].keys(): + assert len(subjects_results[subject]['preds']) == len(subjects_results[subject]['gts']) + if len(subjects_results[subject]['preds']) == 0: + self.logger.info(f'Skip subject {subject} for mmlu') + else: + score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts']) + subjects_metrics[f'{subject}'] = score + for subcat in self.METAINFO['subcategories_list']: + assert len(subcats_results[subcat]['preds']) == len(subcats_results[subcat]['gts']) + if len(subcats_results[subcat]['preds']) == 0: + self.logger.info(f'Skip subcategory {subcat} for mmlu') + else: + score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts']) + subcats_metrics[f'{subcat}'] = score + for cat in self.METAINFO['categories'].keys(): + assert len(cats_results[cat]['preds']) == len(cats_results[cat]['gts']) + if len(cats_results[cat]['preds']) == 0: + self.logger.info(f'Skip category {cat} for mmlu') + else: + score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts']) + cats_metrics[f'{cat}'] = score + + metrics = dict() + metrics.update(subjects_metrics) + metrics.update(subcats_metrics) + metrics.update(cats_metrics) + metrics['average'] = np.mean(list(subjects_metrics.values())) + + table_metrics = dict() + table_metrics.update(cats_metrics) + table_metrics['average'] = np.mean(list(subjects_metrics.values())) + self._print_results(table_metrics) + return metrics + + def _print_results(self, table_metrics: dict) -> None: + table_title = ' MMLU Benchmark ' + table = Table(title=table_title) + console = Console() + table.add_column('Categories', justify='left') + table.add_column('Accuracy (%)', justify='right') + for cat, acc in table_metrics.items(): + table.add_row(cat, '{:.1f}'.format(acc)) + with console.capture() as capture: + console.print(table, end='') + self.logger.info('\n' + capture.get()) diff --git a/mmchat/evaluation/mmlu.py b/mmchat/evaluation/mmlu.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/mmchat/models/algorithms/sft.py b/mmchat/models/algorithms/sft.py index 202b7b67d..0a2bcf458 100644 --- a/mmchat/models/algorithms/sft.py +++ b/mmchat/models/algorithms/sft.py @@ -57,11 +57,12 @@ def __init__(self, llm, data_preprocessor): self.llm = self._build_from_cfg_or_module(llm, LLM) self.llm.config.use_cache = False self.llm.config.torch_dtype = torch.float32 - smart_tokenizer_and_embedding_resize( - special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), - tokenizer=self.tokenizer, - model=self.llm, - ) + if self.tokenizer._pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=self.tokenizer, + model=self.llm, + ) from transformers.models.llama import LlamaTokenizer if isinstance(self.tokenizer, LlamaTokenizer): @@ -110,14 +111,12 @@ def _forward(self, data, data_samples=None): return outputs def predict(self, data, data_samples=None): - outputs = self.llm(**data) - - return outputs - + logits_dict = [{'labels': labels, 'logits': logits} \ + for labels, logits in zip(data['labels'], outputs.logits)] + return logits_dict def compute_loss(self, data, data_samples=None): - outputs = self.llm(**data) # import pdb;pdb.set_trace() loss_dict = {'loss_llm': outputs.loss} diff --git a/mmchat/models/utils/data_processor.py b/mmchat/models/utils/data_processor.py index f305c9d7e..674e6e7cd 100644 --- a/mmchat/models/utils/data_processor.py +++ b/mmchat/models/utils/data_processor.py @@ -70,5 +70,5 @@ def forward(self,instances: Sequence[Dict], training=True) -> Dict[str, torch.Te if labels is not None: data_dict['labels'] = labels - return {'data': data_dict, 'data_samples': None} + return self.cast_data({'data': data_dict, 'data_samples': None}) diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 000000000..1a4807f3f --- /dev/null +++ b/tools/test.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import warnings +from copy import deepcopy + +from mmengine import ConfigDict +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmchat.registry import RUNNERS + + +# TODO: support fuse_conv_bn and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMChat test a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main()