diff --git a/tests/datasets/test_qwen3_vl_tokenize_fn.py b/tests/datasets/test_qwen3_vl_tokenize_fn.py index 6d9ca7a0e..b8127c3bd 100644 --- a/tests/datasets/test_qwen3_vl_tokenize_fn.py +++ b/tests/datasets/test_qwen3_vl_tokenize_fn.py @@ -14,11 +14,16 @@ class TestMLLMTokenizeFn(TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_PATH) - self.tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, rand_video_max_frames=14).build( + self.tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, + rand_video_max_frames=14, + add_vision_id=False).build( self.tokenizer) self.processor = AutoProcessor.from_pretrained(QWEN3_VL_PATH) - def test_qwen3_vl_sft_single_image(self): + @parametrize.parametrize("add_vision_id", [(True,), (False,)]) + def test_qwen3_vl_sft_single_image(self, add_vision_id): + tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, + add_vision_id=add_vision_id).build(self.tokenizer) data_path = 'tests/resource/mllm_sft_single_image_example_data.jsonl' total_step = 5 with open(data_path) as f: @@ -27,7 +32,7 @@ def test_qwen3_vl_sft_single_image(self): break raw_data = json.loads(line) - ret = self.tokenize_fn(raw_data, media_root='tests/') + ret = tokenize_fn(raw_data, media_root='tests/') input_ids_xtuner = ret['input_ids'] pixel_values_xtuner: torch.Tensor = ret['pixel_values'] image_grid_thw_xtuner: torch.Tensor = ret['image_grid_thw'] @@ -44,7 +49,10 @@ def test_qwen3_vl_sft_single_image(self): if not isinstance(msg['content'], list): msg['content'] = [{"type": "text", "text": msg['content']}] - ret = self.processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=True, + ret = self.processor.apply_chat_template(messages, + add_generation_prompt=False, + tokenize=True, + add_vision_id=add_vision_id, return_dict=True) input_ids_hf = ret['input_ids'][0] pixel_values_hf = ret['pixel_values'] @@ -58,16 +66,20 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id): tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, add_vision_id=add_vision_id).build(self.tokenizer) data_path = 'tests/resource/mllm_sft_multi_image_example_data.jsonl' - total_step = 5 + total_index = [0, 1, 2, 3, 4, 10] with open(data_path) as f: for i, line in enumerate(f): - if i >= total_step: - break + if i not in total_index: + continue raw_data = json.loads(line) # \n 必须去掉,否则和 hf 无法对齐 messages = raw_data['messages'] - messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('\n', '') + if i != 10: + messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('\n', '') + else: + messages[0]['content'][1]['text'] = messages[0]['content'][1]['text'].replace('\n', '') + messages[4]['content'][1]['text'] = messages[4]['content'][1]['text'].replace('\n', '') ret = tokenize_fn(raw_data, media_root='tests/') input_ids_xtuner = ret['input_ids'] @@ -76,13 +88,25 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id): # to hf openai format messages = raw_data['messages'] - messages[0]['content'][0]['type'] = 'image' - messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url'] - messages[0]['content'][1]['type'] = 'image' - messages[0]['content'][1]['path'] = 'tests/' + messages[0]['content'][1]['image_url']['url'] - del messages[0]['content'][0]['image_url'] - del messages[0]['content'][1]['image_url'] - messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('', '') + if i != 10: + messages[0]['content'][0]['type'] = 'image' + messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url'] + messages[0]['content'][1]['type'] = 'image' + messages[0]['content'][1]['path'] = 'tests/' + messages[0]['content'][1]['image_url']['url'] + del messages[0]['content'][0]['image_url'] + del messages[0]['content'][1]['image_url'] + messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('', '') + else: + messages[0]['content'][0]['type'] = 'image' + messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url'] + del messages[0]['content'][0]['image_url'] + messages[0]['content'][1]['text'] = messages[0]['content'][1]['text'].replace('', '') + + messages[4]['content'][0]['type'] = 'image' + messages[4]['content'][0]['path'] = 'tests/' + messages[4]['content'][0]['image_url']['url'] + del messages[4]['content'][0]['image_url'] + messages[4]['content'][1]['text'] = messages[4]['content'][1]['text'].replace('', '') + for msg in messages: if not isinstance(msg['content'], list): msg['content'] = [{"type": "text", "text": msg['content']}] @@ -92,6 +116,7 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id): input_ids_hf = ret['input_ids'][0] pixel_values_hf = ret['pixel_values'] image_grid_thw_hf = ret['image_grid_thw'] + self.assertEqual(input_ids_xtuner, input_ids_hf) self.assertTrue(torch.allclose(pixel_values_xtuner, pixel_values_hf)) self.assertTrue(torch.allclose(image_grid_thw_xtuner, image_grid_thw_hf)) @@ -144,38 +169,40 @@ def test_calc_frame_info(self): # case: 如果存在 origin_fps ,则会基于 origin_fps 计算 timestamps self.assertEqual(num_frames_list, [20, 4]) self.assertEqual(origin_fps_list, [10, 8]) - self.assertEqual(timestamps_list, [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65], - [0.25, 1.125]]) + self.assertEqual(timestamps_list, + [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65], + [0.25, 1.125]]) elif i == 2: # case: 测试 origin_fps 为 1 且长度小于 4 时是否正常 self.assertEqual(num_frames_list, [20, 4]) self.assertEqual(origin_fps_list, [10, 1]) - self.assertEqual(timestamps_list, [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65], - [0.0, 0.0]]) + self.assertEqual(timestamps_list, + [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65], + [0.0, 0.0]]) elif i == 3: # case: 测试存在 processed_fps 且一个能被 fps 整除,一个不能且视频长度大于 rand_video_max_frames self.assertEqual(num_frames_list, [10, 14]) self.assertEqual(origin_fps_list, [20, 10]) self.assertEqual(timestamps_list, [[0.25, 1.35, 2.45, 3.55, 4.65], - [0.3, 1.3, 2.4000000000000004, 3.5, 4.6, 5.7, 6.7]]) + [0.3, 1.3, 2.4000000000000004, 3.5, 4.6, 5.7, 6.7]]) elif i == 4: # case: 测试存在 processed_fps 且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames self.assertEqual(num_frames_list, [10, 12]) self.assertEqual(origin_fps_list, [20, 10]) self.assertEqual(timestamps_list, [[0.25, 1.35, 2.45, 3.55, 4.65], - [0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]]) + [0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]]) elif i == 5: # case: 测试存在 frames_timestamp,且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames self.assertEqual(num_frames_list, [4, 14]) self.assertEqual(origin_fps_list, [20, 10]) self.assertEqual(timestamps_list, [[0.25, 1.5], - [0.1, 0.5, 1.1, 1.5, 1.9, 2.5, 2.9]]) + [0.1, 0.5, 1.1, 1.5, 1.9, 2.5, 2.9]]) elif i == 6: # case: 测试存在 frames_timestamp,且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames self.assertEqual(num_frames_list, [4, 12]) self.assertEqual(origin_fps_list, [20, 10]) self.assertEqual(timestamps_list, [[0.25, 1.5], - [0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]]) + [0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]]) elif i == 7: # case: 测试单视频 self.assertEqual(num_frames_list, [4]) @@ -194,7 +221,7 @@ def test_qwen3_vl_sft_video(self, add_vision_id): for line in f: hf_raw_datas.append(json.loads(line)) - total_index = [1,4,5,6,7,8,9] + total_index = [1, 4, 5, 6, 7, 8, 9] with open(data_path) as f: for i, line in enumerate(f): if i not in total_index: @@ -221,16 +248,10 @@ def test_qwen3_vl_sft_video(self, add_vision_id): messages = hf_raw_data['messages'] add_video_root(messages, VIDEO_ROOT) - # 如果只有1个视频,则 add_vision_id 不生效 - if len(tokenize_fn._video_path) <= 1: - add_vision_id_ = False - else: - add_vision_id_ = add_vision_id - if i not in [8, 9]: ret = self.processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=True, do_sample_frames=do_sample_frames, - return_dict=True, add_vision_id=add_vision_id_, + return_dict=True, add_vision_id=add_vision_id, return_tensors="pt") input_ids_hf = ret['input_ids'][0] pixel_values_hf = ret['pixel_values_videos'] @@ -277,7 +298,10 @@ def test_qwen3_vl_pretrain_pure_text(self): input_ids_hf = self.tokenizer(content)['input_ids'] self.assertEqual(input_ids_xtuner, input_ids_hf) - def test_qwen3_vl_pretrain_image(self): + @parametrize.parametrize("add_vision_id", [(True,), (False,)]) + def test_qwen3_vl_pretrain_image(self, add_vision_id): + tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, + add_vision_id=add_vision_id).build(self.tokenizer) data_path = 'tests/resource/mllm_pretrain_image_example_data.jsonl' total_step = 6 with open(data_path, encoding='utf-8') as f: @@ -286,10 +310,10 @@ def test_qwen3_vl_pretrain_image(self): break raw_data = json.loads(line) - ret = self.tokenize_fn(raw_data, media_root='tests/') + ret = tokenize_fn(raw_data, media_root='tests/') input_ids_xtuner = ret['input_ids'] labels_xtuner = torch.tensor(ret['labels']) - input_str = self.tokenize_fn.tokenizer.decode(input_ids_xtuner, skip_special_tokens=False) + input_str = tokenize_fn.tokenizer.decode(input_ids_xtuner, skip_special_tokens=False) input_str = input_str.replace('<|image_pad|>', '') input_xtuner_str = input_str.replace('<|vision_start|><|vision_end|>', '') ground_truth_content = raw_data['messages'][0] @@ -297,7 +321,7 @@ def test_qwen3_vl_pretrain_image(self): if item['type'] == 'text': ground_truth_str = item['text'] + "<|im_end|>" image_cnt = ground_truth_str.count('') - if image_cnt > 1: + if add_vision_id: for i in range(image_cnt): ground_truth_str = ground_truth_str.replace('', f'Picture {i + 1}: ', 1) diff --git a/tests/resource/mllm_sft_multi_image_example_data.jsonl b/tests/resource/mllm_sft_multi_image_example_data.jsonl index fe6e0f0e4..1e8b17c94 100644 --- a/tests/resource/mllm_sft_multi_image_example_data.jsonl +++ b/tests/resource/mllm_sft_multi_image_example_data.jsonl @@ -7,4 +7,5 @@ {"id": 7, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "\n\nWhat are the similarities between the two images?"}]}, {"role": "assistant", "content": "Both images contain animals."}, {"role": "user", "content": "What animals are there?"}, {"role": "assistant", "content": "The first image contains two cats, and the second image contains a dog."}]} {"id": 8, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "\n\nCan you describe the color of the dog in the second image?"}]}, {"role": "assistant", "content": "The dog in the image is brown."}]} {"id": 9, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "\n\nHow many cats are in the first image?"}]}, {"role": "assistant", "content": "There are 2 cats in the image."}, {"role": "user", "content": "What else is in the first image?"}, {"role": "assistant", "content": "There are also 2 TV remotes in the first image."}, {"role": "user", "content": "What are the two cats doing?"}, {"role": "assistant", "content": "They are leisurely lying on the sofa."}, {"role": "user", "content": "Can you describe the first image?"}, {"role": "assistant", "content": "The image shows two cats leisurely lying on the sofa, with 2 TV remotes next to them."}]} -{"id": 10, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "\n\nCan you describe the type of collar the dog in the second image has?"}]}, {"role": "assistant", "content": "The dog has a red collar."}]} \ No newline at end of file +{"id": 10, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "\n\nCan you describe the type of collar the dog in the second image has?"}]}, {"role": "assistant", "content": "The dog has a red collar."}]} +{"id": 11, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "text", "text": "\nHow many cats are in the first image?"}]}, {"role": "assistant", "content": "There are 2 cats in the image."}, {"role": "user", "content": "What else is in the first image?"}, {"role": "assistant", "content": "There are also 2 TV remotes in the first image."}, {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}},{"type": "text", "text": "\nWhat are the two cats doing?"}]}, {"role": "assistant", "content": "They are leisurely lying on the sofa."}, {"role": "user", "content": "Can you describe the first image?"}, {"role": "assistant", "content": "The image shows two cats leisurely lying on the sofa, with 2 TV remotes next to them."}]} \ No newline at end of file diff --git a/xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py b/xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py index 09116515c..f4786f183 100644 --- a/xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py +++ b/xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py @@ -91,6 +91,7 @@ def replace_image_token( add_vision_id: bool = False, ): current_image_idx = 0 + total_img_cnt = 0 for msg in messages.messages: if msg.role == "pretrain": assert len(messages.messages) == 1, "pretrain message should only have one message" @@ -104,9 +105,9 @@ def replace_image_token( image_cnt = text.count(IMAGE_TOKEN_ALIAS) for i in range(image_cnt): image_tokens = f"{chat_template.image_start_token}{chat_template.image_context_token * num_image_token_list[current_image_idx]}{chat_template.image_end_token}" # type: ignore - if add_vision_id and image_cnt > 1: - # add vision id for each image when there are multiple images - image_tokens = f"Picture {i + 1}: " + image_tokens + if add_vision_id: + image_tokens = f"Picture {total_img_cnt + 1}: " + image_tokens + total_img_cnt += 1 text = text.replace(IMAGE_TOKEN_ALIAS, image_tokens, 1) current_image_idx += 1 c.text = text diff --git a/xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py b/xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py index c2b837906..ea909d74e 100644 --- a/xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py +++ b/xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py @@ -135,6 +135,7 @@ def replace_video_token( add_vision_id: bool = False, ): current_image_idx = 0 + total_video_cnt = 0 n_video = len(num_image_token_list) n_image = sum([len(num_image_token_list[i]) for i in range(n_video)]) if len(timestamps_list) > 0: @@ -166,12 +167,13 @@ def replace_video_token( timestamps = f"<{start_time:.1f}-{end_time:.1f} seconds>" text = text.replace(IMAGE_TOKEN_ALIAS, f"{timestamps}", 1) - if add_vision_id and video_cnt > 1: + if add_vision_id: # 标记每个视频 text = text.replace("", IMAGE_TOKEN_ALIAS) for i in range(video_cnt): - video_index = f"Video {i + 1}: " + video_index = f"Video {total_video_cnt + 1}: " text = text.replace(IMAGE_TOKEN_ALIAS, f"{video_index}", 1) + total_video_cnt += 1 text = text.replace("", IMAGE_TOKEN_ALIAS) video_cnt = text.count(IMAGE_TOKEN_ALIAS) @@ -885,8 +887,6 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig): # When handling multiple images or multiple videos, # it's helpful to add labels to the images and videos for better reference. - # 注意这个逻辑和 hf 官方不是完全一致。 hf 官方只要开启这个 flag 就一定追加,不管是单个图片还是单个视频 - # xtuner 中做了优化,开启该 flag 且存在多图或者多视频才会追加 add_vision_id: bool = True def build(