diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index fa23f6969b633..5e80ee3738ed8 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -781,5 +781,11 @@ void BackendManager::RewindKVCache(size_t index) { } } +void BackendManager::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + if (concrete_backend_) { + concrete_backend_->ReorderKVCache(src_indices, dst_indices); + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 9f560340a2033..f8a74b9cbcfa4 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -31,6 +31,7 @@ class BackendManager { void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data); ov::CompiledModel GetOVCompiledModel(); void RewindKVCache(size_t index); + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 7c3ee7e76c3f9..7f4d1f74cfb7b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -315,6 +315,12 @@ void BasicBackend::RewindKVCache(size_t index) { }); } +void BasicBackend::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { + infer_request->ReorderKVCache(src_indices, dst_indices); + }); +} + void BasicBackend::Infer(OrtKernelContext* ctx) const { Ort::KernelContext context(ctx); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 7639e024c52cb..c7505d59eec0c 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -137,6 +137,7 @@ class BasicBackend : public IBackend { return exe_network_.Get(); } void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; private: bool ValidateSubgraph(std::map>& const_outputs_map); diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 365a4625815d6..4444f37ac7433 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -18,6 +18,7 @@ class IBackend { virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} + virtual void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) {} }; using ptr_stream_t = std::unique_ptr; class BackendFactory { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index a099f85b2a4b9..b7b0894d7bff7 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/contexts.h" @@ -286,6 +287,64 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span std::variant> { + std::vector indices; + while (!input.empty()) { + const auto delimiter_pos = input.find(','); + const auto part = input.substr(0, delimiter_pos); + errno = 0; + char* parse_end = nullptr; + // strtoll/stoll already skips whitespaces + const auto index = std::strtol(part.data(), &parse_end, 10); + if (parse_end == part.data()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Failed to parse kvcache_reorder " + index_type + ": " + std::string(part)); + } + if (index < 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "kvcache_reorder " + index_type + " cannot be negative: " + std::string(part)); + } + if (errno == ERANGE) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part)); + } + indices.push_back(static_cast(index)); + if (delimiter_pos != std::string_view::npos) { + // ignore any trailing chars after the number, can do futher checking if needed + input.remove_prefix(part.size() + 1); + } else { + break; + } + } + return indices; + }; + + const auto src_indices = parse_indices(src_string, "src_index"); + if (src_indices.index() == 0) { + return std::get<0>(src_indices); + } + + const auto dst_indices = parse_indices(dst_string, "dst_index"); + if (dst_indices.index() == 0) { + return std::get<0>(dst_indices); + } + + // Trigger KVCache Reorder for target Backend with vector arguments + for (auto& backend : backend_managers_) { + backend.ReorderKVCache(std::get<1>(src_indices), std::get<1>(dst_indices)); + } } else { // Handle unknown options LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index a57db77c37dfa..5255729478dbe 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -109,9 +109,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, bool model_status = IsStateful(model); LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); + // Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding + bool is_fused_kvcache_reorder = false; if (!model_status) { LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; - PatchStatefulDecoder(model); + // TO-DO: extend to NPU device when OpenVINO NPU has related optimization + is_fused_kvcache_reorder = hw_target.find("GPU") != std::string::npos; + PatchStatefulDecoder(model, is_fused_kvcache_reorder); } if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { @@ -152,7 +156,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); - OVExeNetwork exe(compiled_model, hw_target, true); + OVExeNetwork exe(compiled_model, hw_target, true, is_fused_kvcache_reorder); return exe; } @@ -332,7 +336,7 @@ std::shared_ptr OVExeNetwork::CreateInferRequest() { auto infReq = compiled_model_obj.create_infer_request(); std::shared_ptr ovInfReq; if (is_stateful_causallm) { - ovInfReq = std::make_shared(std::move(infReq), target_device); + ovInfReq = std::make_shared(std::move(infReq), target_device, is_fused_kvcache_reorder); } else { ovInfReq = std::make_shared(std::move(infReq)); } @@ -377,8 +381,8 @@ void OVInferRequest::Infer() { "In Error Couldn't start Inference"); } -StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) - : OVInferRequest(std::move(infer_request)), target_device(device) { +StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder) + : OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); _npu_logits_slice_required = IsNPULogitsSliceRequired(); @@ -469,6 +473,32 @@ void StatefulOVInferRequest::PreProcessInferRequest() { // TODO(ankit): Address this issue and implement the fix at the appropriate layer. FillTensor("beam_idx", ov::element::i32, {1}, 0); + if (is_fused_kvcache_reorder){ + ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape(); + const auto kv_num_heads = dst_idx_shape[1]; + const auto kv_head_size = dst_idx_shape[3]; + if (kv_src_indices.size() > 0) { + ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()}); + const auto src_idx_ptr = src_idx_tensor.data(); + for (size_t i = 0; i < kv_src_indices.size(); ++i) { + src_idx_ptr[i] = static_cast(kv_src_indices[i]); + } + ovInfReq.set_tensor("src_idx", src_idx_tensor); + + ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size}); + const auto dst_idx_ptr = dst_idx_tensor.data(); + for (size_t i = 0; i < kv_num_heads; ++i) { + for (size_t j = 0; j < kv_dst_indices.size(); ++j) { + std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]); + } + } + ovInfReq.set_tensor("dst_idx", dst_idx_tensor); + } else { + FillTensor("src_idx", ov::element::i32, {0}, 0); + FillTensor("dst_idx", ov::element::i32, {1, kv_num_heads, 0, kv_head_size}, 0); + } + } + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. if (prefill_use_full_chat_history) { auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); @@ -505,6 +535,29 @@ void StatefulOVInferRequest::PreProcessInferRequest() { void StatefulOVInferRequest::Infer() { PreProcessInferRequest(); OVInferRequest::Infer(); + PostProcessInferRequest(); +} + +void StatefulOVInferRequest::PostProcessInferRequest() { + if(is_fused_kvcache_reorder){ + kv_src_indices.clear(); + kv_dst_indices.clear(); + } +} + +void StatefulOVInferRequest::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + // Validate input parameters + if (src_indices.size() != dst_indices.size()) { + ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. " + "Got src_indices.size()=" + std::to_string(src_indices.size()) + + ", dst_indices.size()=" + std::to_string(dst_indices.size())); + } + + LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with " + << src_indices.size() << " index pairs"; + + kv_src_indices = src_indices; + kv_dst_indices = dst_indices; } void StatefulOVInferRequest::RewindKVCache(size_t index) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index aa4b3fbe64898..2b61a7d603be6 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -91,10 +91,11 @@ class OVExeNetwork { ov::CompiledModel compiled_model_obj; std::string target_device; bool is_stateful_causallm; + bool is_fused_kvcache_reorder = false; public: - explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false) - : compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {} + explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool fused_kvcache_reorder = false) + : compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_fused_kvcache_reorder(fused_kvcache_reorder) {} OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {} ov::CompiledModel& Get() { return compiled_model_obj; } std::shared_ptr CreateInferRequest(); @@ -136,14 +137,16 @@ class OVInferRequest { return ovInfReq; } virtual void RewindKVCache([[maybe_unused]] size_t index) {} + virtual void ReorderKVCache([[maybe_unused]] const std::vector& src_indices, [[maybe_unused]] const std::vector& dst_indices) {} }; class StatefulOVInferRequest : public OVInferRequest { public: - explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device); + explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder = false); void Infer() override; void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; void FillTensor(const std::string& tensor_name, const ov::element::Type& type, const std::vector& shape, int32_t fill_value); void CacheTensor(const std::string& tensor_name, std::vector& cache); @@ -153,13 +156,19 @@ class StatefulOVInferRequest : public OVInferRequest { private: void PreProcessInferRequest(); + void PostProcessInferRequest(); std::string target_device; + std::vector cached_input_ids; + std::vector cached_position_ids; + std::vector kv_src_indices; + std::vector kv_dst_indices; + // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, // and ensure that full chat history is passed for each prefill call. bool prefill_use_full_chat_history = false; - std::vector cached_input_ids; - std::vector cached_position_ids; + // If fused_kvcache_reorder, will include kv_sec/dst_indices as input + bool is_fused_kvcache_reorder = false; bool IsNPULogitsSliceRequired(); bool _npu_logits_slice_required = false; diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index fd2b5797a1f40..770c371c399b8 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -75,7 +75,8 @@ std::string GetInputOutputName(std::shared_ptr ov_model, void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim) { + int gather_dim, + const bool is_fused_kvcache_reorder) { if (ModelHasInputOutputNames(ov_model, "beam_idx")) { throw std::runtime_error("Model already has fused cache"); } @@ -91,6 +92,7 @@ void FuseCacheReorder(std::shared_ptr ov_model, std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape(); auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); beam_idx->set_friendly_name("beam_idx"); @@ -98,6 +100,23 @@ void FuseCacheReorder(std::shared_ptr ov_model, ov_model->add_parameters({beam_idx}); not_kv_inputs.push_back(beam_idx->get_friendly_name()); + std::shared_ptr src_idx; + std::shared_ptr dst_idx; + + if (is_fused_kvcache_reorder) { + src_idx = std::make_shared(ov::element::i32, ov::PartialShape({update_shape[2]})); + src_idx->set_friendly_name("src_idx"); + src_idx->output(0).get_tensor().add_names({"src_idx"}); + ov_model->add_parameters({src_idx}); + not_kv_inputs.push_back(src_idx->get_friendly_name()); + + dst_idx = std::make_shared(ov::element::i32, update_shape); + dst_idx->set_friendly_name("dst_idx"); + dst_idx->output(0).get_tensor().add_names({"dst_idx"}); + ov_model->add_parameters({dst_idx}); + not_kv_inputs.push_back(dst_idx->get_friendly_name()); + } + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx for (const auto& input_name : key_value_input_names) { auto parameter_output_port = ov_model->input(input_name); @@ -108,9 +127,25 @@ void FuseCacheReorder(std::shared_ptr ov_model, beam_idx, ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + std::shared_ptr output_node; + if (is_fused_kvcache_reorder) { + auto updatekv_gather_op = + std::make_shared(gather_op, + src_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + + auto updatekv_op = std::make_shared(gather_op, + dst_idx, + updatekv_gather_op, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + output_node = updatekv_op; + } else { + output_node = gather_op; + } + // Replace the source output for all consumers of the input tensor for (auto& consumer : consumers) { - consumer.replace_source_output(gather_op->output(0)); + consumer.replace_source_output(output_node->output(0)); } } @@ -247,7 +282,7 @@ std::pair, std::vector> ExtractInputKVTens } // Updated PatchStatefulDecoder function -void PatchStatefulDecoder(std::shared_ptr model) { +void PatchStatefulDecoder(std::shared_ptr model, const bool is_fused_kvcache_reorder) { // Use the dynamic pattern-based extraction logic auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); @@ -269,7 +304,7 @@ void PatchStatefulDecoder(std::shared_ptr model) { // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 auto batch_dim = 0; - FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim); + FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, is_fused_kvcache_reorder); MakeStateful(model, key_value_input_names, key_value_output_names); } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h index 0b89c4ed02e13..bfb6224fc8993 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -13,6 +13,7 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset12.hpp" #include "openvino/opsets/opset13.hpp" namespace onnxruntime { @@ -25,13 +26,14 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim); + int gather_dim, + const bool is_fused_kvcache_reorder = false); void MakeStateful(std::shared_ptr& ov_model, const std::vector& key_value_input_names, const std::vector& key_value_output_names); -void PatchStatefulDecoder(std::shared_ptr model); +void PatchStatefulDecoder(std::shared_ptr model, const bool is_fused_kvcache_reorder = false); bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name);