Skip to content

Commit 1665865

Browse files
committed
refactor with int32 indices, string_view parsing.
move fuse flag in exenetwork.
1 parent ab95e39 commit 1665865

File tree

10 files changed

+90
-91
lines changed

10 files changed

+90
-91
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ void BackendManager::RewindKVCache(size_t index) {
781781
}
782782
}
783783

784-
void BackendManager::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
784+
void BackendManager::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
785785
if (concrete_backend_) {
786786
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
787787
}

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class BackendManager {
3131
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
3232
ov::CompiledModel GetOVCompiledModel();
3333
void RewindKVCache(size_t index);
34-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices);
34+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices);
3535

3636
private:
3737
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ void BasicBackend::RewindKVCache(size_t index) {
315315
});
316316
}
317317

318-
void BasicBackend::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
318+
void BasicBackend::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
319319
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
320320
infer_request->ReorderKVCache(src_indices, dst_indices);
321321
});

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class BasicBackend : public IBackend {
137137
return exe_network_.Get();
138138
}
139139
void RewindKVCache(size_t index) override;
140-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
140+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
141141

142142
private:
143143
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class IBackend {
1818
virtual ov::CompiledModel GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
2020
virtual void RewindKVCache(size_t index) {}
21-
virtual void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {}
21+
virtual void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {}
2222
};
2323
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
2424
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <string>
66
#include <memory>
77
#include <vector>
8+
#include <cerrno>
89
#include "core/providers/shared_library/provider_api.h"
910
#include "core/providers/openvino/openvino_execution_provider.h"
1011
#include "core/providers/openvino/contexts.h"
@@ -295,53 +296,54 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
295296
"kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists");
296297
}
297298

298-
std::string src_string = value.substr(0, delimiter_pos);
299-
std::string dst_string = value.substr(delimiter_pos + 1);
300-
301-
auto parse_indices = [](const std::string& input, const std::string& index_type) -> std::pair<Status, std::vector<size_t>> {
302-
std::vector<size_t> indices;
303-
std::stringstream stream(input);
304-
std::string token;
305-
306-
try {
307-
while (std::getline(stream, token, ',')) {
308-
// Trim whitespace
309-
token.erase(0, token.find_first_not_of(" \t"));
310-
token.erase(token.find_last_not_of(" \t") + 1);
311-
312-
if (!token.empty()) {
313-
int64_t index = std::stoll(token);
314-
if (index >= 0) {
315-
indices.push_back(static_cast<size_t>(index));
316-
} else {
317-
return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
318-
"kvcache_reorder " + index_type + " cannot be negative: " + std::to_string(index)),
319-
std::vector<size_t>()};
320-
}
321-
}
299+
std::string_view src_string(value.begin(), value.begin() + delimiter_pos);
300+
std::string_view dst_string(value.begin() + delimiter_pos + 1, value.end());
301+
302+
constexpr auto parse_indices = [](std::string_view input, const std::string& index_type) -> std::variant<Status, std::vector<int32_t>> {
303+
std::vector<int32_t> indices;
304+
while (!input.empty()) {
305+
const auto delimiter_pos = input.find(',');
306+
const auto part = input.substr(0, delimiter_pos);
307+
errno = 0;
308+
char* parse_end = nullptr;
309+
// strtoll/stoll already skips whitespaces
310+
const auto index = std::strtol(part.data(), &parse_end, 10);
311+
if (parse_end == part.data()) {
312+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
313+
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(part));
314+
}
315+
if (index < 0) {
316+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
317+
"kvcache_reorder " + index_type + " cannot be negative: " + std::string(part));
318+
}
319+
if (errno == ERANGE) {
320+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
321+
"kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part));
322+
}
323+
indices.push_back(static_cast<int32_t>(index));
324+
if (delimiter_pos != std::string_view::npos) {
325+
// ignore any trailing chars after the number, can do futher checking if needed
326+
input.remove_prefix(part.size() + 1);
327+
} else {
328+
break;
322329
}
323-
} catch (const std::exception& e) {
324-
return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
325-
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(e.what())),
326-
std::vector<size_t>()};
327330
}
328-
329-
return {Status::OK(), std::move(indices)};
331+
return indices;
330332
};
331333

332-
auto [src_status, src_indices] = parse_indices(src_string, "src_index");
333-
if (!src_status.IsOK()) {
334-
return src_status;
334+
const auto src_indices = parse_indices(src_string, "src_index");
335+
if (src_indices.index() == 0) {
336+
return std::get<0>(src_indices);
335337
}
336338

337-
auto [dst_status, dst_indices] = parse_indices(dst_string, "dst_index");
338-
if (!dst_status.IsOK()) {
339-
return dst_status;
339+
const auto dst_indices = parse_indices(dst_string, "dst_index");
340+
if (dst_indices.index() == 0) {
341+
return std::get<0>(dst_indices);
340342
}
341343

342344
// Trigger KVCache Reorder for target Backend with vector arguments
343345
for (auto& backend : backend_managers_) {
344-
backend.ReorderKVCache(src_indices, dst_indices);
346+
backend.ReorderKVCache(std::get<1>(src_indices), std::get<1>(dst_indices));
345347
}
346348
} else {
347349
// Handle unknown options

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
109109

110110
bool model_status = IsStateful(model);
111111
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
112+
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
113+
bool is_fused_kvcache_reorder = false;
112114
if (!model_status) {
113115
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
114-
PatchStatefulDecoder(model, hw_target);
116+
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
117+
is_fused_kvcache_reorder = hw_target.find("GPU") != std::string::npos;
118+
PatchStatefulDecoder(model, is_fused_kvcache_reorder);
115119
}
116120

117121
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
@@ -152,7 +156,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
152156

153157
LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow";
154158
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
155-
OVExeNetwork exe(compiled_model, hw_target, true);
159+
OVExeNetwork exe(compiled_model, hw_target, true, is_fused_kvcache_reorder);
156160
return exe;
157161
}
158162

@@ -332,7 +336,7 @@ std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
332336
auto infReq = compiled_model_obj.create_infer_request();
333337
std::shared_ptr<OVInferRequest> ovInfReq;
334338
if (is_stateful_causallm) {
335-
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device);
339+
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device, is_fused_kvcache_reorder);
336340
} else {
337341
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
338342
}
@@ -377,10 +381,9 @@ void OVInferRequest::Infer() {
377381
"In Error Couldn't start Inference");
378382
}
379383

380-
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
381-
: OVInferRequest(std::move(infer_request)), target_device(device) {
384+
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder)
385+
: OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) {
382386
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
383-
is_support_kvcache_reorder = device.find("GPU") != std::string::npos;
384387

385388
_npu_logits_slice_required = IsNPULogitsSliceRequired();
386389

@@ -470,23 +473,23 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
470473
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
471474
FillTensor("beam_idx", ov::element::i32, {1}, 0);
472475

473-
if (is_support_kvcache_reorder){
476+
if (is_fused_kvcache_reorder){
474477
ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape();
475-
uint64_t kv_num_heads = dst_idx_shape[1];
476-
uint64_t kv_head_size = dst_idx_shape[3];
478+
const auto kv_num_heads = dst_idx_shape[1];
479+
const auto kv_head_size = dst_idx_shape[3];
477480
if (kv_src_indices.size() > 0) {
478481
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
479-
for (auto i = 0; i < kv_src_indices.size(); ++i) {
480-
src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
482+
const auto src_idx_ptr = src_idx_tensor.data<int32_t>();
483+
for (size_t i = 0; i < kv_src_indices.size(); ++i) {
484+
src_idx_ptr[i] = static_cast<int32_t>(kv_src_indices[i]);
481485
}
482486
ovInfReq.set_tensor("src_idx", src_idx_tensor);
483487

484488
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size});
485-
for (auto i = 0; i < kv_dst_indices.size(); ++i) {
486-
for (auto j = 0; j < kv_num_heads; ++j) {
487-
for (auto k = 0; k < kv_head_size; ++k) {
488-
dst_idx_tensor.data<int32_t>()[(j * kv_dst_indices.size() + i) * kv_head_size + k] = int32_t(kv_dst_indices[i]);
489-
}
489+
const auto dst_idx_ptr = dst_idx_tensor.data<int32_t>();
490+
for (size_t i = 0; i < kv_num_heads; ++i) {
491+
for (size_t j = 0; j < kv_dst_indices.size(); ++j) {
492+
std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
490493
}
491494
}
492495
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
@@ -536,13 +539,13 @@ void StatefulOVInferRequest::Infer() {
536539
}
537540

538541
void StatefulOVInferRequest::PostProcessInferRequest() {
539-
if(is_support_kvcache_reorder){
542+
if(is_fused_kvcache_reorder){
540543
kv_src_indices.clear();
541544
kv_dst_indices.clear();
542545
}
543546
}
544547

545-
void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
548+
void StatefulOVInferRequest::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {
546549
// Validate input parameters
547550
if (src_indices.size() != dst_indices.size()) {
548551
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
@@ -553,12 +556,8 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
553556
LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
554557
<< src_indices.size() << " index pairs";
555558

556-
kv_src_indices.clear();
557-
kv_dst_indices.clear();
558-
for (int i = 0; i < src_indices.size(); ++i) {
559-
kv_src_indices.emplace_back(src_indices[i]);
560-
kv_dst_indices.emplace_back(dst_indices[i]);
561-
}
559+
kv_src_indices = src_indices;
560+
kv_dst_indices = dst_indices;
562561
}
563562

564563
void StatefulOVInferRequest::RewindKVCache(size_t index) {

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ class OVExeNetwork {
9191
ov::CompiledModel compiled_model_obj;
9292
std::string target_device;
9393
bool is_stateful_causallm;
94+
bool is_fused_kvcache_reorder = false;
9495

9596
public:
96-
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false)
97-
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {}
97+
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool fused_kvcache_reorder = false)
98+
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_fused_kvcache_reorder(fused_kvcache_reorder) {}
9899
OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {}
99100
ov::CompiledModel& Get() { return compiled_model_obj; }
100101
std::shared_ptr<OVInferRequest> CreateInferRequest();
@@ -136,16 +137,16 @@ class OVInferRequest {
136137
return ovInfReq;
137138
}
138139
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
139-
virtual void ReorderKVCache([[maybe_unused]] const std::vector<size_t>& src_indices, [[maybe_unused]] const std::vector<size_t>& dst_indices) {}
140+
virtual void ReorderKVCache([[maybe_unused]] const std::vector<int32_t>& src_indices, [[maybe_unused]] const std::vector<int32_t>& dst_indices) {}
140141
};
141142

142143
class StatefulOVInferRequest : public OVInferRequest {
143144
public:
144-
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device);
145+
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder = false);
145146

146147
void Infer() override;
147148
void RewindKVCache(size_t index) override;
148-
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
149+
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
149150
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
150151
const std::vector<size_t>& shape, int32_t fill_value);
151152
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
@@ -158,15 +159,16 @@ class StatefulOVInferRequest : public OVInferRequest {
158159
void PostProcessInferRequest();
159160
std::string target_device;
160161

161-
// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
162-
// and ensure that full chat history is passed for each prefill call.
163-
bool prefill_use_full_chat_history = false;
164162
std::vector<int64_t> cached_input_ids;
165163
std::vector<int64_t> cached_position_ids;
164+
std::vector<int32_t> kv_src_indices;
165+
std::vector<int32_t> kv_dst_indices;
166166

167-
bool is_support_kvcache_reorder = false;
168-
std::vector<int64_t> kv_src_indices;
169-
std::vector<int64_t> kv_dst_indices;
167+
// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
168+
// and ensure that full chat history is passed for each prefill call.
169+
bool prefill_use_full_chat_history = false;
170+
// If fused_kvcache_reorder, will include kv_sec/dst_indices as input
171+
bool is_fused_kvcache_reorder = false;
170172

171173
bool IsNPULogitsSliceRequired();
172174
bool _npu_logits_slice_required = false;

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,11 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
7676
std::vector<std::string>& not_kv_inputs,
7777
const std::vector<std::string>& key_value_input_names,
7878
int gather_dim,
79-
const std::string& device) {
79+
const bool is_fused_kvcache_reorder) {
8080
if (ModelHasInputOutputNames(ov_model, "beam_idx")) {
8181
throw std::runtime_error("Model already has fused cache");
8282
}
8383

84-
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
85-
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
86-
bool is_support_kvcache_reorder = device.find("GPU") != std::string::npos;
87-
8884
// Define input name candidates in priority order
8985
const std::vector<std::string> input_name_candidates = {
9086
"inputs_embeds", // Default fallback
@@ -107,7 +103,7 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
107103
std::shared_ptr<ov::opset13::Parameter> src_idx;
108104
std::shared_ptr<ov::opset13::Parameter> dst_idx;
109105

110-
if (is_support_kvcache_reorder) {
106+
if (is_fused_kvcache_reorder) {
111107
src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));
112108
src_idx->set_friendly_name("src_idx");
113109
src_idx->output(0).get_tensor().add_names({"src_idx"});
@@ -132,16 +128,16 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
132128
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));
133129

134130
std::shared_ptr<ov::Node> output_node;
135-
if (is_support_kvcache_reorder) {
131+
if (is_fused_kvcache_reorder) {
136132
auto updatekv_gather_op =
137133
std::make_shared<ov::opset13::Gather>(gather_op,
138134
src_idx,
139135
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
140136

141137
auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
142-
dst_idx,
143-
updatekv_gather_op,
144-
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
138+
dst_idx,
139+
updatekv_gather_op,
140+
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
145141
output_node = updatekv_op;
146142
} else {
147143
output_node = gather_op;
@@ -286,7 +282,7 @@ std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTens
286282
}
287283

288284
// Updated PatchStatefulDecoder function
289-
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const std::string& device) {
285+
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool is_fused_kvcache_reorder) {
290286
// Use the dynamic pattern-based extraction logic
291287
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
292288
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);
@@ -308,7 +304,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const std::string& d
308304
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
309305
auto batch_dim = 0;
310306

311-
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, device);
307+
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, is_fused_kvcache_reorder);
312308

313309
MakeStateful(model, key_value_input_names, key_value_output_names);
314310
}

0 commit comments

Comments
 (0)