Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions candle-core/src/gcu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2689,8 +2689,7 @@ impl BackendStorage for GcuStorage {
pub struct Rope {
pub cos_sin_length: i32,
pub cos_sin_stride: i32,
pub index_positions: Vec<i32>,
pub batch: i32,
pub index_positions: crate::Tensor,
pub num_tokens: i32,
pub q_head_size: i32,
pub k_head_size: i32,
Expand Down Expand Up @@ -2737,10 +2736,18 @@ impl crate::CustomOp3 for Rope {
};
let shape = query_l.shape();

let positions = dev.htod_copy(self.index_positions.to_vec()).w()?;
let (positions, positions_l) = self.index_positions.storage_and_layout();
let positions = match &*positions {
crate::Storage::Gcu(p) => p,
_ => panic!("positions must be a gcu tensor"),
};

match (&query.slice, &key.slice) {
(GcuStorageSlice::BF16(query_), GcuStorageSlice::BF16(key_)) => {
match (&query.slice, &key.slice, &positions.slice) {
(
GcuStorageSlice::BF16(query_),
GcuStorageSlice::BF16(key_),
GcuStorageSlice::I64(positions_),
) => {
let (func, cos_sin_ptr) = match &cos_sin.slice {
GcuStorageSlice::BF16(cos_sin_) => (
dev.get_or_load_func("rope_bf16", ubridge::EMBEDDING)?,
Expand All @@ -2758,8 +2765,7 @@ impl crate::CustomOp3 for Rope {
cos_sin_ptr,
self.cos_sin_length,
self.cos_sin_stride,
positions.device_ptr(),
self.batch,
positions_.device_ptr(),
self.num_tokens,
self.q_head_size,
self.k_head_size,
Expand All @@ -2769,7 +2775,11 @@ impl crate::CustomOp3 for Rope {
);
unsafe { func.launch(cfg, params) }.w()?;
}
(GcuStorageSlice::F32(query_), GcuStorageSlice::F32(key_)) => {
(
GcuStorageSlice::F32(query_),
GcuStorageSlice::F32(key_),
GcuStorageSlice::I64(positions_),
) => {
let (func, cos_sin_ptr) = match &cos_sin.slice {
GcuStorageSlice::F32(cos_sin_) => (
dev.get_or_load_func("rope_f32", ubridge::EMBEDDING)?,
Expand All @@ -2783,8 +2793,7 @@ impl crate::CustomOp3 for Rope {
cos_sin_ptr,
self.cos_sin_length,
self.cos_sin_stride,
positions.device_ptr(),
self.batch,
positions_.device_ptr(),
self.num_tokens,
self.q_head_size,
self.k_head_size,
Expand All @@ -2794,7 +2803,11 @@ impl crate::CustomOp3 for Rope {
);
unsafe { func.launch(cfg, params) }.w()?;
}
(GcuStorageSlice::F16(query_), GcuStorageSlice::F16(key_)) => {
(
GcuStorageSlice::F16(query_),
GcuStorageSlice::F16(key_),
GcuStorageSlice::I64(positions_),
) => {
let (func, cos_sin_ptr) = match &cos_sin.slice {
GcuStorageSlice::F16(cos_sin_) => (
dev.get_or_load_func("rope_f16", ubridge::EMBEDDING)?,
Expand All @@ -2812,8 +2825,7 @@ impl crate::CustomOp3 for Rope {
cos_sin_ptr,
self.cos_sin_length,
self.cos_sin_stride,
positions.device_ptr(),
self.batch,
positions_.device_ptr(),
self.num_tokens,
self.q_head_size,
self.k_head_size,
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/tensor_cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ impl Tensor {
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.map(|a| a.as_ref().transpose(0, dim)?.contiguous())
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
cat.transpose(0, dim)?.contiguous()
}
}

Expand Down
48 changes: 22 additions & 26 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ pub fn apply_rotary_emb_qkv(
key: &Tensor,
cos_sin: &Tensor,
_: &Tensor,
index_positions: &Vec<i32>,
index_positions: &Tensor,
split_dim: usize,
query_key_transposed: bool,
gpt_neox: bool,
Expand All @@ -1067,8 +1067,7 @@ pub fn apply_rotary_emb_qkv(
cos_sin: &Tensor,
cos_sin_length: i32,
cos_sin_stride: i32,
index_positions: &Vec<i32>,
batch: i32,
index_positions: &Tensor,
num_tokens: i32,
q_head_size: i32,
k_head_size: i32,
Expand All @@ -1081,7 +1080,6 @@ pub fn apply_rotary_emb_qkv(
cos_sin_length,
cos_sin_stride,
index_positions: index_positions.clone(),
batch,
num_tokens,
q_head_size,
k_head_size,
Expand All @@ -1105,8 +1103,7 @@ pub fn apply_rotary_emb_qkv(
cos_sin_length as i32,
cos_sin_stride as i32,
index_positions,
b_sz as i32,
seq_len as i32,
(b_sz * seq_len) as i32,
q_head_size as i32,
k_head_size as i32,
hidden_size as i32,
Expand All @@ -1126,8 +1123,7 @@ pub fn apply_rotary_emb_qkv(
cos_sin_length as i32,
cos_sin_stride as i32,
index_positions,
b_sz as i32,
seq_len as i32,
(b_sz * seq_len) as i32,
q_head_size as i32,
k_head_size as i32,
hidden_size as i32,
Expand All @@ -1144,7 +1140,7 @@ pub fn apply_rotary_emb_qkv(
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
index_pos: usize,
index_pos: &usize,
split_dim: usize,
query_key_transposed: bool,
gpt_neox: bool,
Expand All @@ -1160,8 +1156,8 @@ pub fn apply_rotary_emb_qkv(
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = cos.narrow(0, index_pos, seq_len)?;
let sin = sin.narrow(0, index_pos, seq_len)?;
let cos = cos.narrow(0, *index_pos, seq_len)?;
let sin = sin.narrow(0, *index_pos, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
Expand Down Expand Up @@ -1191,7 +1187,7 @@ pub fn partial_rotary_emb_qkv(
&key_rot,
&cos_sin,
&sin,
index_pos,
&index_pos,
0,
query_key_transposed,
true,
Expand Down Expand Up @@ -1222,7 +1218,7 @@ pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result<T

#[cfg(not(feature = "gcu"))]
pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result<Tensor> {
Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous()
Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?
}

#[cfg(feature = "gcu")]
Expand Down Expand Up @@ -1807,7 +1803,7 @@ fn update_cache<
let v = v.as_gcu_slice::<T>()?;
let kc = kc.as_gcu_slice::<T>()?;
let vc = vc.as_gcu_slice::<T>()?;
let s = s.as_gcu_slice::<i32>()?;
let s = s.as_gcu_slice::<i64>()?;

// Get cuda views for all tensors
let k = k.slice(k_l.start_offset()..);
Expand Down Expand Up @@ -2496,7 +2492,6 @@ pub fn expert_mask(input: &Tensor, v: u32) -> Result<Tensor> {
)?)
}


//input: [batch, M (topk or 1), k]
//weight: [num_experts, n, k]
//indices: [batch, topk]
Expand Down Expand Up @@ -2527,21 +2522,23 @@ fn indexed_moe_func<
let (indices_value, indices_l) = indices.storage_and_layout();

assert!(
input.dims().len() == 3
&& weight.dims().len() == 3
&& indices.dims().len() == 2,
input.dims().len() == 3 && weight.dims().len() == 3 && indices.dims().len() == 2,
"Invalid input dims!"
);
let (b1, topk) = indices.dims2()?;
let (batch, m, k) = input.dims3()?;
let (num_experts, n, k1) = weight.dims3()?;
let tile_size = if batch > 12 {
128
} else {
64
};
assert!(k % tile_size ==0, "indexed_moe: k dim must be aligned to {}!", tile_size);
assert!(n % tile_size ==0, "indexed_moe: n dim must be aligned to {}!", tile_size);
let tile_size = if batch > 12 { 128 } else { 64 };
assert!(
k % tile_size == 0,
"indexed_moe: k dim must be aligned to {}!",
tile_size
);
assert!(
n % tile_size == 0,
"indexed_moe: n dim must be aligned to {}!",
tile_size
);

// let (b2, _, n1) = out_l.dims3()?;

Expand Down Expand Up @@ -2584,7 +2581,6 @@ fn indexed_moe_func<
let indices_value = indices_value.as_gcu_slice::<u32>()?;
let indices_value = indices_value.slice(indices_l.start_offset()..);


match input.dtype() {
DType::F16 => unsafe {
indexed_moe_f16(
Expand Down
7 changes: 3 additions & 4 deletions candle-transformers/src/models/chatglm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((cfg.seq_length, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor;
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
Ok(Self { cache, cos_sin })
}
Expand Down Expand Up @@ -294,15 +294,14 @@ impl SelfAttention {
// let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;

let rot_dim = rotary_emb.cache.dim(D::Minus2)? * 2;
let mut input_positions = Vec::<i32>::new();
input_positions.push(seqlen_offset as i32);
#[cfg(feature = "gcu")]
let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?;
let (query_layer, key_layer) = candle_nn::ops::apply_rotary_emb_qkv(
&query_layer,
&key_layer,
&rotary_emb.cos_sin,
&rotary_emb.cache,
&input_positions,
&seqlen_offset,
rot_dim,
false,
false,
Expand Down
11 changes: 4 additions & 7 deletions candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
Expand Down Expand Up @@ -206,11 +206,8 @@ impl Attention {
(q, k, v.contiguous()?)
};

// let (query_states, key_states) =
// self.rotary_emb
// .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let mut input_positions = Vec::<i32>::new();
input_positions.push(seqlen_offset as i32);
#[cfg(feature = "gcu")]
let seqlen_offset = Tensor::new(seqlen_offset as i64, &query_states.device())?;
let (query_states, key_states) = apply_rotary_emb_qkv(
&query_states,
&key_states,
Expand All @@ -220,7 +217,7 @@ impl Attention {
&self.rotary_emb.cos
},
&self.rotary_emb.sin,
&input_positions,
&seqlen_offset,
0,
true,
true,
Expand Down
7 changes: 3 additions & 4 deletions candle-transformers/src/models/glm4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((cfg.seq_length, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor;
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
Ok(Self { cache, cos_sin })
}
Expand Down Expand Up @@ -301,15 +301,14 @@ impl SelfAttention {
// let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
// let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;

let mut input_positions = Vec::<i32>::new();
input_positions.push(seqlen_offset as i32);
#[cfg(feature = "gcu")]
let seqlen_offset = Tensor::new(seqlen_offset as i64, &xs.device())?;
let (query_layer, key_layer) = candle_nn::ops::apply_rotary_emb_qkv(
&query_layer,
&key_layer,
&rotary_emb.cos_sin,
&rotary_emb.cache,
&input_positions,
&seqlen_offset,
rotary_emb.cache.dim(D::Minus2)? * 2,
false,
false,
Expand Down
12 changes: 5 additions & 7 deletions candle-transformers/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let cos_sin =
Tensor::cat(&[&idx_theta.cos()?, &idx_theta.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor;
let cos_sin = Tensor::cat(&[&idx_theta.cos()?, &idx_theta.sin()?], D::Minus1)?; //must be contiguous tensor;
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Expand Down Expand Up @@ -293,10 +292,9 @@ impl CausalSelfAttention {
.transpose(1, 2)?;
(q, k, v.contiguous()?)
};
let mut input_positions = Vec::<i32>::new();
for _ in 0..b_sz {
input_positions.push(index_pos as i32);
}

#[cfg(feature = "gcu")]
let index_pos = Tensor::new(index_pos as i64, &q.device())?;
let (q, mut k) = candle_nn::apply_rotary_emb_qkv(
&q,
&k,
Expand All @@ -306,7 +304,7 @@ impl CausalSelfAttention {
&cache.cos
},
&cache.sin,
&input_positions,
&index_pos,
0,
true,
true,
Expand Down
11 changes: 5 additions & 6 deletions candle-transformers/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor;
let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; //must be contiguous tensor;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;

Ok(Self {
Expand Down Expand Up @@ -275,10 +275,9 @@ impl Attention {
.transpose(1, 2)?;
(q, k, v.contiguous()?)
};
let mut input_positions = Vec::<i32>::new();
for _ in 0..b_sz {
input_positions.push(seqlen_offset as i32);
}

#[cfg(feature = "gcu")]
let seqlen_offset = Tensor::new(seqlen_offset as i64, &query_states.device())?;
let (query_states, key_states) = candle_nn::apply_rotary_emb_qkv(
&query_states,
&key_states,
Expand All @@ -288,7 +287,7 @@ impl Attention {
&self.rotary_emb.cos
},
&self.rotary_emb.sin,
&input_positions,
&seqlen_offset,
0,
true,
true,
Expand Down
Loading
Loading