From a9de0bdc9574c97664855a92441b92d0e199510e Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 14 Oct 2025 03:51:17 +0000 Subject: [PATCH] Rope Change compatible with varlen attention --- candle-core/src/gcu_backend.rs | 38 ++++++++++------ candle-core/src/tensor_cat.rs | 4 +- candle-nn/src/ops.rs | 48 ++++++++++----------- candle-transformers/src/models/chatglm.rs | 7 ++- candle-transformers/src/models/gemma.rs | 11 ++--- candle-transformers/src/models/glm4.rs | 7 ++- candle-transformers/src/models/llama.rs | 12 +++--- candle-transformers/src/models/mistral.rs | 11 +++-- candle-transformers/src/models/mixformer.rs | 15 +++---- candle-transformers/src/models/phi.rs | 8 ++-- candle-transformers/src/models/phi3.rs | 37 ++++++---------- candle-transformers/src/models/qwen2.rs | 7 +-- candle-transformers/src/models/stable_lm.rs | 9 ++-- candle-transformers/src/models/yi.rs | 11 +++-- 14 files changed, 106 insertions(+), 119 deletions(-) diff --git a/candle-core/src/gcu_backend.rs b/candle-core/src/gcu_backend.rs index d34554e664..4286c8b1c0 100644 --- a/candle-core/src/gcu_backend.rs +++ b/candle-core/src/gcu_backend.rs @@ -2689,8 +2689,7 @@ impl BackendStorage for GcuStorage { pub struct Rope { pub cos_sin_length: i32, pub cos_sin_stride: i32, - pub index_positions: Vec, - pub batch: i32, + pub index_positions: crate::Tensor, pub num_tokens: i32, pub q_head_size: i32, pub k_head_size: i32, @@ -2737,10 +2736,18 @@ impl crate::CustomOp3 for Rope { }; let shape = query_l.shape(); - let positions = dev.htod_copy(self.index_positions.to_vec()).w()?; + let (positions, positions_l) = self.index_positions.storage_and_layout(); + let positions = match &*positions { + crate::Storage::Gcu(p) => p, + _ => panic!("positions must be a gcu tensor"), + }; - match (&query.slice, &key.slice) { - (GcuStorageSlice::BF16(query_), GcuStorageSlice::BF16(key_)) => { + match (&query.slice, &key.slice, &positions.slice) { + ( + GcuStorageSlice::BF16(query_), + GcuStorageSlice::BF16(key_), + GcuStorageSlice::I64(positions_), + ) => { let (func, cos_sin_ptr) = match &cos_sin.slice { GcuStorageSlice::BF16(cos_sin_) => ( dev.get_or_load_func("rope_bf16", ubridge::EMBEDDING)?, @@ -2758,8 +2765,7 @@ impl crate::CustomOp3 for Rope { cos_sin_ptr, self.cos_sin_length, self.cos_sin_stride, - positions.device_ptr(), - self.batch, + positions_.device_ptr(), self.num_tokens, self.q_head_size, self.k_head_size, @@ -2769,7 +2775,11 @@ impl crate::CustomOp3 for Rope { ); unsafe { func.launch(cfg, params) }.w()?; } - (GcuStorageSlice::F32(query_), GcuStorageSlice::F32(key_)) => { + ( + GcuStorageSlice::F32(query_), + GcuStorageSlice::F32(key_), + GcuStorageSlice::I64(positions_), + ) => { let (func, cos_sin_ptr) = match &cos_sin.slice { GcuStorageSlice::F32(cos_sin_) => ( dev.get_or_load_func("rope_f32", ubridge::EMBEDDING)?, @@ -2783,8 +2793,7 @@ impl crate::CustomOp3 for Rope { cos_sin_ptr, self.cos_sin_length, self.cos_sin_stride, - positions.device_ptr(), - self.batch, + positions_.device_ptr(), self.num_tokens, self.q_head_size, self.k_head_size, @@ -2794,7 +2803,11 @@ impl crate::CustomOp3 for Rope { ); unsafe { func.launch(cfg, params) }.w()?; } - (GcuStorageSlice::F16(query_), GcuStorageSlice::F16(key_)) => { + ( + GcuStorageSlice::F16(query_), + GcuStorageSlice::F16(key_), + GcuStorageSlice::I64(positions_), + ) => { let (func, cos_sin_ptr) = match &cos_sin.slice { GcuStorageSlice::F16(cos_sin_) => ( dev.get_or_load_func("rope_f16", ubridge::EMBEDDING)?, @@ -2812,8 +2825,7 @@ impl crate::CustomOp3 for Rope { cos_sin_ptr, self.cos_sin_length, self.cos_sin_stride, - positions.device_ptr(), - self.batch, + positions_.device_ptr(), self.num_tokens, self.q_head_size, self.k_head_size, diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index ab9212059d..c24532d4f8 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -67,10 +67,10 @@ impl Tensor { } else { let args: Vec = args .iter() - .map(|a| a.as_ref().transpose(0, dim)) + .map(|a| a.as_ref().transpose(0, dim)?.contiguous()) .collect::>>()?; let cat = Self::cat0(&args)?; - cat.transpose(0, dim) + cat.transpose(0, dim)?.contiguous() } } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index f82b63dd91..09c623963a 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1056,7 +1056,7 @@ pub fn apply_rotary_emb_qkv( key: &Tensor, cos_sin: &Tensor, _: &Tensor, - index_positions: &Vec, + index_positions: &Tensor, split_dim: usize, query_key_transposed: bool, gpt_neox: bool, @@ -1067,8 +1067,7 @@ pub fn apply_rotary_emb_qkv( cos_sin: &Tensor, cos_sin_length: i32, cos_sin_stride: i32, - index_positions: &Vec, - batch: i32, + index_positions: &Tensor, num_tokens: i32, q_head_size: i32, k_head_size: i32, @@ -1081,7 +1080,6 @@ pub fn apply_rotary_emb_qkv( cos_sin_length, cos_sin_stride, index_positions: index_positions.clone(), - batch, num_tokens, q_head_size, k_head_size, @@ -1105,8 +1103,7 @@ pub fn apply_rotary_emb_qkv( cos_sin_length as i32, cos_sin_stride as i32, index_positions, - b_sz as i32, - seq_len as i32, + (b_sz * seq_len) as i32, q_head_size as i32, k_head_size as i32, hidden_size as i32, @@ -1126,8 +1123,7 @@ pub fn apply_rotary_emb_qkv( cos_sin_length as i32, cos_sin_stride as i32, index_positions, - b_sz as i32, - seq_len as i32, + (b_sz * seq_len) as i32, q_head_size as i32, k_head_size as i32, hidden_size as i32, @@ -1144,7 +1140,7 @@ pub fn apply_rotary_emb_qkv( k: &Tensor, cos: &Tensor, sin: &Tensor, - index_pos: usize, + index_pos: &usize, split_dim: usize, query_key_transposed: bool, gpt_neox: bool, @@ -1160,8 +1156,8 @@ pub fn apply_rotary_emb_qkv( Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) } let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; - let cos = cos.narrow(0, index_pos, seq_len)?; - let sin = sin.narrow(0, index_pos, seq_len)?; + let cos = cos.narrow(0, *index_pos, seq_len)?; + let sin = sin.narrow(0, *index_pos, seq_len)?; let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; @@ -1191,7 +1187,7 @@ pub fn partial_rotary_emb_qkv( &key_rot, &cos_sin, &sin, - index_pos, + &index_pos, 0, query_key_transposed, true, @@ -1222,7 +1218,7 @@ pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result Result { - Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous() + Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)? } #[cfg(feature = "gcu")] @@ -1807,7 +1803,7 @@ fn update_cache< let v = v.as_gcu_slice::()?; let kc = kc.as_gcu_slice::()?; let vc = vc.as_gcu_slice::()?; - let s = s.as_gcu_slice::()?; + let s = s.as_gcu_slice::()?; // Get cuda views for all tensors let k = k.slice(k_l.start_offset()..); @@ -2496,7 +2492,6 @@ pub fn expert_mask(input: &Tensor, v: u32) -> Result { )?) } - //input: [batch, M (topk or 1), k] //weight: [num_experts, n, k] //indices: [batch, topk] @@ -2527,21 +2522,23 @@ fn indexed_moe_func< let (indices_value, indices_l) = indices.storage_and_layout(); assert!( - input.dims().len() == 3 - && weight.dims().len() == 3 - && indices.dims().len() == 2, + input.dims().len() == 3 && weight.dims().len() == 3 && indices.dims().len() == 2, "Invalid input dims!" ); let (b1, topk) = indices.dims2()?; let (batch, m, k) = input.dims3()?; let (num_experts, n, k1) = weight.dims3()?; - let tile_size = if batch > 12 { - 128 - } else { - 64 - }; - assert!(k % tile_size ==0, "indexed_moe: k dim must be aligned to {}!", tile_size); - assert!(n % tile_size ==0, "indexed_moe: n dim must be aligned to {}!", tile_size); + let tile_size = if batch > 12 { 128 } else { 64 }; + assert!( + k % tile_size == 0, + "indexed_moe: k dim must be aligned to {}!", + tile_size + ); + assert!( + n % tile_size == 0, + "indexed_moe: n dim must be aligned to {}!", + tile_size + ); // let (b2, _, n1) = out_l.dims3()?; @@ -2584,7 +2581,6 @@ fn indexed_moe_func< let indices_value = indices_value.as_gcu_slice::()?; let indices_value = indices_value.slice(indices_l.start_offset()..); - match input.dtype() { DType::F16 => unsafe { indexed_moe_f16( diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index b67ddc8f18..504277e56b 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -76,7 +76,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((cfg.seq_length, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; Ok(Self { cache, cos_sin }) } @@ -294,15 +294,14 @@ impl SelfAttention { // let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?; let rot_dim = rotary_emb.cache.dim(D::Minus2)? * 2; - let mut input_positions = Vec::::new(); - input_positions.push(seqlen_offset as i32); #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?; let (query_layer, key_layer) = candle_nn::ops::apply_rotary_emb_qkv( &query_layer, &key_layer, &rotary_emb.cos_sin, &rotary_emb.cache, - &input_positions, + &seqlen_offset, rot_dim, false, false, diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 123978919a..c71e05ec62 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -71,7 +71,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, @@ -206,11 +206,8 @@ impl Attention { (q, k, v.contiguous()?) }; - // let (query_states, key_states) = - // self.rotary_emb - // .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; - let mut input_positions = Vec::::new(); - input_positions.push(seqlen_offset as i32); + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &query_states.device())?; let (query_states, key_states) = apply_rotary_emb_qkv( &query_states, &key_states, @@ -220,7 +217,7 @@ impl Attention { &self.rotary_emb.cos }, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, 0, true, true, diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index b51f543450..d940375533 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -80,7 +80,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((cfg.seq_length, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; Ok(Self { cache, cos_sin }) } @@ -301,15 +301,14 @@ impl SelfAttention { // let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?; // let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?; - let mut input_positions = Vec::::new(); - input_positions.push(seqlen_offset as i32); #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?; let (query_layer, key_layer) = candle_nn::ops::apply_rotary_emb_qkv( &query_layer, &key_layer, &rotary_emb.cos_sin, &rotary_emb.cache, - &input_positions, + &seqlen_offset, rotary_emb.cache.dim(D::Minus2)? * 2, false, false, diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index cd4c3645ac..21cc6cd07c 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -203,8 +203,7 @@ impl Cache { .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let cos_sin = - Tensor::cat(&[&idx_theta.cos()?, &idx_theta.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&idx_theta.cos()?, &idx_theta.sin()?], D::Minus1)?; //must be contiguous tensor; let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; @@ -293,10 +292,9 @@ impl CausalSelfAttention { .transpose(1, 2)?; (q, k, v.contiguous()?) }; - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(index_pos as i32); - } + + #[cfg(feature = "gcu")] + let index_pos = Tensor::new(index_pos as i64, &q.device())?; let (q, mut k) = candle_nn::apply_rotary_emb_qkv( &q, &k, @@ -306,7 +304,7 @@ impl CausalSelfAttention { &cache.cos }, &cache.sin, - &input_positions, + &index_pos, 0, true, true, diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 676dda5664..9138abf4e6 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -130,7 +130,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { @@ -275,10 +275,9 @@ impl Attention { .transpose(1, 2)?; (q, k, v.contiguous()?) }; - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(seqlen_offset as i32); - } + + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &query_states.device())?; let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv( &query_states, &key_states, @@ -288,7 +287,7 @@ impl Attention { &self.rotary_emb.cos }, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, 0, true, true, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index ece359e94b..1432582892 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -279,29 +279,28 @@ impl MHA { seqlen_offset: usize, ) -> Result<(Tensor, Tensor, Tensor)> { let (b_sz, _, _, _, _) = qkv.dims5()?; - if qkv.device().is_gcu() { + #[cfg(feature = "gcu")] + { let q = qkv.i((.., .., 0))?; let k = qkv.i((.., .., 1))?; let v = qkv.i((.., .., 2))?; let (_, rotary_dim) = self.rotary_emb.cos.dims2()?; let rotary_dim = rotary_dim * 2; - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(seqlen_offset as i32); - } - #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &q.device())?; + let (q, k) = candle_nn::apply_rotary_emb_qkv( &q, &k, &self.rotary_emb.cos_sin, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, rotary_dim, false, true, )?; Ok((q, k, v)) - } else { + } + #[cfg(not(feature = "gcu"))] { self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) } } diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index d09c1e032e..c399e8e4c4 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -71,7 +71,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((cfg.max_position_embeddings, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { dim, @@ -243,8 +243,8 @@ impl Attention { true, )?; - let mut input_positions = Vec::::new(); - input_positions.push(seqlen_offset as i32); + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &query_states.device())?; #[cfg(feature = "gcu")] let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv( @@ -252,7 +252,7 @@ impl Attention { &key_states, &self.rotary_emb.cos_sin, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, self.rotary_emb.dim, true, true, diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index 2c07339050..1002704b84 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -69,7 +69,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -84,28 +84,19 @@ impl RotaryEmbedding { seqlen_offset: usize, ) -> Result<(Tensor, Tensor)> { let (b_sz, _h, seq_len, _n_embd) = q.dims4()?; - if q.device().is_gcu() { - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(seqlen_offset as i32); - } - candle_nn::apply_rotary_emb_qkv( - q, - k, - &self.cos_sin, - &self.sin, - &input_positions, - 0, - true, - true, - ) - } else { - let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; - let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; - Ok((q_embed, k_embed)) - } + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &q.device())?; + + candle_nn::apply_rotary_emb_qkv( + q, + k, + &self.cos_sin, + &self.sin, + &seqlen_offset, + 0, + true, + true, + ) } } diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 0e43603567..cd37280d0b 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -186,8 +186,9 @@ impl Attention { .transpose(1, 2)?; (q, k, v.contiguous()?) }; - let mut input_positions = Vec::::new(); - input_positions.push(seqlen_offset as i32); + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?; + let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv( &query_states, &key_states, @@ -197,7 +198,7 @@ impl Attention { &self.rotary_emb.cos }, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, 0, true, true, diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 6c4eac5c8e..0de4297fc5 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -278,18 +278,15 @@ impl Attention { true, )?; - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(seqlen_offset as i32); - } - #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?; + let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv( &query_states, &key_states, &self.rotary_emb.cos_sin, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, self.rotary_ndims, true, true, diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 28efb05d7f..1d631514c7 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -94,7 +94,7 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor; let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, @@ -224,10 +224,9 @@ impl Attention { .transpose(1, 2)?; (q, k, v.contiguous()?) }; - let mut input_positions = Vec::::new(); - for _ in 0..b_sz { - input_positions.push(seqlen_offset as i32); - } + #[cfg(feature = "gcu")] + let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?; + let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv( &query_states, &key_states, @@ -237,7 +236,7 @@ impl Attention { &self.rotary_emb.cos }, &self.rotary_emb.sin, - &input_positions, + &seqlen_offset, 0, true, true,