diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index fb93c72320d45..19ae040bc4443 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -154,7 +154,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv metadef_id_generator_{ModelMetadefIdGenerator::Create()}, external_alloc_{info.external_alloc}, external_free_{info.external_free}, - external_empty_cache_{info.external_empty_cache} { + external_empty_cache_{info.external_empty_cache}, + max_dynamic_batch_{info.max_dynamic_batch} { InitProviderOrtApi(); // Set GPU device to be used and read device properties for feature usage. @@ -180,6 +181,13 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); + // Get max dynamic batch size from environment variable + const auto max_dynamic_batch_env{GetEnvironmentVar(migraphx_env_vars::kModelMaxDynamicBatch)}; + if (!max_dynamic_batch_env.empty()) { + max_dynamic_batch_ = std::stoull(max_dynamic_batch_env); + LOGS_DEFAULT(INFO) << "\n " << migraphx_env_vars::kModelMaxDynamicBatch << ": " << max_dynamic_batch_; + } + // Verify configuration correctness and adjust accordingly. #if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) @@ -237,7 +245,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_table_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ - << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_ + << "\n " << migraphx_provider_option::kModelMaxDynamicBatch << ": " << max_dynamic_batch_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { @@ -1315,6 +1324,80 @@ std::string make_hash(const char* v) { constexpr std::uint64_t MIGraphX_Version = ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); +// Helper function to get power of 2 batch sizes up to max +std::vector GetPowerOf2BatchSizes(size_t max_batch) { + std::vector batch_sizes; + if (max_batch == 0) return batch_sizes; + + for (size_t batch = 1; batch <= max_batch; batch *= 2) { + batch_sizes.push_back(batch); + } + + // If max_batch is not a power of 2, add it + if (batch_sizes.empty() || batch_sizes.back() != max_batch) { + batch_sizes.push_back(max_batch); + } + + return batch_sizes; +} + +// Helper: Compile a single program with specific batch size for all inputs +migraphx::program CompileProgramWithBatch( + const std::string& onnx_string, + const std::vector& input_names, + const std::vector>& all_input_base_shapes, + size_t batch_size, + migraphx::onnx_options options, + const migraphx::target& t, + bool fp16_enable, + bool bf16_enable, + bool int8_enable, + bool fp8_enable, + bool int8_calibration_cache_available, + std::unordered_map& dynamic_range_map, + bool exhaustive_tune, + const std::filesystem::path& model_path) { + + LOGS_DEFAULT(VERBOSE) << "[CompileBatch] Compiling for batch size: " << batch_size; + + // Set input shapes with the specified batch size for ALL inputs + for (size_t i = 0; i < input_names.size() && i < all_input_base_shapes.size(); ++i) { + std::vector shape_with_batch; + shape_with_batch.push_back(batch_size); + for (auto dim : all_input_base_shapes[i]) { + shape_with_batch.push_back(static_cast(dim)); + } + options.set_input_parameter_shape(input_names[i], shape_with_batch); + + std::ostringstream ss; + ss << "["; + for (size_t j = 0; j < shape_with_batch.size(); ++j) { + if (j > 0) ss << ", "; + ss << shape_with_batch[j]; + } + ss << "]"; + LOGS_DEFAULT(VERBOSE) << "[CompileBatch] Input '" << input_names[i] << "' shape: " << ss.str(); + } + +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + if (!model_path.empty()) { + options.set_external_data_path(model_path.parent_path().string()); + } +#endif +#endif + + migraphx::program prog = migraphx::parse_onnx_buffer(onnx_string, options); + migraphx::program_parameters quant_params; + + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, dynamic_range_map); + compile_program(prog, t, exhaustive_tune); + + LOGS_DEFAULT(VERBOSE) << "[CompileBatch] Compilation complete for batch size: " << batch_size; + return prog; +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1509,6 +1592,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "[Compile] Saving compiled model to cache: " << model_cache_file.string(); save_compiled_model(prog, model_cache_file); LOGS_DEFAULT(VERBOSE) << "[Compile] Model saved successfully with batch-aware filename"; + // Note: Batch pre-compilation happens after this block for all cases (cache hit or miss) } else { LOGS_DEFAULT(VERBOSE) << "[Compile] Cache hit! Loaded precompiled model from: " << model_cache_file.string(); } @@ -1531,8 +1615,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // NOTE: DO NOT set output shapes as input parameters! // Outputs are dynamically inferred by MIGraphX based on input shapes } else { - LOGS_DEFAULT(VERBOSE) << "[Compile] Deferring compilation until runtime (no static input shapes available)"; - LOGS_DEFAULT(VERBOSE) << "[Compile] Will use default batch size of 1, then recompile with actual batch at runtime"; + LOGS_DEFAULT(INFO) << "[Compile] No static input shapes available, compiling with default batch size 1"; + // Still compile with default shapes so we have something in the batch cache +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif + prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + migraphx::program_parameters quant_params; + + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + compile_program(prog, t_, exhaustive_tune_); + LOGS_DEFAULT(INFO) << "[Compile] Compiled model with default shapes"; } // compile the program @@ -1541,6 +1637,119 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& map_onnx_string_[fused_node.Name()] = onnx_string_buffer; map_input_index_[fused_node.Name()] = input_name_index; map_no_input_shape_[fused_node.Name()] = no_input_shape; + + // Initialize batch program cache for this node + batch_program_cache_[fused_node.Name()] = std::map(); + + // Build base shapes for ALL inputs (excluding batch dimension) + // This is needed for both single batch and multi-batch compilation + std::vector> all_input_base_shapes; + for (size_t i = 0; i < input_tensor.size(); ++i) { + std::vector base_shape; + if (input_tensor[i]->Shape() != nullptr) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + const auto& dim = tensor_shape->dim(j); + if (dim.has_dim_value()) { + base_shape.push_back(dim.dim_value()); + } else { + base_shape.push_back(1); // Default for symbolic dims + } + } + } + all_input_base_shapes.push_back(base_shape); + } + + // Extract the batch size from the main compiled program (from first input) + size_t main_prog_batch_size = 1; // Default + if (!input_tensor.empty() && input_tensor[0]->Shape() != nullptr) { + auto tensor_shape = input_tensor[0]->Shape(); + if (tensor_shape->dim_size() > 0) { + const auto& batch_dim = tensor_shape->dim(0); + if (batch_dim.has_dim_value()) { + main_prog_batch_size = static_cast(batch_dim.dim_value()); + } + } + } + + // Always store the main compiled program in the batch cache + // This ensures at least one batch size is always available + { + std::lock_guard lock(batch_cache_mutex_); + batch_program_cache_[fused_node.Name()][main_prog_batch_size] = prog; + LOGS_DEFAULT(INFO) << "[Compile] Stored main program in batch cache for batch size: " << main_prog_batch_size; + } + + // Pre-compile/load additional programs for other batch sizes when max_dynamic_batch_ > 0 + // This compiles power-of-2 batch sizes up to max_dynamic_batch_ + if (!model_cache_path_.empty() && !no_input_shape && max_dynamic_batch_ > 0) { + std::lock_guard lock(batch_cache_mutex_); + + // Compile power-of-2 batch sizes up to max_dynamic_batch_ + auto batch_sizes_to_compile = GetPowerOf2BatchSizes(max_dynamic_batch_); + LOGS_DEFAULT(INFO) << "[Compile] Pre-compiling " << batch_sizes_to_compile.size() + << " batch sizes (powers of 2 up to " << max_dynamic_batch_ << ")"; + + int compiled_count = 0; + int loaded_count = 0; + for (size_t batch : batch_sizes_to_compile) { + // Skip if already in cache (from earlier pre-compilation) + if (batch_program_cache_[fused_node.Name()].find(batch) != batch_program_cache_[fused_node.Name()].end()) { + LOGS_DEFAULT(VERBOSE) << "[Compile] Batch size " << batch << " already in memory cache, skipping"; + continue; + } + + // Build input shapes with this batch size (using first input's shape for hash) + std::vector batch_input_shapes; + batch_input_shapes.push_back(static_cast(batch)); // Batch dimension + if (!all_input_base_shapes.empty()) { + batch_input_shapes.insert(batch_input_shapes.end(), + all_input_base_shapes[0].begin(), + all_input_base_shapes[0].end()); + } + + if (!batch_input_shapes.empty()) { + auto batch_cache_hash = make_hash(batch_input_shapes); + auto batch_cache_file = model_cache_path_ / (mxr_filename_prefix + batch_cache_hash + ".mxr"); + + LOGS_DEFAULT(VERBOSE) << "[Compile] Looking for batch " << batch << " cache file: " << batch_cache_file.string(); + + migraphx::program batch_prog; + if (load_precompiled_model(batch_prog, batch_cache_file)) { + // Successfully loaded from disk + batch_program_cache_[fused_node.Name()][batch] = std::move(batch_prog); + loaded_count++; + LOGS_DEFAULT(INFO) << "[Compile] Loaded cached model for batch size " << batch; + } else { + // Cache miss - compile the program using CompileProgramWithBatch + LOGS_DEFAULT(INFO) << "[Compile] Compiling model for batch size " << batch << "..."; + + try { + batch_prog = CompileProgramWithBatch( + onnx_string_buffer, input_names, all_input_base_shapes, batch, + options, t_, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, + int8_calibration_cache_available_, dynamic_range_map_, exhaustive_tune_, model_path_); + + // Save to disk for future runs + save_compiled_model(batch_prog, batch_cache_file); + LOGS_DEFAULT(INFO) << "[Compile] Saved compiled model for batch " << batch << " to: " << batch_cache_file.string(); + + // Store in memory cache + batch_program_cache_[fused_node.Name()][batch] = std::move(batch_prog); + compiled_count++; + + } catch (const std::exception& e) { + LOGS_DEFAULT(ERROR) << "[Compile] Failed to compile batch " << batch << ": " << e.what(); + } + } + } + } + + LOGS_DEFAULT(INFO) << "[Compile] Batch cache ready: " << loaded_count << " loaded from disk, " + << compiled_count << " newly compiled, " + << batch_program_cache_[fused_node.Name()].size() << " total batch sizes available"; + } + NodeComputeInfo compute_info; compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); @@ -1548,7 +1757,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, - model_cache_path_.string(), dump_model_ops_}; + model_cache_path_.string(), dump_model_ops_, exhaustive_tune_, max_dynamic_batch_, + &batch_program_cache_[context->node_name], &batch_cache_mutex_, std::string(context->node_name), + session_input_names}; *state = p.release(); return 0; }; @@ -1685,137 +1896,94 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - LOGS_DEFAULT(VERBOSE) << "[Compute] Input shape mismatch detected, initiating recompilation"; - - std::filesystem::path model_cache_file; - // empty cache path means the MXR caching is disabled - always compile - if (!model_cache_path_.empty()) { - // Ensure input_shapes has all updated dimensions including new batch sizes - if (input_shapes.empty()) { - LOGS_DEFAULT(WARNING) << "[Compute] Input shapes vector is empty, rebuilding from current inputs"; - for (auto&& name : param_shapes.names()) { - if (map_input_name_index.count(name) > 0) { - auto input_tensor = ctx.GetInput(map_input_name_index[name]); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); - } - } - } + LOGS_DEFAULT(VERBOSE) << "[Compute] Input shape mismatch detected"; - // Log the shapes being used for cache key generation - std::ostringstream shapes_str; - shapes_str << "["; - for (size_t i = 0; i < input_shapes.size(); ++i) { - if (i > 0) shapes_str << ", "; - shapes_str << input_shapes[i]; - } - shapes_str << "]"; - LOGS_DEFAULT(VERBOSE) << "[Compute] Cache key input shapes (including updated batch): " << shapes_str.str(); + // Extract batch size from first ACTUAL runtime input (not constants/weights) + // We need to find an input that varies with batch size + size_t requested_batch = 0; + bool found_batch_input = false; - auto cache_hash = make_hash(input_shapes); - model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + cache_hash + ".mxr"); - LOGS_DEFAULT(VERBOSE) << "[Compute] Cache file with batch-aware hash: " << model_cache_file.string(); - } + // First, try to get batch from actual model inputs (session-level inputs) + for (auto& it : map_input_name_index) { + auto& name = it.first; + auto& index = it.second; - if (!load_precompiled_model(prog, model_cache_file)) { - LOGS_DEFAULT(VERBOSE) << "[Compute] Cache miss. Compiling model with updated batch size"; + // Skip if this looks like a weight/constant (session_input_names contains only real inputs) + if (mgx_state->session_input_names.count(name) == 0) { + continue; // This is likely a constant/weight, skip it + } + + auto input_tensor = ctx.GetInput(index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); - // CRITICAL: Ensure ALL input parameter shapes are explicitly set as static shapes in cmp_options - // This must be done BEFORE parsing to treat dynamic shapes as static for compilation - // NOTE: Only set shapes for actual runtime input parameters, NOT for constants/initializers - // MIGraphX will automatically infer shapes for constants and intermediate tensors - LOGS_DEFAULT(VERBOSE) << "[Compute] Setting " << map_input_name_index.size() - << " input parameter shapes as static in MIGraphX options (excluding constants)"; + if (!tensor_shape.empty()) { + requested_batch = static_cast(tensor_shape[0]); + found_batch_input = true; + LOGS_DEFAULT(VERBOSE) << "[Compute] Extracted batch size " << requested_batch + << " from session input '" << name << "'"; + break; + } + } + // Fallback: if no session input found, use first available input + if (!found_batch_input && !map_input_name_index.empty()) { for (auto& it : map_input_name_index) { - auto& name = it.first; - auto& index = it.second; - auto input_tensor = ctx.GetInput(index); + auto input_tensor = ctx.GetInput(it.second); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); - std::vector ort_lens(tensor_shape.begin(), tensor_shape.end()); - - // Set shape as static parameter for MIGraphX compilation - // Only for actual input parameters - constants/initializers are handled by MIGraphX - cmp_options.set_input_parameter_shape(name, ort_lens); - - LOGS_DEFAULT(VERBOSE) << "[Compute] Set static shape for input parameter '" << name << "': [" - << [&]() { - std::ostringstream ss; - for (size_t i = 0; i < ort_lens.size(); ++i) { - if (i > 0) ss << ", "; - ss << ort_lens[i]; - } - return ss.str(); - }() << "]"; - } - LOGS_DEFAULT(VERBOSE) << "[Compute] All input parameter shapes set as static"; - LOGS_DEFAULT(VERBOSE) << "[Compute] MIGraphX will infer shapes for constants and intermediate tensors"; -#ifndef ENABLE_TRAINING_CORE -#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH - cmp_options.set_external_data_path(model_path_.parent_path().string()); -#endif -#endif - LOGS_DEFAULT(VERBOSE) << "[Compute] Parsing ONNX buffer with static input shapes"; - prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); - LOGS_DEFAULT(VERBOSE) << "[Compute] ONNX parsing complete"; - - // Verify that MIGraphX parsed with correct shapes for input parameters - auto parsed_param_shapes = prog.get_parameter_shapes(); - LOGS_DEFAULT(VERBOSE) << "[Compute] Verifying parsed parameter shapes (" - << parsed_param_shapes.size() << " total parameters):"; - for (auto&& param_name : parsed_param_shapes.names()) { - auto shape = parsed_param_shapes[param_name]; - auto lens = shape.lengths(); - std::ostringstream ss; - ss << "["; - for (size_t i = 0; i < lens.size(); ++i) { - if (i > 0) ss << ", "; - ss << lens[i]; + if (!tensor_shape.empty()) { + requested_batch = static_cast(tensor_shape[0]); + LOGS_DEFAULT(WARNING) << "[Compute] Extracted batch size " << requested_batch + << " from input '" << it.first << "' (no session input found)"; + break; } - ss << "]"; - - // Distinguish between input parameters we set and constants MIGraphX inferred - bool is_input_param = (map_input_name_index.count(param_name) > 0); - LOGS_DEFAULT(VERBOSE) << "[Compute] Parameter '" << param_name << "' parsed shape: " << ss.str() - << (is_input_param ? " (input parameter)" : " (constant/internal)"); } + } - migraphx::program_parameters quant_params; + LOGS_DEFAULT(VERBOSE) << "[Compute] Requested batch size: " << requested_batch; - if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { - auto local_param_shapes = prog.get_parameter_shapes(); - // Add input parameter data and the values they're set to - for (auto&& name : local_param_shapes.names()) { - if (map_input_name_index.count(name) > 0) { - auto input_tensor = ctx.GetInput(map_input_name_index[name]); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - const auto tensor_type = tensor_info.GetElementType(); + // Try to find program in batch cache first + bool found_in_cache = false; + { + std::lock_guard lock(*mgx_state->batch_cache_mutex_ptr); + auto& batch_cache = *mgx_state->batch_program_cache_ptr; - migraphx_shape_datatype_t mgx_type; - getMIGraphXType(tensor_type, mgx_type); - auto mgx_s = local_param_shapes[name]; + if (batch_cache.find(requested_batch) != batch_cache.end()) { + LOGS_DEFAULT(VERBOSE) << "[Compute] Found program in batch cache for batch size: " << requested_batch; + prog = batch_cache[requested_batch]; + found_in_cache = true; + } + } - if (mgx_type != mgx_s.type()) { - LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; - } - quant_params.add(name, migraphx::argument(local_param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); - } + if (!found_in_cache) { + // Batch size not found in pre-compiled cache - return error + // All batch sizes should be pre-compiled during Compile phase + LOGS_DEFAULT(ERROR) << "[Compute] Batch size " << requested_batch + << " not found in pre-compiled cache. " + << "Ensure max_dynamic_batch is set correctly and includes this batch size."; + + // List available batch sizes in cache for debugging + { + std::lock_guard lock(*mgx_state->batch_cache_mutex_ptr); + auto& batch_cache = *mgx_state->batch_program_cache_ptr; + std::ostringstream available; + available << "Available batch sizes: ["; + bool first = true; + for (const auto& kv : batch_cache) { + if (!first) available << ", "; + available << kv.first; + first = false; } + available << "]"; + LOGS_DEFAULT(ERROR) << "[Compute] " << available.str(); } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, - fp8_enable, int8_calibration_cache_available, map_dynamic_range); - compile_program(prog, t, exhaustive_tune_); - - // Save compiled model with batch-aware filename - LOGS_DEFAULT(VERBOSE) << "[Compute] Saving compiled model with updated batch size to: " - << model_cache_file.string(); - save_compiled_model(prog, model_cache_file); - } else { - LOGS_DEFAULT(VERBOSE) << "[Compute] Cache hit! Loaded precompiled model with matching batch size"; + + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "MIGraphX: Batch size ", requested_batch, " not pre-compiled. ", + "Set ORT_MIGRAPHX_MAX_DYNAMIC_BATCH environment variable to at least ", requested_batch, + " to pre-compile this batch size during model loading."); } mgx_state->prog = prog; @@ -1917,19 +2085,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto prog_outputs = prog.run_async(m, static_cast(rocm_stream)); LOGS_DEFAULT(VERBOSE) << "[Compute] Execution complete, got " << prog_outputs.size() << " outputs"; - // Verify actual output shapes match expectations - for (std::size_t i = 0; i < prog_outputs.size(); ++i) { - auto actual_shape = prog_outputs[i].get_shape(); - auto actual_lens = actual_shape.lengths(); - std::ostringstream ss; - ss << "["; - for (size_t j = 0; j < actual_lens.size(); ++j) { - if (j > 0) ss << ", "; - ss << actual_lens[j]; + // Verify actual output shapes match expectations (only in verbose mode) + if (logging::LoggingManager::DefaultLogger().GetSeverity() <= logging::Severity::kVERBOSE) { + for (std::size_t i = 0; i < prog_outputs.size(); ++i) { + auto actual_shape = prog_outputs[i].get_shape(); + auto actual_lens = actual_shape.lengths(); + std::ostringstream ss; + ss << "["; + for (size_t j = 0; j < actual_lens.size(); ++j) { + if (j > 0) ss << ", "; + ss << actual_lens[j]; + } + ss << "]"; + LOGS_DEFAULT(VERBOSE) << "[Compute] Actual output " << i << " shape after execution: " << ss.str() + << (actual_lens.size() > 0 ? " (batch=" + std::to_string(actual_lens[0]) + ")" : ""); } - ss << "]"; - LOGS_DEFAULT(VERBOSE) << "[Compute] Actual output " << i << " shape after execution: " << ss.str() - << (actual_lens.size() > 0 ? " (batch=" + std::to_string(actual_lens[0]) + ")" : ""); } // In case of input parameters are reused as output parameter call hipMemcpy diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 12758b87b2cad..c94eecf2a95df 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -58,6 +58,10 @@ struct MIGraphXFuncState { bool dump_model_ops = false; bool exhaustive_tune = false; size_t max_dynamic_batch; + std::map* batch_program_cache_ptr = nullptr; + std::mutex* batch_cache_mutex_ptr = nullptr; + std::string node_name; + std::set session_input_names; // Track actual model inputs vs constants }; // Logical device representation. @@ -140,6 +144,11 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::unordered_map> map_input_index_; std::unordered_map map_no_input_shape_; + // Cache of compiled programs indexed by batch size for each node + // Key: node_name, Value: map of batch_size -> program + std::unordered_map> batch_program_cache_; + std::mutex batch_cache_mutex_; // Protect batch_program_cache_ + AllocatorPtr allocator_; std::unique_ptr metadef_id_generator_; void* external_alloc_{nullptr};