Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions tests/datasets/test_qwen3_vl_tokenize_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
class TestMLLMTokenizeFn(TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_PATH)
self.tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH, rand_video_max_frames=14).build(
self.tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH,
rand_video_max_frames=14,
add_vision_id=False).build(
self.tokenizer)
self.processor = AutoProcessor.from_pretrained(QWEN3_VL_PATH)

def test_qwen3_vl_sft_single_image(self):
@parametrize.parametrize("add_vision_id", [(True,), (False,)])
def test_qwen3_vl_sft_single_image(self, add_vision_id):
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH,
add_vision_id=add_vision_id).build(self.tokenizer)
data_path = 'tests/resource/mllm_sft_single_image_example_data.jsonl'
total_step = 5
with open(data_path) as f:
Expand All @@ -27,7 +32,7 @@ def test_qwen3_vl_sft_single_image(self):
break
raw_data = json.loads(line)

ret = self.tokenize_fn(raw_data, media_root='tests/')
ret = tokenize_fn(raw_data, media_root='tests/')
input_ids_xtuner = ret['input_ids']
pixel_values_xtuner: torch.Tensor = ret['pixel_values']
image_grid_thw_xtuner: torch.Tensor = ret['image_grid_thw']
Expand All @@ -44,7 +49,10 @@ def test_qwen3_vl_sft_single_image(self):
if not isinstance(msg['content'], list):
msg['content'] = [{"type": "text", "text": msg['content']}]

ret = self.processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=True,
ret = self.processor.apply_chat_template(messages,
add_generation_prompt=False,
tokenize=True,
add_vision_id=add_vision_id,
return_dict=True)
input_ids_hf = ret['input_ids'][0]
pixel_values_hf = ret['pixel_values']
Expand All @@ -58,16 +66,20 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id):
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH,
add_vision_id=add_vision_id).build(self.tokenizer)
data_path = 'tests/resource/mllm_sft_multi_image_example_data.jsonl'
total_step = 5
total_index = [0, 1, 2, 3, 4, 10]
with open(data_path) as f:
for i, line in enumerate(f):
if i >= total_step:
break
if i not in total_index:
continue
raw_data = json.loads(line)

# \n 必须去掉,否则和 hf 无法对齐
messages = raw_data['messages']
messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('\n', '')
if i != 10:
messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('\n', '')
else:
messages[0]['content'][1]['text'] = messages[0]['content'][1]['text'].replace('\n', '')
messages[4]['content'][1]['text'] = messages[4]['content'][1]['text'].replace('\n', '')

ret = tokenize_fn(raw_data, media_root='tests/')
input_ids_xtuner = ret['input_ids']
Expand All @@ -76,13 +88,25 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id):

# to hf openai format
messages = raw_data['messages']
messages[0]['content'][0]['type'] = 'image'
messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url']
messages[0]['content'][1]['type'] = 'image'
messages[0]['content'][1]['path'] = 'tests/' + messages[0]['content'][1]['image_url']['url']
del messages[0]['content'][0]['image_url']
del messages[0]['content'][1]['image_url']
messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('<IMG_CONTEXT>', '')
if i != 10:
messages[0]['content'][0]['type'] = 'image'
messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url']
messages[0]['content'][1]['type'] = 'image'
messages[0]['content'][1]['path'] = 'tests/' + messages[0]['content'][1]['image_url']['url']
del messages[0]['content'][0]['image_url']
del messages[0]['content'][1]['image_url']
messages[0]['content'][2]['text'] = messages[0]['content'][2]['text'].replace('<IMG_CONTEXT>', '')
else:
messages[0]['content'][0]['type'] = 'image'
messages[0]['content'][0]['path'] = 'tests/' + messages[0]['content'][0]['image_url']['url']
del messages[0]['content'][0]['image_url']
messages[0]['content'][1]['text'] = messages[0]['content'][1]['text'].replace('<IMG_CONTEXT>', '')

messages[4]['content'][0]['type'] = 'image'
messages[4]['content'][0]['path'] = 'tests/' + messages[4]['content'][0]['image_url']['url']
del messages[4]['content'][0]['image_url']
messages[4]['content'][1]['text'] = messages[4]['content'][1]['text'].replace('<IMG_CONTEXT>', '')

for msg in messages:
if not isinstance(msg['content'], list):
msg['content'] = [{"type": "text", "text": msg['content']}]
Expand All @@ -92,6 +116,7 @@ def test_qwen3_vl_sft_multi_image(self, add_vision_id):
input_ids_hf = ret['input_ids'][0]
pixel_values_hf = ret['pixel_values']
image_grid_thw_hf = ret['image_grid_thw']

self.assertEqual(input_ids_xtuner, input_ids_hf)
self.assertTrue(torch.allclose(pixel_values_xtuner, pixel_values_hf))
self.assertTrue(torch.allclose(image_grid_thw_xtuner, image_grid_thw_hf))
Expand Down Expand Up @@ -144,38 +169,40 @@ def test_calc_frame_info(self):
# case: 如果存在 origin_fps ,则会基于 origin_fps 计算 timestamps
self.assertEqual(num_frames_list, [20, 4])
self.assertEqual(origin_fps_list, [10, 8])
self.assertEqual(timestamps_list, [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65],
[0.25, 1.125]])
self.assertEqual(timestamps_list,
[[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65],
[0.25, 1.125]])
elif i == 2:
# case: 测试 origin_fps 为 1 且长度小于 4 时是否正常
self.assertEqual(num_frames_list, [20, 4])
self.assertEqual(origin_fps_list, [10, 1])
self.assertEqual(timestamps_list, [[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65],
[0.0, 0.0]])
self.assertEqual(timestamps_list,
[[0.25, 1.3, 2.35, 3.35, 4.45, 5.45, 6.55, 7.55, 8.600000000000001, 9.65],
[0.0, 0.0]])
elif i == 3:
# case: 测试存在 processed_fps 且一个能被 fps 整除,一个不能且视频长度大于 rand_video_max_frames
self.assertEqual(num_frames_list, [10, 14])
self.assertEqual(origin_fps_list, [20, 10])
self.assertEqual(timestamps_list, [[0.25, 1.35, 2.45, 3.55, 4.65],
[0.3, 1.3, 2.4000000000000004, 3.5, 4.6, 5.7, 6.7]])
[0.3, 1.3, 2.4000000000000004, 3.5, 4.6, 5.7, 6.7]])
elif i == 4:
# case: 测试存在 processed_fps 且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames
self.assertEqual(num_frames_list, [10, 12])
self.assertEqual(origin_fps_list, [20, 10])
self.assertEqual(timestamps_list, [[0.25, 1.35, 2.45, 3.55, 4.65],
[0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]])
[0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]])
elif i == 5:
# case: 测试存在 frames_timestamp,且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames
self.assertEqual(num_frames_list, [4, 14])
self.assertEqual(origin_fps_list, [20, 10])
self.assertEqual(timestamps_list, [[0.25, 1.5],
[0.1, 0.5, 1.1, 1.5, 1.9, 2.5, 2.9]])
[0.1, 0.5, 1.1, 1.5, 1.9, 2.5, 2.9]])
elif i == 6:
# case: 测试存在 frames_timestamp,且一个能被 fps 整除,一个不能且视频长度小于 rand_video_max_frames
self.assertEqual(num_frames_list, [4, 12])
self.assertEqual(origin_fps_list, [20, 10])
self.assertEqual(timestamps_list, [[0.25, 1.5],
[0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]])
[0.1, 0.5, 0.9, 1.2999999999999998, 1.7000000000000002, 2.1]])
elif i == 7:
# case: 测试单视频
self.assertEqual(num_frames_list, [4])
Expand All @@ -194,7 +221,7 @@ def test_qwen3_vl_sft_video(self, add_vision_id):
for line in f:
hf_raw_datas.append(json.loads(line))

total_index = [1,4,5,6,7,8,9]
total_index = [1, 4, 5, 6, 7, 8, 9]
with open(data_path) as f:
for i, line in enumerate(f):
if i not in total_index:
Expand All @@ -221,16 +248,10 @@ def test_qwen3_vl_sft_video(self, add_vision_id):
messages = hf_raw_data['messages']
add_video_root(messages, VIDEO_ROOT)

# 如果只有1个视频,则 add_vision_id 不生效
if len(tokenize_fn._video_path) <= 1:
add_vision_id_ = False
else:
add_vision_id_ = add_vision_id

if i not in [8, 9]:
ret = self.processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=True,
do_sample_frames=do_sample_frames,
return_dict=True, add_vision_id=add_vision_id_,
return_dict=True, add_vision_id=add_vision_id,
return_tensors="pt")
input_ids_hf = ret['input_ids'][0]
pixel_values_hf = ret['pixel_values_videos']
Expand Down Expand Up @@ -277,7 +298,10 @@ def test_qwen3_vl_pretrain_pure_text(self):
input_ids_hf = self.tokenizer(content)['input_ids']
self.assertEqual(input_ids_xtuner, input_ids_hf)

def test_qwen3_vl_pretrain_image(self):
@parametrize.parametrize("add_vision_id", [(True,), (False,)])
def test_qwen3_vl_pretrain_image(self, add_vision_id):
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_PATH,
add_vision_id=add_vision_id).build(self.tokenizer)
data_path = 'tests/resource/mllm_pretrain_image_example_data.jsonl'
total_step = 6
with open(data_path, encoding='utf-8') as f:
Expand All @@ -286,18 +310,18 @@ def test_qwen3_vl_pretrain_image(self):
break
raw_data = json.loads(line)

ret = self.tokenize_fn(raw_data, media_root='tests/')
ret = tokenize_fn(raw_data, media_root='tests/')
input_ids_xtuner = ret['input_ids']
labels_xtuner = torch.tensor(ret['labels'])
input_str = self.tokenize_fn.tokenizer.decode(input_ids_xtuner, skip_special_tokens=False)
input_str = tokenize_fn.tokenizer.decode(input_ids_xtuner, skip_special_tokens=False)
input_str = input_str.replace('<|image_pad|>', '')
input_xtuner_str = input_str.replace('<|vision_start|><|vision_end|>', '<IMG_CONTEXT>')
ground_truth_content = raw_data['messages'][0]
for item in ground_truth_content['content']:
if item['type'] == 'text':
ground_truth_str = item['text'] + "<|im_end|>"
image_cnt = ground_truth_str.count('<IMG_CONTEXT>')
if image_cnt > 1:
if add_vision_id:
for i in range(image_cnt):
ground_truth_str = ground_truth_str.replace('<IMG_CONTEXT>',
f'Picture {i + 1}: <IMG_CONTEXT_1>', 1)
Expand Down
3 changes: 2 additions & 1 deletion tests/resource/mllm_sft_multi_image_example_data.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
{"id": 7, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\nWhat are the similarities between the two images?"}]}, {"role": "assistant", "content": "Both images contain animals."}, {"role": "user", "content": "What animals are there?"}, {"role": "assistant", "content": "The first image contains two cats, and the second image contains a dog."}]}
{"id": 8, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\nCan you describe the color of the dog in the second image?"}]}, {"role": "assistant", "content": "The dog in the image is brown."}]}
{"id": 9, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\nHow many cats are in the first image?"}]}, {"role": "assistant", "content": "There are 2 cats in the image."}, {"role": "user", "content": "What else is in the first image?"}, {"role": "assistant", "content": "There are also 2 TV remotes in the first image."}, {"role": "user", "content": "What are the two cats doing?"}, {"role": "assistant", "content": "They are leisurely lying on the sofa."}, {"role": "user", "content": "Can you describe the first image?"}, {"role": "assistant", "content": "The image shows two cats leisurely lying on the sofa, with 2 TV remotes next to them."}]}
{"id": 10, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\nCan you describe the type of collar the dog in the second image has?"}]}, {"role": "assistant", "content": "The dog has a red collar."}]}
{"id": 10, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}}, {"type": "text", "text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\nCan you describe the type of collar the dog in the second image has?"}]}, {"role": "assistant", "content": "The dog has a red collar."}]}
{"id": 11, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "text", "text": "<IMG_CONTEXT>\nHow many cats are in the first image?"}]}, {"role": "assistant", "content": "There are 2 cats in the image."}, {"role": "user", "content": "What else is in the first image?"}, {"role": "assistant", "content": "There are also 2 TV remotes in the first image."}, {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}},{"type": "text", "text": "<IMG_CONTEXT>\nWhat are the two cats doing?"}]}, {"role": "assistant", "content": "They are leisurely lying on the sofa."}, {"role": "user", "content": "Can you describe the first image?"}, {"role": "assistant", "content": "The image shows two cats leisurely lying on the sofa, with 2 TV remotes next to them."}]}
7 changes: 4 additions & 3 deletions xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def replace_image_token(
add_vision_id: bool = False,
):
current_image_idx = 0
total_img_cnt = 0
for msg in messages.messages:
if msg.role == "pretrain":
assert len(messages.messages) == 1, "pretrain message should only have one message"
Expand All @@ -104,9 +105,9 @@ def replace_image_token(
image_cnt = text.count(IMAGE_TOKEN_ALIAS)
for i in range(image_cnt):
image_tokens = f"{chat_template.image_start_token}{chat_template.image_context_token * num_image_token_list[current_image_idx]}{chat_template.image_end_token}" # type: ignore
if add_vision_id and image_cnt > 1:
# add vision id for each image when there are multiple images
image_tokens = f"Picture {i + 1}: " + image_tokens
if add_vision_id:
image_tokens = f"Picture {total_img_cnt + 1}: " + image_tokens
total_img_cnt += 1
text = text.replace(IMAGE_TOKEN_ALIAS, image_tokens, 1)
current_image_idx += 1
c.text = text
Expand Down
Loading