Skip to content

Commit 4e7c895

Browse files
committed
refactor with int32 indices, string_view parsing.
move fuse flag in exenetwork.
1 parent 7d201fa commit 4e7c895

File tree

10 files changed

+90
-91
lines changed

10 files changed

+90
-91
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ void BackendManager::RewindKVCache(size_t index) {
892892
}
893893
}
894894

895-
void BackendManager::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
895+
void BackendManager::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
896896
if (concrete_backend_) {
897897
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
898898
}

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class BackendManager {
3131
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
3232
ov::CompiledModel GetOVCompiledModel();
3333
void RewindKVCache(size_t index);
34-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices);
34+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices);
3535

3636
private:
3737
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ void BasicBackend::RewindKVCache(size_t index) {
334334
});
335335
}
336336

337-
void BasicBackend::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
337+
void BasicBackend::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
338338
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
339339
infer_request->ReorderKVCache(src_indices, dst_indices);
340340
});

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class BasicBackend : public IBackend {
138138
return exe_network_.Get();
139139
}
140140
void RewindKVCache(size_t index) override;
141-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
141+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
142142

143143
private:
144144
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class IBackend {
1818
virtual ov::CompiledModel GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
2020
virtual void RewindKVCache(size_t index) {}
21-
virtual void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {}
21+
virtual void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {}
2222
};
2323
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
2424
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <string>
66
#include <memory>
77
#include <vector>
8+
#include <cerrno>
89
#include "core/providers/shared_library/provider_api.h"
910
#include "core/providers/openvino/openvino_execution_provider.h"
1011
#include "core/providers/openvino/contexts.h"
@@ -297,53 +298,54 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
297298
"kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists");
298299
}
299300

300-
std::string src_string = value.substr(0, delimiter_pos);
301-
std::string dst_string = value.substr(delimiter_pos + 1);
302-
303-
auto parse_indices = [](const std::string& input, const std::string& index_type) -> std::pair<Status, std::vector<size_t>> {
304-
std::vector<size_t> indices;
305-
std::stringstream stream(input);
306-
std::string token;
307-
308-
try {
309-
while (std::getline(stream, token, ',')) {
310-
// Trim whitespace
311-
token.erase(0, token.find_first_not_of(" \t"));
312-
token.erase(token.find_last_not_of(" \t") + 1);
313-
314-
if (!token.empty()) {
315-
int64_t index = std::stoll(token);
316-
if (index >= 0) {
317-
indices.push_back(static_cast<size_t>(index));
318-
} else {
319-
return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
320-
"kvcache_reorder " + index_type + " cannot be negative: " + std::to_string(index)),
321-
std::vector<size_t>()};
322-
}
323-
}
301+
std::string_view src_string(value.begin(), value.begin() + delimiter_pos);
302+
std::string_view dst_string(value.begin() + delimiter_pos + 1, value.end());
303+
304+
constexpr auto parse_indices = [](std::string_view input, const std::string& index_type) -> std::variant<Status, std::vector<int32_t>> {
305+
std::vector<int32_t> indices;
306+
while (!input.empty()) {
307+
const auto delimiter_pos = input.find(',');
308+
const auto part = input.substr(0, delimiter_pos);
309+
errno = 0;
310+
char* parse_end = nullptr;
311+
// strtoll/stoll already skips whitespaces
312+
const auto index = std::strtol(part.data(), &parse_end, 10);
313+
if (parse_end == part.data()) {
314+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
315+
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(part));
316+
}
317+
if (index < 0) {
318+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
319+
"kvcache_reorder " + index_type + " cannot be negative: " + std::string(part));
320+
}
321+
if (errno == ERANGE) {
322+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
323+
"kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part));
324+
}
325+
indices.push_back(static_cast<int32_t>(index));
326+
if (delimiter_pos != std::string_view::npos) {
327+
// ignore any trailing chars after the number, can do futher checking if needed
328+
input.remove_prefix(part.size() + 1);
329+
} else {
330+
break;
324331
}
325-
} catch (const std::exception& e) {
326-
return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
327-
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(e.what())),
328-
std::vector<size_t>()};
329332
}
330-
331-
return {Status::OK(), std::move(indices)};
333+
return indices;
332334
};
333335

334-
auto [src_status, src_indices] = parse_indices(src_string, "src_index");
335-
if (!src_status.IsOK()) {
336-
return src_status;
336+
const auto src_indices = parse_indices(src_string, "src_index");
337+
if (src_indices.index() == 0) {
338+
return std::get<0>(src_indices);
337339
}
338340

339-
auto [dst_status, dst_indices] = parse_indices(dst_string, "dst_index");
340-
if (!dst_status.IsOK()) {
341-
return dst_status;
341+
const auto dst_indices = parse_indices(dst_string, "dst_index");
342+
if (dst_indices.index() == 0) {
343+
return std::get<0>(dst_indices);
342344
}
343345

344346
// Trigger KVCache Reorder for target Backend with vector arguments
345347
for (auto& backend : backend_managers_) {
346-
backend.ReorderKVCache(src_indices, dst_indices);
348+
backend.ReorderKVCache(std::get<1>(src_indices), std::get<1>(dst_indices));
347349
}
348350
} else {
349351
// Handle unknown options

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
104104

105105
bool model_status = IsStateful(model);
106106
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
107+
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
108+
bool is_fused_kvcache_reorder = false;
107109
if (!model_status) {
108110
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
109-
PatchStatefulDecoder(model, hw_target);
111+
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
112+
is_fused_kvcache_reorder = hw_target.find("GPU") != std::string::npos;
113+
PatchStatefulDecoder(model, is_fused_kvcache_reorder);
110114
}
111115

112116
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
@@ -147,7 +151,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
147151

148152
LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow";
149153
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
150-
OVExeNetwork exe(compiled_model, hw_target, true);
154+
OVExeNetwork exe(compiled_model, hw_target, true, is_fused_kvcache_reorder);
151155
return exe;
152156
}
153157

@@ -313,7 +317,7 @@ std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
313317
auto infReq = compiled_model_obj.create_infer_request();
314318
std::shared_ptr<OVInferRequest> ovInfReq;
315319
if (is_stateful_causallm) {
316-
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device);
320+
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device, is_fused_kvcache_reorder);
317321
} else {
318322
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
319323
}
@@ -358,10 +362,9 @@ void OVInferRequest::Infer() {
358362
"In Error Couldn't start Inference");
359363
}
360364

361-
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
362-
: OVInferRequest(std::move(infer_request)), target_device(device) {
365+
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder)
366+
: OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) {
363367
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364-
is_support_kvcache_reorder = device.find("GPU") != std::string::npos;
365368

366369
// check if there is input_ids tensors and if the tensor type is int64,
367370
// because logic prefill_use_full_chat_history is only for specific inputs and data type
@@ -424,23 +427,23 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
424427
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
425428
FillTensor("beam_idx", ov::element::i32, {1}, 0);
426429

427-
if (is_support_kvcache_reorder){
430+
if (is_fused_kvcache_reorder){
428431
ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape();
429-
uint64_t kv_num_heads = dst_idx_shape[1];
430-
uint64_t kv_head_size = dst_idx_shape[3];
432+
const auto kv_num_heads = dst_idx_shape[1];
433+
const auto kv_head_size = dst_idx_shape[3];
431434
if (kv_src_indices.size() > 0) {
432435
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
433-
for (auto i = 0; i < kv_src_indices.size(); ++i) {
434-
src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
436+
const auto src_idx_ptr = src_idx_tensor.data<int32_t>();
437+
for (size_t i = 0; i < kv_src_indices.size(); ++i) {
438+
src_idx_ptr[i] = static_cast<int32_t>(kv_src_indices[i]);
435439
}
436440
ovInfReq.set_tensor("src_idx", src_idx_tensor);
437441

438442
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size});
439-
for (auto i = 0; i < kv_dst_indices.size(); ++i) {
440-
for (auto j = 0; j < kv_num_heads; ++j) {
441-
for (auto k = 0; k < kv_head_size; ++k) {
442-
dst_idx_tensor.data<int32_t>()[(j * kv_dst_indices.size() + i) * kv_head_size + k] = int32_t(kv_dst_indices[i]);
443-
}
443+
const auto dst_idx_ptr = dst_idx_tensor.data<int32_t>();
444+
for (size_t i = 0; i < kv_num_heads; ++i) {
445+
for (size_t j = 0; j < kv_dst_indices.size(); ++j) {
446+
std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
444447
}
445448
}
446449
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
@@ -490,13 +493,13 @@ void StatefulOVInferRequest::Infer() {
490493
}
491494

492495
void StatefulOVInferRequest::PostProcessInferRequest() {
493-
if(is_support_kvcache_reorder){
496+
if(is_fused_kvcache_reorder){
494497
kv_src_indices.clear();
495498
kv_dst_indices.clear();
496499
}
497500
}
498501

499-
void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
502+
void StatefulOVInferRequest::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
500503
// Validate input parameters
501504
if (src_indices.size() != dst_indices.size()) {
502505
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
@@ -507,12 +510,8 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
507510
LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
508511
<< src_indices.size() << " index pairs";
509512

510-
kv_src_indices.clear();
511-
kv_dst_indices.clear();
512-
for (int i = 0; i < src_indices.size(); ++i) {
513-
kv_src_indices.emplace_back(src_indices[i]);
514-
kv_dst_indices.emplace_back(dst_indices[i]);
515-
}
513+
kv_src_indices = src_indices;
514+
kv_dst_indices = dst_indices;
516515
}
517516

518517
void StatefulOVInferRequest::RewindKVCache(size_t index) {

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ class OVExeNetwork {
8787
ov::CompiledModel compiled_model_obj;
8888
std::string target_device;
8989
bool is_stateful_causallm;
90+
bool is_fused_kvcache_reorder = false;
9091

9192
public:
92-
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false)
93-
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {}
93+
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool fused_kvcache_reorder = false)
94+
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_fused_kvcache_reorder(fused_kvcache_reorder) {}
9495
OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {}
9596
ov::CompiledModel& Get() { return compiled_model_obj; }
9697
std::shared_ptr<OVInferRequest> CreateInferRequest();
@@ -132,16 +133,16 @@ class OVInferRequest {
132133
return ovInfReq;
133134
}
134135
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
135-
virtual void ReorderKVCache([[maybe_unused]] const std::vector<size_t>& src_indices, [[maybe_unused]] const std::vector<size_t>& dst_indices) {}
136+
virtual void ReorderKVCache([[maybe_unused]] const std::vector<int32_t>& src_indices, [[maybe_unused]] const std::vector<int32_t>& dst_indices) {}
136137
};
137138

138139
class StatefulOVInferRequest : public OVInferRequest {
139140
public:
140-
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device);
141+
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder = false);
141142

142143
void Infer() override;
143144
void RewindKVCache(size_t index) override;
144-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
145+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
145146
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
146147
const std::vector<size_t>& shape, int32_t fill_value);
147148
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
@@ -153,15 +154,16 @@ class StatefulOVInferRequest : public OVInferRequest {
153154
void PostProcessInferRequest();
154155
std::string target_device;
155156

156-
// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
157-
// and ensure that full chat history is passed for each prefill call.
158-
bool prefill_use_full_chat_history = false;
159157
std::vector<int64_t> cached_input_ids;
160158
std::vector<int64_t> cached_position_ids;
159+
std::vector<int32_t> kv_src_indices;
160+
std::vector<int32_t> kv_dst_indices;
161161

162-
bool is_support_kvcache_reorder = false;
163-
std::vector<int64_t> kv_src_indices;
164-
std::vector<int64_t> kv_dst_indices;
162+
// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
163+
// and ensure that full chat history is passed for each prefill call.
164+
bool prefill_use_full_chat_history = false;
165+
// If fused_kvcache_reorder, will include kv_sec/dst_indices as input
166+
bool is_fused_kvcache_reorder = false;
165167
};
166168

167169
} // namespace openvino_ep

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,11 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
7676
std::vector<std::string>& not_kv_inputs,
7777
const std::vector<std::string>& key_value_input_names,
7878
int gather_dim,
79-
const std::string& device) {
79+
const bool is_fused_kvcache_reorder) {
8080
if (ModelHasInputOutputNames(ov_model, "beam_idx")) {
8181
throw std::runtime_error("Model already has fused cache");
8282
}
8383

84-
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
85-
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
86-
bool is_support_kvcache_reorder = device.find("GPU") != std::string::npos;
87-
8884
// Define input name candidates in priority order
8985
const std::vector<std::string> input_name_candidates = {
9086
"inputs_embeds", // Default fallback
@@ -107,7 +103,7 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
107103
std::shared_ptr<ov::opset13::Parameter> src_idx;
108104
std::shared_ptr<ov::opset13::Parameter> dst_idx;
109105

110-
if (is_support_kvcache_reorder) {
106+
if (is_fused_kvcache_reorder) {
111107
src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));
112108
src_idx->set_friendly_name("src_idx");
113109
src_idx->output(0).get_tensor().add_names({"src_idx"});
@@ -132,16 +128,16 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
132128
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));
133129

134130
std::shared_ptr<ov::Node> output_node;
135-
if (is_support_kvcache_reorder) {
131+
if (is_fused_kvcache_reorder) {
136132
auto updatekv_gather_op =
137133
std::make_shared<ov::opset13::Gather>(gather_op,
138134
src_idx,
139135
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
140136

141137
auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
142-
dst_idx,
143-
updatekv_gather_op,
144-
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
138+
dst_idx,
139+
updatekv_gather_op,
140+
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
145141
output_node = updatekv_op;
146142
} else {
147143
output_node = gather_op;
@@ -287,7 +283,7 @@ std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTens
287283
}
288284

289285
// Updated PatchStatefulDecoder function
290-
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const std::string& device) {
286+
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool is_fused_kvcache_reorder) {
291287
// Use the dynamic pattern-based extraction logic
292288
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
293289
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);
@@ -309,7 +305,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const std::string& d
309305
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
310306
auto batch_dim = 0;
311307

312-
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, device);
308+
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, is_fused_kvcache_reorder);
313309

314310
MakeStateful(model, key_value_input_names, key_value_output_names);
315311
}

0 commit comments

Comments
 (0)