Skip to content
Open
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,5 +781,11 @@ void BackendManager::RewindKVCache(size_t index) {
}
}

void BackendManager::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
if (concrete_backend_) {
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
}
}

} // namespace openvino_ep
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BackendManager {
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
ov::CompiledModel GetOVCompiledModel();
void RewindKVCache(size_t index);
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices);

private:
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ void BasicBackend::RewindKVCache(size_t index) {
});
}

void BasicBackend::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
infer_request->ReorderKVCache(src_indices, dst_indices);
});
}

void BasicBackend::Infer(OrtKernelContext* ctx) const {
Ort::KernelContext context(ctx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BasicBackend : public IBackend {
return exe_network_.Get();
}
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;

private:
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/ibackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class IBackend {
virtual ov::CompiledModel GetOVCompiledModel() = 0;
virtual ~IBackend() = default;
virtual void RewindKVCache(size_t index) {}
virtual void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {}
};
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
class BackendFactory {
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/openvino/openvino_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <memory>
#include <vector>
#include <cerrno>
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/openvino/openvino_execution_provider.h"
#include "core/providers/openvino/contexts.h"
Expand Down Expand Up @@ -286,6 +287,64 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
}
}
} else if (key == "kvcache_reorder") {
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors
// src_indices = [1,2,3], dst_indices = [4,5,6]
size_t delimiter_pos = value.find(';');
if (delimiter_pos == std::string::npos) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists");
}

std::string_view src_string(value.begin(), value.begin() + delimiter_pos);
std::string_view dst_string(value.begin() + delimiter_pos + 1, value.end());

constexpr auto parse_indices = [](std::string_view input, const std::string& index_type) -> std::variant<Status, std::vector<int32_t>> {
std::vector<int32_t> indices;
while (!input.empty()) {
const auto delimiter_pos = input.find(',');
const auto part = input.substr(0, delimiter_pos);
errno = 0;
char* parse_end = nullptr;
// strtoll/stoll already skips whitespaces
const auto index = std::strtol(part.data(), &parse_end, 10);
if (parse_end == part.data()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(part));
}
if (index < 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder " + index_type + " cannot be negative: " + std::string(part));
}
if (errno == ERANGE) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part));
}
indices.push_back(static_cast<int32_t>(index));
if (delimiter_pos != std::string_view::npos) {
// ignore any trailing chars after the number, can do futher checking if needed
input.remove_prefix(part.size() + 1);
} else {
break;
}
}
return indices;
};

const auto src_indices = parse_indices(src_string, "src_index");
if (src_indices.index() == 0) {
return std::get<0>(src_indices);
}

const auto dst_indices = parse_indices(dst_string, "dst_index");
if (dst_indices.index() == 0) {
return std::get<0>(dst_indices);
}

// Trigger KVCache Reorder for target Backend with vector arguments
for (auto& backend : backend_managers_) {
backend.ReorderKVCache(std::get<1>(src_indices), std::get<1>(dst_indices));
}
} else {
// Handle unknown options
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;
Expand Down
63 changes: 58 additions & 5 deletions onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,

bool model_status = IsStateful(model);
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
bool is_fused_kvcache_reorder = false;
if (!model_status) {
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
PatchStatefulDecoder(model);
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
is_fused_kvcache_reorder = hw_target.find("GPU") != std::string::npos;
PatchStatefulDecoder(model, is_fused_kvcache_reorder);
}

if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
Expand Down Expand Up @@ -152,7 +156,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,

LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow";
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
OVExeNetwork exe(compiled_model, hw_target, true);
OVExeNetwork exe(compiled_model, hw_target, true, is_fused_kvcache_reorder);
return exe;
}

Expand Down Expand Up @@ -332,7 +336,7 @@ std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
auto infReq = compiled_model_obj.create_infer_request();
std::shared_ptr<OVInferRequest> ovInfReq;
if (is_stateful_causallm) {
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device);
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device, is_fused_kvcache_reorder);
} else {
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
}
Expand Down Expand Up @@ -377,8 +381,8 @@ void OVInferRequest::Infer() {
"In Error Couldn't start Inference");
}

StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
: OVInferRequest(std::move(infer_request)), target_device(device) {
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder)
: OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) {
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));

_npu_logits_slice_required = IsNPULogitsSliceRequired();
Expand Down Expand Up @@ -469,6 +473,32 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
FillTensor("beam_idx", ov::element::i32, {1}, 0);

if (is_fused_kvcache_reorder){
ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape();
const auto kv_num_heads = dst_idx_shape[1];
const auto kv_head_size = dst_idx_shape[3];
if (kv_src_indices.size() > 0) {
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
const auto src_idx_ptr = src_idx_tensor.data<int32_t>();
for (size_t i = 0; i < kv_src_indices.size(); ++i) {
src_idx_ptr[i] = static_cast<int32_t>(kv_src_indices[i]);
}
ovInfReq.set_tensor("src_idx", src_idx_tensor);

ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size});
const auto dst_idx_ptr = dst_idx_tensor.data<int32_t>();
for (size_t i = 0; i < kv_num_heads; ++i) {
for (size_t j = 0; j < kv_dst_indices.size(); ++j) {
std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
}
}
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
} else {
FillTensor("src_idx", ov::element::i32, {0}, 0);
FillTensor("dst_idx", ov::element::i32, {1, kv_num_heads, 0, kv_head_size}, 0);
}
}

// If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
if (prefill_use_full_chat_history) {
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
Expand Down Expand Up @@ -505,6 +535,29 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
void StatefulOVInferRequest::Infer() {
PreProcessInferRequest();
OVInferRequest::Infer();
PostProcessInferRequest();
}

void StatefulOVInferRequest::PostProcessInferRequest() {
if(is_fused_kvcache_reorder){
kv_src_indices.clear();
kv_dst_indices.clear();
}
}

void StatefulOVInferRequest::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
// Validate input parameters
if (src_indices.size() != dst_indices.size()) {
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
"Got src_indices.size()=" + std::to_string(src_indices.size()) +
", dst_indices.size()=" + std::to_string(dst_indices.size()));
}

LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
<< src_indices.size() << " index pairs";

kv_src_indices = src_indices;
kv_dst_indices = dst_indices;
}

void StatefulOVInferRequest::RewindKVCache(size_t index) {
Expand Down
19 changes: 14 additions & 5 deletions onnxruntime/core/providers/openvino/ov_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ class OVExeNetwork {
ov::CompiledModel compiled_model_obj;
std::string target_device;
bool is_stateful_causallm;
bool is_fused_kvcache_reorder = false;

public:
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false)
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {}
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool fused_kvcache_reorder = false)
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_fused_kvcache_reorder(fused_kvcache_reorder) {}
OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {}
ov::CompiledModel& Get() { return compiled_model_obj; }
std::shared_ptr<OVInferRequest> CreateInferRequest();
Expand Down Expand Up @@ -136,14 +137,16 @@ class OVInferRequest {
return ovInfReq;
}
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
virtual void ReorderKVCache([[maybe_unused]] const std::vector<int32_t>& src_indices, [[maybe_unused]] const std::vector<int32_t>& dst_indices) {}
};

class StatefulOVInferRequest : public OVInferRequest {
public:
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device);
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder = false);

void Infer() override;
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
const std::vector<size_t>& shape, int32_t fill_value);
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
Expand All @@ -153,13 +156,19 @@ class StatefulOVInferRequest : public OVInferRequest {

private:
void PreProcessInferRequest();
void PostProcessInferRequest();
std::string target_device;

std::vector<int64_t> cached_input_ids;
std::vector<int64_t> cached_position_ids;
std::vector<int32_t> kv_src_indices;
std::vector<int32_t> kv_dst_indices;

// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
// and ensure that full chat history is passed for each prefill call.
bool prefill_use_full_chat_history = false;
std::vector<int64_t> cached_input_ids;
std::vector<int64_t> cached_position_ids;
// If fused_kvcache_reorder, will include kv_sec/dst_indices as input
bool is_fused_kvcache_reorder = false;

bool IsNPULogitsSliceRequired();
bool _npu_logits_slice_required = false;
Expand Down
43 changes: 39 additions & 4 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ std::string GetInputOutputName(std::shared_ptr<ov::Model> ov_model,
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::vector<std::string>& not_kv_inputs,
const std::vector<std::string>& key_value_input_names,
int gather_dim) {
int gather_dim,
const bool is_fused_kvcache_reorder) {
if (ModelHasInputOutputNames(ov_model, "beam_idx")) {
throw std::runtime_error("Model already has fused cache");
}
Expand All @@ -91,13 +92,31 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);

auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape();

auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
beam_idx->set_friendly_name("beam_idx");
beam_idx->output(0).get_tensor().add_names({"beam_idx"});
ov_model->add_parameters({beam_idx});
not_kv_inputs.push_back(beam_idx->get_friendly_name());

std::shared_ptr<ov::opset13::Parameter> src_idx;
std::shared_ptr<ov::opset13::Parameter> dst_idx;

if (is_fused_kvcache_reorder) {
src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));
src_idx->set_friendly_name("src_idx");
src_idx->output(0).get_tensor().add_names({"src_idx"});
ov_model->add_parameters({src_idx});
not_kv_inputs.push_back(src_idx->get_friendly_name());

dst_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, update_shape);
dst_idx->set_friendly_name("dst_idx");
dst_idx->output(0).get_tensor().add_names({"dst_idx"});
ov_model->add_parameters({dst_idx});
not_kv_inputs.push_back(dst_idx->get_friendly_name());
}

// Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
for (const auto& input_name : key_value_input_names) {
auto parameter_output_port = ov_model->input(input_name);
Expand All @@ -108,9 +127,25 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
beam_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));

std::shared_ptr<ov::Node> output_node;
if (is_fused_kvcache_reorder) {
auto updatekv_gather_op =
std::make_shared<ov::opset13::Gather>(gather_op,
src_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {2}));

auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
dst_idx,
updatekv_gather_op,
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
output_node = updatekv_op;
} else {
output_node = gather_op;
}

// Replace the source output for all consumers of the input tensor
for (auto& consumer : consumers) {
consumer.replace_source_output(gather_op->output(0));
consumer.replace_source_output(output_node->output(0));
}
}

Expand Down Expand Up @@ -247,7 +282,7 @@ std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTens
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool is_fused_kvcache_reorder) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);
Expand All @@ -269,7 +304,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
auto batch_dim = 0;

FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim);
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, is_fused_kvcache_reorder);

MakeStateful(model, key_value_input_names, key_value_output_names);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "openvino/pass/manager.hpp"
#include "openvino/pass/make_stateful.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/opsets/opset13.hpp"

namespace onnxruntime {
Expand All @@ -25,13 +26,14 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::vector<std::string>& not_kv_inputs,
const std::vector<std::string>& key_value_input_names,
int gather_dim);
int gather_dim,
const bool is_fused_kvcache_reorder = false);

void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
const std::vector<std::string>& key_value_input_names,
const std::vector<std::string>& key_value_output_names);

void PatchStatefulDecoder(std::shared_ptr<ov::Model> model);
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool is_fused_kvcache_reorder = false);

bool HasOpWithType(const std::shared_ptr<const ov::Model>& function, const std::string& type_name);

Expand Down