diff --git a/README.md b/README.md index 98d26bb4d3..1a83503bf4 100644 --- a/README.md +++ b/README.md @@ -464,12 +464,11 @@ For multimodal models deployed directly with `NeMoMultimodalDeployable`, use the ```python from nemo_deploy.multimodal import NemoQueryMultimodalPytorch -from PIL import Image nq = NemoQueryMultimodalPytorch(url="localhost:8000", model_name="qwen") output = nq.query_multimodal( prompts=["What is in this image?"], - images=[Image.open("/path/to/image.jpg")], + images=["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"], max_length=100, top_k=1, top_p=0.0, diff --git a/nemo_deploy/multimodal/nemo_multimodal_deployable.py b/nemo_deploy/multimodal/nemo_multimodal_deployable.py index 5bcbb8a05a..e84d1be6b1 100644 --- a/nemo_deploy/multimodal/nemo_multimodal_deployable.py +++ b/nemo_deploy/multimodal/nemo_multimodal_deployable.py @@ -157,8 +157,16 @@ def apply_chat_template(self, messages, add_generation_prompt=True): ) return text - def base64_to_image(self, image_base64): - """Convert base64-encoded image to PIL Image.""" + def process_image_input(self, image_source): + """Process image input from base64-encoded string or HTTP URL. + + Args: + image_source (str): Image source - either base64-encoded image string with data URI prefix + (e.g., "data:image;base64,...") or HTTP/HTTPS URL (e.g., "http://example.com/image.jpg") + + Returns: + Processed image content suitable for model inference. + """ if isinstance(self.inference_wrapped_model, QwenVLInferenceWrapper): from qwen_vl_utils import process_vision_info @@ -166,7 +174,7 @@ def base64_to_image(self, image_base64): { "role": "user", "content": [ - {"type": "image", "image": f"data:image;base64,{image_base64}"}, + {"type": "image", "image": image_source}, ], } ] @@ -259,6 +267,12 @@ def _infer_fn( Returns: dict: sentences. """ + # Handle temperature=0.0 for greedy decoding + if temperature == 0.0: + LOGGER.warning("temperature=0.0 detected. Setting top_k=1 for greedy sampling.") + top_k = 1 + top_p = 0.0 + inference_params = CommonInferenceParams( temperature=float(temperature), top_k=int(top_k), @@ -266,7 +280,7 @@ def _infer_fn( num_tokens_to_generate=num_tokens_to_generate, ) - images = [self.base64_to_image(img_b64) for img_b64 in images] + images = [self.process_image_input(image_source) for image_source in images] results = self.generate( prompts, diff --git a/nemo_deploy/multimodal/query_multimodal.py b/nemo_deploy/multimodal/query_multimodal.py index 17d1e49a87..41590ca0a1 100644 --- a/nemo_deploy/multimodal/query_multimodal.py +++ b/nemo_deploy/multimodal/query_multimodal.py @@ -195,9 +195,16 @@ class NemoQueryMultimodalPytorch: nq = NemoQueryMultimodalPytorch(url="localhost", model_name="qwen") - # Encode image to base64 + # Option 1: Use HTTP URL directly + output = nq.query_multimodal( + prompts=["Describe this image"], + images=["http://example.com/image.jpg"], + max_length=100, + ) + + # Option 2: Encode image to base64 with data URI prefix with open("image.jpg", "rb") as f: - image_base64 = base64.b64encode(f.read()).decode('utf-8') + image_base64 = "data:image;base64," + base64.b64encode(f.read()).decode('utf-8') output = nq.query_multimodal( prompts=["Describe this image"], @@ -231,7 +238,8 @@ def query_multimodal( Args: prompts (List[str]): List of input text prompts. - images (List[str]): List of base64-encoded image strings. + images (List[str]): List of image strings - either base64-encoded with data URI prefix + (e.g., "data:image;base64,...") or HTTP/HTTPS URLs (e.g., "http://example.com/image.jpg"). max_length (Optional[int]): Maximum number of tokens to generate. max_batch_size (Optional[int]): Maximum batch size for inference. top_k (Optional[int]): Limits to the top K tokens to consider at each step. diff --git a/nemo_deploy/service/fastapi_interface_to_pytriton_multimodal.py b/nemo_deploy/service/fastapi_interface_to_pytriton_multimodal.py index 3955753ea3..df854e0e59 100644 --- a/nemo_deploy/service/fastapi_interface_to_pytriton_multimodal.py +++ b/nemo_deploy/service/fastapi_interface_to_pytriton_multimodal.py @@ -19,7 +19,7 @@ import numpy as np import requests from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, model_validator +from pydantic import BaseModel from pydantic_settings import BaseSettings from nemo_deploy.multimodal.query_multimodal import NemoQueryMultimodalPytorch @@ -82,18 +82,10 @@ class BaseMultimodalRequest(BaseModel): max_tokens: int = 50 temperature: float = 1.0 top_p: float = 0.0 - top_k: int = 1 + top_k: int = 0 random_seed: Optional[int] = None max_batch_size: int = 4 - @model_validator(mode="after") - def set_greedy_params(self): - """Validate parameters for greedy decoding.""" - if self.temperature == 0 and self.top_p == 0: - logging.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.") - self.top_k = 1 - return self - class MultimodalCompletionRequest(BaseMultimodalRequest): """Represents a request for multimodal text completion. @@ -290,12 +282,33 @@ def dict_to_str(messages): @app.post("/v1/chat/completions/") async def chat_completions_v1(request: MultimodalChatCompletionRequest): - """Defines the multimodal chat completions endpoint and queries the model deployed on PyTriton server.""" + """Defines the multimodal chat completions endpoint and queries the model deployed on PyTriton server. + + Supports two image content formats (normalized internally to format 1): + 1. {"type": "image", "image": "url_or_base64"} + 2. {"type": "image_url", "image_url": {"url": "url_or_base64"}} (OpenAI-style, converted to format 1) + """ url = f"http://{triton_settings.triton_service_ip}:{triton_settings.triton_service_port}" prompts = request.messages if not isinstance(request.messages, list): prompts = [request.messages] + + # Normalize image_url format to image format for consistent processing + for message in prompts: + for content in message["content"]: + if content["type"] == "image_url": + # Convert OpenAI-style image_url to standard image format + if isinstance(content.get("image_url"), dict): + image_data = content["image_url"]["url"] + else: + image_data = content["image_url"] + # Transform to image format + content["type"] = "image" + content["image"] = image_data + # Remove image_url field + content.pop("image_url", None) + # Serialize the dictionary to a JSON string represnetation to be able to convert to numpy array # (str_list2numpy) and back to list (str_ndarray2list) as required by PyTriton. Using the dictionaries directly # with these methods is not possible as they expect string type. diff --git a/scripts/deploy/multimodal/query_ray_deployment.py b/scripts/deploy/multimodal/query_fastapi_inframework.py similarity index 85% rename from scripts/deploy/multimodal/query_ray_deployment.py rename to scripts/deploy/multimodal/query_fastapi_inframework.py index 971af56559..1ee17fe3d9 100644 --- a/scripts/deploy/multimodal/query_ray_deployment.py +++ b/scripts/deploy/multimodal/query_fastapi_inframework.py @@ -64,19 +64,16 @@ def load_image_from_path(image_path: str) -> str: image_path: Path to local image file or URL Returns: - Base64-encoded image string + Image string - HTTP URL directly or base64-encoded string for local files """ if image_path.startswith(("http://", "https://")): - LOGGER.info(f"Loading image from URL: {image_path}") - response = requests.get(image_path, timeout=30) - response.raise_for_status() - image_content = response.content + LOGGER.info(f"Using image URL directly: {image_path}") + return image_path else: - LOGGER.info(f"Loading image from local path: {image_path}") + LOGGER.info(f"Loading and encoding image from local path: {image_path}") with open(image_path, "rb") as f: image_content = f.read() - - return base64.b64encode(image_content).decode("utf-8") + return "data:image;base64," + base64.b64encode(image_content).decode("utf-8") def test_completions_endpoint(base_url: str, model_id: str, prompt: str = None, image_source: str = None) -> None: @@ -114,8 +111,8 @@ def test_completions_endpoint(base_url: str, model_id: str, prompt: str = None, payload["prompt"] = text try: - image_base64 = load_image_from_path(image_source) - payload["image"] = image_base64 + image_data = load_image_from_path(image_source) + payload["image"] = image_data except Exception as e: LOGGER.error(f"Failed to load image: {e}") return @@ -130,7 +127,12 @@ def test_completions_endpoint(base_url: str, model_id: str, prompt: str = None, def test_chat_completions_endpoint(base_url: str, model_id: str, prompt: str = None, image_source: str = None) -> None: - """Test the chat completions endpoint for multimodal models.""" + """Test the chat completions endpoint for multimodal models. + + Supports two image content formats: + 1. {"type": "image", "image": "url_or_base64"} + 2. {"type": "image_url", "image_url": {"url": "url_or_base64"}} (OpenAI-style) + """ url = f"{base_url}/v1/chat/completions/" # Use provided prompt or default @@ -141,8 +143,10 @@ def test_chat_completions_endpoint(base_url: str, model_id: str, prompt: str = N content = [] try: - image_base64 = load_image_from_path(image_source) - content.append({"type": "image", "image": image_base64}) + image_data = load_image_from_path(image_source) + # Using format 1: {"type": "image", "image": "url_or_base64"} + # Alternative format 2: {"type": "image_url", "image_url": {"url": "url_or_base64"}} + content.append({"type": "image", "image": image_data}) except Exception as e: LOGGER.error(f"Failed to load image: {e}") return @@ -167,19 +171,6 @@ def test_chat_completions_endpoint(base_url: str, model_id: str, prompt: str = N LOGGER.error(f"Error: {response.text}") -def test_models_endpoint(base_url: str) -> None: - """Test the models endpoint.""" - url = f"{base_url}/v1/models" - - LOGGER.info(f"Testing models endpoint at {url}") - response = requests.get(url) - LOGGER.info(f"Response status code: {response.status_code}") - if response.status_code == 200: - LOGGER.info(f"Response: {json.dumps(response.json(), indent=2)}") - else: - LOGGER.error(f"Error: {response.text}") - - def test_health_endpoint(base_url: str) -> None: """Test the health endpoint.""" url = f"{base_url}/v1/health" @@ -218,7 +209,6 @@ def main(): test_completions_endpoint(base_url, args.model_id, args.prompt, args.image) test_chat_completions_endpoint(base_url, args.model_id, args.prompt, args.image) test_health_endpoint(base_url) - test_models_endpoint(base_url) if __name__ == "__main__": diff --git a/scripts/deploy/multimodal/query_inframework.py b/scripts/deploy/multimodal/query_inframework.py index a7ddf1cc63..24accc0966 100644 --- a/scripts/deploy/multimodal/query_inframework.py +++ b/scripts/deploy/multimodal/query_inframework.py @@ -17,7 +17,6 @@ import logging import time -import requests from transformers import AutoProcessor from nemo_deploy.multimodal.query_multimodal import NemoQueryMultimodalPytorch @@ -32,19 +31,16 @@ def load_image_from_path(image_path: str) -> str: image_path: Path to local image file or URL Returns: - Base64-encoded image string + Image string - HTTP URL directly or base64-encoded string for local files """ if image_path.startswith(("http://", "https://")): - LOGGER.info(f"Loading image from URL: {image_path}") - response = requests.get(image_path, timeout=30) - response.raise_for_status() - image_content = response.content + LOGGER.info(f"Using image URL directly: {image_path}") + return image_path else: - LOGGER.info(f"Loading image from local path: {image_path}") + LOGGER.info(f"Loading and encoding image from local path: {image_path}") with open(image_path, "rb") as f: image_content = f.read() - - return base64.b64encode(image_content).decode("utf-8") + return "data:image;base64," + base64.b64encode(image_content).decode("utf-8") def get_args(): @@ -121,7 +117,7 @@ def query(): with open(args.prompt_file, "r") as f: args.prompt = f.read() - image_base64 = load_image_from_path(args.image) + image_source = load_image_from_path(args.image) if "Qwen" in args.processor_name: processor = AutoProcessor.from_pretrained(args.processor_name) @@ -146,7 +142,7 @@ def query(): nemo_query = NemoQueryMultimodalPytorch(args.url, args.model_name) outputs = nemo_query.query_multimodal( prompts=[args.prompt], - images=[image_base64], + images=[image_source], max_length=args.max_output_len, max_batch_size=args.max_batch_size, top_k=args.top_k, diff --git a/tests/unit_tests/deploy/test_fastapi_interface_to_pytriton_multimodal.py b/tests/unit_tests/deploy/test_fastapi_interface_to_pytriton_multimodal.py index 304f811025..2da62db31e 100644 --- a/tests/unit_tests/deploy/test_fastapi_interface_to_pytriton_multimodal.py +++ b/tests/unit_tests/deploy/test_fastapi_interface_to_pytriton_multimodal.py @@ -89,7 +89,7 @@ def test_base_multimodal_request_defaults(self): assert request.max_tokens == 50 assert request.temperature == 1.0 assert request.top_p == 0.0 - assert request.top_k == 1 + assert request.top_k == 0 assert request.random_seed is None assert request.max_batch_size == 4 @@ -112,11 +112,6 @@ def test_base_multimodal_request_custom_values(self): assert request.random_seed == 42 assert request.max_batch_size == 8 - def test_base_multimodal_request_greedy_validation(self): - """Test BaseMultimodalRequest validator for greedy sampling.""" - request = BaseMultimodalRequest(model="test-model", temperature=0, top_p=0, top_k=5) - assert request.top_k == 1 - def test_multimodal_completion_request(self): """Test MultimodalCompletionRequest.""" request = MultimodalCompletionRequest( @@ -274,7 +269,7 @@ def test_completions_with_image(self, client, mock_triton_settings): request_data = { "model": "test-model", "prompt": "Describe this image", - "image": "base64_encoded_image_data", + "image": "data:image;base64,base64_encoded_image_data", "temperature": 0.7, } @@ -291,7 +286,7 @@ def test_completions_with_image(self, client, mock_triton_settings): mock_query.assert_called_once() call_kwargs = mock_query.call_args[1] - assert call_kwargs["images"] == ["base64_encoded_image_data"] + assert call_kwargs["images"] == ["data:image;base64,base64_encoded_image_data"] assert call_kwargs["temperature"] == 0.7 def test_completions_with_custom_params(self, client, mock_triton_settings): @@ -357,7 +352,7 @@ def test_chat_completions_with_image(self, client, mock_triton_settings): "role": "user", "content": [ {"type": "text", "text": "What's in this image?"}, - {"type": "image", "image": "base64_image_data"}, + {"type": "image", "image": "data:image;base64,base64_image_data"}, ], } ] @@ -376,7 +371,7 @@ def test_chat_completions_with_image(self, client, mock_triton_settings): mock_query.assert_called_once() call_kwargs = mock_query.call_args[1] - assert call_kwargs["images"] == ["base64_image_data"] + assert call_kwargs["images"] == ["data:image;base64,base64_image_data"] def test_chat_completions_multiple_images(self, client, mock_triton_settings): """Test /v1/chat/completions/ endpoint with multiple images.""" @@ -385,8 +380,8 @@ def test_chat_completions_multiple_images(self, client, mock_triton_settings): "role": "user", "content": [ {"type": "text", "text": "Compare these images"}, - {"type": "image", "image": "base64_image_1"}, - {"type": "image", "image": "base64_image_2"}, + {"type": "image", "image": "data:image;base64,base64_image_1"}, + {"type": "image", "image": "data:image;base64,base64_image_2"}, ], } ] @@ -403,9 +398,64 @@ def test_chat_completions_multiple_images(self, client, mock_triton_settings): mock_query.assert_called_once() call_kwargs = mock_query.call_args[1] - assert call_kwargs["images"] == ["base64_image_1", "base64_image_2"] + assert call_kwargs["images"] == ["data:image;base64,base64_image_1", "data:image;base64,base64_image_2"] assert call_kwargs["max_length"] == 200 + def test_chat_completions_with_image_url_format(self, client, mock_triton_settings): + """Test /v1/chat/completions/ endpoint with OpenAI-style image_url format.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, + ], + } + ] + request_data = {"model": "test-model", "messages": messages} + + mock_output = {"choices": [{"text": [["I see a cat"]]}], "model": "test-model"} + + with patch("nemo_deploy.service.fastapi_interface_to_pytriton_multimodal.query_multimodal_async") as mock_query: + mock_query.return_value = mock_output + + response = client.post("/v1/chat/completions/", json=request_data) + + assert response.status_code == 200 + result = response.json() + assert result["choices"][0]["message"]["content"] == "I see a cat" + + mock_query.assert_called_once() + call_kwargs = mock_query.call_args[1] + assert call_kwargs["images"] == ["https://example.com/image.jpg"] + + def test_chat_completions_with_mixed_image_formats(self, client, mock_triton_settings): + """Test /v1/chat/completions/ endpoint with mixed image and image_url formats.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these images"}, + {"type": "image", "image": "data:image;base64,base64_data"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, + ], + } + ] + request_data = {"model": "test-model", "messages": messages} + + mock_output = {"choices": [{"text": [["Comparison"]]}], "model": "test-model"} + + with patch("nemo_deploy.service.fastapi_interface_to_pytriton_multimodal.query_multimodal_async") as mock_query: + mock_query.return_value = mock_output + + response = client.post("/v1/chat/completions/", json=request_data) + + assert response.status_code == 200 + + mock_query.assert_called_once() + call_kwargs = mock_query.call_args[1] + assert call_kwargs["images"] == ["data:image;base64,base64_data", "https://example.com/image.jpg"] + def test_chat_completions_with_params(self, client, mock_triton_settings): """Test /v1/chat/completions/ endpoint with custom parameters.""" messages = [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}] @@ -452,7 +502,7 @@ def test_helper_fun(self): url="http://localhost:8000", model="test-model", prompts=["test prompt"], - images=["image_data"], + images=["data:image;base64,image_data"], temperature=0.7, top_k=10, top_p=0.9, @@ -465,7 +515,7 @@ def test_helper_fun(self): mock_nq_class.assert_called_once_with(url="http://localhost:8000", model_name="test-model") mock_nq.query_multimodal.assert_called_once_with( prompts=["test prompt"], - images=["image_data"], + images=["data:image;base64,image_data"], temperature=0.7, top_k=10, top_p=0.9, diff --git a/tests/unit_tests/deploy/test_nemo_multimodal_deployable.py b/tests/unit_tests/deploy/test_nemo_multimodal_deployable.py index 538fe06653..56554cc4ae 100644 --- a/tests/unit_tests/deploy/test_nemo_multimodal_deployable.py +++ b/tests/unit_tests/deploy/test_nemo_multimodal_deployable.py @@ -255,10 +255,10 @@ def test_infer_fn(self, deployable, sample_image_base64, sample_image): prompts = ["Test prompt 1", "Test prompt 2"] images = [sample_image_base64, sample_image_base64] - with patch.object(deployable, "base64_to_image") as mock_base64_to_image: + with patch.object(deployable, "process_image_input") as mock_process_image_input: with patch.object(deployable, "generate") as mock_generate: - # Mock base64_to_image to return PIL Images - mock_base64_to_image.return_value = sample_image + # Mock process_image_input to return PIL Images + mock_process_image_input.return_value = sample_image mock_generate.return_value = [MockResult("Generated text 1"), MockResult("Generated text 2")] result = deployable._infer_fn( @@ -272,8 +272,8 @@ def test_infer_fn(self, deployable, sample_image_base64, sample_image): max_batch_size=3, ) - # Check that base64_to_image was called for each image - assert mock_base64_to_image.call_count == 2 + # Check that process_image_input was called for each image + assert mock_process_image_input.call_count == 2 # Check that generate was called with the right parameters assert mock_generate.call_count == 1 @@ -301,16 +301,16 @@ def test_infer_fn_default_params(self, deployable, sample_image_base64, sample_i prompts = ["Test prompt"] images = [sample_image_base64] - with patch.object(deployable, "base64_to_image") as mock_base64_to_image: + with patch.object(deployable, "process_image_input") as mock_process_image_input: with patch.object(deployable, "generate") as mock_generate: - # Mock base64_to_image to return PIL Images - mock_base64_to_image.return_value = sample_image + # Mock process_image_input to return PIL Images + mock_process_image_input.return_value = sample_image mock_generate.return_value = [MockResult("Generated text 1")] result = deployable._infer_fn(prompts=prompts, images=images) - # Check that base64_to_image was called - assert mock_base64_to_image.call_count == 1 + # Check that process_image_input was called + assert mock_process_image_input.call_count == 1 # Check that generate was called with the right parameters assert mock_generate.call_count == 1 @@ -331,6 +331,42 @@ def test_infer_fn_default_params(self, deployable, sample_image_base64, sample_i assert result["sentences"] == ["Generated text 1"] + def test_infer_fn_with_temperature_zero(self, deployable): + """Test _infer_fn with temperature=0.0 for greedy decoding.""" + sample_image = Image.new("RGB", (100, 100)) + sample_image_base64 = "data:image;base64,test_base64_string" + + prompts = ["Test prompt"] + images = [sample_image_base64] + + with patch.object(deployable, "process_image_input") as mock_process_image: + with patch.object(deployable, "generate") as mock_generate: + # Mock process_image_input to return PIL Images + mock_process_image.return_value = sample_image + mock_generate.return_value = [MockResult("Generated text")] + + result = deployable._infer_fn( + prompts=prompts, + images=images, + temperature=0.0, # Should trigger greedy sampling handling + top_k=5, # Should be overridden to 1 + top_p=0.5, # Should be overridden to 0.0 + num_tokens_to_generate=100, + ) + + # Check that generate was called with the right parameters + assert mock_generate.call_count == 1 + call_args = mock_generate.call_args + + # Check that inference_params has greedy sampling parameters + assert isinstance(call_args[0][2], CommonInferenceParams) + assert call_args[0][2].temperature == 0.0 # Kept as 0.0 + assert call_args[0][2].top_k == 1 # Overridden for greedy sampling + assert call_args[0][2].top_p == 0.0 # Overridden for greedy sampling + assert call_args[0][2].num_tokens_to_generate == 100 + + assert result["sentences"] == ["Generated text"] + def test_dict_to_str_function(self): """Test the dict_to_str utility function.""" from nemo_deploy.multimodal.nemo_multimodal_deployable import dict_to_str @@ -484,8 +520,8 @@ def test_apply_chat_template_without_generation_prompt(self, deployable): ) assert result == expected_text - def test_base64_to_image_with_qwenvl_wrapper(self, deployable): - """Test base64_to_image with QwenVLInferenceWrapper.""" + def test_process_image_input_with_qwenvl_wrapper(self, deployable): + """Test process_image_input with QwenVLInferenceWrapper using base64 image.""" # Create a mock QwenVLInferenceWrapper class mock_qwenvl_class = MagicMock() @@ -493,9 +529,8 @@ def test_base64_to_image_with_qwenvl_wrapper(self, deployable): # Use isinstance check to return True for QwenVLInferenceWrapper deployable.inference_wrapped_model = MagicMock() - image_base64 = ( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" - ) + # Image source with data URI prefix (new format) + image_source = "data:image;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" expected_image = Image.new("RGB", (100, 100)) with patch("nemo_deploy.multimodal.nemo_multimodal_deployable.QwenVLInferenceWrapper", mock_qwenvl_class): @@ -506,7 +541,7 @@ def test_base64_to_image_with_qwenvl_wrapper(self, deployable): with patch("qwen_vl_utils.process_vision_info") as mock_process: mock_process.return_value = (expected_image, None) - result = deployable.base64_to_image(image_base64) + result = deployable.process_image_input(image_source) # Verify isinstance was called to check the model type mock_isinstance.assert_called_once_with(deployable.inference_wrapped_model, mock_qwenvl_class) @@ -516,19 +551,50 @@ def test_base64_to_image_with_qwenvl_wrapper(self, deployable): assert len(call_args) == 1 assert call_args[0]["role"] == "user" assert call_args[0]["content"][0]["type"] == "image" - assert call_args[0]["content"][0]["image"] == f"data:image;base64,{image_base64}" + assert call_args[0]["content"][0]["image"] == image_source + + assert result == expected_image + + def test_process_image_input_with_http_url(self, deployable): + """Test process_image_input with HTTP URL.""" + # Create a mock QwenVLInferenceWrapper class + mock_qwenvl_class = MagicMock() + + # Make deployable.inference_wrapped_model an instance of the mock class + deployable.inference_wrapped_model = MagicMock() + + # HTTP URL as image source + image_source = "https://example.com/image.jpg" + expected_image = Image.new("RGB", (100, 100)) + + with patch("nemo_deploy.multimodal.nemo_multimodal_deployable.QwenVLInferenceWrapper", mock_qwenvl_class): + # Make isinstance return True for our mock + with patch("nemo_deploy.multimodal.nemo_multimodal_deployable.isinstance") as mock_isinstance: + mock_isinstance.return_value = True + + with patch("qwen_vl_utils.process_vision_info") as mock_process: + mock_process.return_value = (expected_image, None) + + result = deployable.process_image_input(image_source) + + # Verify process_vision_info was called with URL + call_args = mock_process.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0]["role"] == "user" + assert call_args[0]["content"][0]["type"] == "image" + assert call_args[0]["content"][0]["image"] == image_source assert result == expected_image - def test_base64_to_image_with_unsupported_model(self, deployable): - """Test base64_to_image with unsupported model raises ValueError.""" + def test_process_image_input_with_unsupported_model(self, deployable): + """Test process_image_input with unsupported model raises ValueError.""" # Create a mock QwenVLInferenceWrapper class mock_qwenvl_class = MagicMock() # Make sure the wrapped model is NOT a QwenVLInferenceWrapper deployable.inference_wrapped_model = MagicMock() - image_base64 = "test_base64_string" + image_source = "data:image;base64,test_base64_string" with patch("nemo_deploy.multimodal.nemo_multimodal_deployable.QwenVLInferenceWrapper", mock_qwenvl_class): # Make isinstance return False for our mock (not a QwenVLInferenceWrapper) @@ -536,7 +602,7 @@ def test_base64_to_image_with_unsupported_model(self, deployable): mock_isinstance.return_value = False with pytest.raises(ValueError, match="not supported"): - deployable.base64_to_image(image_base64) + deployable.process_image_input(image_source) def test_ray_infer_fn(self, deployable): """Test ray_infer_fn method.""" diff --git a/tests/unit_tests/deploy/test_query_multimodal.py b/tests/unit_tests/deploy/test_query_multimodal.py index e383d05dc0..1a5fa5b249 100644 --- a/tests/unit_tests/deploy/test_query_multimodal.py +++ b/tests/unit_tests/deploy/test_query_multimodal.py @@ -138,7 +138,7 @@ def query_multimodal_pytorch(self): @pytest.fixture def mock_images(self): # Create sample base64-encoded image strings for testing - return ["mock_base64_image_1", "mock_base64_image_2"] + return ["data:image;base64,mock_base64_image_1", "data:image;base64,mock_base64_image_2"] @pytest.fixture def mock_prompts(self): @@ -305,7 +305,7 @@ def test_query_multimodal_single_prompt_single_image(self, mock_model_client, qu mock_model_client.return_value.__enter__.return_value = mock_client_instance # Use mock base64 image string - base64_image = "mock_base64_single_image" + base64_image = "data:image;base64,mock_base64_single_image" result = query_multimodal_pytorch.query_multimodal(prompts=["Single prompt"], images=[base64_image])