@@ -109,9 +109,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
109109
110110 bool model_status = IsStateful (model);
111111 LOGS_DEFAULT (INFO) << log_tag << " Model IsStateful() Status:\t " << (model_status ? " True" : " False" );
112+ // Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
113+ bool is_fused_kvcache_reorder = false ;
112114 if (!model_status) {
113115 LOGS_DEFAULT (INFO) << log_tag << " Converting from Stateless OV Model to Stateful OV Model" << std::endl;
114- PatchStatefulDecoder (model, hw_target);
116+ // TO-DO: extend to NPU device when OpenVINO NPU has related optimization
117+ is_fused_kvcache_reorder = hw_target.find (" GPU" ) != std::string::npos;
118+ PatchStatefulDecoder (model, is_fused_kvcache_reorder);
115119 }
116120
117121 if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled ()) {
@@ -152,7 +156,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
152156
153157 LOGS_DEFAULT (INFO) << log_tag << " Compiling OV Model using Stateful Transformation flow" ;
154158 compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
155- OVExeNetwork exe (compiled_model, hw_target, true );
159+ OVExeNetwork exe (compiled_model, hw_target, true , is_fused_kvcache_reorder );
156160 return exe;
157161}
158162
@@ -332,7 +336,7 @@ std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
332336 auto infReq = compiled_model_obj.create_infer_request ();
333337 std::shared_ptr<OVInferRequest> ovInfReq;
334338 if (is_stateful_causallm) {
335- ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), target_device);
339+ ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), target_device, is_fused_kvcache_reorder );
336340 } else {
337341 ovInfReq = std::make_shared<OVInferRequest>(std::move (infReq));
338342 }
@@ -377,10 +381,9 @@ void OVInferRequest::Infer() {
377381 " In Error Couldn't start Inference" );
378382}
379383
380- StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string device)
381- : OVInferRequest(std::move(infer_request)), target_device(device) {
384+ StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string device, bool fused_kvcache_reorder )
385+ : OVInferRequest(std::move(infer_request)), target_device(device), is_fused_kvcache_reorder(fused_kvcache_reorder) {
382386 bool gpu_or_npu = ((device.find (" NPU" ) != std::string::npos) || (device.find (" GPU" ) != std::string::npos));
383- is_support_kvcache_reorder = device.find (" GPU" ) != std::string::npos;
384387
385388 _npu_logits_slice_required = IsNPULogitsSliceRequired ();
386389
@@ -470,23 +473,23 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
470473 // TODO(ankit): Address this issue and implement the fix at the appropriate layer.
471474 FillTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
472475
473- if (is_support_kvcache_reorder ){
476+ if (is_fused_kvcache_reorder ){
474477 ov::Shape dst_idx_shape = ovInfReq.get_tensor (" dst_idx" ).get_shape ();
475- uint64_t kv_num_heads = dst_idx_shape[1 ];
476- uint64_t kv_head_size = dst_idx_shape[3 ];
478+ const auto kv_num_heads = dst_idx_shape[1 ];
479+ const auto kv_head_size = dst_idx_shape[3 ];
477480 if (kv_src_indices.size () > 0 ) {
478481 ov::Tensor src_idx_tensor = ov::Tensor (ov::element::i32 , {kv_src_indices.size ()});
479- for (auto i = 0 ; i < kv_src_indices.size (); ++i) {
480- src_idx_tensor.data <int32_t >()[i] = int32_t (kv_src_indices[i]);
482+ const auto src_idx_ptr = src_idx_tensor.data <int32_t >();
483+ for (size_t i = 0 ; i < kv_src_indices.size (); ++i) {
484+ src_idx_ptr[i] = static_cast <int32_t >(kv_src_indices[i]);
481485 }
482486 ovInfReq.set_tensor (" src_idx" , src_idx_tensor);
483487
484488 ov::Tensor dst_idx_tensor = ov::Tensor (ov::element::i32 , {1 , kv_num_heads, kv_dst_indices.size (), kv_head_size});
485- for (auto i = 0 ; i < kv_dst_indices.size (); ++i) {
486- for (auto j = 0 ; j < kv_num_heads; ++j) {
487- for (auto k = 0 ; k < kv_head_size; ++k) {
488- dst_idx_tensor.data <int32_t >()[(j * kv_dst_indices.size () + i) * kv_head_size + k] = int32_t (kv_dst_indices[i]);
489- }
489+ const auto dst_idx_ptr = dst_idx_tensor.data <int32_t >();
490+ for (size_t i = 0 ; i < kv_num_heads; ++i) {
491+ for (size_t j = 0 ; j < kv_dst_indices.size (); ++j) {
492+ std::fill_n (dst_idx_ptr + (i * kv_dst_indices.size () + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
490493 }
491494 }
492495 ovInfReq.set_tensor (" dst_idx" , dst_idx_tensor);
@@ -536,13 +539,13 @@ void StatefulOVInferRequest::Infer() {
536539}
537540
538541void StatefulOVInferRequest::PostProcessInferRequest () {
539- if (is_support_kvcache_reorder ){
542+ if (is_fused_kvcache_reorder ){
540543 kv_src_indices.clear ();
541544 kv_dst_indices.clear ();
542545 }
543546}
544547
545- void StatefulOVInferRequest::ReorderKVCache (const std::vector<size_t >& src_indices, const std::vector<size_t >& dst_indices) {
548+ void StatefulOVInferRequest::ReorderKVCache (const std::vector<int32_t >& src_indices, const std::vector<int32_t >& dst_indices) {
546549 // Validate input parameters
547550 if (src_indices.size () != dst_indices.size ()) {
548551 ORT_THROW (log_tag + " ReorderKVCache: src_indices and dst_indices must have the same size. "
@@ -553,12 +556,8 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
553556 LOGS_DEFAULT (INFO) << log_tag << " ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
554557 << src_indices.size () << " index pairs" ;
555558
556- kv_src_indices.clear ();
557- kv_dst_indices.clear ();
558- for (int i = 0 ; i < src_indices.size (); ++i) {
559- kv_src_indices.emplace_back (src_indices[i]);
560- kv_dst_indices.emplace_back (dst_indices[i]);
561- }
559+ kv_src_indices = src_indices;
560+ kv_dst_indices = dst_indices;
562561}
563562
564563void StatefulOVInferRequest::RewindKVCache (size_t index) {
0 commit comments