From ff3b6022677e2ed4351d63a5bc3b6672def37a0b Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 18 Dec 2025 12:52:10 +0800 Subject: [PATCH 01/14] backend buffer: allocate on host --- ggml/src/ggml-openvino/ggml-decoder.cpp | 125 +++--- ggml/src/ggml-openvino/ggml-openvino-extra.h | 247 +++++++++++ ggml/src/ggml-openvino/ggml-openvino.cpp | 425 ++++++++++++++++++- ggml/src/ggml-openvino/ggml-quants.cpp | 161 +++++-- ggml/src/ggml-openvino/ggml-quants.hpp | 29 +- ggml/src/ggml-openvino/utils.cpp | 20 +- 6 files changed, 904 insertions(+), 103 deletions(-) create mode 100644 ggml/src/ggml-openvino/ggml-openvino-extra.h diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 275a8a216ae..409a16e8162 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -2,6 +2,7 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" +#include "ggml-openvino-extra.h" #include "ggml-quants.hpp" #include @@ -17,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +35,7 @@ #include #include #include +#include #include GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, @@ -512,8 +515,49 @@ std::map> GgmlOvDecoder::create_weight_no return model_weights; } +// Static cache for quantized weight nodes (keyed by tensor data pointer) +// This is a fallback for when tensors don't have pre-built constants in extra +static std::unordered_map> s_quantized_weight_cache; +static std::mutex s_quantized_weight_cache_mutex; + std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, std::optional requant_type) { + // Check if we have a pre-built constant from the OpenVINO backend buffer + // This is set during ggml_backend_openvino_buffer_set_tensor + if (tensor->extra != nullptr && !requant_type.has_value()) { + // Cast to our extra base type and check the type + auto * extra_base = static_cast(tensor->extra); + + if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) { + // F16/F32/BF16 weight with shared-memory constant + auto * weight_extra = static_cast(tensor->extra); + if (weight_extra->constant) { + GGML_LOG_DEBUG("%s: using pre-built constant for %s\n", __func__, tensor->name); + return weight_extra->constant; + } + } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) { + // Quantized weight with pre-extracted data + auto * quant_extra = static_cast(tensor->extra); + if (quant_extra->constant) { + GGML_LOG_DEBUG("%s: using pre-extracted quantized constant for %s\n", __func__, tensor->name); + return quant_extra->constant; + } + } + } + + // Fallback: Check static cache for quantized weights (keyed by data pointer) + // This handles cases where tensors weren't loaded through OpenVINO buffer + if (ggml_is_quantized(tensor->type) && !requant_type.has_value()) { + std::lock_guard lock(s_quantized_weight_cache_mutex); + auto it = s_quantized_weight_cache.find(tensor->data); + if (it != s_quantized_weight_cache.end()) { + GGML_LOG_DEBUG("%s: using cached quantized constant for %s\n", __func__, tensor->name); + return it->second; + } + } + + GGML_LOG_DEBUG("%s: creating new constant for %s (extra=%p)\n", __func__, tensor->name, tensor->extra); + std::set weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { @@ -543,63 +587,48 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor return weight_node; } - // Quantized case - OPENVINO_ASSERT(tensor->extra == nullptr, "Unsupported weight tensor: " + std::string(tensor->name) + - " Possibly this is a repacked quantized weights"); + // Quantized case - extra should be nullptr (not our type) + // Our ggml_openvino_weight_extra is only set for F16/F32 weights + if (tensor->extra != nullptr) { + // Check if it's our type - if so, something is wrong + auto * extra_base = static_cast(tensor->extra); + if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT || + extra_base->type == ggml_openvino_extra_base::Type::TENSOR) { + OPENVINO_ASSERT(false, "Quantized weight tensor has unexpected extra type: " + std::string(tensor->name)); + } + // Otherwise it might be repacked quantized weights from another backend + OPENVINO_ASSERT(false, "Unsupported weight tensor: " + std::string(tensor->name) + + " Possibly this is a repacked quantized weights"); + } if (requant_type.has_value()) { return requantize(tensor, requant_type.value()); } - ov::element::Type weight_type; - if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { - weight_type = ov::element::u4; - } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K || tensor.type == GGUF_TYPE_Q5_K - weight_type = ov::element::u8; - } - - uint64_t weights_per_block; - // here we only consider sub block, q6k:16 q4k:32 q5k:32 - if (tensor->type == GGML_TYPE_Q6_K) { - weights_per_block = 16; - } else { - weights_per_block = 32; + // Extract quantized weights using the shared function + auto layout = ggml_openvino_get_extracted_layout(tensor); + if (layout.total_size == 0) { + OPENVINO_THROW("Unsupported quantized type for ", tensor->name, " type=", ggml_type_name(tensor->type)); } - OPENVINO_ASSERT(node_shape.back() % weights_per_block == 0, "[load_gguf] tensor ", tensor->name, - " has incompatible last dim shape: ", node_shape.back()); + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; ov::Tensor weights(weight_type, node_shape); - // For scales and biases - node_shape[node_shape.size() - 1] = node_shape[node_shape.size() - 1] / weights_per_block; - ov::Tensor scales(ov::element::f16, node_shape); - ov::Tensor biases(ov::element::f16, node_shape); - - ov::Output weight_node; - if (tensor->type == GGML_TYPE_Q4_0) { - extract_q4_0_data(tensor, weights, scales, biases); - weight_node = make_int4_weights(weights, scales, biases, weights_per_block); - } else if (tensor->type == GGML_TYPE_Q4_1) { - extract_q4_1_data(tensor, weights, scales, biases); - weight_node = make_int4_weights(weights, scales, biases, weights_per_block); - } else if (tensor->type == GGML_TYPE_Q8_0) { - extract_q8_0_data(tensor, weights, scales, biases); - weight_node = make_int8_weights(weights, scales, biases, weights_per_block); - } else if (tensor->type == GGML_TYPE_Q6_K) { - extract_q6_k_data(tensor, weights, scales, biases); - weight_node = make_int8_weights(weights, scales, biases, weights_per_block); - } else if (tensor->type == GGML_TYPE_Q4_K) { - extract_q4_k_data(tensor, weights, scales, biases); - weight_node = make_int4_weights(weights, scales, biases, weights_per_block); - } else if (tensor->type == GGML_TYPE_Q5_K) { - extract_q5_k_data(tensor, weights, scales, biases); - weight_node = make_int8_weights(weights, scales, biases, weights_per_block); - } - - OPENVINO_ASSERT(weight_node.get_shape().size() == 2, "Weight should be 2D"); - - weight_node.get_node_shared_ptr()->set_friendly_name(tensor->name); - return weight_node.get_node_shared_ptr(); + ov::Tensor scales(ov::element::f16, scale_shape); + ov::Tensor biases(ov::element::f16, scale_shape); + + auto result = extract_quantized_weights(tensor, tensor->data, weights, scales, biases); + result->set_friendly_name(tensor->name); + + // Cache the quantized weight node for future reuse + if (ggml_is_quantized(tensor->type) && !requant_type.has_value()) { + std::lock_guard lock(s_quantized_weight_cache_mutex); + s_quantized_weight_cache[tensor->data] = result; + GGML_LOG_DEBUG("%s: cached quantized constant for %s\n", __func__, tensor->name); + } + + return result; } void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) { diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h new file mode 100644 index 00000000000..99db8704123 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ggml.h" + +// ExtraQuantType enum - defines requantization target formats +enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; + +// ===================================================== +// Global Device Configuration (singleton) +// ===================================================== +// Initialized once during backend init from GGML_OPENVINO_DEVICE env var + +struct ggml_openvino_device_config { + std::string device_name = "CPU"; + bool is_npu = false; + bool initialized = false; + + void init() { + if (initialized) return; + const char* env = std::getenv("GGML_OPENVINO_DEVICE"); + if (env) { + device_name = env; + is_npu = (device_name == "NPU"); + } + initialized = true; + } +}; + +// Get the global device config singleton +inline ggml_openvino_device_config& ggml_openvino_get_device_config() { + static ggml_openvino_device_config config; + return config; +} + +// Initialize device config (call during backend init) +inline void ggml_openvino_init_device_config() { + ggml_openvino_get_device_config().init(); +} + +// Get the device name +inline const std::string& ggml_openvino_get_device_name() { + return ggml_openvino_get_device_config().device_name; +} + +// Check if running on NPU +inline bool ggml_openvino_is_npu() { + return ggml_openvino_get_device_config().is_npu; +} + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +inline std::optional ggml_openvino_get_requant_type(ggml_type type) { + if (!ggml_openvino_is_npu()) { + return std::nullopt; + } + // NPU requantization rules + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + return ExtraQuantType::Q4_0_128; + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q5_K: + return ExtraQuantType::F16; + default: + return std::nullopt; + } +} + +// ===================================================== +// OpenVINO Tensor Extra Types +// ===================================================== +// These types are stored in tensor->extra by the OpenVINO backend buffer. +// They allow: +// 1. Pre-built ov::Constant nodes for weights (avoiding memcpy during graph construction) +// 2. ov::Tensor wrappers for KV cache / compute tensors (for direct use with infer_request) + +// Base class for OpenVINO tensor extra data +struct ggml_openvino_extra_base { + enum class Type { WEIGHT, QUANTIZED_WEIGHT, TENSOR }; + Type type; + virtual ~ggml_openvino_extra_base() = default; +protected: + explicit ggml_openvino_extra_base(Type t) : type(t) {} +}; + +// Extra data for F16/F32/BF16 weight tensors - stores the pre-built ov::Constant node +struct ggml_openvino_weight_extra : public ggml_openvino_extra_base { + std::shared_ptr constant; // Pre-built OpenVINO Constant node + + explicit ggml_openvino_weight_extra(std::shared_ptr c) + : ggml_openvino_extra_base(Type::WEIGHT), constant(std::move(c)) {} +}; + +// Extra data for quantized weight tensors - stores extracted weights/scales/biases and ov::Constant +struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base { + ov::Tensor weights; // U4 or U8 extracted weights + ov::Tensor scales; // F16 scales + ov::Tensor biases; // F16 biases (zero points) + std::shared_ptr constant; // Pre-built OpenVINO weight subgraph + + ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor b, std::shared_ptr c) + : ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), + weights(std::move(w)), scales(std::move(s)), biases(std::move(b)), constant(std::move(c)) {} +}; + +// Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request +struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base { + std::shared_ptr tensor; // For direct use with infer_request + + explicit ggml_openvino_tensor_extra(std::shared_ptr t) + : ggml_openvino_extra_base(Type::TENSOR), tensor(std::move(t)) {} +}; + +// ===================================================== +// Extracted Size Calculation for Quantized Tensors +// ===================================================== +// For quantized tensors, we need extra space to store extracted weights, scales, and biases. +// Returns the total size needed in the buffer for extracted data. + +struct ggml_openvino_extracted_layout { + size_t total_size; // Total bytes needed + size_t weights_offset; // Offset to weights in buffer + size_t weights_size; // Size of weights in bytes + size_t scales_offset; // Offset to scales in buffer + size_t scales_size; // Size of scales in bytes + size_t biases_offset; // Offset to biases in buffer + size_t biases_size; // Size of biases in bytes + bool is_u4; // true for U4 weights, false for U8 + int64_t weights_per_block;// weights per scale/bias block + + // Requantization info + bool is_requant; // true if this tensor needs requantization + std::optional requant_type; // target requant type if is_requant +}; + +// Calculate the buffer layout for extracted quantized data +inline ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor) { + ggml_openvino_extracted_layout layout = {}; + + if (!ggml_is_quantized(tensor->type)) { + return layout; + } + + // Only handle 2D weight tensors + if (tensor->ne[2] != 1 || tensor->ne[3] != 1) { + return layout; + } + + int64_t n_elements = ggml_nelements(tensor); + const size_t alignment = 64; // Good for SIMD + + // Check if requantization is needed (NPU-specific) + auto requant_type = ggml_openvino_get_requant_type(tensor->type); + if (requant_type.has_value()) { + layout.is_requant = true; + layout.requant_type = requant_type; + + // Special case: requant to F16 - just store F16 weights, no scales/biases + if (requant_type.value() == ExtraQuantType::F16) { + layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes + layout.total_size = layout.weights_size; + layout.weights_offset = 0; + // No scales/biases for F16 + return layout; + } + + // Requant to different quantized format (e.g., Q4_0_128) + switch (requant_type.value()) { + case ExtraQuantType::Q4_0_128: + layout.is_u4 = true; + layout.weights_per_block = 128; + break; + case ExtraQuantType::Q8_0_32: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + default: + // Unsupported requant type - fall through to normal extraction + layout.is_requant = false; + layout.requant_type = std::nullopt; + break; + } + + if (layout.is_requant) { + // Calculate sizes for requantized format + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); + layout.biases_size = n_blocks * sizeof(uint16_t); + + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.biases_offset + layout.biases_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + return layout; + } + } + + // Normal extraction (no requant) - determine format based on tensor type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + layout.is_u4 = true; + layout.weights_per_block = 32; + break; + case GGML_TYPE_Q8_0: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + case GGML_TYPE_Q6_K: + layout.is_u4 = false; + layout.weights_per_block = 16; + break; + case GGML_TYPE_Q5_K: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + default: + // Unsupported quantization type + return layout; + } + + // Calculate sizes + // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + + // Scales and biases: F16 per block + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + layout.biases_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + + // Layout in buffer: [weights | scales | biases] with alignment + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.biases_offset + layout.biases_size; + + return layout; +} diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index e809d250f70..747d1b8a307 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -3,18 +3,429 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-impl.h" +#include "ggml-openvino-extra.h" #include "ggml-openvino/utils.h" +#include "ggml-quants.hpp" #include "ggml.h" #include +#include +#include #include #include +#include #include #include #include #define GGML_OPENVINO_MAX_STREAMS 8 +// OpenVINO buffer alignment (same as CPU for compatibility) +#define GGML_OPENVINO_BUFFER_ALIGNMENT 64 + +// ===================================================== +// OpenVINO Buffer Implementation using ov::Tensor +// ===================================================== +// +// Design: This implementation uses a hybrid approach: +// 1. For weight tensors: Store a pre-built ov::op::v0::Constant in tensor->extra +// - This avoids the memcpy during graph construction +// - For quantized weights, the constant is already converted to OpenVINO format +// 2. For KV cache / compute tensors: Store an ov::Tensor in tensor->extra +// - This can be directly passed to infer_request +// - Future: can be changed to ov::RemoteTensor for GPU/NPU +// +// This design is similar to: +// - CUDA split buffer: tensor->extra stores device pointers +// - CPU repack buffer: tensor->extra stores tensor_traits with repacked data +// ===================================================== + +// Buffer context that manages per-tensor allocations (no contiguous buffer for weights) +struct ggml_backend_openvino_buffer_context { + int device; + std::string name; + + // For non-weight buffers (KV cache, compute), we still use contiguous allocation + void * data; + size_t size; + bool is_weight_buffer; // Set when buffer usage is set to WEIGHTS + + // Track all extras for cleanup + std::vector tensor_extras; + + ggml_backend_openvino_buffer_context(int device, size_t size) : + device(device), + name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), + data(nullptr), + size(size), + is_weight_buffer(false) { + // Allocate aligned contiguous memory + if (size > 0) { +#ifdef _WIN32 + data = _aligned_malloc(size, GGML_OPENVINO_BUFFER_ALIGNMENT); +#else + data = aligned_alloc(GGML_OPENVINO_BUFFER_ALIGNMENT, size); +#endif + if (data == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size); + } + } + } + + ~ggml_backend_openvino_buffer_context() { + // Clean up all tensor extras + for (auto * extra : tensor_extras) { + delete extra; + } + tensor_extras.clear(); + + // Free contiguous memory + if (data != nullptr) { +#ifdef _WIN32 + _aligned_free(data); +#else + free(data); +#endif + data = nullptr; + } + } +}; + +// Buffer type context (per-device) +struct ggml_backend_openvino_buffer_type_context { + int device; + std::string name; +}; + +// Buffer interface functions +static void ggml_backend_openvino_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->data; +} + +static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + // Views share the extra from view_src + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + if (tensor->view_src->extra != nullptr) { + tensor->extra = tensor->view_src->extra; + } + return GGML_STATUS_SUCCESS; + } + + // For non-view tensors, tensor->extra will be set in set_tensor + // when the actual weight data is loaded + GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + memset((char *) tensor->data + offset, value, size); + GGML_UNUSED(buffer); +} + +static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Check if this is a weight buffer (usage is set BEFORE set_tensor is called) + bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + // Full tensor set: offset=0, full size, not a view + bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr); + // 2D tensor (typical weight shape) + bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1); + + // Check if this is a quantized weight tensor that needs extraction/requantization + ggml_openvino_extracted_layout layout = {}; + if (is_weight_buffer && is_full_tensor_set && is_2d && ggml_is_quantized(tensor->type)) { + layout = ggml_openvino_get_extracted_layout(tensor); + } + + if (layout.total_size > 0) { + uint8_t * buf_base = (uint8_t *) tensor->data; + + // 2D shape for weights [rows, cols] + ov::Shape weight_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; + + try { + std::shared_ptr constant; + + if (layout.is_requant && layout.requant_type.has_value()) { + // Requantization path + if (layout.requant_type.value() == ExtraQuantType::F16) { + // Requant to F16: create F16 tensor with external memory, requantize fills it + ov::Tensor weights(ov::element::f16, weight_shape, buf_base); + ov::Tensor dummy_scales, dummy_biases; // Not used for F16 + // requantize_to_buffers fills weights and returns a Constant wrapping it + constant = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, + dummy_biases); + + // Store in tensor->extra (use weight_extra since it's F16) + auto * extra = new ggml_openvino_weight_extra(constant); + ctx->tensor_extras.push_back(extra); + tensor->extra = extra; + + GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); + } else { + // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {static_cast(tensor->ne[1]), + static_cast(tensor->ne[0] / layout.weights_per_block)}; + + ov::Tensor weights(weight_type, weight_shape, buf_base + layout.weights_offset); + ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + ov::Tensor biases(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + + constant = requantize_to_buffers(tensor, data, layout.requant_type.value(), + layout.weights_per_block, weights, scales, biases); + + // Store in tensor->extra + auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), + std::move(biases), constant); + ctx->tensor_extras.push_back(extra); + tensor->extra = extra; + + GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, + layout.requant_type.value() == ExtraQuantType::Q4_0_128 ? "Q4_0_128" : "Q8_0_32", + layout.is_u4 ? 4 : 8, layout.weights_per_block); + } + } else { + // Normal extraction path (no requant) + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + ov::Shape scale_shape = {static_cast(tensor->ne[1]), + static_cast(tensor->ne[0] / layout.weights_per_block)}; + + ov::Tensor weights(weight_type, weight_shape, buf_base + layout.weights_offset); + ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + ov::Tensor biases(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + + constant = extract_quantized_weights(tensor, data, weights, scales, biases); + + // Store in tensor->extra + auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), + std::move(biases), constant); + ctx->tensor_extras.push_back(extra); + tensor->extra = extra; + + GGML_LOG_DEBUG("%s: extracted quantized constant for %s (u%d, %zu weights, %ld blocks)\n", __func__, + tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + } + + } catch (const std::exception & e) { + GGML_LOG_ERROR("%s: failed to process quantized data for %s: %s\n", __func__, tensor->name, e.what()); + // Fall back to storing raw data + memcpy((char *) tensor->data + offset, data, size); + } + } else if (is_weight_buffer && is_full_tensor_set && is_2d && + (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16)) { + // F16/F32/BF16 weight tensor - copy data and create shared-memory constant + memcpy((char *) tensor->data + offset, data, size); + + try { + // Get OpenVINO element type + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + default: + return; // Should not happen + } + + // Create 2D shape (OpenVINO expects [rows, cols]) + ov::Shape shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; + + // Create ov::Tensor with external memory, then wrap with Constant + ov::Tensor ov_tensor(element_type, shape, tensor->data); + auto constant = std::make_shared(ov_tensor); + constant->set_friendly_name(tensor->name); + + // Store in tensor->extra + ggml_openvino_weight_extra * extra = new ggml_openvino_weight_extra(constant); + ctx->tensor_extras.push_back(extra); + tensor->extra = extra; + + GGML_LOG_DEBUG("%s: created shared-memory constant for %s\n", __func__, tensor->name); + + } catch (const std::exception & e) { + GGML_LOG_DEBUG("%s: failed to create shared-memory constant for %s: %s\n", __func__, tensor->name, + e.what()); + } + } else { + // Non-weight tensor (KV cache, activations, etc.) - just copy data + memcpy((char *) tensor->data + offset, data, size); + } +} + +static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + memcpy(data, (const char *) tensor->data + offset, size); + GGML_UNUSED(buffer); +} + +static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + GGML_ASSERT(src != nullptr && dst != nullptr); + // Can copy from any host buffer (including other OpenVINO buffers) + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + GGML_UNUSED(buffer); +} + +static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + if (ctx->data != nullptr) { + memset(ctx->data, value, ctx->size); + } +} + +static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { + /* .free_buffer = */ ggml_backend_openvino_buffer_free_buffer, + /* .get_base = */ ggml_backend_openvino_buffer_get_base, + /* .init_tensor = */ ggml_backend_openvino_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, + /* .clear = */ ggml_backend_openvino_buffer_clear, + /* .reset = */ NULL, +}; + +// Buffer type interface functions +static const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_openvino_buffer_type_context * buft_ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + + // Create buffer context with contiguous memory allocation + ggml_backend_openvino_buffer_context * ctx = new ggml_backend_openvino_buffer_context(buft_ctx->device, size); + + if (ctx->data == nullptr && size > 0) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + delete ctx; + return nullptr; + } + + return ggml_backend_buffer_init(buft, ggml_backend_openvino_buffer_interface, ctx, size); +} + +static size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return GGML_OPENVINO_BUFFER_ALIGNMENT; +} + +static size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return SIZE_MAX; +} + +static size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + GGML_UNUSED(buft); + + // For quantized 2D tensors (weights), we need extra space for extracted data + if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) { + ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor); + if (layout.total_size > 0) { + GGML_LOG_DEBUG( + "%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu biases=%zu)\n", + __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size, layout.scales_size, + layout.biases_size); + return layout.total_size; + } + } + + return ggml_nbytes(tensor); +} + +static bool ggml_backend_openvino_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + // Currently using host memory via ov::Tensor + // This will be false when using GPU/NPU remote tensors + return true; +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_openvino_buffer_type_is_host, +}; + +// Get buffer type for a specific device +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector buffer_types; + static std::vector buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + +// Check if a buffer is an OpenVINO buffer +static bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; +} + +// ===================================================== +// OpenVINO Backend Context and Interface +// ===================================================== + struct ggml_backend_openvino_context { int device; // the device ID currently in use std::string name; // context Name @@ -111,13 +522,6 @@ GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid()); } -// device buffer -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) { - GGML_ASSERT(device >= 0); - return ggml_backend_cpu_buffer_type(); - GGML_UNUSED(device); -} - struct ggml_backend_openvino_device_context { int device; std::string name; @@ -350,7 +754,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con } static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); + // Support our own buffer type and any host buffer (for mmap'd files, etc.) + return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name || ggml_backend_buft_is_host(buft); GGML_UNUSED(dev); } @@ -410,6 +815,10 @@ static int get_openvino_device_count() { } static ggml_openvino_device_info ggml_openvino_init() { + // Initialize device config singleton from env var + ggml_openvino_init_device_config(); + GGML_LOG_INFO("OpenVINO: using device %s\n", ggml_openvino_get_device_name().c_str()); + ggml_openvino_device_info info = {}; info.device_count = get_openvino_device_count(); return info; diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 2076c3c75d3..662f27be7ad 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -418,11 +418,124 @@ ov::Output make_int4_weights(ov::Tensor & weight, return std::make_shared(w_zp_s, ov::element::f32); } -std::shared_ptr requantize(const ggml_tensor * tensor, ExtraQuantType requant_type) { - std::vector weights_f32(tensor->ne[0] * tensor->ne[1]); - ggml_get_type_traits(tensor->type)->to_float(tensor->data, weights_f32.data(), ggml_nelements(tensor)); +// Extract quantized weights from tensor and create weight subgraph +std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor, + const void * data, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & biases) { + // Create a temporary tensor for extraction functions that read from tensor->data + ggml_tensor temp_tensor = *tensor; + temp_tensor.data = const_cast(data); + + // Determine block size based on tensor type + int64_t weights_per_block; + bool is_u4; + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + is_u4 = true; + weights_per_block = 32; + break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_K: + is_u4 = false; + weights_per_block = 32; + break; + case GGML_TYPE_Q6_K: + is_u4 = false; + weights_per_block = 16; + break; + default: + throw std::runtime_error("Unsupported quantized type for extraction: " + + std::string(ggml_type_name(tensor->type))); + } + + // Extract quantized data + switch (tensor->type) { + case GGML_TYPE_Q4_0: + extract_q4_0_data(&temp_tensor, weights, scales, biases); + break; + case GGML_TYPE_Q4_1: + extract_q4_1_data(&temp_tensor, weights, scales, biases); + break; + case GGML_TYPE_Q4_K: + extract_q4_k_data(&temp_tensor, weights, scales, biases); + break; + case GGML_TYPE_Q8_0: + extract_q8_0_data(&temp_tensor, weights, scales, biases); + break; + case GGML_TYPE_Q6_K: + extract_q6_k_data(&temp_tensor, weights, scales, biases); + break; + case GGML_TYPE_Q5_K: + extract_q5_k_data(&temp_tensor, weights, scales, biases); + break; + default: + throw std::runtime_error("Unsupported quantized type: " + std::string(ggml_type_name(tensor->type))); + } + + // Create the OpenVINO weight subgraph + ov::Output weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else { + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} + +// Requantize weights to target format, writing to provided buffers +std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, + const void * data, + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & biases) { + int64_t n_elements = ggml_nelements(tensor); + + // First dequantize to F32 + std::vector weights_f32(n_elements); + ggml_get_type_traits(tensor->type)->to_float(data, weights_f32.data(), n_elements); + + // Handle F16 case - just convert and create constant + if (requant_type == ExtraQuantType::F16) { + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), n_elements); + auto result = std::make_shared(weights); + result->set_friendly_name(tensor->name); + return result; + } + + // Requantize to target quantized format + bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); + + if (is_u4) { + quantize_q4_0(weights_f32.data(), weights, scales, biases, n_elements, block_size); + } else if (requant_type == ExtraQuantType::Q8_1_C) { + quantize_q8_1(weights_f32.data(), weights, scales, biases, n_elements, block_size); + } else { + quantize_q8_0(weights_f32.data(), weights, scales, biases, n_elements, block_size); + } + + // Create the OpenVINO weight subgraph + ov::Output weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, biases, block_size); + } else { + weight_node = make_int8_weights(weights, scales, biases, block_size); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} - std::shared_ptr weight_node; +std::shared_ptr requantize(const ggml_tensor * tensor, ExtraQuantType requant_type) { ov::Shape node_shape = {(uint64_t) (tensor->ne[1]), (uint64_t) (tensor->ne[0])}; // FIXME hardcoded workaround to fix the case where token_emb.weight is q4_0 (instead of q6_k) @@ -432,42 +545,28 @@ std::shared_ptr requantize(const ggml_tensor * tensor, ExtraQuantType requant_type = ExtraQuantType::F16; } - if (requant_type == ExtraQuantType::F16) { - ov::Tensor weights(ov::element::f16, node_shape); - ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), ggml_nelements(tensor)); - std::shared_ptr weight_node = std::make_shared(weights); - weight_node->set_friendly_name(tensor->name); - return weight_node; - } - + // Determine block size int64_t block_size = node_shape[1]; if (requant_type == ExtraQuantType::Q4_0_128) { block_size = 128; } else if (requant_type == ExtraQuantType::Q8_0_32) { block_size = 32; } - auto scales_shape = ov::Shape{node_shape[0], node_shape[1] / block_size}; - - ov::Tensor weights; - ov::Tensor scales(ov::element::f16, scales_shape); - ov::Tensor bias(ov::element::f16, scales_shape); - if (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128) { - weights = ov::Tensor(ov::element::u4, node_shape); - quantize_q4_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); - weight_node = make_int4_weights(weights, scales, bias, block_size).get_node_shared_ptr(); - } else if (requant_type == ExtraQuantType::Q8_1_C) { - weights = ov::Tensor(ov::element::u8, node_shape); - quantize_q8_1(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); - weight_node = make_int8_weights(weights, scales, bias, block_size).get_node_shared_ptr(); - } else if (requant_type == ExtraQuantType::Q8_0_C || requant_type == ExtraQuantType::Q8_0_32) { - weights = ov::Tensor(ov::element::u8, node_shape); - quantize_q8_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); - weight_node = make_int8_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + // Allocate tensors + ov::Tensor weights, scales, biases; + if (requant_type == ExtraQuantType::F16) { + weights = ov::Tensor(ov::element::f16, node_shape); + } else { + bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); + ov::element::Type weight_type = is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scales_shape = {node_shape[0], node_shape[1] / block_size}; + weights = ov::Tensor(weight_type, node_shape); + scales = ov::Tensor(ov::element::f16, scales_shape); + biases = ov::Tensor(ov::element::f16, scales_shape); } - weight_node->set_friendly_name(tensor->name); - return weight_node; + return requantize_to_buffers(tensor, tensor->data, requant_type, block_size, weights, scales, biases); } void quantize_q4_0(const float * x, diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index 71ae317a39e..0f14a6ed2dc 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -1,10 +1,11 @@ #pragma once +#include "ggml-openvino-extra.h" // For ExtraQuantType +#include "ggml.h" + #include #include #include -#include "ggml.h" - void unpack_32_4(const uint8_t* data, uint8_t* dst); void extract_q4_0_data(const ggml_tensor* tensor, @@ -51,10 +52,32 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& biases, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); -enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; +// ExtraQuantType is defined in ggml-openvino-extra.h std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type); +// Extract quantized weights from tensor and create weight subgraph +// If weights/scales/biases are provided (non-empty), uses them as output buffers +// Otherwise allocates new ov::Tensors internally +// Returns the weight node (make_int4_weights or make_int8_weights result) +std::shared_ptr extract_quantized_weights( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & biases); + +// Requantize weights from tensor to target format, writing to provided buffers +// For F16 target, only weights buffer is used (scales/biases ignored) +// Returns the weight node +std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, + const void * data, // Source data pointer + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & biases); + void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, int64_t qk); void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 836e366fd7f..251fb82361b 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -1,6 +1,7 @@ #include "utils.h" #include "ggml-impl.h" +#include "ggml-openvino-extra.h" #include "ggml-openvino/ggml-decoder.h" #include "ggml.h" #include "openvino/frontend.hpp" @@ -39,23 +40,14 @@ static ov::Core core; enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { - auto get_device = [&] { - std::string device = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU"; - auto available_devices = core.get_available_devices(); - if (std::find(available_devices.begin(), available_devices.end(), device) == available_devices.end()) { - GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device.c_str()); - device = "CPU"; - } - return device; - }; - if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { std::string filename = "cgraph.txt"; GgmlOvDecoder::dump_cgraph(cgraph, filename); } - static const auto device = get_device(); - static const auto is_static = device == "NPU" ? true : false; + // Use device from singleton (initialized during backend init) + const auto & device = ggml_openvino_get_device_name(); + const auto is_static = ggml_openvino_is_npu(); return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device); } @@ -413,7 +405,8 @@ ov::AnyMap get_ov_compile_config(const std::string & device) { } std::map get_types_to_requant(const std::string & device) { - if (device == "NPU") { + // Use singleton to check if NPU (device param kept for API compatibility) + if (ggml_openvino_is_npu()) { return { {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, @@ -423,6 +416,7 @@ std::map get_types_to_requant(const std::string & dev }; } return {}; + GGML_UNUSED(device); } bool is_naive(ggml_cgraph * cgraph) { From b9b02a6169fa576a6ee206d926884b956584cf70 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 18 Dec 2025 17:03:03 +0800 Subject: [PATCH 02/14] Use shared_buffer for GPU NPU; Refactor --- ggml/src/ggml-openvino/CMakeLists.txt | 3 +- ggml/src/ggml-openvino/ggml-decoder.cpp | 78 ++------ ggml/src/ggml-openvino/ggml-decoder.h | 7 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 177 ++++++++++++++++++ ggml/src/ggml-openvino/ggml-openvino-extra.h | 159 ++-------------- ggml/src/ggml-openvino/ggml-openvino.cpp | 154 +++++++-------- ggml/src/ggml-openvino/ggml-quants.cpp | 106 +++++++++++ ggml/src/ggml-openvino/ggml-quants.hpp | 10 + ggml/src/ggml-openvino/utils.cpp | 19 +- ggml/src/ggml-openvino/utils.h | 2 - 10 files changed, 389 insertions(+), 326 deletions(-) create mode 100644 ggml/src/ggml-openvino/ggml-openvino-extra.cpp diff --git a/ggml/src/ggml-openvino/CMakeLists.txt b/ggml/src/ggml-openvino/CMakeLists.txt index 3051a8b2405..175b585661d 100644 --- a/ggml/src/ggml-openvino/CMakeLists.txt +++ b/ggml/src/ggml-openvino/CMakeLists.txt @@ -1,4 +1,5 @@ find_package(OpenVINO REQUIRED) +find_package(OpenCL REQUIRED) include("${OpenVINO_DIR}/../3rdparty/tbb/lib/cmake/TBB/TBBConfig.cmake") @@ -10,7 +11,7 @@ ggml_add_backend_library(ggml-openvino ${GGML_HEADERS_OPENVINO} ) -target_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb) +target_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb OpenCL::OpenCL) if (GGML_OPENVINO) if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 409a16e8162..2d6437f0691 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -3,6 +3,7 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-openvino-extra.h" +#include "ggml-openvino.h" #include "ggml-quants.hpp" #include @@ -471,9 +472,7 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name // return kv_param_res_names; // } -std::map> GgmlOvDecoder::create_weight_nodes( - ggml_cgraph * cgraph, - std::map types_to_requantize) { +std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph) { std::map> model_weights; static std::mutex weights_mutex; auto * nodes = cgraph->nodes; @@ -498,10 +497,7 @@ std::map> GgmlOvDecoder::create_weight_no } } if (should_create) { - auto requant_type = types_to_requantize.count(src->type) ? - std::optional(types_to_requantize.at(src->type)) : - std::nullopt; - auto weight_node = create_weight_node(src, requant_type); + auto weight_node = create_weight_node(src); weight_node->set_friendly_name(src_name); { std::lock_guard lock(weights_mutex); @@ -520,11 +516,14 @@ std::map> GgmlOvDecoder::create_weight_no static std::unordered_map> s_quantized_weight_cache; static std::mutex s_quantized_weight_cache_mutex; -std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, - std::optional requant_type) { +std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor) { // Check if we have a pre-built constant from the OpenVINO backend buffer // This is set during ggml_backend_openvino_buffer_set_tensor - if (tensor->extra != nullptr && !requant_type.has_value()) { + if (tensor->extra) { + if (!ggml_backend_buffer_is_openvino(tensor->buffer)) { + OPENVINO_ASSERT(false, "Unsupported weight tensor: " + std::string(tensor->name) + + " Possibly this is a cpu backend repacked quantized weights"); + } // Cast to our extra base type and check the type auto * extra_base = static_cast(tensor->extra); @@ -547,7 +546,7 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor // Fallback: Check static cache for quantized weights (keyed by data pointer) // This handles cases where tensors weren't loaded through OpenVINO buffer - if (ggml_is_quantized(tensor->type) && !requant_type.has_value()) { + if (ggml_is_quantized(tensor->type)) { std::lock_guard lock(s_quantized_weight_cache_mutex); auto it = s_quantized_weight_cache.find(tensor->data); if (it != s_quantized_weight_cache.end()) { @@ -565,64 +564,11 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor ggml_type_name(tensor->type)); } - auto node_type = get_ov_type(tensor); - auto node_shape = get_shape(tensor); - auto ne_total = ggml_nelements(tensor); - - OPENVINO_ASSERT(node_shape[0] == 1, "Got 4D weights, expect all weights to be 2D: ", tensor->name); - node_shape.erase(node_shape.begin()); - OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name); - node_shape.erase(node_shape.begin()); - - // F16 and F32 case - if (node_type != ov::element::dynamic) { - ov::Tensor weights(node_type, node_shape); - memcpy(weights.data(), tensor->data, ne_total * node_type.size()); - std::shared_ptr weight_node = std::make_shared(weights); - // Disabled because it triggers a bug in NPUW, no performance impact on CPU GPU - // if (node_type == ov::element::f16) { - // weight_node = std::make_shared(weight_node, ov::element::f32); - // } - weight_node->set_friendly_name(tensor->name); - return weight_node; - } - - // Quantized case - extra should be nullptr (not our type) - // Our ggml_openvino_weight_extra is only set for F16/F32 weights - if (tensor->extra != nullptr) { - // Check if it's our type - if so, something is wrong - auto * extra_base = static_cast(tensor->extra); - if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT || - extra_base->type == ggml_openvino_extra_base::Type::TENSOR) { - OPENVINO_ASSERT(false, "Quantized weight tensor has unexpected extra type: " + std::string(tensor->name)); - } - // Otherwise it might be repacked quantized weights from another backend - OPENVINO_ASSERT(false, "Unsupported weight tensor: " + std::string(tensor->name) + - " Possibly this is a repacked quantized weights"); - } - - if (requant_type.has_value()) { - return requantize(tensor, requant_type.value()); - } - - // Extract quantized weights using the shared function - auto layout = ggml_openvino_get_extracted_layout(tensor); - if (layout.total_size == 0) { - OPENVINO_THROW("Unsupported quantized type for ", tensor->name, " type=", ggml_type_name(tensor->type)); - } - - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - - ov::Tensor weights(weight_type, node_shape); - ov::Tensor scales(ov::element::f16, scale_shape); - ov::Tensor biases(ov::element::f16, scale_shape); - - auto result = extract_quantized_weights(tensor, tensor->data, weights, scales, biases); + std::shared_ptr result = process_weight_tensor(tensor, tensor->data, nullptr); result->set_friendly_name(tensor->name); // Cache the quantized weight node for future reuse - if (ggml_is_quantized(tensor->type) && !requant_type.has_value()) { + if (ggml_is_quantized(tensor->type)) { std::lock_guard lock(s_quantized_weight_cache_mutex); s_quantized_weight_cache[tensor->data] = result; GGML_LOG_DEBUG("%s: cached quantized constant for %s\n", __func__, tensor->name); diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index edcd0367854..0b302b9320b 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -179,12 +179,9 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename); - static std::shared_ptr create_weight_node(ggml_tensor * tensor, - std::optional requant_type = std::nullopt); + static std::shared_ptr create_weight_node(ggml_tensor * tensor); - static std::map> create_weight_nodes( - ggml_cgraph * cgraph, - std::map types_to_requantize = {}); + static std::map> create_weight_nodes(ggml_cgraph * cgraph); const ggml_tensor * get_tensor_used_op(const ggml_tensor * tensor) const; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp new file mode 100644 index 00000000000..75b27c8fa81 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -0,0 +1,177 @@ +#include "ggml-openvino-extra.h" + +#include "ggml-impl.h" + +ov::Core & ov_singleton_core() { + static ov::Core core; + return core; +} + +// ===================================================== +// Device Configuration Implementations +// ===================================================== + +void ggml_openvino_device_config::init() { + if (initialized) { + return; + } + device_name = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU"; + auto available_devices = ov_singleton_core().get_available_devices(); + if (std::find(available_devices.begin(), available_devices.end(), device_name) == available_devices.end()) { + GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device_name.c_str()); + device_name = "CPU"; + } + is_npu = (device_name == "NPU"); + initialized = true; +} + +// Get the global device config singleton +ggml_openvino_device_config & ggml_openvino_get_device_config() { + static ggml_openvino_device_config config; + return config; +} + +// Initialize device config (call during backend init) +void ggml_openvino_init_device_config() { + ggml_openvino_get_device_config().init(); +} + +// Get the device name +const std::string & ggml_openvino_get_device_name() { + return ggml_openvino_get_device_config().device_name; +} + +// Check if running on NPU +bool ggml_openvino_is_npu() { + return ggml_openvino_get_device_config().is_npu; +} + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +std::optional ggml_openvino_get_requant_type(ggml_type type) { + if (!ggml_openvino_is_npu()) { + return std::nullopt; + } + // NPU requantization rules + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + return ExtraQuantType::Q4_0_128; + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q5_K: + return ExtraQuantType::F16; + default: + return std::nullopt; + } +} + +// ===================================================== +// Extracted Layout Calculation +// ===================================================== + +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor) { + ggml_openvino_extracted_layout layout = {}; + + if (!ggml_is_quantized(tensor->type)) { + return layout; + } + + // Only handle 2D weight tensors + if (tensor->ne[2] != 1 || tensor->ne[3] != 1) { + return layout; + } + + int64_t n_elements = ggml_nelements(tensor); + const size_t alignment = 64; // Good for SIMD + + // Check if requantization is needed (NPU-specific) + auto requant_type = ggml_openvino_get_requant_type(tensor->type); + if (requant_type.has_value()) { + layout.is_requant = true; + layout.requant_type = requant_type; + + // Special case: requant to F16 - just store F16 weights, no scales/biases + if (requant_type.value() == ExtraQuantType::F16) { + layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes + layout.total_size = layout.weights_size; + layout.weights_offset = 0; + // No scales/biases for F16 + return layout; + } + + // Requant to different quantized format (e.g., Q4_0_128) + switch (requant_type.value()) { + case ExtraQuantType::Q4_0_128: + layout.is_u4 = true; + layout.weights_per_block = 128; + break; + case ExtraQuantType::Q8_0_32: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + default: + // Unsupported requant type - fall through to normal extraction + layout.is_requant = false; + layout.requant_type = std::nullopt; + break; + } + + if (layout.is_requant) { + // Calculate sizes for requantized format + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); + layout.biases_size = n_blocks * sizeof(uint16_t); + + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.biases_offset = + layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.biases_offset + layout.biases_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + return layout; + } + } + + // Normal extraction (no requant) - determine format based on tensor type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + layout.is_u4 = true; + layout.weights_per_block = 32; + break; + case GGML_TYPE_Q8_0: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + case GGML_TYPE_Q6_K: + layout.is_u4 = false; + layout.weights_per_block = 16; + break; + case GGML_TYPE_Q5_K: + layout.is_u4 = false; + layout.weights_per_block = 32; + break; + default: + // Unsupported quantization type + return layout; + } + + // Calculate sizes + // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + + // Scales and biases: F16 per block + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + layout.biases_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + + // Layout in buffer: [weights | scales | biases] with alignment + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.biases_offset + layout.biases_size; + + return layout; +} diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index 99db8704123..7e0138388ff 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -1,16 +1,20 @@ #pragma once +#include "ggml.h" +#include "openvino/runtime/core.hpp" + #include #include -#include #include #include +#include #include -#include "ggml.h" // ExtraQuantType enum - defines requantization target formats enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; +ov::Core & ov_singleton_core(); + // ===================================================== // Global Device Configuration (singleton) // ===================================================== @@ -21,56 +25,23 @@ struct ggml_openvino_device_config { bool is_npu = false; bool initialized = false; - void init() { - if (initialized) return; - const char* env = std::getenv("GGML_OPENVINO_DEVICE"); - if (env) { - device_name = env; - is_npu = (device_name == "NPU"); - } - initialized = true; - } + void init(); }; // Get the global device config singleton -inline ggml_openvino_device_config& ggml_openvino_get_device_config() { - static ggml_openvino_device_config config; - return config; -} +ggml_openvino_device_config & ggml_openvino_get_device_config(); // Initialize device config (call during backend init) -inline void ggml_openvino_init_device_config() { - ggml_openvino_get_device_config().init(); -} +void ggml_openvino_init_device_config(); // Get the device name -inline const std::string& ggml_openvino_get_device_name() { - return ggml_openvino_get_device_config().device_name; -} +const std::string & ggml_openvino_get_device_name(); // Check if running on NPU -inline bool ggml_openvino_is_npu() { - return ggml_openvino_get_device_config().is_npu; -} +bool ggml_openvino_is_npu(); // Get requantization type for a tensor type (returns nullopt if no requant needed) -inline std::optional ggml_openvino_get_requant_type(ggml_type type) { - if (!ggml_openvino_is_npu()) { - return std::nullopt; - } - // NPU requantization rules - switch (type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_K: - return ExtraQuantType::Q4_0_128; - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q5_K: - return ExtraQuantType::F16; - default: - return std::nullopt; - } -} +std::optional ggml_openvino_get_requant_type(ggml_type type); // ===================================================== // OpenVINO Tensor Extra Types @@ -140,108 +111,4 @@ struct ggml_openvino_extracted_layout { }; // Calculate the buffer layout for extracted quantized data -inline ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor) { - ggml_openvino_extracted_layout layout = {}; - - if (!ggml_is_quantized(tensor->type)) { - return layout; - } - - // Only handle 2D weight tensors - if (tensor->ne[2] != 1 || tensor->ne[3] != 1) { - return layout; - } - - int64_t n_elements = ggml_nelements(tensor); - const size_t alignment = 64; // Good for SIMD - - // Check if requantization is needed (NPU-specific) - auto requant_type = ggml_openvino_get_requant_type(tensor->type); - if (requant_type.has_value()) { - layout.is_requant = true; - layout.requant_type = requant_type; - - // Special case: requant to F16 - just store F16 weights, no scales/biases - if (requant_type.value() == ExtraQuantType::F16) { - layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes - layout.total_size = layout.weights_size; - layout.weights_offset = 0; - // No scales/biases for F16 - return layout; - } - - // Requant to different quantized format (e.g., Q4_0_128) - switch (requant_type.value()) { - case ExtraQuantType::Q4_0_128: - layout.is_u4 = true; - layout.weights_per_block = 128; - break; - case ExtraQuantType::Q8_0_32: - layout.is_u4 = false; - layout.weights_per_block = 32; - break; - default: - // Unsupported requant type - fall through to normal extraction - layout.is_requant = false; - layout.requant_type = std::nullopt; - break; - } - - if (layout.is_requant) { - // Calculate sizes for requantized format - layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; - int64_t n_blocks = n_elements / layout.weights_per_block; - layout.scales_size = n_blocks * sizeof(uint16_t); - layout.biases_size = n_blocks * sizeof(uint16_t); - - layout.weights_offset = 0; - layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; - layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; - layout.total_size = layout.biases_offset + layout.biases_size; - layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); - return layout; - } - } - - // Normal extraction (no requant) - determine format based on tensor type - switch (tensor->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_K: - layout.is_u4 = true; - layout.weights_per_block = 32; - break; - case GGML_TYPE_Q8_0: - layout.is_u4 = false; - layout.weights_per_block = 32; - break; - case GGML_TYPE_Q6_K: - layout.is_u4 = false; - layout.weights_per_block = 16; - break; - case GGML_TYPE_Q5_K: - layout.is_u4 = false; - layout.weights_per_block = 32; - break; - default: - // Unsupported quantization type - return layout; - } - - // Calculate sizes - // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes - layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; - - // Scales and biases: F16 per block - int64_t n_blocks = n_elements / layout.weights_per_block; - layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - layout.biases_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - - // Layout in buffer: [weights | scales | biases] with alignment - layout.weights_offset = 0; - layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; - layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; - layout.total_size = layout.biases_offset + layout.biases_size; - - return layout; -} +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 747d1b8a307..e20ae71e408 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -12,7 +12,11 @@ #include #include #include +#include #include +#include +#include +#include #include #include #include @@ -48,7 +52,8 @@ struct ggml_backend_openvino_buffer_context { // For non-weight buffers (KV cache, compute), we still use contiguous allocation void * data; size_t size; - bool is_weight_buffer; // Set when buffer usage is set to WEIGHTS + + std::shared_ptr ov_tensor; // Track all extras for cleanup std::vector tensor_extras; @@ -57,18 +62,42 @@ struct ggml_backend_openvino_buffer_context { device(device), name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), data(nullptr), - size(size), - is_weight_buffer(false) { - // Allocate aligned contiguous memory - if (size > 0) { + size(size) { + if (size == 0) { + return; + } + + const auto & device_name = ggml_openvino_get_device_name(); + auto & core = ov_singleton_core(); + + if (device_name == "CPU") { #ifdef _WIN32 - data = _aligned_malloc(size, GGML_OPENVINO_BUFFER_ALIGNMENT); + data = _aligned_malloc(alloc_size, GGML_OPENVINO_BUFFER_ALIGNMENT); #else data = aligned_alloc(GGML_OPENVINO_BUFFER_ALIGNMENT, size); #endif - if (data == nullptr) { - GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size); - } + ov_tensor = std::make_shared(ov::element::u8, ov::Shape{size}, data); + } else if (device_name == "GPU") { + auto gpu_context = core.get_default_context("GPU").as(); + auto usm_tensor = gpu_context.create_usm_host_tensor(ov::element::u8, ov::Shape{size}); + data = usm_tensor.get(); + ov_tensor = std::make_shared(std::move(usm_tensor)); + } else { + auto npu_context = core.get_default_context("NPU").as(); + auto l0_tensor = npu_context.create_l0_host_tensor(ov::element::u8, ov::Shape{size}); + data = l0_tensor.get(); + ov_tensor = std::make_shared(std::move(l0_tensor)); + } + + if (data == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size); + return; + } + + if (reinterpret_cast(data) % GGML_OPENVINO_BUFFER_ALIGNMENT != 0) { + GGML_LOG_ERROR("%s: %s buffer is not aligned to %d bytes\n", __func__, device_name.c_str(), + GGML_OPENVINO_BUFFER_ALIGNMENT); + GGML_ABORT("fatal error"); } } @@ -78,15 +107,12 @@ struct ggml_backend_openvino_buffer_context { delete extra; } tensor_extras.clear(); - - // Free contiguous memory - if (data != nullptr) { + if (data && ggml_openvino_get_device_name() == "CPU") { #ifdef _WIN32 _aligned_free(data); #else free(data); #endif - data = nullptr; } } }; @@ -156,57 +182,26 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer } if (layout.total_size > 0) { + // Quantized weight tensor with extraction/requantization uint8_t * buf_base = (uint8_t *) tensor->data; - // 2D shape for weights [rows, cols] - ov::Shape weight_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; - try { - std::shared_ptr constant; - - if (layout.is_requant && layout.requant_type.has_value()) { - // Requantization path - if (layout.requant_type.value() == ExtraQuantType::F16) { - // Requant to F16: create F16 tensor with external memory, requantize fills it - ov::Tensor weights(ov::element::f16, weight_shape, buf_base); - ov::Tensor dummy_scales, dummy_biases; // Not used for F16 - // requantize_to_buffers fills weights and returns a Constant wrapping it - constant = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, - dummy_biases); - - // Store in tensor->extra (use weight_extra since it's F16) - auto * extra = new ggml_openvino_weight_extra(constant); - ctx->tensor_extras.push_back(extra); - tensor->extra = extra; - - GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); - } else { - // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape scale_shape = {static_cast(tensor->ne[1]), - static_cast(tensor->ne[0] / layout.weights_per_block)}; - - ov::Tensor weights(weight_type, weight_shape, buf_base + layout.weights_offset); - ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - ov::Tensor biases(ov::element::f16, scale_shape, buf_base + layout.biases_offset); - - constant = requantize_to_buffers(tensor, data, layout.requant_type.value(), - layout.weights_per_block, weights, scales, biases); - - // Store in tensor->extra - auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), - std::move(biases), constant); - ctx->tensor_extras.push_back(extra); - tensor->extra = extra; + std::shared_ptr constant = process_weight_tensor(tensor, data, buf_base); + constant->set_friendly_name(tensor->name); - GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, - layout.requant_type.value() == ExtraQuantType::Q4_0_128 ? "Q4_0_128" : "Q8_0_32", - layout.is_u4 ? 4 : 8, layout.weights_per_block); - } + // Store in tensor->extra + if (layout.is_requant && layout.requant_type.has_value() && + layout.requant_type.value() == ExtraQuantType::F16) { + // F16 requant case - use weight_extra + auto * extra = new ggml_openvino_weight_extra(constant); + ctx->tensor_extras.push_back(extra); + tensor->extra = extra; + GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); } else { - // Normal extraction path (no requant) + // Quantized case - use quantized_weight_extra + // Create tensors with external memory (already filled by process_weight_tensor) ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + ov::Shape weight_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; ov::Shape scale_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0] / layout.weights_per_block)}; @@ -214,16 +209,20 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); ov::Tensor biases(ov::element::f16, scale_shape, buf_base + layout.biases_offset); - constant = extract_quantized_weights(tensor, data, weights, scales, biases); - - // Store in tensor->extra auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), std::move(biases), constant); ctx->tensor_extras.push_back(extra); tensor->extra = extra; - GGML_LOG_DEBUG("%s: extracted quantized constant for %s (u%d, %zu weights, %ld blocks)\n", __func__, - tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + if (layout.is_requant) { + GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, + layout.requant_type.value() == ExtraQuantType::Q4_0_128 ? "Q4_0_128" : "Q8_0_32", + layout.is_u4 ? 4 : 8, layout.weights_per_block); + } else { + int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + GGML_LOG_DEBUG("%s: extracted quantized constant for %s (u%d, %zu weights, %ld blocks)\n", __func__, + tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + } } } catch (const std::exception & e) { @@ -233,32 +232,9 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer } } else if (is_weight_buffer && is_full_tensor_set && is_2d && (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16)) { - // F16/F32/BF16 weight tensor - copy data and create shared-memory constant - memcpy((char *) tensor->data + offset, data, size); - + // F16/F32/BF16 weight tensor try { - // Get OpenVINO element type - ov::element::Type element_type; - switch (tensor->type) { - case GGML_TYPE_F32: - element_type = ov::element::f32; - break; - case GGML_TYPE_F16: - element_type = ov::element::f16; - break; - case GGML_TYPE_BF16: - element_type = ov::element::bf16; - break; - default: - return; // Should not happen - } - - // Create 2D shape (OpenVINO expects [rows, cols]) - ov::Shape shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; - - // Create ov::Tensor with external memory, then wrap with Constant - ov::Tensor ov_tensor(element_type, shape, tensor->data); - auto constant = std::make_shared(ov_tensor); + std::shared_ptr constant = process_weight_tensor(tensor, data, tensor->data); constant->set_friendly_name(tensor->name); // Store in tensor->extra @@ -418,7 +394,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(in } // Check if a buffer is an OpenVINO buffer -static bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { +bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; } diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 662f27be7ad..6cacc7b0340 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -569,6 +569,112 @@ std::shared_ptr requantize(const ggml_tensor * tensor, ExtraQuantType return requantize_to_buffers(tensor, tensor->data, requant_type, block_size, weights, scales, biases); } +std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(data != nullptr); + + // Get 2D shape for weights [rows, cols] + ov::Shape node_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; + + // Handle F16/F32/BF16 weights + if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + default: + OPENVINO_THROW("Unexpected tensor type in F16/F32/BF16 path"); + } + + if (output_base_ptr) { + // Using external buffer - copy data and create shared-memory constant + size_t tensor_bytes = ggml_nbytes(tensor); + memcpy(output_base_ptr, data, tensor_bytes); + ov::Tensor ov_tensor(element_type, node_shape, output_base_ptr); + return std::make_shared(ov_tensor); + } else { + // Allocate internal buffer + ov::Tensor weights(element_type, node_shape); + memcpy(weights.data(), data, ggml_nelements(tensor) * element_type.size()); + return std::make_shared(weights); + } + } + + // Handle quantized weights + if (!ggml_is_quantized(tensor->type)) { + OPENVINO_THROW("Unsupported weight tensor type: ", ggml_type_name(tensor->type)); + } + + auto layout = ggml_openvino_get_extracted_layout(tensor); + if (layout.total_size == 0) { + OPENVINO_THROW("Unsupported quantized type: ", ggml_type_name(tensor->type)); + } + + std::shared_ptr result; + + if (layout.is_requant && layout.requant_type.has_value()) { + // Requantization path + if (layout.requant_type.value() == ExtraQuantType::F16) { + // Requant to F16 + ov::Tensor weights; + if (output_base_ptr) { + weights = ov::Tensor(ov::element::f16, node_shape, + static_cast(output_base_ptr) + layout.weights_offset); + } else { + weights = ov::Tensor(ov::element::f16, node_shape); + } + ov::Tensor dummy_scales, dummy_biases; // Not used for F16 + result = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, dummy_biases); + } else { + // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + + ov::Tensor weights, scales, biases; + if (output_base_ptr) { + uint8_t * buf_base = static_cast(output_base_ptr); + weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); + scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + biases = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + } else { + weights = ov::Tensor(weight_type, node_shape); + scales = ov::Tensor(ov::element::f16, scale_shape); + biases = ov::Tensor(ov::element::f16, scale_shape); + } + + result = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, weights, + scales, biases); + } + } else { + // Normal extraction path (no requant) + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + + ov::Tensor weights, scales, biases; + if (output_base_ptr) { + uint8_t * buf_base = static_cast(output_base_ptr); + weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); + scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + biases = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + } else { + weights = ov::Tensor(weight_type, node_shape); + scales = ov::Tensor(ov::element::f16, scale_shape); + biases = ov::Tensor(ov::element::f16, scale_shape); + } + + result = extract_quantized_weights(tensor, data, weights, scales, biases); + } + + return result; +} + void quantize_q4_0(const float * x, ov::Tensor & weights_arr, ov::Tensor & scales_arr, diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index 0f14a6ed2dc..b1d286f1b83 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -78,6 +78,16 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, ov::Tensor & scales, ov::Tensor & biases); +// Process weight tensor and create an OpenVINO constant node +// Handles F16/F32/BF16 and quantized weights, with optional requantization +// If output_base_ptr is nullptr, allocates internal buffers (for decoder use) +// If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use) +// Returns the weight constant node +std::shared_ptr process_weight_tensor( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + void * output_base_ptr = nullptr); // Base pointer for output buffers (or nullptr for internal allocation) + void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, int64_t qk); void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 251fb82361b..6d56af93189 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -107,7 +107,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin infer_request_cache.erase(key); std::shared_ptr model; - auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device)); + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static); decoder_end_time = ggml_time_us(); @@ -255,7 +255,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { infer_request_cache_prefill.erase(key); std::shared_ptr model; - auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device)); + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, true, prefill_chunk_size); @@ -404,21 +404,6 @@ ov::AnyMap get_ov_compile_config(const std::string & device) { return config; } -std::map get_types_to_requant(const std::string & device) { - // Use singleton to check if NPU (device param kept for API compatibility) - if (ggml_openvino_is_npu()) { - return { - {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q6_K, ExtraQuantType::F16 }, - {GGML_TYPE_Q5_K, ExtraQuantType::F16 }, - }; - } - return {}; - GGML_UNUSED(device); -} - bool is_naive(ggml_cgraph * cgraph) { constexpr int naive_graph_size_threshold = 20; return cgraph->n_nodes < naive_graph_size_threshold; diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 85bb3a2f882..81fb2c2035d 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -73,8 +73,6 @@ graph_key compute_graph_key(struct ggml_cgraph * cgraph); ov::AnyMap get_ov_compile_config(const std::string & device); -std::map get_types_to_requant(const std::string & device); - ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name); ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, const std::string & param_name); From e6e99445fd7375a8c94003528d51153e8215ad8a Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 19 Dec 2025 16:58:07 +0800 Subject: [PATCH 03/14] Add ov_backend_host_buffer; Use cached remote context --- ggml/include/ggml-openvino.h | 8 ++ ggml/src/ggml-openvino/ggml-decoder.cpp | 29 ++++- .../src/ggml-openvino/ggml-openvino-extra.cpp | 85 ++++++++++++++ ggml/src/ggml-openvino/ggml-openvino-extra.h | 11 ++ ggml/src/ggml-openvino/ggml-openvino.cpp | 111 ++++++++++++++++-- ggml/src/ggml-openvino/utils.cpp | 92 ++++++++------- ggml/src/ggml-openvino/utils.h | 2 - 7 files changed, 281 insertions(+), 57 deletions(-) diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h index b690a16378e..392e26c48ef 100644 --- a/ggml/include/ggml-openvino.h +++ b/ggml/include/ggml-openvino.h @@ -18,9 +18,17 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device); GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend); +GGML_BACKEND_API bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft); + // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device); + GGML_BACKEND_API int ggml_backend_openvino_get_device_count(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void); diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 2d6437f0691..13ef00dcb64 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -593,11 +593,19 @@ void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filena << std::setw(20) << "op" << std::setw(20) << "name" << std::setw(3) << " " - << std::setw(50) << "stride" + << std::setw(62) << "stride" + << std::setw(20) << "buffer_type" << "\n"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + // Get buffer type name + const char * buf_name = "none"; + ggml_backend_buffer_t buf = node->view_src ? node->view_src->buffer : node->buffer; + if (buf) { + buf_name = ggml_backend_buffer_name(buf); + } + file << " - " << std::setw(3) << i << ": [ " << std::setw(5) << node->ne[0] << ", " << std::setw(5) << node->ne[1] << ", " @@ -610,10 +618,18 @@ void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filena << std::setw(5) << node->nb[1] << ", " << std::setw(5) << node->nb[2] << ", " << std::setw(5) << node->nb[3] << "] " + << std::right << std::setw(15) << buf_name << std::right << "\n"; for (int i = 0; i < GGML_MAX_SRC; i++) { if (auto* src = node->src[i]) { + // Get buffer type name for source + const char * src_buf_name = "none"; + ggml_backend_buffer_t src_buf = src->view_src ? src->view_src->buffer : src->buffer; + if (src_buf) { + src_buf_name = ggml_backend_buffer_name(src_buf); + } + file << std::setw(10) << " [ " << std::setw(5) << src->ne[0] << ", " << std::setw(5) << src->ne[1] << ", " @@ -627,6 +643,7 @@ void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filena << std::setw(5) << src->nb[1] << ", " << std::setw(5) << src->nb[2] << ", " << std::setw(5) << src->nb[3] << "] " + << std::right << std::setw(15) << src_buf_name << std::right << "\n"; } } @@ -636,11 +653,19 @@ void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filena for (int i = 0; i < cgraph->n_leafs; i++) { ggml_tensor * node = cgraph->leafs[i]; + // Get buffer type name for leaf + const char * leaf_buf_name = "none"; + ggml_backend_buffer_t leaf_buf = node->view_src ? node->view_src->buffer : node->buffer; + if (leaf_buf) { + leaf_buf_name = ggml_backend_buffer_name(leaf_buf); + } + file << " - " << std::setw(3) << i << ": [ " << std::setw(5) << node->ne[0] << ", " << std::setw(5) << node->ne[1] << "] " << std::setw(8) << ggml_op_name(node->op) << " " - << std::setw(16) << ggml_get_name(node) << "\n"; + << std::setw(16) << ggml_get_name(node) + << std::setw(20) << leaf_buf_name << "\n"; } // clang-format on file << "========================================\n"; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 75b27c8fa81..085ae1ece4c 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -2,6 +2,9 @@ #include "ggml-impl.h" +#include +#include + ov::Core & ov_singleton_core() { static ov::Core core; return core; @@ -22,6 +25,31 @@ void ggml_openvino_device_config::init() { device_name = "CPU"; } is_npu = (device_name == "NPU"); + + auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); + if (device_name == "NPU") { + compile_config = { + {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, + {"NPU_USE_NPUW", "YES" }, + {"NPUW_DEVICES", "NPU" }, + {"NPUW_FOLD", "YES" }, + {"NPUW_WEIGHTS_BANK", "shared"}, + {"NPUW_FUNCALL_FOR_ALL", "YES" }, + {"NPUW_FUNCALL_ASYNC", "YES" }, + {"NPUW_DQ", "YES" }, + {"NPUW_DQ_FULL", "NO" }, + }; + if (cache_dir) { + compile_config["NPUW_CACHE_DIR"] = cache_dir; + } + } else if (cache_dir) { + ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } + + if (device_name != "CPU") { + remote_context = ov_singleton_core().get_default_context(device_name); + } + initialized = true; } @@ -46,6 +74,16 @@ bool ggml_openvino_is_npu() { return ggml_openvino_get_device_config().is_npu; } +// Get the remote context for the current device (returns empty optional for CPU) +std::optional ggml_openvino_get_remote_context() { + return ggml_openvino_get_device_config().remote_context; +} + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config() { + return ggml_openvino_get_device_config().compile_config; +} + // Get requantization type for a tensor type (returns nullopt if no requant needed) std::optional ggml_openvino_get_requant_type(ggml_type type) { if (!ggml_openvino_is_npu()) { @@ -175,3 +213,50 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten return layout; } + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor) { + ov::Shape shape; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + shape.push_back(static_cast(tensor->ne[i])); + } + + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + case GGML_TYPE_I32: + element_type = ov::element::i32; + break; + case GGML_TYPE_I64: + element_type = ov::element::i64; + break; + default: + GGML_LOG_ERROR("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type)); + return nullptr; + } + + const auto & device_name = ggml_openvino_get_device_name(); + auto remote_context = ggml_openvino_get_remote_context(); + + std::shared_ptr ov_tensor; + if (device_name == "CPU") { + ov_tensor = std::make_shared(element_type, shape, tensor->data); + } else if (device_name == "GPU") { + auto gpu_context = remote_context->as(); + auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(usm_tensor)); + } else { + auto npu_context = remote_context->as(); + auto l0_tensor = npu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(l0_tensor)); + } + + return new ggml_openvino_tensor_extra(ov_tensor); +} diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index 7e0138388ff..fdd8312dfff 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -15,6 +16,12 @@ enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; ov::Core & ov_singleton_core(); +// Get the remote context for the current device (returns empty optional for CPU) +std::optional ggml_openvino_get_remote_context(); + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config(); + // ===================================================== // Global Device Configuration (singleton) // ===================================================== @@ -24,6 +31,8 @@ struct ggml_openvino_device_config { std::string device_name = "CPU"; bool is_npu = false; bool initialized = false; + std::optional remote_context; + ov::AnyMap compile_config; void init(); }; @@ -112,3 +121,5 @@ struct ggml_openvino_extracted_layout { // Calculate the buffer layout for extracted quantized data ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor); + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index e20ae71e408..c5c25fb6c1b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -56,7 +56,7 @@ struct ggml_backend_openvino_buffer_context { std::shared_ptr ov_tensor; // Track all extras for cleanup - std::vector tensor_extras; + std::map tensor_extras; ggml_backend_openvino_buffer_context(int device, size_t size) : device(device), @@ -103,8 +103,8 @@ struct ggml_backend_openvino_buffer_context { ~ggml_backend_openvino_buffer_context() { // Clean up all tensor extras - for (auto * extra : tensor_extras) { - delete extra; + for (auto & pair : tensor_extras) { + delete pair.second; } tensor_extras.clear(); if (data && ggml_openvino_get_device_name() == "CPU") { @@ -144,9 +144,20 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu return GGML_STATUS_SUCCESS; } - // For non-view tensors, tensor->extra will be set in set_tensor - // when the actual weight data is loaded - GGML_UNUSED(buffer); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (tensor->data != nullptr) { + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); + if (extra != nullptr) { + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + } + } + return GGML_STATUS_SUCCESS; } @@ -194,7 +205,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer layout.requant_type.value() == ExtraQuantType::F16) { // F16 requant case - use weight_extra auto * extra = new ggml_openvino_weight_extra(constant); - ctx->tensor_extras.push_back(extra); + ctx->tensor_extras[tensor] = extra; tensor->extra = extra; GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); } else { @@ -211,7 +222,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), std::move(biases), constant); - ctx->tensor_extras.push_back(extra); + ctx->tensor_extras[tensor] = extra; tensor->extra = extra; if (layout.is_requant) { @@ -239,7 +250,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer // Store in tensor->extra ggml_openvino_weight_extra * extra = new ggml_openvino_weight_extra(constant); - ctx->tensor_extras.push_back(extra); + ctx->tensor_extras[tensor] = extra; tensor->extra = extra; GGML_LOG_DEBUG("%s: created shared-memory constant for %s\n", __func__, tensor->name); @@ -251,6 +262,19 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer } else { // Non-weight tensor (KV cache, activations, etc.) - just copy data memcpy((char *) tensor->data + offset, data, size); + + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); + if (extra == nullptr) { + GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); + return; + } + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; } } @@ -393,11 +417,67 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(in return &buffer_types[device]; } -// Check if a buffer is an OpenVINO buffer +// ===================================================== +// OpenVINO Host Buffer Implementation +// ===================================================== + +static const char * ggml_backend_openvino_host_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + static std::string name; + name = ctx->name + "_HOST"; + return name.c_str(); +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_host_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_host_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_openvino_buffer_type_is_host, +}; + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector buffer_types; + static std::vector buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_host_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; } +bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; +} + +bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_host_buffer_type_get_name; +} + // ===================================================== // OpenVINO Backend Context and Interface // ===================================================== @@ -552,6 +632,11 @@ static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(g return ggml_backend_openvino_buffer_type(ctx->device); } +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_host_buffer_type(ctx->device); +} + static bool is_op_unsupported_case(const ggml_tensor * op) { switch (op->op) { case GGML_OP_SOFT_MAX: { @@ -731,7 +816,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { // Support our own buffer type and any host buffer (for mmap'd files, etc.) - return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name || ggml_backend_buft_is_host(buft); + return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft); + // return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_openvino_host(buft); GGML_UNUSED(dev); } @@ -743,7 +829,8 @@ static const struct ggml_backend_device_i ggml_backend_openvino_device_interface /* .get_props = */ ggml_backend_openvino_device_get_props, /* .init_backend = */ ggml_backend_openvino_device_init, /* .get_buffer_type = */ ggml_backend_openvino_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, + // /* .get_host_buffer_type = */ NULL, + /* .get_host_buffer_type = */ ggml_backend_openvino_device_get_host_buffer_type, /* .buffer_from_host_ptr = */ NULL, /* .supports_op = */ ggml_backend_openvino_device_supports_op, /* .supports_buft = */ ggml_backend_openvino_device_supports_buft, diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 6d56af93189..89cf51f8801 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -37,11 +37,9 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" -static ov::Core core; - enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { - std::string filename = "cgraph.txt"; + std::string filename = "cgraph_ov.txt"; GgmlOvDecoder::dump_cgraph(cgraph, filename); } @@ -52,8 +50,9 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { } enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device) { + auto & core = ov_singleton_core(); + const auto & config = ggml_openvino_get_compile_config(); static auto is_static = false; - static auto config = get_ov_compile_config(device); // if (is_naive(cgraph)) { // return naive_compute(cgraph, core, device, config); @@ -124,7 +123,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin ov::serialize(model, timestamped_filename); } - auto compiled_model = core.compile_model(model, device, config); + ov::CompiledModel compiled_model; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model = core.compile_model(model, remote_context.value(), config); + } else { + compiled_model = core.compile_model(model, device, config); + } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); infer_request_cache[key] = infer_request; @@ -173,18 +178,20 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin if (getenv("GGML_OPENVINO_PROFILING")) { GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); - GGML_LOG_INFO(" - Graph decoder Time: %ld ms \n", (decoder_end_time - start_time) / 1000); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); if (!cache_hit) { - GGML_LOG_INFO(" - Graph conversion Time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); - GGML_LOG_INFO(" - Graph compile Time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); } - GGML_LOG_INFO(" - Graph Inference Time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); } return GGML_STATUS_SUCCESS; } enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { + auto & core = ov_singleton_core(); + auto get_prefill_chunk_size = [] { const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE"); if (chunk_size_str && atoi(chunk_size_str) > 0) { @@ -196,7 +203,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { static std::string device = "NPU"; static auto is_static = true; static auto prefill_chunk_size = get_prefill_chunk_size(); - static auto config = get_ov_compile_config(device); + const auto & config = ggml_openvino_get_compile_config(); if (is_naive(cgraph)) { return naive_compute(cgraph, core, device, config); @@ -281,8 +288,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { ov::serialize(model_decode, timestamped_filename); } - auto compiled_model_prefill = core.compile_model(model_prefill, device, get_ov_compile_config(device)); - auto compiled_model_decode = core.compile_model(model_decode, device, get_ov_compile_config(device)); + ov::CompiledModel compiled_model_prefill; + ov::CompiledModel compiled_model_decode; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model_prefill = core.compile_model(model_prefill, remote_context.value(), config); + compiled_model_decode = core.compile_model(model_decode, remote_context.value(), config); + } else { + compiled_model_prefill = core.compile_model(model_prefill, device, config); + compiled_model_decode = core.compile_model(model_decode, device, config); + } infer_request_cache_prefill[key] = std::make_shared(compiled_model_prefill.create_infer_request()); @@ -369,41 +384,17 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { if (getenv("GGML_OPENVINO_PROFILING")) { GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); - GGML_LOG_INFO(" - Graph decoder Time: %ld ms \n", (decoder_end_time - start_time) / 1000); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); if (!cache_hit) { - GGML_LOG_INFO(" - Graph conversion Time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); - GGML_LOG_INFO(" - Graph compile Time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); } - GGML_LOG_INFO(" - Graph Inference Time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); } return GGML_STATUS_SUCCESS; } -ov::AnyMap get_ov_compile_config(const std::string & device) { - ov::AnyMap config; - auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); - if (device == "NPU") { - config = { - {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, - {"NPU_USE_NPUW", "YES" }, - {"NPUW_DEVICES", "NPU" }, - {"NPUW_FOLD", "YES" }, - {"NPUW_WEIGHTS_BANK", "shared"}, - {"NPUW_FUNCALL_FOR_ALL", "YES" }, - {"NPUW_FUNCALL_ASYNC", "YES" }, - {"NPUW_DQ", "YES" }, - {"NPUW_DQ_FULL", "NO" }, - }; - if (cache_dir) { - config["NPUW_CACHE_DIR"] = cache_dir; - } - } else if (cache_dir) { - core.set_property(ov::cache_dir(cache_dir)); - } - return config; -} - bool is_naive(ggml_cgraph * cgraph) { constexpr int naive_graph_size_threshold = 20; return cgraph->n_nodes < naive_graph_size_threshold; @@ -428,7 +419,14 @@ enum ggml_status naive_compute(ggml_cgraph * cgraph, if (getenv("GGML_OPENVINO_DUMP_IR")) { ov::serialize(model, "IR_naive.xml"); } - auto infer_request = core.compile_model(model, device, config).create_infer_request(); + + ov::InferRequest infer_request; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + infer_request = core.compile_model(model, remote_context.value(), config).create_infer_request(); + } else { + infer_request = core.compile_model(model, device, config).create_infer_request(); + } auto ov_params = model->get_parameters(); for (size_t i = 0; i < ov_params.size(); i++) { @@ -451,6 +449,18 @@ enum ggml_status naive_compute(ggml_cgraph * cgraph, namespace { ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string & name) { const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name); + + if (ggml_tensor->extra != nullptr) { + // GGML_LOG_DEBUG("Using ggml_tensor->extra as ov::Tensor for input: %s\n", name.c_str()); + auto * extra_base = static_cast(ggml_tensor->extra); + if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) { + throw std::runtime_error("ggml tensor extra is not of type TENSOR for input: " + name); + } + auto * tensor_extra = static_cast(extra_base); + return *tensor_extra->tensor; + } + + // GGML_LOG_DEBUG("Converting ggml tensor to ov::Tensor for input: %s\n", name.c_str()); auto * input_data = ggml_tensor->data; ov::Shape input_shape; if (ggml_tensor->op == GGML_OP_VIEW) { diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 81fb2c2035d..44ca2db00fa 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -71,8 +71,6 @@ bool get_is_prefill(const ggml_tensor * inp_pos); graph_key compute_graph_key(struct ggml_cgraph * cgraph); -ov::AnyMap get_ov_compile_config(const std::string & device); - ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name); ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, const std::string & param_name); From 7146806960d3b4d2210ef28401195a7df8cad1fa Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 22 Dec 2025 16:45:17 +0800 Subject: [PATCH 04/14] Put kvcache on GPU --- .../src/ggml-openvino/ggml-openvino-extra.cpp | 78 +++++++- ggml/src/ggml-openvino/ggml-openvino-extra.h | 33 ++++ ggml/src/ggml-openvino/ggml-openvino.cpp | 169 ++++++++++++++++-- 3 files changed, 262 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 085ae1ece4c..aa50d46c03d 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -46,13 +46,56 @@ void ggml_openvino_device_config::init() { ov_singleton_core().set_property(ov::cache_dir(cache_dir)); } - if (device_name != "CPU") { + // Initialize remote context with queue sharing for GPU + if (device_name == "GPU") { + // Create OpenCL context and queue + cl_int err; + cl_platform_id platform; + err = clGetPlatformIDs(1, &platform, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL platform: %d\n", err); + return; + } + + cl_device_id cl_device; + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &cl_device, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL device: %d\n", err); + return; + } + + cl_context cl_ctx = clCreateContext(nullptr, 1, &cl_device, nullptr, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL context: %d\n", err); + return; + } + + cl_queue = clCreateCommandQueueWithProperties(cl_ctx, cl_device, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL command queue: %d\n", err); + clReleaseContext(cl_ctx); + return; + } + + // Create OpenVINO remote context with queue sharing + remote_context = ov::intel_gpu::ocl::ClContext(ov_singleton_core(), cl_queue); + + // Release the context (queue keeps a reference) + clReleaseContext(cl_ctx); + } else if (device_name == "NPU") { remote_context = ov_singleton_core().get_default_context(device_name); } initialized = true; } +ggml_openvino_device_config::~ggml_openvino_device_config() { + if (cl_queue != nullptr) { + clReleaseCommandQueue(cl_queue); + cl_queue = nullptr; + } +} + // Get the global device config singleton ggml_openvino_device_config & ggml_openvino_get_device_config() { static ggml_openvino_device_config config; @@ -84,6 +127,39 @@ const ov::AnyMap & ggml_openvino_get_compile_config() { return ggml_openvino_get_device_config().compile_config; } +// Get the OpenCL command queue for GPU operations +cl_command_queue ggml_openvino_get_cl_queue() { + return ggml_openvino_get_device_config().cl_queue; +} + +// Get the clEnqueueMemFillINTEL function pointer (lazy load) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL() { + static clEnqueueMemFillINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemFillINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemFillINTEL"); + } + } + return fn; +} + +// Get the clEnqueueMemcpyINTEL function pointer (lazy load) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() { + static clEnqueueMemcpyINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemcpyINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemcpyINTEL"); + } + } + return fn; +} + // Get requantization type for a tensor type (returns nullopt if no requant needed) std::optional ggml_openvino_get_requant_type(ggml_type type) { if (!ggml_openvino_is_npu()) { diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index fdd8312dfff..a1a85141906 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -3,6 +3,9 @@ #include "ggml.h" #include "openvino/runtime/core.hpp" +#define CL_TARGET_OPENCL_VERSION 300 +#include + #include #include #include @@ -22,6 +25,34 @@ std::optional ggml_openvino_get_remote_context(); // Get the compile config for the current device const ov::AnyMap & ggml_openvino_get_compile_config(); +// Get the OpenCL command queue for GPU operations (returns nullptr for CPU/NPU) +cl_command_queue ggml_openvino_get_cl_queue(); + +// Intel USM extension function type +typedef cl_int(CL_API_CALL * clEnqueueMemFillINTEL_fn)(cl_command_queue queue, + void * dst_ptr, + const void * pattern, + size_t pattern_size, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +typedef cl_int(CL_API_CALL * clEnqueueMemcpyINTEL_fn)(cl_command_queue queue, + cl_bool blocking, + void * dst_ptr, + const void * src_ptr, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +// Get the clEnqueueMemFillINTEL function pointer (returns nullptr if not available) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL(); + +// Get the clEnqueueMemcpyINTEL function pointer (returns nullptr if not available) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL(); + // ===================================================== // Global Device Configuration (singleton) // ===================================================== @@ -33,8 +64,10 @@ struct ggml_openvino_device_config { bool initialized = false; std::optional remote_context; ov::AnyMap compile_config; + cl_command_queue cl_queue = nullptr; void init(); + ~ggml_openvino_device_config(); }; // Get the global device config singleton diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index c5c25fb6c1b..e139c2d662d 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -8,6 +8,8 @@ #include "ggml-quants.hpp" #include "ggml.h" +#include + #include #include #include @@ -52,17 +54,23 @@ struct ggml_backend_openvino_buffer_context { // For non-weight buffers (KV cache, compute), we still use contiguous allocation void * data; size_t size; + bool is_remote; - std::shared_ptr ov_tensor; + // Wrapping of the buffer + std::shared_ptr ov_buffer; // Track all extras for cleanup std::map tensor_extras; - ggml_backend_openvino_buffer_context(int device, size_t size) : + // Used for re-allocation on device for kvcache + void * data_prev; + + ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) : device(device), name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), data(nullptr), - size(size) { + size(size), + is_remote(is_remote) { if (size == 0) { return; } @@ -76,17 +84,22 @@ struct ggml_backend_openvino_buffer_context { #else data = aligned_alloc(GGML_OPENVINO_BUFFER_ALIGNMENT, size); #endif - ov_tensor = std::make_shared(ov::element::u8, ov::Shape{size}, data); + ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } else if (device_name == "GPU") { auto gpu_context = core.get_default_context("GPU").as(); - auto usm_tensor = gpu_context.create_usm_host_tensor(ov::element::u8, ov::Shape{size}); + ov::intel_gpu::ocl::USMTensor usm_tensor; + if (is_remote) { + usm_tensor = gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + } else { + usm_tensor = gpu_context.create_usm_host_tensor(ov::element::u8, ov::Shape{size}); + } data = usm_tensor.get(); - ov_tensor = std::make_shared(std::move(usm_tensor)); + ov_buffer = std::make_shared(std::move(usm_tensor)); } else { auto npu_context = core.get_default_context("NPU").as(); auto l0_tensor = npu_context.create_l0_host_tensor(ov::element::u8, ov::Shape{size}); data = l0_tensor.get(); - ov_tensor = std::make_shared(std::move(l0_tensor)); + ov_buffer = std::make_shared(std::move(l0_tensor)); } if (data == nullptr) { @@ -135,6 +148,22 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer } static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Put kvcache on device memory for GPU + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && + ggml_openvino_get_device_name() == "GPU") { + GGML_ASSERT(ctx->tensor_extras.empty()); + auto device = ctx->device; + auto size = ctx->size; + auto * data_prev = ctx->data; + delete ctx; + ctx = new ggml_backend_openvino_buffer_context(device, size, true); + buffer->context = ctx; + tensor->data = (char *) ctx->data + ((char *) tensor->data - (char *) data_prev); + } + // Views share the extra from view_src if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -144,7 +173,7 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu return GGML_STATUS_SUCCESS; } - ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + ctx = (ggml_backend_openvino_buffer_context *) buffer->context; if (tensor->data != nullptr) { ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); @@ -166,9 +195,28 @@ static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buf uint8_t value, size_t offset, size_t size) { + GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); - memset((char *) tensor->data + offset, value, size); - GGML_UNUSED(buffer); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memfill + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, (char *) tensor->data + offset, &pattern, sizeof(pattern), size, 0, nullptr, + nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer\n", __func__); + } + } else { + memset((char *) tensor->data + offset, value, size); + } } static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -176,6 +224,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer const void * data, size_t offset, size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; @@ -260,8 +309,23 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer e.what()); } } else { - // Non-weight tensor (KV cache, activations, etc.) - just copy data - memcpy((char *) tensor->data + offset, data, size); + // Non-weight tensor (KV cache, activations, etc.) - copy data + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy (host-to-device) + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, (char *) tensor->data + offset, data, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy((char *) tensor->data + offset, data, size); + } ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); if (extra == nullptr) { @@ -283,28 +347,99 @@ static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer void * data, size_t offset, size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); - memcpy(data, (const char *) tensor->data + offset, size); - GGML_UNUSED(buffer); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy (device-to-host) + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, data, (const char *) tensor->data + offset, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy(data, (const char *) tensor->data + offset, size); + } } static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + // GGML_LOG_DEBUG("%s: src tensor name=%s, dst tensor name=%s\n", __func__, src->name, dst->name); GGML_ASSERT(src != nullptr && dst != nullptr); - // Can copy from any host buffer (including other OpenVINO buffers) + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue == nullptr || mem_cpy_fn == nullptr) { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + return false; + } + // Can copy from host to device + if (ggml_backend_buffer_is_host(src->buffer)) { + cl_int err = mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (host-to-device) failed with error %d\n", __func__, err); + return false; + } + return true; + } + // Can also copy from device to device if both are OpenVINO remote buffers + if (ggml_backend_buffer_is_openvino(src->buffer)) { + ggml_backend_openvino_buffer_context * src_ctx = + (ggml_backend_openvino_buffer_context *) src->buffer->context; + if (src_ctx->is_remote) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (device-to-device) failed with error %d\n", __func__, + err); + return false; + } + return true; + } + } + return false; + } + + // Host buffer - can copy from any host buffer if (ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, ggml_nbytes(src)); return true; } return false; - GGML_UNUSED(buffer); } static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->data != nullptr) { + GGML_ASSERT(ctx->data != nullptr); + if (!ctx->is_remote) { memset(ctx->data, value, ctx->size); + } else { + // For remote (device) buffers, use OpenCL command queue + GGML_ASSERT(ggml_openvino_get_device_name() == "GPU"); + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, ctx->data, &pattern, sizeof(pattern), ctx->size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_WARN("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_WARN("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\n", + __func__); + } } } From 1f99f5f0f4b863995daec91bb3fd40638492aaab Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 24 Dec 2025 10:51:13 +0800 Subject: [PATCH 05/14] Use ggml_aligned_malloc --- ggml/include/ggml-openvino.h | 6 ++++-- ggml/src/ggml-openvino/ggml-openvino.cpp | 23 +++++------------------ 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h index 392e26c48ef..46c1485f663 100644 --- a/ggml/include/ggml-openvino.h +++ b/ggml/include/ggml-openvino.h @@ -51,8 +51,10 @@ struct ggml_openvino_device_info { std::array default_tensor_split = {}; }; -const ggml_openvino_device_info & ggml_openvino_info(); - #ifdef __cplusplus } #endif + +#ifdef __cplusplus +const ggml_openvino_device_info & ggml_openvino_info(); +#endif diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index e139c2d662d..acaa3ddc001 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -24,11 +24,6 @@ #include #include -#define GGML_OPENVINO_MAX_STREAMS 8 - -// OpenVINO buffer alignment (same as CPU for compatibility) -#define GGML_OPENVINO_BUFFER_ALIGNMENT 64 - // ===================================================== // OpenVINO Buffer Implementation using ov::Tensor // ===================================================== @@ -79,11 +74,7 @@ struct ggml_backend_openvino_buffer_context { auto & core = ov_singleton_core(); if (device_name == "CPU") { -#ifdef _WIN32 - data = _aligned_malloc(alloc_size, GGML_OPENVINO_BUFFER_ALIGNMENT); -#else - data = aligned_alloc(GGML_OPENVINO_BUFFER_ALIGNMENT, size); -#endif + data = ggml_aligned_malloc(size); ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } else if (device_name == "GPU") { auto gpu_context = core.get_default_context("GPU").as(); @@ -107,9 +98,9 @@ struct ggml_backend_openvino_buffer_context { return; } - if (reinterpret_cast(data) % GGML_OPENVINO_BUFFER_ALIGNMENT != 0) { + if (reinterpret_cast(data) % TENSOR_ALIGNMENT != 0) { GGML_LOG_ERROR("%s: %s buffer is not aligned to %d bytes\n", __func__, device_name.c_str(), - GGML_OPENVINO_BUFFER_ALIGNMENT); + TENSOR_ALIGNMENT); GGML_ABORT("fatal error"); } } @@ -121,11 +112,7 @@ struct ggml_backend_openvino_buffer_context { } tensor_extras.clear(); if (data && ggml_openvino_get_device_name() == "CPU") { -#ifdef _WIN32 - _aligned_free(data); -#else - free(data); -#endif + ggml_aligned_free(data, size); } } }; @@ -479,7 +466,7 @@ static ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml static size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { GGML_UNUSED(buft); - return GGML_OPENVINO_BUFFER_ALIGNMENT; + return TENSOR_ALIGNMENT; } static size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { From b97bc3f8b589e4b2efb83832f4b75152d91b3499 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 25 Dec 2025 16:07:44 +0800 Subject: [PATCH 06/14] only use remote tensor for kvcache --- .../src/ggml-openvino/ggml-openvino-extra.cpp | 22 +++---- ggml/src/ggml-openvino/ggml-openvino-extra.h | 2 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 58 +++++++++---------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index aa50d46c03d..908a9752477 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -290,7 +290,7 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten return layout; } -ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor) { +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote) { ov::Shape shape; for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { shape.push_back(static_cast(tensor->ne[i])); @@ -322,16 +322,18 @@ ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor auto remote_context = ggml_openvino_get_remote_context(); std::shared_ptr ov_tensor; - if (device_name == "CPU") { - ov_tensor = std::make_shared(element_type, shape, tensor->data); - } else if (device_name == "GPU") { - auto gpu_context = remote_context->as(); - auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); - ov_tensor = std::make_shared(std::move(usm_tensor)); + if (is_remote) { + if (device_name == "GPU") { + auto gpu_context = remote_context->as(); + auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(usm_tensor)); + } else { + auto npu_context = remote_context->as(); + auto l0_tensor = npu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(l0_tensor)); + } } else { - auto npu_context = remote_context->as(); - auto l0_tensor = npu_context.create_tensor(element_type, shape, tensor->data); - ov_tensor = std::make_shared(std::move(l0_tensor)); + ov_tensor = std::make_shared(element_type, shape, tensor->data); } return new ggml_openvino_tensor_extra(ov_tensor); diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index a1a85141906..2f9d257769d 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -155,4 +155,4 @@ struct ggml_openvino_extracted_layout { // Calculate the buffer layout for extracted quantized data ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor); -ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor); +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index acaa3ddc001..c0d555e86fa 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -8,8 +8,6 @@ #include "ggml-quants.hpp" #include "ggml.h" -#include - #include #include #include @@ -73,24 +71,22 @@ struct ggml_backend_openvino_buffer_context { const auto & device_name = ggml_openvino_get_device_name(); auto & core = ov_singleton_core(); - if (device_name == "CPU") { - data = ggml_aligned_malloc(size); - ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); - } else if (device_name == "GPU") { - auto gpu_context = core.get_default_context("GPU").as(); - ov::intel_gpu::ocl::USMTensor usm_tensor; - if (is_remote) { - usm_tensor = gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + if (is_remote) { + if (device_name == "GPU") { + auto gpu_context = core.get_default_context("GPU").as(); + ov::intel_gpu::ocl::USMTensor usm_tensor = + gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + data = usm_tensor.get(); + ov_buffer = std::make_shared(std::move(usm_tensor)); } else { - usm_tensor = gpu_context.create_usm_host_tensor(ov::element::u8, ov::Shape{size}); + auto npu_context = core.get_default_context("NPU").as(); + auto l0_tensor = npu_context.create_l0_host_tensor(ov::element::u8, ov::Shape{size}); + data = l0_tensor.get(); + ov_buffer = std::make_shared(std::move(l0_tensor)); } - data = usm_tensor.get(); - ov_buffer = std::make_shared(std::move(usm_tensor)); } else { - auto npu_context = core.get_default_context("NPU").as(); - auto l0_tensor = npu_context.create_l0_host_tensor(ov::element::u8, ov::Shape{size}); - data = l0_tensor.get(); - ov_buffer = std::make_shared(std::move(l0_tensor)); + data = ggml_aligned_malloc(size); + ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } if (data == nullptr) { @@ -111,7 +107,7 @@ struct ggml_backend_openvino_buffer_context { delete pair.second; } tensor_extras.clear(); - if (data && ggml_openvino_get_device_name() == "CPU") { + if (!is_remote && data != nullptr) { ggml_aligned_free(data, size); } } @@ -135,12 +131,12 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer } static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - // Put kvcache on device memory for GPU + // Put kvcache on device memory if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && - ggml_openvino_get_device_name() == "GPU") { + ggml_openvino_get_device_name() != "CPU") { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -163,7 +159,7 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu ctx = (ggml_backend_openvino_buffer_context *) buffer->context; if (tensor->data != nullptr) { - ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); if (extra != nullptr) { auto it = ctx->tensor_extras.find(tensor); if (it != ctx->tensor_extras.end()) { @@ -186,7 +182,7 @@ static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buf GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote) { + if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { // For remote (device) buffers, use OpenCL USM memfill cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); @@ -297,8 +293,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer } } else { // Non-weight tensor (KV cache, activations, etc.) - copy data - if (ctx->is_remote) { - // For remote (device) buffers, use OpenCL USM memcpy (host-to-device) + if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); if (queue != nullptr && mem_cpy_fn != nullptr) { @@ -314,7 +309,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer memcpy((char *) tensor->data + offset, data, size); } - ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor); + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); if (extra == nullptr) { GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); return; @@ -338,7 +333,7 @@ static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote) { + if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { // For remote (device) buffers, use OpenCL USM memcpy (device-to-host) cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); @@ -363,7 +358,7 @@ static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer GGML_ASSERT(src != nullptr && dst != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote) { + if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { // For remote (device) buffers, use OpenCL USM memcpy cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); @@ -409,10 +404,7 @@ static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; GGML_ASSERT(ctx->data != nullptr); - if (!ctx->is_remote) { - memset(ctx->data, value, ctx->size); - } else { - // For remote (device) buffers, use OpenCL command queue + if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { GGML_ASSERT(ggml_openvino_get_device_name() == "GPU"); cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); @@ -427,6 +419,8 @@ static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uin GGML_LOG_WARN("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\n", __func__); } + } else { + memset(ctx->data, value, ctx->size); } } From 6b477166e64e3516bd5acd3a979c1e84093c9a99 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 25 Dec 2025 17:08:51 +0800 Subject: [PATCH 07/14] only use remote tensor for kvcache for GPU --- .../src/ggml-openvino/ggml-openvino-extra.cpp | 19 ++++------- ggml/src/ggml-openvino/ggml-openvino.cpp | 34 ++++++++----------- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 908a9752477..eff1627cb4c 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -1,6 +1,7 @@ #include "ggml-openvino-extra.h" #include "ggml-impl.h" +#include "ggml.h" #include #include @@ -224,9 +225,8 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_per_block = 32; break; default: - // Unsupported requant type - fall through to normal extraction - layout.is_requant = false; - layout.requant_type = std::nullopt; + layout.weights_per_block = -1; + GGML_ABORT("Code of re-quantizing to channel-wise is not updated"); break; } @@ -323,15 +323,10 @@ ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor std::shared_ptr ov_tensor; if (is_remote) { - if (device_name == "GPU") { - auto gpu_context = remote_context->as(); - auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); - ov_tensor = std::make_shared(std::move(usm_tensor)); - } else { - auto npu_context = remote_context->as(); - auto l0_tensor = npu_context.create_tensor(element_type, shape, tensor->data); - ov_tensor = std::make_shared(std::move(l0_tensor)); - } + GGML_ASSERT(device_name == "GPU"); + auto gpu_context = remote_context->as(); + auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(usm_tensor)); } else { ov_tensor = std::make_shared(element_type, shape, tensor->data); } diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index c0d555e86fa..9b1fd55adfb 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -72,18 +72,13 @@ struct ggml_backend_openvino_buffer_context { auto & core = ov_singleton_core(); if (is_remote) { - if (device_name == "GPU") { - auto gpu_context = core.get_default_context("GPU").as(); - ov::intel_gpu::ocl::USMTensor usm_tensor = - gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); - data = usm_tensor.get(); - ov_buffer = std::make_shared(std::move(usm_tensor)); - } else { - auto npu_context = core.get_default_context("NPU").as(); - auto l0_tensor = npu_context.create_l0_host_tensor(ov::element::u8, ov::Shape{size}); - data = l0_tensor.get(); - ov_buffer = std::make_shared(std::move(l0_tensor)); - } + // NPU memory is too small even for kvcache + GGML_ASSERT(device_name == "GPU"); + auto gpu_context = core.get_default_context("GPU").as(); + ov::intel_gpu::ocl::USMTensor usm_tensor = + gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + data = usm_tensor.get(); + ov_buffer = std::make_shared(std::move(usm_tensor)); } else { data = ggml_aligned_malloc(size); ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); @@ -134,9 +129,9 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - // Put kvcache on device memory + // Put kvcache on device memory for GPU if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && - ggml_openvino_get_device_name() != "CPU") { + ggml_openvino_get_device_name() == "GPU") { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -182,7 +177,7 @@ static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buf GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { + if (ctx->is_remote) { // For remote (device) buffers, use OpenCL USM memfill cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); @@ -293,7 +288,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer } } else { // Non-weight tensor (KV cache, activations, etc.) - copy data - if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { + if (ctx->is_remote) { cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); if (queue != nullptr && mem_cpy_fn != nullptr) { @@ -333,7 +328,7 @@ static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { + if (ctx->is_remote) { // For remote (device) buffers, use OpenCL USM memcpy (device-to-host) cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); @@ -358,7 +353,7 @@ static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer GGML_ASSERT(src != nullptr && dst != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { + if (ctx->is_remote) { // For remote (device) buffers, use OpenCL USM memcpy cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); @@ -404,8 +399,7 @@ static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; GGML_ASSERT(ctx->data != nullptr); - if (ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { - GGML_ASSERT(ggml_openvino_get_device_name() == "GPU"); + if (ctx->is_remote) { cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); if (queue != nullptr && mem_fill_fn != nullptr) { From beac408768461ce7df57198a9065781e26ee22e4 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 26 Dec 2025 11:38:45 +0800 Subject: [PATCH 08/14] FIX: use remote tensor from singleton --- ggml/src/ggml-openvino/ggml-openvino.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 9b1fd55adfb..a1b5b5dd321 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -69,12 +69,11 @@ struct ggml_backend_openvino_buffer_context { } const auto & device_name = ggml_openvino_get_device_name(); - auto & core = ov_singleton_core(); if (is_remote) { - // NPU memory is too small even for kvcache GGML_ASSERT(device_name == "GPU"); - auto gpu_context = core.get_default_context("GPU").as(); + auto remote_context = ggml_openvino_get_remote_context(); + auto gpu_context = remote_context->as(); ov::intel_gpu::ocl::USMTensor usm_tensor = gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); data = usm_tensor.get(); @@ -129,7 +128,7 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - // Put kvcache on device memory for GPU + // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU") { GGML_ASSERT(ctx->tensor_extras.empty()); From c5881540f65274a494b02cb44773dd260cebe300 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 26 Dec 2025 13:52:09 +0800 Subject: [PATCH 09/14] Update build.md to include OpenCL --- docs/build.md | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/docs/build.md b/docs/build.md index f7e793c155a..54d6e9517b7 100644 --- a/docs/build.md +++ b/docs/build.md @@ -610,10 +610,23 @@ Follow the instructions below to install OpenVINO runtime and build llama.cpp wi sudo apt-get update sudo apt-get install -y build-essential libcurl4-openssl-dev libtbb12 cmake ninja-build python3-pip curl wget tar ``` + - OpenCL + ```bash + sudo apt install ocl-icd-opencl-dev opencl-headers opencl-clhpp-headers intel-opencl-icd + ``` - **Windows:** - - Download Microsoft.VisualStudio.2022.BuildTools [Visual_Studio_Build_Tools]https://aka.ms/vs/17/release/vs_BuildTools.exe Select "Desktop development with C++" under workloads. + - Download Microsoft.VisualStudio.2022.BuildTools: [Visual_Studio_Build_Tools](https://aka.ms/vs/17/release/vs_BuildTools.exe) + Select "Desktop development with C++" under workloads - Install git + - Install OpenCL with vcpkg + ```powershell + cd C:\ + git clone https://github.com/microsoft/vcpkg + cd vcpkg + bootstrap-vcpkg.bat + vcpkg install opencl + ``` - Use "x64 Native Tools Command Prompt" for Build ### 1. Install OpenVINO Runtime @@ -625,19 +638,19 @@ Follow the instructions below to install OpenVINO runtime and build llama.cpp wi
📦 Click to expand OpenVINO 2025.3 installation from an archive file on Ubuntu
- + ```bash wget https://raw.githubusercontent.com/ravi9/misc-scripts/main/openvino/ov-archive-install/install-openvino-from-archive.sh chmod +x install-openvino-from-archive.sh ./install-openvino-from-archive.sh ``` + + Verify OpenVINO is initialized properly: + ```bash + echo $OpenVINO_DIR + ```
- - Verify OpenVINO is initialized properly - - **Linux:** - ```bash - echo $OpenVINO_DIR - ``` ### 2. Build llama.cpp with OpenVINO Backend @@ -657,14 +670,14 @@ git switch dev_backend_openvino cmake --build build/ReleaseOV --config Release -j $(nproc) ``` -- **Windows:** +- **Windows:** ```bash # Build with OpenVINO support "C:\Program Files (x86)\Intel\openvino_2025.3.0\setupvars.bat" - cmake -B build/ReleaseOV -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF -DLLAMA_CURL=OFF + cmake -B build\ReleaseOV -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF -DLLAMA_CURL=OFF -DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake cmake --build build\ReleaseOV --config Release ``` - - For faster compilation, add the -- /m argument to run multiple jobs in parallel with as many CPU cores available. + - For faster compilation, add the -- /m argument to run multiple jobs in parallel with as many CPU cores available. ```bash cmake --build build\ReleaseOV --config Release -- /m ``` @@ -741,7 +754,7 @@ docker build --target=full -t llama-openvino:full -f .devops/openvino.Dockerfile # Build a minimal CLI-only image containing just the llama-cli executable. docker build --target=light -t llama-openvino:light -f .devops/openvino.Dockerfile . -# Builds a server-only image with llama-server executable, health check endpoint, and REST API support. +# Builds a server-only image with llama-server executable, health check endpoint, and REST API support. docker build --target=server -t llama-openvino:server -f .devops/openvino.Dockerfile . # If you are behind a proxy: @@ -764,17 +777,17 @@ llama-openvino:light --no-warmup -m /models/Llama-3.2-1B-Instruct.fp16.gguf docker run --rm -it --env GGML_OPENVINO_DEVICE=NPU -v ~/models:/models \ --device=/dev/accel --group-add=$(stat -c "%g" /dev/dri/render* | head -n 1) -u $(id -u):$(id -g) \ llama-openvino:light --no-warmup -m /models/Llama-3.2-1B-Instruct.fp16.gguf -``` +``` Run Llama.cpp Server with OpenVINO Backend ```bash # Run the Server Docker container server -docker run --rm -it -p 8080:8080 -v ~/models:/models llama-openvino:server --no-warmup -m /models/Llama-3.2-1B-Instruct.fp16.gguf +docker run --rm -it -p 8080:8080 -v ~/models:/models llama-openvino:server --no-warmup -m /models/Llama-3.2-1B-Instruct.fp16.gguf # In a NEW terminal, test the server with curl # If you are behind a proxy, make sure to set NO_PROXY to avoid proxy for localhost -export NO_PROXY=localhost,127.0.0.1 +export NO_PROXY=localhost,127.0.0.1 # Test health endpoint curl -f http://localhost:8080/health From d66a58eeef9c2f111d1613e6fc484135d259e627 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 26 Dec 2025 15:18:30 +0800 Subject: [PATCH 10/14] NPU always requant to q4_0_128 --- .../src/ggml-openvino/ggml-openvino-extra.cpp | 16 ++++++--- ggml/src/ggml-openvino/ggml-openvino-extra.h | 2 +- ggml/src/ggml-openvino/ggml-quants.cpp | 34 ------------------- ggml/src/ggml-openvino/ggml-quants.hpp | 4 --- 4 files changed, 12 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index eff1627cb4c..26cc386dff0 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -3,6 +3,7 @@ #include "ggml-impl.h" #include "ggml.h" +#include #include #include @@ -162,19 +163,24 @@ clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() { } // Get requantization type for a tensor type (returns nullopt if no requant needed) -std::optional ggml_openvino_get_requant_type(ggml_type type) { +std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor) { if (!ggml_openvino_is_npu()) { return std::nullopt; } // NPU requantization rules - switch (type) { + if (strncmp(tensor->name, "token_embd.weight", 17) == 0) { + return ExtraQuantType::F16; + } + if (strncmp(tensor->name, "output.weight", 13) == 0) { + return ExtraQuantType::Q4_0_128; + } + switch (tensor->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_K: - return ExtraQuantType::Q4_0_128; case GGML_TYPE_Q6_K: case GGML_TYPE_Q5_K: - return ExtraQuantType::F16; + return ExtraQuantType::Q4_0_128; default: return std::nullopt; } @@ -200,7 +206,7 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten const size_t alignment = 64; // Good for SIMD // Check if requantization is needed (NPU-specific) - auto requant_type = ggml_openvino_get_requant_type(tensor->type); + auto requant_type = ggml_openvino_get_requant_type(tensor); if (requant_type.has_value()) { layout.is_requant = true; layout.requant_type = requant_type; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index 2f9d257769d..fbfe459edf5 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -83,7 +83,7 @@ const std::string & ggml_openvino_get_device_name(); bool ggml_openvino_is_npu(); // Get requantization type for a tensor type (returns nullopt if no requant needed) -std::optional ggml_openvino_get_requant_type(ggml_type type); +std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor); // ===================================================== // OpenVINO Tensor Extra Types diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 6cacc7b0340..1a5679cd8dd 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -535,40 +535,6 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, return result; } -std::shared_ptr requantize(const ggml_tensor * tensor, ExtraQuantType requant_type) { - ov::Shape node_shape = {(uint64_t) (tensor->ne[1]), (uint64_t) (tensor->ne[0])}; - - // FIXME hardcoded workaround to fix the case where token_emb.weight is q4_0 (instead of q6_k) - // (In some q4_0 models which use two different weight for token_emb and output, token_emb is q4_0) - std::string device = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : ""; - if (device == "NPU" && std::string(tensor->name) == "token_embd.weight") { - requant_type = ExtraQuantType::F16; - } - - // Determine block size - int64_t block_size = node_shape[1]; - if (requant_type == ExtraQuantType::Q4_0_128) { - block_size = 128; - } else if (requant_type == ExtraQuantType::Q8_0_32) { - block_size = 32; - } - - // Allocate tensors - ov::Tensor weights, scales, biases; - if (requant_type == ExtraQuantType::F16) { - weights = ov::Tensor(ov::element::f16, node_shape); - } else { - bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); - ov::element::Type weight_type = is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape scales_shape = {node_shape[0], node_shape[1] / block_size}; - weights = ov::Tensor(weight_type, node_shape); - scales = ov::Tensor(ov::element::f16, scales_shape); - biases = ov::Tensor(ov::element::f16, scales_shape); - } - - return requantize_to_buffers(tensor, tensor->data, requant_type, block_size, weights, scales, biases); -} - std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr) { GGML_ASSERT(tensor != nullptr); GGML_ASSERT(data != nullptr); diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index b1d286f1b83..a1334e2408d 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -52,10 +52,6 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& biases, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); -// ExtraQuantType is defined in ggml-openvino-extra.h - -std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type); - // Extract quantized weights from tensor and create weight subgraph // If weights/scales/biases are provided (non-empty), uses them as output buffers // Otherwise allocates new ov::Tensors internally From c26433098174560654f3246b7c91ea0845ede75e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 29 Dec 2025 15:25:59 +0800 Subject: [PATCH 11/14] Optimize symmetric quant weight extraction: use single zp --- .../src/ggml-openvino/ggml-openvino-extra.cpp | 32 ++++- ggml/src/ggml-openvino/ggml-openvino-extra.h | 1 + ggml/src/ggml-openvino/ggml-quants.cpp | 133 ++++++++++++++---- 3 files changed, 140 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 26cc386dff0..2f24d7a1dbb 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -192,6 +192,7 @@ std::optional ggml_openvino_get_requant_type(const ggml_tensor * ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor) { ggml_openvino_extracted_layout layout = {}; + layout.is_symmetric = false; if (!ggml_is_quantized(tensor->type)) { return layout; @@ -225,10 +226,26 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten case ExtraQuantType::Q4_0_128: layout.is_u4 = true; layout.weights_per_block = 128; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q4_0_C: + layout.is_u4 = true; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; break; case ExtraQuantType::Q8_0_32: layout.is_u4 = false; layout.weights_per_block = 32; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_0_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_1_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; break; default: layout.weights_per_block = -1; @@ -241,7 +258,8 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - layout.biases_size = n_blocks * sizeof(uint16_t); + // For symmetric quantization, we only need one bias value (not one per block) + layout.biases_size = layout.is_symmetric ? sizeof(uint16_t) : n_blocks * sizeof(uint16_t); layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -256,7 +274,14 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Normal extraction (no requant) - determine format based on tensor type switch (tensor->type) { case GGML_TYPE_Q4_0: + layout.is_u4 = true; + layout.weights_per_block = 32; + layout.is_symmetric = true; + break; case GGML_TYPE_Q4_1: + layout.is_u4 = true; + layout.weights_per_block = 32; + break; case GGML_TYPE_Q4_K: layout.is_u4 = true; layout.weights_per_block = 32; @@ -264,10 +289,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten case GGML_TYPE_Q8_0: layout.is_u4 = false; layout.weights_per_block = 32; + layout.is_symmetric = true; break; case GGML_TYPE_Q6_K: layout.is_u4 = false; layout.weights_per_block = 16; + layout.is_symmetric = true; break; case GGML_TYPE_Q5_K: layout.is_u4 = false; @@ -285,7 +312,8 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales and biases: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - layout.biases_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + // For symmetric quantization, we only need one bias value (not one per block) + layout.biases_size = layout.is_symmetric ? sizeof(uint16_t) : n_blocks * sizeof(uint16_t); // Layout in buffer: [weights | scales | biases] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index fbfe459edf5..e2c5a8ceeae 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -146,6 +146,7 @@ struct ggml_openvino_extracted_layout { size_t biases_size; // Size of biases in bytes bool is_u4; // true for U4 weights, false for U8 int64_t weights_per_block;// weights per scale/bias block + bool is_symmetric; // true for symmetric quantization // Requantization info bool is_requant; // true if this tensor needs requantization diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 1a5679cd8dd..8946b73a561 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -55,9 +55,18 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * scales = scales_arr.data::value_type>(); auto * biases = biases_arr.data::value_type>(); + bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - biases[i] = ov::float16(-8.f * static_cast(scales[i])); + // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) + if (is_scalar_bias) { + if (i == 0) { + biases[0] = ov::float16(-8.f * static_cast(scales[0])); + } + } else { + biases[i] = ov::float16(-8.f * static_cast(scales[i])); + } unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); }); } @@ -95,10 +104,19 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * scales = scales_arr.data::value_type>(); auto * biases = biases_arr.data::value_type>(); + bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - biases[i] = ov::float16(-128.f * static_cast(scales[i])); + // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) + if (is_scalar_bias) { + if (i == 0) { + biases[0] = ov::float16(-128.f * static_cast(scales[0])); + } + } else { + biases[i] = ov::float16(-128.f * static_cast(scales[i])); + } for (size_t j = 0; j < weights_per_block; ++j) { uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. // Original data is in int8_t, so we add a bias of -128 and invert the first bit. @@ -190,6 +208,8 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * scales = scales_arr.data::value_type>(); auto * biases = biases_arr.data::value_type>(); + bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + ov::parallel_for(n_super_block, [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; @@ -199,7 +219,14 @@ void extract_q6_k_data(const ggml_tensor * tensor, for (size_t j = 0; j < 16; j++) { scales[j + i * 16] = ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); + // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) + if (is_scalar_bias) { + if (i == 0 && j == 0) { + biases[0] = ov::float16(-32.f * static_cast(scales[0])); + } + } else { + biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); + } } uint8_t * ql = block_data; @@ -302,15 +329,22 @@ ov::Output make_int8_weights(ov::Tensor & weight, // Expand dimensions for scales and biases auto scale_shape = scales.get_shape(); + auto bias_shape = biases.get_shape(); + bool is_scalar_bias = bias_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; if (packed_shape[1] == 1) { + // Requantized channel-wise case packed_shape.erase(packed_shape.begin() + 1); } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - biases.set_shape(scale_shape); + // For symmetric quantization, biases remain scalar (don't resize) + if (!is_scalar_bias) { + bias_shape = scale_shape; + biases.set_shape(bias_shape); + } } // Create graph nodes @@ -318,15 +352,23 @@ ov::Output make_int8_weights(ov::Tensor & weight, static_cast(weight.data()), nullptr); weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - ov::Tensor biases_u8(ov::element::u8, scale_shape); + ov::Tensor biases_u8(ov::element::u8, is_scalar_bias ? ov::Shape{} : scale_shape); // Calculate zero point const ov::float16 * bias_data = biases.data::value_type>(); const ov::float16 * scale_data = scales.data::value_type>(); uint8_t * bias_u8_data = biases_u8.data(); - for (size_t i = 0; i < biases_u8.get_size(); ++i) { - bias_u8_data[i] = - (uint8_t) std::round(-1.f * static_cast(bias_data[i]) / static_cast(scale_data[i])); + + if (is_scalar_bias) { + // Symmetric quantization: single bias value for all blocks + // For Q8_0, bias = -128 * scale, so zero_point = 128 + bias_u8_data[0] = (uint8_t) std::round(-1.f * static_cast(bias_data[0]) / static_cast(scale_data[0])); + } else { + // Asymmetric quantization: per-block biases + for (size_t i = 0; i < biases_u8.get_size(); ++i) { + bias_u8_data[i] = + (uint8_t) std::round(-1.f * static_cast(bias_data[i]) / static_cast(scale_data[i])); + } } auto zero_point = std::make_shared(biases_u8); @@ -361,17 +403,23 @@ ov::Output make_int4_weights(ov::Tensor & weight, // Expand dimensions for scales and biases ov::Shape scale_bias_shape = scales.get_shape(); + auto bias_shape = biases.get_shape(); + bool is_scalar_bias = bias_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; - // Requantized channel-wise case if (packed_shape[1] == 1) { + // Requantized channel-wise case packed_shape.erase(packed_shape.begin() + 1); } else { scale_bias_shape.push_back(1); scales.set_shape(scale_bias_shape); - biases.set_shape(scale_bias_shape); + // For symmetric quantization, biases remain scalar (don't resize) + if (!is_scalar_bias) { + bias_shape = scale_bias_shape; + biases.set_shape(bias_shape); + } } auto weights_node = std::make_shared(ov::element::u4, packed_shape, @@ -382,14 +430,23 @@ ov::Output make_int4_weights(ov::Tensor & weight, // Pack zero points: two subsequent values into one const ov::float16 * bias_data = biases.data::value_type>(); const ov::float16 * scale_data = scales.data::value_type>(); - ov::Tensor zero_point_tensor(ov::element::u4, scale_bias_shape); + ov::Tensor zero_point_tensor(ov::element::u4, is_scalar_bias ? ov::Shape{} : scale_bias_shape); uint8_t * zero_point_data = static_cast(zero_point_tensor.data()); - for (size_t i = 0; i < zero_point_tensor.get_byte_size(); ++i) { - uint8_t bias1 = - (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2]) / static_cast(scale_data[i * 2])); - uint8_t bias2 = (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2 + 1]) / - static_cast(scale_data[i * 2 + 1])); - zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); + + if (is_scalar_bias) { + // Symmetric quantization: single bias value for all blocks + // For Q4_0, bias = -8 * scale, so zero_point = 8 + uint8_t zp = (uint8_t) std::round(-1.f * static_cast(bias_data[0]) / static_cast(scale_data[0])); + zero_point_data[0] = (zp << 4) | (zp & 0x0F); + } else { + // Asymmetric quantization: per-block biases + for (size_t i = 0; i < zero_point_tensor.get_byte_size(); ++i) { + uint8_t bias1 = + (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2]) / static_cast(scale_data[i * 2])); + uint8_t bias2 = (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2 + 1]) / + static_cast(scale_data[i * 2 + 1])); + zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); + } } auto zero_points_node = std::make_shared(zero_point_tensor); @@ -602,17 +659,19 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + // For symmetric quantization, biases are a single value instead of per-block + ov::Shape bias_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; ov::Tensor weights, scales, biases; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - biases = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + biases = ov::Tensor(ov::element::f16, bias_shape, buf_base + layout.biases_offset); } else { weights = ov::Tensor(weight_type, node_shape); scales = ov::Tensor(ov::element::f16, scale_shape); - biases = ov::Tensor(ov::element::f16, scale_shape); + biases = ov::Tensor(ov::element::f16, bias_shape); } result = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, weights, @@ -622,17 +681,19 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons // Normal extraction path (no requant) ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + // For symmetric quantization, biases are a single value instead of per-block + ov::Shape bias_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; ov::Tensor weights, scales, biases; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - biases = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + biases = ov::Tensor(ov::element::f16, bias_shape, buf_base + layout.biases_offset); } else { weights = ov::Tensor(weight_type, node_shape); scales = ov::Tensor(ov::element::f16, scale_shape); - biases = ov::Tensor(ov::element::f16, scale_shape); + biases = ov::Tensor(ov::element::f16, bias_shape); } result = extract_quantized_weights(tensor, data, weights, scales, biases); @@ -653,6 +714,8 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); auto * biases = biases_arr.data::value_type>(); + bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max float max = 0.0f; @@ -669,7 +732,13 @@ void quantize_q4_0(const float * x, if (d == 0) { scales[i] = ov::float16(1.0f); - biases[i] = ov::float16(-8.0f); + if (is_scalar_bias) { + if (i == 0) { + biases[0] = ov::float16(-8.0f); + } + } else { + biases[i] = ov::float16(-8.0f); + } uint8_t zp = 8; memset(weights + i * qk / 2, zp | (zp << 4), qk / 2); continue; @@ -677,7 +746,14 @@ void quantize_q4_0(const float * x, const float id = 1.0f / d; scales[i] = ov::float16(d); - biases[i] = ov::float16(-8.f * d); + // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) + if (is_scalar_bias) { + if (i == 0) { + biases[0] = ov::float16(-8.f * d); + } + } else { + biases[i] = ov::float16(-8.f * d); + } for (int j = 0; j < qk / 2; ++j) { const float x0 = x[i * qk + 2 * j] * id; @@ -701,6 +777,8 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); auto * biases = biases_arr.data::value_type>(); + bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -714,7 +792,14 @@ void quantize_q8_0(const float * x, const float d = amax / 127.0f; const float id = d ? 1.0f / d : 0.0f; scales[i] = ov::float16(d); - biases[i] = ov::float16(-128.0f * d); + // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) + if (is_scalar_bias) { + if (i == 0) { + biases[0] = ov::float16(-128.0f * d); + } + } else { + biases[i] = ov::float16(-128.0f * d); + } for (int j = 0; j < qk; ++j) { const float x0 = x[i * qk + j] * id; From ad86ae56d2aa23b9e1eed41002cdc8068a0c523f Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 29 Dec 2025 15:27:50 +0800 Subject: [PATCH 12/14] Use Q8_0_C in token embd, lm_head, and for 5 and 6 bits quant --- ggml/src/ggml-openvino/ggml-openvino-extra.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 2f24d7a1dbb..35d3d93cfd1 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -164,23 +164,19 @@ clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() { // Get requantization type for a tensor type (returns nullopt if no requant needed) std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor) { - if (!ggml_openvino_is_npu()) { - return std::nullopt; - } - // NPU requantization rules if (strncmp(tensor->name, "token_embd.weight", 17) == 0) { - return ExtraQuantType::F16; + return ExtraQuantType::Q8_0_C; } if (strncmp(tensor->name, "output.weight", 13) == 0) { + return ExtraQuantType::Q8_0_C; + } + if (ggml_openvino_is_npu()) { return ExtraQuantType::Q4_0_128; } switch (tensor->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q5_K: - return ExtraQuantType::Q4_0_128; + return ExtraQuantType::Q8_0_C; default: return std::nullopt; } From e84622855b3064c7c618104453ef45d80b7b5865 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 30 Dec 2025 10:51:40 +0800 Subject: [PATCH 13/14] Update build.md --- docs/build.md | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/docs/build.md b/docs/build.md index 54d6e9517b7..9fa6fccad19 100644 --- a/docs/build.md +++ b/docs/build.md @@ -664,22 +664,16 @@ git switch dev_backend_openvino - **Linux:** ```bash - # Build with OpenVINO support source /opt/intel/openvino/setupvars.sh cmake -B build/ReleaseOV -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF - cmake --build build/ReleaseOV --config Release -j $(nproc) + cmake --build build/ReleaseOV --parallel ``` - **Windows:** ```bash - # Build with OpenVINO support "C:\Program Files (x86)\Intel\openvino_2025.3.0\setupvars.bat" - cmake -B build\ReleaseOV -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF -DLLAMA_CURL=OFF -DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake - cmake --build build\ReleaseOV --config Release - ``` - - For faster compilation, add the -- /m argument to run multiple jobs in parallel with as many CPU cores available. - ```bash - cmake --build build\ReleaseOV --config Release -- /m + cmake -B build\ReleaseOV -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF -DLLAMA_CURL=OFF -DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake + cmake --build build\ReleaseOV --parallel ``` ### 3. Download Sample Model @@ -687,16 +681,9 @@ git switch dev_backend_openvino Download models for testing: ```bash -# Create models directory mkdir -p ~/models/ - -# Download model file: Llama-3.2-1B-Instruct.fp16.gguf -wget https://huggingface.co/MaziyarPanahi/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct.fp16.gguf \ - -O ~/models/Llama-3.2-1B-Instruct.fp16.gguf - -# Download model file: Phi-3-mini-4k-instruct-fp16.gguf -wget https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-fp16.gguf \ - -O ~/models/Phi-3-mini-4k-instruct-fp16.gguf +wget https://huggingface.co/unsloth/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf \ + -O ~/models/Llama-3.2-1B-Instruct-Q4_0.gguf ``` ### 4. Run inference with OpenVINO backend: @@ -704,20 +691,14 @@ wget https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/P When using the OpenVINO backend, the first inference token may have slightly higher latency due to on-the-fly conversion to the OpenVINO graph. Subsequent tokens and runs will be faster. ```bash -export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache -# Default device is GPU. -# If not set, automatically selects the first available device in priority order: GPU, CPU, NPU. +# If device is unset or unavailable, default to CPU. export GGML_OPENVINO_DEVICE=GPU - -./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " - +./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct-Q4_0.gguf -n 50 "The story of AI is " ``` To run in chat mode: ```bash -export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache -./build/ReleaseOV/bin/llama-cli -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " - +./build/ReleaseOV/bin/llama-cli -m ~/models/Llama-3.2-1B-Instruct-Q4_0.gguf ``` ### Configuration Options @@ -729,16 +710,11 @@ Control OpenVINO behavior using these environment variables: - **`GGML_OPENVINO_PROFILING`**: Enable execution time profiling. - **`GGML_OPENVINO_DUMP_CGRAPH`**: Save compute graph to `cgraph.txt`. - **`GGML_OPENVINO_DUMP_IR`**: Export OpenVINO IR files with timestamps. -- **`GGML_OPENVINO_DEBUG_INPUT`**: Enable input debugging. -- **`GGML_OPENVINO_DEBUG_OUTPUT`**: Enable output debugging. ### Example with Profiling ```bash -export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache -export GGML_OPENVINO_PROFILING=1 - -GGML_OPENVINO_DEVICE=GPU ./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " +GGML_OPENVINO_PROFILING=1 GGML_OPENVINO_DEVICE=GPU ./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct-Q4_0.gguf -n 50 "The story of AI is " ``` ### Docker build Llama.cpp with OpenVINO Backend From 3f29ffd14edd16f00917a7a5e99dfd20175ad935 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 7 Jan 2026 16:56:30 +0800 Subject: [PATCH 14/14] Support -ctk f32 --- ggml/src/ggml-openvino/ggml-decoder.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 13ef00dcb64..51fb433410c 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -296,6 +296,9 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr std::string name = std::string(node->name); if (node->op == GGML_OP_FLASH_ATTN_EXT) { auto * cache_k_perm = node->src[1]; + if (cache_k_perm->op == GGML_OP_CPY) { + cache_k_perm = cache_k_perm->src[0]; + } assert(cache_k_perm->op == GGML_OP_PERMUTE); auto * cache_k_view = cache_k_perm->src[0]; assert(cache_k_view->op == GGML_OP_VIEW);