@@ -104,9 +104,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
104104
105105 bool model_status = IsStateful (model);
106106 LOGS_DEFAULT (INFO) << log_tag << " Model IsStateful() Status:\t " << (model_status ? " True" : " False" );
107+ // Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
108+ bool is_fused_kvcache_reorder = false ;
107109 if (!model_status) {
108110 LOGS_DEFAULT (INFO) << log_tag << " Converting from Stateless OV Model to Stateful OV Model" << std::endl;
109- PatchStatefulDecoder (model, hw_target);
111+ // TO-DO: extend to NPU device when OpenVINO NPU has related optimization
112+ is_fused_kvcache_reorder = hw_target.find (" GPU" ) != std::string::npos;
113+ PatchStatefulDecoder (model, is_fused_kvcache_reorder);
110114 }
111115
112116 if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled ()) {
@@ -147,7 +151,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
147151
148152 LOGS_DEFAULT (INFO) << log_tag << " Compiling OV Model using Stateful Transformation flow" ;
149153 compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
150- OVExeNetwork exe (compiled_model, hw_target, true );
154+ OVExeNetwork exe (compiled_model, hw_target, true , is_fused_kvcache_reorder );
151155 return exe;
152156}
153157
@@ -313,7 +317,7 @@ std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
313317 auto infReq = compiled_model_obj.create_infer_request ();
314318 std::shared_ptr<OVInferRequest> ovInfReq;
315319 if (is_stateful_causallm) {
316- ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), target_device);
320+ ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), target_device, is_fused_kvcache_reorder );
317321 } else {
318322 ovInfReq = std::make_shared<OVInferRequest>(std::move (infReq));
319323 }
@@ -358,10 +362,9 @@ void OVInferRequest::Infer() {
358362 " In Error Couldn't start Inference" );
359363}
360364
361- StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string device)
362- : OVInferRequest(std::move(infer_request)), target_device(device) {
365+ StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder )
366+ : OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) {
363367 bool gpu_or_npu = ((device.find (" NPU" ) != std::string::npos) || (device.find (" GPU" ) != std::string::npos));
364- is_support_kvcache_reorder = device.find (" GPU" ) != std::string::npos;
365368
366369 // check if there is input_ids tensors and if the tensor type is int64,
367370 // because logic prefill_use_full_chat_history is only for specific inputs and data type
@@ -424,23 +427,23 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
424427 // TODO(ankit): Address this issue and implement the fix at the appropriate layer.
425428 FillTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
426429
427- if (is_support_kvcache_reorder ){
430+ if (is_fused_kvcache_reorder ){
428431 ov::Shape dst_idx_shape = ovInfReq.get_tensor (" dst_idx" ).get_shape ();
429- uint64_t kv_num_heads = dst_idx_shape[1 ];
430- uint64_t kv_head_size = dst_idx_shape[3 ];
432+ const auto kv_num_heads = dst_idx_shape[1 ];
433+ const auto kv_head_size = dst_idx_shape[3 ];
431434 if (kv_src_indices.size () > 0 ) {
432435 ov::Tensor src_idx_tensor = ov::Tensor (ov::element::i32 , {kv_src_indices.size ()});
433- for (auto i = 0 ; i < kv_src_indices.size (); ++i) {
434- src_idx_tensor.data <int32_t >()[i] = int32_t (kv_src_indices[i]);
436+ const auto src_idx_ptr = src_idx_tensor.data <int32_t >();
437+ for (size_t i = 0 ; i < kv_src_indices.size (); ++i) {
438+ src_idx_ptr[i] = static_cast <int32_t >(kv_src_indices[i]);
435439 }
436440 ovInfReq.set_tensor (" src_idx" , src_idx_tensor);
437441
438442 ov::Tensor dst_idx_tensor = ov::Tensor (ov::element::i32 , {1 , kv_num_heads, kv_dst_indices.size (), kv_head_size});
439- for (auto i = 0 ; i < kv_dst_indices.size (); ++i) {
440- for (auto j = 0 ; j < kv_num_heads; ++j) {
441- for (auto k = 0 ; k < kv_head_size; ++k) {
442- dst_idx_tensor.data <int32_t >()[(j * kv_dst_indices.size () + i) * kv_head_size + k] = int32_t (kv_dst_indices[i]);
443- }
443+ const auto dst_idx_ptr = dst_idx_tensor.data <int32_t >();
444+ for (size_t i = 0 ; i < kv_num_heads; ++i) {
445+ for (size_t j = 0 ; j < kv_dst_indices.size (); ++j) {
446+ std::fill_n (dst_idx_ptr + (i * kv_dst_indices.size () + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
444447 }
445448 }
446449 ovInfReq.set_tensor (" dst_idx" , dst_idx_tensor);
@@ -490,13 +493,13 @@ void StatefulOVInferRequest::Infer() {
490493}
491494
492495void StatefulOVInferRequest::PostProcessInferRequest () {
493- if (is_support_kvcache_reorder ){
496+ if (is_fused_kvcache_reorder ){
494497 kv_src_indices.clear ();
495498 kv_dst_indices.clear ();
496499 }
497500}
498501
499- void StatefulOVInferRequest::ReorderKVCache (const std::vector<size_t >& src_indices, const std::vector<size_t >& dst_indices) {
502+ void StatefulOVInferRequest::ReorderKVCache (const std::vector<int32_t >& src_indices, const std::vector<int32_t >& dst_indices) {
500503 // Validate input parameters
501504 if (src_indices.size () != dst_indices.size ()) {
502505 ORT_THROW (log_tag + " ReorderKVCache: src_indices and dst_indices must have the same size. "
@@ -507,12 +510,8 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
507510 LOGS_DEFAULT (INFO) << log_tag << " ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
508511 << src_indices.size () << " index pairs" ;
509512
510- kv_src_indices.clear ();
511- kv_dst_indices.clear ();
512- for (int i = 0 ; i < src_indices.size (); ++i) {
513- kv_src_indices.emplace_back (src_indices[i]);
514- kv_dst_indices.emplace_back (dst_indices[i]);
515- }
513+ kv_src_indices = src_indices;
514+ kv_dst_indices = dst_indices;
516515}
517516
518517void StatefulOVInferRequest::RewindKVCache (size_t index) {
0 commit comments