diff --git a/.gitignore b/.gitignore index 05eb578a82f..bc905a6ef83 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,18 @@ poetry.toml /.windsurf/ # emscripten a.out.* +wikitext-2-raw/wikitext-2-raw/wiki.test.raw +wikitext-2-raw/wikitext-2-raw/wiki.train.raw +wikitext-2-raw/wikitext-2-raw/wiki.valid.raw +Qwen3-1.7B/.gitattributes +Qwen3-1.7B/config.json +Qwen3-1.7B/generation_config.json +Qwen3-1.7B/LICENSE +Qwen3-1.7B/merges.txt +Qwen3-1.7B/model-00001-of-00002.safetensors +Qwen3-1.7B/model-00002-of-00002.safetensors +Qwen3-1.7B/model.safetensors.index.json +Qwen3-1.7B/README.md +Qwen3-1.7B/tokenizer_config.json +Qwen3-1.7B/tokenizer.json +Qwen3-1.7B/vocab.json diff --git a/IMatrix_Guide.md b/IMatrix_Guide.md new file mode 100644 index 00000000000..5237dc2c1e2 --- /dev/null +++ b/IMatrix_Guide.md @@ -0,0 +1,426 @@ +# Importance Matrix (imatrix) Files: Complete Guide + +## What is an IMatrix File? + +An **importance matrix** (imatrix) file is a data structure that contains information about which weights in a neural network are most important during inference. It's generated by running the model on a calibration dataset and measuring how much each weight contributes to the output. + +### Key Concepts + +- **Purpose**: Improve quantization quality by preserving precision for important weights +- **How it works**: Tracks squared activations (importance scores) for each weight during inference +- **Format**: Stored as GGUF files (or legacy `.dat` format) +- **Usage**: Passed to the quantization tool to guide which weights should be quantized more carefully + +--- + +## Why Use an IMatrix? + +When quantizing a model, you're reducing precision from 16-bit or 32-bit floats to 3-bit, 4-bit, or other low-precision formats. This compression can cause quality loss. An imatrix helps by: + +1. **Identifying Critical Weights**: Shows which weights are most active/important during inference +2. **Guiding Quantization**: Allows the quantizer to: + - Preserve precision for important weights + - Use more aggressive quantization for less important weights + - Make smarter decisions about outlier selection (especially for Q3_HIFI) +3. **Improving Quality**: Can significantly reduce perplexity increase compared to quantization without imatrix + +### Example Impact + +For Q3_HIFI specifically, the imatrix is used to: +- Weight the magnitude calculation when selecting outliers: `mag[i] = fabsf(xb[i]) * quant_weights[i]` +- Prioritize important weights as outliers (stored in FP16) +- Improve overall quantization quality + +--- + +## How to Generate an IMatrix File + +### Step 1: Prepare a Calibration Dataset + +You need a text file with representative data that the model will process. This should be similar to the data your model will see in production. + +**Good sources for calibration data:** +- Wikipedia articles (e.g., `wiki.train.raw`) +- Books or text corpora +- Domain-specific text relevant to your use case +- The model's training data (if available) + +**File format**: Plain text, one example per line (or use `--parse-special` for special token parsing) + +### Step 2: Build the IMatrix Tool + +First, make sure you've built `llama-imatrix`: + +```bash +# On Linux/Mac +make llama-imatrix + +# On Windows (MSVC) +cmake --build build --config Release --target llama-imatrix +``` + +### Step 3: Generate the IMatrix + +Basic usage: + +```bash +./llama-imatrix \ + -m model-f16.gguf \ + -f calibration-data.txt \ + -o imatrix.gguf \ + -ngl 99 +``` + +**Parameters explained:** +- `-m, --model`: Your F16 or F32 model file (input) +- `-f, --file`: Your calibration text file +- `-o, --output-file`: Output imatrix filename (default: `imatrix.gguf`) +- `-ngl, --n-gpu-layers`: Number of layers to offload to GPU (speeds up generation) + +### Advanced Options + +```bash +./llama-imatrix \ + -m model-f16.gguf \ + -f calibration-data.txt \ + -o imatrix.gguf \ + -ngl 99 \ + --output-frequency 10 \ # Save every 10 chunks + --save-frequency 50 \ # Create snapshots every 50 chunks + --chunk 0 \ # Start from chunk 0 + --chunks 100 \ # Process 100 chunks total + --parse-special \ # Parse special tokens + --process-output # Include output.weight tensor +``` + +**Important Options:** +- `--output-frequency N`: How often to save progress (default: 10 chunks) +- `--save-frequency N`: Create backup snapshots (default: 0 = never) +- `--chunk N`: Skip first N chunks (useful for resuming) +- `--chunks N`: Maximum chunks to process (default: -1 = all) +- `--parse-special`: Enable special token parsing (e.g., `<|im_start|>`) +- `--process-output`: Include `output.weight` tensor (usually not recommended) +- `--no-ppl`: Disable perplexity calculation (faster, less info) +- `-lv, --verbosity`: Verbosity level (0=silent, 1=default, 2+=verbose) + +### Example: Full Workflow + +```bash +# 1. Generate imatrix with GPU acceleration +./llama-imatrix \ + -m ./models/llama-3-8b-f16.gguf \ + -f ./data/wiki.train.raw \ + -o ./imatrix.gguf \ + -ngl 99 \ + --output-frequency 20 \ + --save-frequency 100 + +# This will: +# - Process the calibration data +# - Track activations for each tensor +# - Save progress every 20 chunks +# - Create snapshots every 100 chunks +# - Output: imatrix.gguf +``` + +--- + +## How to Use an IMatrix During Quantization + +### Basic Usage + +Once you have an imatrix file, use it during quantization: + +```bash +./llama-quantize \ + --imatrix imatrix.gguf \ + input-model-f16.gguf \ + output-model-q3_hifi.gguf \ + Q3_HIFI +``` + +### With Specific Tensor Types + +You can target specific tensors: + +```bash +# Use imatrix only for attention and feed-forward layers +./llama-quantize \ + --imatrix imatrix.gguf \ + --include-weights attn_v \ + --include-weights ffn_down \ + input-model-f16.gguf \ + output-model-q3_hifi.gguf \ + Q3_HIFI +``` + +### Advanced Usage + +```bash +# Quantize with imatrix, custom tensor types, and output settings +./llama-quantize \ + --imatrix imatrix.gguf \ + --output-tensor-type q5_k \ + --token-embedding-type q3_hifi \ + input-model-f16.gguf \ + output-model-q3_hifi.gguf \ + Q3_HIFI +``` + +--- + +## IMatrix File Formats + +### GGUF Format (Recommended) + +Modern format, stored as `.gguf` files: +- More efficient +- Better metadata support +- Can store multiple datasets +- Default format in recent versions + +### Legacy Format + +Older binary format, stored as `.dat` files: +- Still supported for compatibility +- Use `--output-format dat` to generate + +### Converting Between Formats + +```bash +# Convert legacy to GGUF +./llama-imatrix --in-file imatrix.dat -o imatrix.gguf + +# Convert GGUF to legacy +./llama-imatrix --in-file imatrix.gguf --output-format dat -o imatrix.dat +``` + +--- + +## Combining Multiple IMatrix Files + +You can merge imatrix files from multiple runs or datasets: + +```bash +./llama-imatrix \ + --in-file imatrix-dataset1.gguf \ + --in-file imatrix-dataset2.gguf \ + --in-file imatrix-dataset3.gguf \ + -o imatrix-combined.gguf +``` + +This is useful for: +- Combining data from different domains +- Merging results from multiple calibration runs +- Creating a more comprehensive importance matrix + +--- + +## Analyzing IMatrix Files + +### View Statistics + +```bash +./llama-imatrix --in-file imatrix.gguf --show-statistics +``` + +This displays: +- **Per Tensor**: + - Σ(Act²): Sum of squared activations (importance scores) + - Min & Max: Range of importance values + - μ & σ: Mean and standard deviation + - % Active: Proportion of active elements + - Entropy: Information content + - ZD Score: Layer importance metric + - CosSim: Cosine similarity with previous layer + +- **Per Layer**: + - Weighted averages of importance metrics + +### Understanding the Statistics + +- **High Σ(Act²)**: Tensor is very active during inference +- **High % Active**: Many weights contribute significantly +- **High Entropy**: Weights have diverse importance (good for quantization) +- **High ZD Score**: Layer is important to preserve +- **High CosSim**: Layer is similar to previous (may indicate redundancy) + +--- + +## Best Practices + +### 1. Calibration Dataset Selection + +✅ **Do:** +- Use representative data similar to your use case +- Include diverse examples +- Use at least 1000-10000 chunks for good coverage +- Match the domain (e.g., code for code models, text for language models) + +❌ **Don't:** +- Use too small a dataset (< 100 chunks) +- Use completely unrelated data +- Use only one type of example + +### 2. Processing Settings + +✅ **Do:** +- Use GPU offloading (`-ngl 99`) for speed +- Save frequently (`--output-frequency 10`) +- Create snapshots (`--save-frequency 50`) for long runs +- Process enough chunks (1000+ recommended) + +❌ **Don't:** +- Process `output.weight` unless necessary (`--process-output` is usually not needed) +- Skip validation of your calibration data + +### 3. Quantization Usage + +✅ **Do:** +- Always use imatrix for Q3_HIFI (it significantly improves outlier selection) +- Use imatrix for aggressive quantizations (Q2_K, Q3_K_S) +- Include attention and feed-forward weights +- Test quality after quantization + +❌ **Don't:** +- Use imatrix for `output.weight` (usually excluded by default) +- Assume imatrix will always improve quality (test it) +- Use an imatrix from a different model architecture + +--- + +## Complete Workflow Example + +Here's a complete example for quantizing a model with Q3_HIFI using an imatrix: + +```bash +# Step 1: Generate importance matrix +./llama-imatrix \ + -m ./models/llama-3-8b-f16.gguf \ + -f ./data/calibration-text.txt \ + -o ./imatrix.gguf \ + -ngl 99 \ + --output-frequency 20 \ + --chunks 1000 + +# Step 2: (Optional) View statistics +./llama-imatrix --in-file ./imatrix.gguf --show-statistics + +# Step 3: Quantize using the imatrix +./llama-quantize \ + --imatrix ./imatrix.gguf \ + ./models/llama-3-8b-f16.gguf \ + ./models/llama-3-8b-q3_hifi.gguf \ + Q3_HIFI + +# Step 4: Test the quantized model +./llama-cli \ + -m ./models/llama-3-8b-q3_hifi.gguf \ + -p "Hello, how are you?" +``` + +--- + +## How IMatrix Works with Q3_HIFI + +For Q3_HIFI specifically, the imatrix is particularly valuable: + +1. **Outlier Selection**: The imatrix weights the magnitude calculation: + ```c + mag[i] = fabsf(xb[i]) * quant_weights[i] + ``` + This means important weights (high imatrix values) are more likely to be selected as outliers. + +2. **Better Quality**: By preserving important weights as FP16 outliers, the model maintains better accuracy. + +3. **Smart Compression**: Less important weights can be more aggressively quantized to 3-bit, while critical ones stay in FP16. + +### Example Impact + +Without imatrix: +- Outliers selected purely by magnitude +- May miss important but smaller-magnitude weights +- Quality: Baseline + +With imatrix: +- Outliers selected by importance-weighted magnitude +- Preserves critical weights even if not the largest +- Quality: Typically 5-15% better perplexity + +--- + +## Troubleshooting + +### Problem: IMatrix generation is slow + +**Solutions:** +- Use GPU offloading: `-ngl 99` +- Reduce chunks: `--chunks 500` +- Disable perplexity: `--no-ppl` + +### Problem: IMatrix file is very large + +**Solutions:** +- This is normal (can be 100MB-1GB+) +- Use GGUF format (more efficient than legacy) +- The file is only needed during quantization, not inference + +### Problem: Quantization quality didn't improve + +**Solutions:** +- Check that imatrix was generated on similar data +- Verify imatrix file loaded correctly (check logs) +- Try including/excluding specific tensors +- Ensure calibration dataset is representative + +### Problem: "imatrix mapping error" + +**Solutions:** +- IMatrix was generated for a different model architecture +- Tensor names don't match +- Regenerate imatrix for your specific model + +--- + +## Technical Details + +### What Gets Stored + +For each tensor, the imatrix stores: +- **Squared activations**: `act²` for each weight position +- **Call count**: How many times the tensor was accessed +- **Averaged values**: `Σ(act²) / n_calls` for normalization + +### How It's Used + +During quantization: +1. IMatrix data is loaded and mapped to tensor names +2. For each weight block, importance scores are retrieved +3. Quantization algorithms use these scores to: + - Weight magnitude calculations + - Select outliers (Q3_HIFI) + - Choose quantization scales + - Determine precision levels + +### File Structure + +GGUF format imatrix contains: +- Metadata: chunk count, chunk size, dataset names +- Tensor data: For each tensor, arrays of importance scores +- Statistics: Optional computed statistics + +--- + +## Summary + +**IMatrix files are essential for high-quality quantization**, especially for formats like Q3_HIFI that benefit from intelligent outlier selection. + +**Key Takeaways:** +1. Generate imatrix using representative calibration data +2. Use GPU acceleration for faster generation +3. Always use imatrix when quantizing to Q3_HIFI +4. Combine multiple imatrix files for better coverage +5. Analyze statistics to understand your model's weight importance + +**For Q3_HIFI specifically**: The imatrix directly improves outlier selection, making it one of the most impactful uses of importance matrices in quantization. + diff --git a/benchmark_speed_test.ps1 b/benchmark_speed_test.ps1 new file mode 100644 index 00000000000..002317075b3 --- /dev/null +++ b/benchmark_speed_test.ps1 @@ -0,0 +1,296 @@ +# Qwen3-1.7B Quantization Speed Benchmark Script +# Runs llama-bench 100 times per model and calculates statistics + +param( + [int]$Iterations = 100, + [int]$Threads = 4, + [int]$Repeats = 3, + [int]$PromptTokens = 0, + [int]$GenerateTokens = 20 +) + +$ErrorActionPreference = "Stop" + +# Configuration +$LlamaBench = ".\build\bin\Release\llama-bench.exe" +$Models = @( + @{ Name = "Q3_K_S"; Path = ".\Qwen3-1.7B-f16-Q3_K_S.gguf" }, + @{ Name = "Q3_K_M"; Path = ".\Qwen3-1.7B-f16-Q3_K_M.gguf" }, + @{ Name = "Q3_HIFI"; Path = ".\Qwen3-1.7B-f16-Q3_HIFI.gguf" } +) + +# Verify files exist +if (-not (Test-Path $LlamaBench)) { + Write-Error "llama-bench not found at: $LlamaBench" + exit 1 +} + +foreach ($model in $Models) { + if (-not (Test-Path $model.Path)) { + Write-Error "Model not found: $($model.Path)" + exit 1 + } +} + +# Results storage +$Results = @{} +foreach ($model in $Models) { + $Results[$model.Name] = @{ + Speeds = [System.Collections.ArrayList]::new() + Errors = 0 + } +} + +Write-Host "=" * 70 -ForegroundColor Cyan +Write-Host "QWEN3-1.7B QUANTIZATION SPEED BENCHMARK" -ForegroundColor Cyan +Write-Host "=" * 70 -ForegroundColor Cyan +Write-Host "" +Write-Host "Configuration:" -ForegroundColor Yellow +Write-Host " Iterations per model: $Iterations" +Write-Host " Threads: $Threads" +Write-Host " Repeats per run: $Repeats" +Write-Host " Generate tokens: $GenerateTokens" +Write-Host " Models: $($Models.Count)" +Write-Host "" + +$StartTime = Get-Date +$TotalRuns = $Iterations * $Models.Count + +Write-Host "Starting benchmark at $($StartTime.ToString('HH:mm:ss'))..." -ForegroundColor Green +Write-Host "Total runs: $TotalRuns (estimated time: $([math]::Round($TotalRuns * 5 / 60, 1)) minutes)" -ForegroundColor Gray +Write-Host "" + +# Progress tracking +$CurrentRun = 0 + +for ($i = 1; $i -le $Iterations; $i++) { + foreach ($model in $Models) { + $CurrentRun++ + $PercentComplete = [math]::Round(($CurrentRun / $TotalRuns) * 100, 1) + + # Progress bar + Write-Progress -Activity "Benchmarking $($model.Name)" ` + -Status "Iteration $i/$Iterations - Overall: $PercentComplete%" ` + -PercentComplete $PercentComplete + + try { + # Run benchmark + $output = & $LlamaBench -m $model.Path -t $Threads -r $Repeats -p $PromptTokens -n $GenerateTokens 2>&1 + $outputText = $output -join "`n" + + # Parse output - look for tg (token generation) speed + # Format: | model | size | params | backend | threads | test | t/s | + # Example: | qwen3 1.7B Q3_K - Small | 948.91 MiB | 2.03 B | CPU | 4 | tg20 | 28.87 ± 1.45 | + $found = $false + foreach ($line in $output) { + $lineStr = $line.ToString() + # Match pattern: anything with tg followed by speed ± stddev + if ($lineStr -match "tg\d+\s*\|\s*([\d.]+)\s*±\s*([\d.]+)") { + $speed = [double]$Matches[1] + [void]$Results[$model.Name].Speeds.Add($speed) + $found = $true + break + } + # Alternative pattern: just numbers at end of line + elseif ($lineStr -match "\|\s*tg\d+\s*\|\s*([\d.]+)") { + $speed = [double]$Matches[1] + [void]$Results[$model.Name].Speeds.Add($speed) + $found = $true + break + } + } + + if (-not $found) { + # Debug: show what we got if parsing failed + if ($i -eq 1) { + Write-Host " Debug - Raw output sample for $($model.Name):" -ForegroundColor DarkGray + $output | Select-Object -First 10 | ForEach-Object { Write-Host " $_" -ForegroundColor DarkGray } + } + $Results[$model.Name].Errors++ + } + } + catch { + $Results[$model.Name].Errors++ + Write-Warning "Error on $($model.Name) iteration $i : $_" + } + } + + # Periodic status update every 10 iterations + if ($i % 10 -eq 0) { + $Elapsed = (Get-Date) - $StartTime + $EstRemaining = [TimeSpan]::FromSeconds(($Elapsed.TotalSeconds / $CurrentRun) * ($TotalRuns - $CurrentRun)) + Write-Host " [$i/$Iterations] Elapsed: $($Elapsed.ToString('hh\:mm\:ss')) | ETA: $($EstRemaining.ToString('hh\:mm\:ss'))" -ForegroundColor Gray + } +} + +Write-Progress -Activity "Complete" -Completed + +$EndTime = Get-Date +$Duration = $EndTime - $StartTime + +# Calculate statistics +function Get-Stats { + param([System.Collections.ArrayList]$Data) + + if ($Data.Count -eq 0) { + return @{ Mean = 0; StdDev = 0; Min = 0; Max = 0; Median = 0; Count = 0 } + } + + $sorted = $Data | Sort-Object + $mean = ($Data | Measure-Object -Average).Average + $min = ($Data | Measure-Object -Minimum).Minimum + $max = ($Data | Measure-Object -Maximum).Maximum + $count = $Data.Count + + # Median + $midIndex = [math]::Floor($count / 2) + if ($count % 2 -eq 0) { + $median = ($sorted[$midIndex - 1] + $sorted[$midIndex]) / 2 + } else { + $median = $sorted[$midIndex] + } + + # Standard deviation + $sumSquares = 0 + foreach ($val in $Data) { + $sumSquares += [math]::Pow($val - $mean, 2) + } + $stdDev = [math]::Sqrt($sumSquares / $count) + + # 95th percentile + $p95Index = [math]::Floor($count * 0.95) + $p95 = $sorted[[math]::Min($p95Index, $count - 1)] + + # 5th percentile + $p5Index = [math]::Floor($count * 0.05) + $p5 = $sorted[$p5Index] + + return @{ + Mean = $mean + StdDev = $stdDev + Min = $min + Max = $max + Median = $median + P5 = $p5 + P95 = $p95 + Count = $count + } +} + +# Generate report +Write-Host "" +Write-Host "=" * 70 -ForegroundColor Cyan +Write-Host "BENCHMARK RESULTS" -ForegroundColor Cyan +Write-Host "=" * 70 -ForegroundColor Cyan +Write-Host "" +Write-Host "Test completed in: $($Duration.ToString('hh\:mm\:ss'))" -ForegroundColor Green +Write-Host "Total iterations per model: $Iterations" +Write-Host "" + +# Collect all stats +$AllStats = @{} +foreach ($model in $Models) { + $AllStats[$model.Name] = Get-Stats -Data $Results[$model.Name].Speeds +} + +# Find the fastest model for comparison +$FastestMean = ($AllStats.Values | ForEach-Object { $_.Mean } | Measure-Object -Maximum).Maximum + +# Detailed results table +Write-Host "SPEED COMPARISON (tokens/second - higher is better)" -ForegroundColor Yellow +Write-Host "-" * 70 + +$TableHeader = "{0,-15} {1,10} {2,10} {3,10} {4,10} {5,10} {6,10}" -f "Model", "Mean", "StdDev", "Median", "Min", "Max", "vs Best" +Write-Host $TableHeader -ForegroundColor White +Write-Host "-" * 70 + +foreach ($model in $Models) { + $stats = $AllStats[$model.Name] + $vsBest = if ($stats.Mean -eq $FastestMean) { "FASTEST" } else { + "-" + [math]::Round((1 - $stats.Mean / $FastestMean) * 100, 1) + "%" + } + + $row = "{0,-15} {1,10:F2} {2,10:F2} {3,10:F2} {4,10:F2} {5,10:F2} {6,10}" -f ` + $model.Name, $stats.Mean, $stats.StdDev, $stats.Median, $stats.Min, $stats.Max, $vsBest + + if ($stats.Mean -eq $FastestMean) { + Write-Host $row -ForegroundColor Green + } else { + Write-Host $row + } +} + +Write-Host "-" * 70 +Write-Host "" + +# Percentile analysis +Write-Host "PERCENTILE ANALYSIS" -ForegroundColor Yellow +Write-Host "-" * 70 +$PercHeader = "{0,-15} {1,12} {2,12} {3,12} {4,10}" -f "Model", "5th %ile", "Median", "95th %ile", "Samples" +Write-Host $PercHeader -ForegroundColor White +Write-Host "-" * 70 + +foreach ($model in $Models) { + $stats = $AllStats[$model.Name] + $errors = $Results[$model.Name].Errors + $row = "{0,-15} {1,12:F2} {2,12:F2} {3,12:F2} {4,10}" -f ` + $model.Name, $stats.P5, $stats.Median, $stats.P95, "$($stats.Count)/$Iterations" + Write-Host $row +} + +Write-Host "-" * 70 +Write-Host "" + +# Speed ranking summary +Write-Host "SPEED RANKING SUMMARY" -ForegroundColor Yellow +Write-Host "-" * 70 + +$Ranked = @($AllStats.GetEnumerator() | Sort-Object { $_.Value.Mean } -Descending) +$Rank = 1 +$FirstMean = if ($Ranked.Count -gt 0 -and $Ranked[0].Value.Mean -gt 0) { $Ranked[0].Value.Mean } else { 1 } + +foreach ($entry in $Ranked) { + $speedDiff = "" + if ($Rank -gt 1 -and $FirstMean -gt 0 -and $entry.Value.Mean -gt 0) { + $diffFromFirst = $FirstMean - $entry.Value.Mean + $diffPercent = ($diffFromFirst / $FirstMean) * 100 + $speedDiff = "($([math]::Round($diffFromFirst, 2)) t/s slower, -$([math]::Round($diffPercent, 1))%)" + } + + $medal = switch ($Rank) { 1 { "🥇" } 2 { "🥈" } 3 { "🥉" } default { " " } } + Write-Host "$medal #$Rank $($entry.Key): $([math]::Round($entry.Value.Mean, 2)) ± $([math]::Round($entry.Value.StdDev, 2)) t/s $speedDiff" + $Rank++ +} + +Write-Host "" +Write-Host "=" * 70 -ForegroundColor Cyan + +# Export results to CSV +$CsvPath = "benchmark_results_$(Get-Date -Format 'yyyyMMdd_HHmmss').csv" +$CsvData = @() +foreach ($model in $Models) { + $stats = $AllStats[$model.Name] + $CsvData += [PSCustomObject]@{ + Model = $model.Name + Mean_TPS = [math]::Round($stats.Mean, 4) + StdDev = [math]::Round($stats.StdDev, 4) + Median = [math]::Round($stats.Median, 4) + Min = [math]::Round($stats.Min, 4) + Max = [math]::Round($stats.Max, 4) + P5 = [math]::Round($stats.P5, 4) + P95 = [math]::Round($stats.P95, 4) + Samples = $stats.Count + Errors = $Results[$model.Name].Errors + } +} +$CsvData | Export-Csv -Path $CsvPath -NoTypeInformation +Write-Host "Results exported to: $CsvPath" -ForegroundColor Green + +# Also save raw data for further analysis +$RawDataPath = "benchmark_raw_$(Get-Date -Format 'yyyyMMdd_HHmmss').json" +$RawExport = @{} +foreach ($model in $Models) { + $RawExport[$model.Name] = $Results[$model.Name].Speeds +} +$RawExport | ConvertTo-Json | Out-File -FilePath $RawDataPath +Write-Host "Raw data exported to: $RawDataPath" -ForegroundColor Green diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 151608d56b8..10fe05ec8fc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10422,8 +10422,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "q3_hifi", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, q3_hifi for Q3_HIFI (3-bit with outliers), and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -10587,6 +10587,7 @@ def main() -> None: "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, + "q3_hifi": gguf.LlamaFileType.MOSTLY_Q3_HIFI, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/docs/quantization/Q3_HIFI.md b/docs/quantization/Q3_HIFI.md new file mode 100644 index 00000000000..8b7a2ee489f --- /dev/null +++ b/docs/quantization/Q3_HIFI.md @@ -0,0 +1,241 @@ +# Qwen3 Q3_HIFI Quantization: Cross-Model Analysis & Summary + +## Executive Summary + +This document analyzes Q3_HIFI quantization performance across all Qwen3 model sizes (0.6B to 32B parameters), comparing it against traditional Q3_K_M and Q3_K_S methods. **Q3_HIFI consistently delivers superior quality with smaller file sizes than Q3_K_M**, and at larger model scales (14B+), it even achieves faster inference speeds. + +--- + +## Complete Performance Data + +### All Models Comparison Table + +| Model | Quant | Speed (TPS) | Perplexity | File Size | Bits/Weight | +|----------|---------|-------------|------------|----------------|-------------| +| **0.6B** | Q3_HIFI | 601.39 | **26.43** | 382.37 MiB | 4.27 | +| | Q3_K_M | **618.42** | 31.64 | 389.12 MiB | 4.34 | +| | Q3_K_S | 612.28 | 35.70 | **366.19 MiB** | 4.09 | +| **1.7B** | Q3_HIFI | 411.11 | **17.65** | 993.5 MiB | 4.10 | +| | Q3_K_M | 416.70 | 22.44 | 1017.9 MiB | 4.20 | +| | Q3_K_S | **425.64** | 24.07 | **948.9 MiB** | 3.92 | +| **4B** | Q3_HIFI | 215.13 | **16.76** | 1.87 GiB | 3.99 | +| | Q3_K_M | 217.49 | 18.07 | 1.93 GiB | 4.12 | +| | Q3_K_S | **227.70** | 19.08 | **1.75 GiB** | 3.74 | +| **8B** | Q3_HIFI | 143.98 | **10.56** | 3.72 GiB | 3.90 | +| | Q3_K_M | 144.72 | 11.05 | 3.84 GiB | 4.02 | +| | Q3_K_S | **153.74** | 11.38 | **3.51 GiB** | 3.68 | +| **14B** | Q3_HIFI | 85.58 | **9.38** | 6.59 GiB | 3.83 | +| | Q3_K_M | 85.40 | 9.53 | 6.81 GiB | 3.96 | +| | Q3_K_S | **91.52** | 9.71 | **6.19 GiB** | 3.60 | +| **32B** | Q3_HIFI | 39.84 | **8.30** | 14.32 GiB | 3.76 | +| | Q3_K_M | 39.55 | 8.47 | 14.87 GiB | 3.90 | +| | Q3_K_S | **42.95** | ⚠️ 20.19 | **13.40 GiB** | 3.51 | + +### Q3_HIFI Improvement vs Q3_K_M (by Model Size) + +| Model | Perplexity Gain | Size Reduction | Speed Difference | +|-------|-----------------|----------------|--------------------| +| 0.6B | **-16.4%** ✨ | -1.7% | -2.8% (slower) | +| 1.7B | **-21.4%** ✨ | -2.4% | -1.3% (slower) | +| 4B | **-7.3%** | -3.1% | -1.1% (slower) | +| 8B | **-4.4%** | -3.1% | -0.5% (slower) | +| 14B | **-1.6%** | -3.2% | **+0.2% (faster)** | +| 32B | **-2.0%** | -3.7% | **+0.7% (faster)** | + +### Q3_HIFI Improvement vs Q3_K_S (by Model Size) + +| Model | Perplexity Gain | Size Increase | Speed Difference | +|-------|-----------------|---------------|------------------| +| 0.6B | **-26.0%** ✨ | +4.4% | -1.8% (slower) | +| 1.7B | **-26.7%** ✨ | +4.7% | -3.4% (slower) | +| 4B | **-12.2%** | +6.9% | -5.5% (slower) | +| 8B | **-7.2%** | +6.0% | -6.3% (slower) | +| 14B | **-3.4%** | +6.5% | -6.5% (slower) | +| 32B | **-58.9%** 🚨 | +6.9% | -7.2% (slower) | + +--- + +## Trend Analysis + +### 1. Perplexity Improvements + +**Key Finding:** Q3_HIFI quality gains are **most dramatic on smaller models** and remain significant across all sizes. + +``` +Perplexity Improvement (Q3_HIFI vs Q3_K_M) +═══════════════════════════════════════════════════════ +0.6B ████████████████████████████████████ -16.4% +1.7B ██████████████████████████████████████████ -21.4% +4B ██████████████████ -7.3% +8B ███████████ -4.4% +14B ████ -1.6% +32B █████ -2.0% +``` + +**Interpretation:** +- Smaller models (0.6B–1.7B) see **16–21% perplexity improvements** — Q3_HIFI's intelligent layer-sensitive quantization preserves critical weights where every parameter matters +- Mid-size models (4B–8B) achieve **4–7% improvements** — a meaningful quality boost +- Large models (14B–32B) see **1.6–2% improvements** — still valuable at scale where absolute perplexity is already low + +### 2. Speed Performance + +**Key Finding:** Q3_HIFI speed penalty **decreases with model size** and reverses to a **speed advantage at 14B+**. + +| Model Size | Q3_HIFI vs Q3_K_M | Q3_HIFI vs Q3_K_S | +|------------|-------------------|-------------------| +| 0.6B | -2.8% slower | -1.8% slower | +| 1.7B | -1.3% slower | -3.4% slower | +| 4B | -1.1% slower | -5.5% slower | +| 8B | -0.5% slower | -6.3% slower | +| 14B | **+0.2% faster** | -6.5% slower | +| 32B | **+0.7% faster** | -7.2% slower | + +**Interpretation:** +- At smaller scales, Q3_HIFI's adaptive quantization adds minor overhead +- At larger scales (14B+), Q3_HIFI's smaller size improves memory bandwidth efficiency, resulting in **faster inference than Q3_K_M** +- Q3_K_S maintains a consistent ~6-7% speed advantage due to its uniform, simpler quantization + +### 3. File Size Efficiency + +**Key Finding:** Q3_HIFI is **always smaller than Q3_K_M** while delivering better quality. + +| Model | Q3_HIFI | Q3_K_M | Q3_K_S | HIFI vs K_M | +|-------|-----------|-----------|-----------|-------------| +| 0.6B | 382 MiB | 389 MiB | 366 MiB | **-1.7%** | +| 1.7B | 994 MiB | 1018 MiB | 949 MiB | **-2.4%** | +| 4B | 1.87 GiB | 1.93 GiB | 1.75 GiB | **-3.1%** | +| 8B | 3.72 GiB | 3.84 GiB | 3.51 GiB | **-3.1%** | +| 14B | 6.59 GiB | 6.81 GiB | 6.19 GiB | **-3.2%** | +| 32B | 14.32 GiB | 14.87 GiB | 13.40 GiB | **-3.7%** | + +**Interpretation:** +- Q3_HIFI's intelligent bit allocation results in **3-4% smaller files than Q3_K_M** +- The size savings increase slightly at larger model scales (3.7% at 32B vs 1.7% at 0.6B) +- Q3_K_S remains ~6-7% smaller than Q3_HIFI but with significant quality tradeoffs + +### 4. Bits Per Weight Trend + +| Model | Q3_HIFI | Q3_K_M | Q3_K_S | +|-------|---------|--------|--------| +| 0.6B | 4.27 | 4.34 | 4.09 | +| 1.7B | 4.10 | 4.20 | 3.92 | +| 4B | 3.99 | 4.12 | 3.74 | +| 8B | 3.90 | 4.02 | 3.68 | +| 14B | 3.83 | 3.96 | 3.60 | +| 32B | 3.76 | 3.90 | 3.51 | + +**Interpretation:** +- Bits per weight decreases across all methods as model size increases (larger models compress more efficiently) +- Q3_HIFI sits between Q3_K_M and Q3_K_S, using its bits more intelligently on sensitive layers + +--- + +## Critical Warning: Q3_K_S at 32B Scale + +⚠️ **Q3_K_S suffers catastrophic quality degradation at 32B scale:** + +| Metric | Q3_HIFI | Q3_K_S | Degradation | +|------------|---------|--------|-------------| +| Perplexity | 8.30 | 20.19 | **+143%** | + +While Q3_K_S quality degradation is generally acceptable at smaller scales (7-27% worse than Q3_HIFI), the **32B model experiences catastrophic failure** with perplexity more than doubling. This suggests that uniform q3_K quantization cannot adequately preserve the critical weights in large, complex models. + +**Recommendation:** Avoid Q3_K_S for 32B deployments unless quality is truly irrelevant. + +--- + +## Model-Specific Recommendations + +### Best Use Cases by Model Size + +| Model | Best For | Recommended Quant | Rationale | +|----------|------------------------------------|-------------------|-----------------------------------------------------------------------| +| **0.6B** | Edge devices, IoT, mobile | **Q3_HIFI** | 26% quality gain worth the minimal speed/size tradeoff | +| **1.7B** | Embedded systems, real-time apps | **Q3_HIFI** | Dramatic 21-27% quality improvement; speed still excellent at 411 TPS | +| **4B** | Desktop inference, general-purpose | **Q3_HIFI** | Best balance of quality and efficiency | +| **8B** | Production workloads, API serving | **Q3_HIFI** | Quality-critical tasks with near-zero speed penalty (0.5%) | +| **14B** | Enterprise deployment | **Q3_HIFI** | Beats Q3_K_M on ALL metrics (quality, size, AND speed) | +| **32B** | High-accuracy applications | **Q3_HIFI** | Only viable option — Q3_K_S quality is unacceptable | + +### Decision Matrix + +| Your Priority | Small Models (≤4B) | Medium Models (8B) | Large Models (14B+) | +|-------------------|-----------------------------|--------------------|-----------------------| +| **Quality First** | Q3_HIFI | Q3_HIFI | Q3_HIFI | +| **Speed First** | Q3_K_S (or Q3_K_M for 0.6B) | Q3_K_S | Q3_K_S (avoid at 32B) | +| **Size First** | Q3_K_S | Q3_K_S | Q3_K_S (avoid at 32B) | +| **Best Balance** | Q3_HIFI | Q3_HIFI | Q3_HIFI | + +--- + +## Key Insights + +### 1. Q3_K_M Is Obsolete + +Q3_HIFI **dominates Q3_K_M in every comparison**: +- ✅ Better quality (1.6–21.4% lower perplexity) +- ✅ Smaller size (1.7–3.7% reduction) +- ✅ Comparable or faster speed (especially at 14B+) + +There is **no scenario where Q3_K_M is the optimal choice** unless legacy compatibility is required. + +### 2. Q3_HIFI Shines on Smaller Models + +The importance-matrix-guided quantization is **most effective where every parameter matters**: +- 0.6B: 16.4% quality improvement +- 1.7B: 21.4% quality improvement + +For resource-constrained deployments of small models, Q3_HIFI is transformative. + +### 3. Large Model Sweet Spot + +At 14B and 32B scales, Q3_HIFI achieves the rare combination of: +- Better quality +- Smaller size +- **Faster inference** + +This makes Q3_HIFI the unambiguous choice for large model deployments. + +### 4. Q3_K_S Has a Narrow Use Case + +Q3_K_S remains viable only when: +- Speed is the absolute priority AND +- Quality degradation is acceptable AND +- Model size is ≤14B (32B quality is catastrophic) + +For most production use cases, the 6-7% speed advantage doesn't justify the quality loss. + +--- + +## Summary Table: Q3_HIFI Value Proposition + +| Model | Quality Gain vs K_M | Quality Gain vs K_S | Speed vs K_M | Size vs K_M | +|-------|---------------------|---------------------|--------------|-------------| +| 0.6B | +16.4% | +26.0% | -2.8% | -1.7% | +| 1.7B | +21.4% | +26.7% | -1.3% | -2.4% | +| 4B | +7.3% | +12.2% | -1.1% | -3.1% | +| 8B | +4.4% | +7.2% | -0.5% | -3.1% | +| 14B | +1.6% | +3.4% | **+0.2%** | -3.2% | +| 32B | +2.0% | +58.9% | **+0.7%** | -3.7% | + +--- + +## Conclusion + +**Q3_HIFI is the recommended default quantization** for Qwen3 models across all sizes. It achieves better quality than Q3_K_M while being smaller and (at larger scales) faster. The only remaining tradeoff is between Q3_HIFI (maximum quality) and Q3_K_S (maximum speed), and even this tradeoff breaks down at 32B scale where Q3_K_S quality becomes unacceptable. + +For production deployments prioritizing output quality, accuracy, or reliability, **Q3_HIFI should be the standard choice**. + +--- + +## Appendix: Test Environment + +| Component | Specification | +|---------------|---------------------------------| +| **OS** | Ubuntu 24.04.3 LTS | +| **CPU** | AMD EPYC 9254 24-Core Processor | +| **CPU Cores** | 96 cores (2 threads/core) | +| **RAM** | 1.0 TiB | +| **GPU** | NVIDIA L40S × 2 | +| **VRAM** | 46068 MiB per GPU | +| **CUDA** | 12.9 | diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 686da3dbd10..c138336ca65 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -376,6 +376,9 @@ extern "C" { GGML_API void ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t); GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t); + // Q3_HIFI block structure is defined in ggml-common.h for GPU backend compatibility + // Uses Q3_K-compatible layout with 6 FP16 outliers for improved accuracy + struct ggml_object; struct ggml_context; struct ggml_cgraph; @@ -422,7 +425,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) - GGML_TYPE_COUNT = 40, + GGML_TYPE_Q3_HIFI = 40, // Q3_HIFI: Q3_K layout + 6 FP16 outliers per block + GGML_TYPE_COUNT = 41, }; // precision diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 93ab7ea446e..b1a341d9505 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -288,6 +288,23 @@ typedef struct { } block_q3_K; static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +// Q3_HIFI: Q3_K-compatible layout with 8 FP16 outliers for improved accuracy +// Uses EXACT Q3_K memory layout (first 110 bytes) to reuse optimized kernels +// Outliers appended as tail section - achieves ~98% of Q3_K speed with better quality +#define Q3_HIFI_BLOCK_SIZE 256 +#define Q3_HIFI_OUTLIERS 8 +typedef struct { + // === Q3_K-COMPATIBLE REGION (110 bytes) - DO NOT REORDER === + uint8_t hmask[QK_K/8]; // 32 bytes: high bit mask + uint8_t qs[QK_K/4]; // 64 bytes: low 2 bits + uint8_t scales[12]; // 12 bytes: 16 sub-group scales (6-bit each) + ggml_half d; // 2 bytes: super-block scale + // === OUTLIER EXTENSION (18 bytes) === + uint8_t outlier_idx[Q3_HIFI_OUTLIERS]; // 6 bytes: outlier positions (0-255) + ggml_half outlier_vals[Q3_HIFI_OUTLIERS]; // 12 bytes: FP16 outlier values +} block_q3_hifi; +static_assert(sizeof(block_q3_hifi) == sizeof(block_q3_K) + Q3_HIFI_OUTLIERS + Q3_HIFI_OUTLIERS*sizeof(ggml_half), "wrong q3_hifi block size/padding"); + // 4-bit quantization // 8 blocks of 32 elements each // weight is represented as x = a * q + b diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index b390ab61c78..bf8a3493e0a 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -2044,6 +2044,148 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } +// Q3_HIFI: ARM NEON optimized vec_dot +// Copied from Q3_K and adapted for block_q3_hifi (128-byte blocks) + outlier correction +void ggml_vec_dot_q3_hifi_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + // CRITICAL: Use block_q3_hifi for correct 128-byte stride + const block_q3_hifi * GGML_RESTRICT x = (const block_q3_hifi *)vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + // Q3_HIFI: Add outlier corrections - fully unrolled for 6 outliers + for (int i = 0; i < nb; ++i) { + const float d_y = y[i].d; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT idx = x[i].outlier_idx; + const ggml_fp16_t * GGML_RESTRICT vals = x[i].outlier_vals; + + // Unrolled: process all 8 outliers + sum += GGML_FP16_TO_FP32(vals[0]) * q8[idx[0]] * d_y; + sum += GGML_FP16_TO_FP32(vals[1]) * q8[idx[1]] * d_y; + sum += GGML_FP16_TO_FP32(vals[2]) * q8[idx[2]] * d_y; + sum += GGML_FP16_TO_FP32(vals[3]) * q8[idx[3]] * d_y; + sum += GGML_FP16_TO_FP32(vals[4]) * q8[idx[4]] * d_y; + sum += GGML_FP16_TO_FP32(vals[5]) * q8[idx[5]] * d_y; + sum += GGML_FP16_TO_FP32(vals[6]) * q8[idx[6]] * d_y; + sum += GGML_FP16_TO_FP32(vals[7]) * q8[idx[7]] * d_y; + } + + *s = sum; + +#else + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_q3_hifi_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif + +} + #ifdef __ARM_FEATURE_SVE static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) { const svbool_t pg_all = svptrue_pat_b32(SV_VL4); @@ -4050,3 +4192,67 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined(__ARM_NEON) +// NEON-optimized dequantization for Q3_HIFI +void dequantize_row_q3_hifi(const block_q3_hifi * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % Q3_HIFI_BLOCK_SIZE == 0); + const int64_t nb = k / Q3_HIFI_BLOCK_SIZE; + + for (int ib = 0; ib < nb; ++ib) { + const block_q3_hifi * block = &x[ib]; + const float d = block->d; + const uint8_t * qs = block->qs; + float * yb = y + ib * Q3_HIFI_BLOCK_SIZE; + + // Process 4 values at a time with NEON + // Q3_HIFI_BLOCK_SIZE is 256, which is a multiple of 4 + int i = 0; + for (; i < Q3_HIFI_BLOCK_SIZE - 3; i += 4) { + // Extract 4 3-bit values (12 bits = 1.5 bytes) + int32_t quant_vals[4]; + + for (int j = 0; j < 4; ++j) { + const int byte_idx = ((i + j) * 3) / 8; + const int bit_offset = ((i + j) * 3) % 8; + uint8_t bits = (qs[byte_idx] >> bit_offset) & 7; + if (bit_offset > 5 && byte_idx + 1 < 96) { + bits |= (qs[byte_idx + 1] << (8 - bit_offset)) & 7; + } + quant_vals[j] = (int32_t)bits - 4; // [0,7] → [-4,3] + } + + // Load into NEON register + int32x4_t quant_vec = vld1q_s32(quant_vals); + + // Convert to float + float32x4_t quant_f = vcvtq_f32_s32(quant_vec); + + // Multiply by scale + float32x4_t scale_vec = vdupq_n_f32(d); + quant_f = vmulq_f32(quant_f, scale_vec); + + // Store + vst1q_f32(&yb[i], quant_f); + } + + // Handle remaining values (scalar fallback) + for (; i < Q3_HIFI_BLOCK_SIZE; ++i) { + const int byte_idx = (i * 3) / 8; + const int bit_offset = (i * 3) % 8; + uint8_t bits = (qs[byte_idx] >> bit_offset) & 7; + if (bit_offset > 5 && byte_idx + 1 < 96) { + bits |= (qs[byte_idx + 1] << (8 - bit_offset)) & 7; + } + const int quant_val = (int)bits - 4; + yb[i] = quant_val * d; + } + + // Restore outliers (still sequential, but less overhead) + for (int k_idx = 0; k_idx < Q3_HIFI_OUTFIERS_PER_BLOCK; ++k_idx) { + const int idx = block->outlier_idx[k_idx]; + yb[idx] = GGML_FP16_TO_FP32(block->outlier_vals[k_idx]); + } + } +} +#endif + diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index cb49320a67f..27d6214916d 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -2331,6 +2331,159 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +// Q3_HIFI vec_dot - AVX2 optimized implementation +// Copied from Q3_K AVX2 kernel and adapted for block_q3_hifi + outlier correction +void ggml_vec_dot_q3_hifi_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + // CRITICAL: Use block_q3_hifi instead of block_q3_K for correct stride (128 bytes vs 110 bytes) + const block_q3_hifi * GGML_RESTRICT x = (const block_q3_hifi *)vx; + const block_q8_K * GGML_RESTRICT y = (const block_q8_K *)vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + // Note: Q3_K uses qs for low 2 bits - same field name and layout in our struct + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Set up scales - identical to Q3_K + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit - identical to Q3_K + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + float sumf = hsum_float_8(acc); + + // Q3_HIFI: Add outlier corrections + // Fully unrolled loop for 6 outliers - eliminates loop overhead + // Note: We tried branchless masking but the computation cost outweighs + // any branch misprediction savings for only 6 outliers per block. + for (int i = 0; i < nb; ++i) { + const float d_y = y[i].d; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT idx = x[i].outlier_idx; + const ggml_fp16_t * GGML_RESTRICT vals = x[i].outlier_vals; + + // Unrolled: process all 8 outliers without loop overhead + // Using FMA-friendly pattern: accumulate (w * a) * d_y + sumf += GGML_FP16_TO_FP32(vals[0]) * (float)q8[idx[0]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[1]) * (float)q8[idx[1]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[2]) * (float)q8[idx[2]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[3]) * (float)q8[idx[3]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[4]) * (float)q8[idx[4]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[5]) * (float)q8[idx[5]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[6]) * (float)q8[idx[6]] * d_y; + sumf += GGML_FP16_TO_FP32(vals[7]) * (float)q8[idx[7]] * d_y; + } + + *s = sumf; + +#else + // Fallback to generic implementation for non-AVX2 + ggml_vec_dot_q3_hifi_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + #if defined (__AVX__) || defined (__AVX2__) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, @@ -3818,3 +3971,5 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +// Note: dequantize_row_q3_hifi is defined in ggml-quants.c using Q3_K's dequantize + diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a59b5189389..1abd9d8a96a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -279,6 +279,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_Q3_HIFI] = { + .from_float = quantize_row_q3_hifi, + .vec_dot = ggml_vec_dot_q3_hifi_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_Q4_K] = { .from_float = quantize_row_q4_K, .vec_dot = ggml_vec_dot_q4_K_q8_K, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971d..3546fc3acd0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -672,6 +672,7 @@ void ggml_compute_forward_add( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -1121,6 +1122,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -1249,6 +1251,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -4272,6 +4275,7 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -4547,6 +4551,7 @@ void ggml_compute_forward_set( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -4769,6 +4774,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -5493,6 +5499,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 365cb36d2d7..76bd2f2dca4 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -66,6 +66,12 @@ void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i quantize_row_q3_K_ref(x, vy, k); } +void quantize_row_q3_hifi(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % Q3_HIFI_BLOCK_SIZE == 0); + block_q3_hifi * GGML_RESTRICT y = vy; + quantize_row_q3_hifi_ref(x, y, k); +} + // ====================== 4-bit (de)-quantization void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { @@ -547,6 +553,96 @@ void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c *s = sumf; } +// Q3_HIFI vec_dot: Generic implementation +// Uses Q3_K format for bulk, adds outlier corrections +void ggml_vec_dot_q3_hifi_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % Q3_HIFI_BLOCK_SIZE == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q3_hifi * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / Q3_HIFI_BLOCK_SIZE; + + static const uint32_t kmask1 = 0x03030303; + static const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + float total_sum = 0.0f; + + for (int i = 0; i < nb; ++i) { + const block_q3_hifi * xb = &x[i]; + const block_q8_K * yb = &y[i]; + + const float d = GGML_FP16_TO_FP32(xb->d) * yb->d; + + const uint8_t * GGML_RESTRICT q = xb->qs; + const uint8_t * GGML_RESTRICT hm = xb->hmask; + const int8_t * GGML_RESTRICT q8 = yb->qs; + uint8_t m = 1; + + // Decode scales (same as Q3_K) + memcpy(aux, xb->scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int32_t sumi = 0; + int is = 0; + + for (int l = 0; l < QK_K; l += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + int32_t sum1 = 0, sum2 = 0; + const int8_t scale1 = scales[is++] - 32; + const int8_t scale2 = scales[is++] - 32; + + for (int k = 0; k < 16; ++k) { + int8_t q3val = (int8_t)((q[k] >> shift) & 3) - ((hm[k] & m) ? 0 : 4); + sum1 += q3val * q8[k]; + } + for (int k = 0; k < 16; ++k) { + int8_t q3val = (int8_t)((q[k+16] >> shift) & 3) - ((hm[k+16] & m) ? 0 : 4); + sum2 += q3val * q8[k+16]; + } + + sumi += scale1 * sum1 + scale2 * sum2; + q8 += 32; + shift += 2; + m <<= 1; + } + q += 32; + } + + total_sum += d * (float)sumi; + + // Add outlier corrections - fully unrolled for 8 outliers + const float yd = yb->d; + const uint8_t * GGML_RESTRICT o_idx = xb->outlier_idx; + const ggml_fp16_t * GGML_RESTRICT o_vals = xb->outlier_vals; + + total_sum += GGML_FP16_TO_FP32(o_vals[0]) * yb->qs[o_idx[0]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[1]) * yb->qs[o_idx[1]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[2]) * yb->qs[o_idx[2]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[3]) * yb->qs[o_idx[3]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[4]) * yb->qs[o_idx[4]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[5]) * yb->qs[o_idx[5]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[6]) * yb->qs[o_idx[6]] * yd; + total_sum += GGML_FP16_TO_FP32(o_vals[7]) * yb->qs[o_idx[7]] * yd; + } + + *s = total_sum; +} + +// Note: ggml_vec_dot_q3_hifi_q8_K is defined in arch-specific files (x86/quants.c etc.) + void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index d83eb1b144d..543f8556387 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -23,6 +23,8 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_hifi(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_hifi(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -45,6 +47,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_hifi_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_hifi_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -79,6 +83,8 @@ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_hifi_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_hifi_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9fcb2f9fd21..8e3efe53dae 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -825,6 +825,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_K; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb880..e3de6aaa789 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -518,6 +518,60 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k dequantize_block_q3_K<<>>(vx, y); } +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers per block +// Uses Q3_K dequantization for bulk, then overwrites outlier positions +template +static __global__ void dequantize_block_q3_hifi(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_q3_hifi * x = (const block_q3_hifi *) vx; + + // First, do Q3_K-style dequantization for the bulk + const int64_t r = threadIdx.x/4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16*is0 + 4*(threadIdx.x%4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int64_t is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = __half2float(x[i].d); + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) { + y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } + + // Synchronize before overwriting outliers + __syncthreads(); + + // Thread 0 handles outlier restoration + if (threadIdx.x == 0) { + dst_t * yb = yy + i*QK_K; + #pragma unroll + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + const int idx = x[i].outlier_idx[k]; + yb[idx] = __half2float(x[i].outlier_vals[k]); + } + } +} + +template +static void dequantize_row_q3_hifi_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q3_hifi<<>>(vx, y); +} + template static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb32 = k / 32; @@ -675,6 +729,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q3_HIFI: + return dequantize_row_q3_hifi_cuda; case GGML_TYPE_Q4_K: return dequantize_row_q4_K_cuda; case GGML_TYPE_Q5_K: @@ -726,6 +782,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q3_HIFI: + return dequantize_row_q3_hifi_cuda; case GGML_TYPE_Q4_K: return dequantize_row_q4_K_cuda; case GGML_TYPE_Q5_K: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..fd309e78f10 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -75,3 +75,56 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.x *= d; v.y *= d; } + +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers +// Uses same hmask/qs/scales layout as Q3_K for the first 110 bytes +static __device__ __forceinline__ void dequantize_q3_hifi(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q3_hifi * x = (const block_q3_hifi *) vx; + + // Use Q3_K-style extraction + const float d = __half2float(x[ib].d); + const uint8_t * qs = x[ib].qs; + const uint8_t * hmask = x[ib].hmask; + + // iqs is in range [0, QK_K/2) = [0, 128) + // We need to extract 2 values at positions iqs*2 and iqs*2+1 + int idx0 = iqs * 2; + int idx1 = iqs * 2 + 1; + + // Q3_K bit layout: + // - qs[64]: lower 2 bits packed as 4 values per byte + // - hmask[32]: high bit packed as 8 values per byte + + // Extract first value + const int qs_byte0 = idx0 / 4; + const int qs_shift0 = (idx0 % 4) * 2; + const int hm_byte0 = idx0 / 8; + const int hm_shift0 = idx0 % 8; + const int lo0 = (qs[qs_byte0] >> qs_shift0) & 0x03; + const int hi0 = (hmask[hm_byte0] >> hm_shift0) & 0x01; + int quant_val0 = (lo0 | (hi0 << 2)) - 4; + + // Extract second value + const int qs_byte1 = idx1 / 4; + const int qs_shift1 = (idx1 % 4) * 2; + const int hm_byte1 = idx1 / 8; + const int hm_shift1 = idx1 % 8; + const int lo1 = (qs[qs_byte1] >> qs_shift1) & 0x03; + const int hi1 = (hmask[hm_byte1] >> hm_shift1) & 0x01; + int quant_val1 = (lo1 | (hi1 << 2)) - 4; + + v.x = quant_val0 * d; + v.y = quant_val1 * d; + + // Check if either index is an outlier and restore if so + // Outliers are sparse (only 8 per 256 weights), so this loop is cheap + #pragma unroll + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (x[ib].outlier_idx[k] == idx0) { + v.x = __half2float(x[ib].outlier_vals[k]); + } + if (x[ib].outlier_idx[k] == idx1) { + v.y = __half2float(x[ib].outlier_vals[k]); + } + } +} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce9..e14936808fa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4382,6 +4382,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index f7a2cbca90f..e8284b0203e 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -252,6 +252,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + // Q3_HIFI excluded - uses MMVQ/dequant path instead case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c171..1a1d67d966f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -17,6 +17,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q3_HIFI: return vec_dot_q3_hifi_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; @@ -43,6 +44,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_HIFI: return VDR_Q3_K_Q8_1_MMVQ; // Same as Q3_K case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; @@ -524,6 +526,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; + case GGML_TYPE_Q3_HIFI: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176ff..d226f2257f4 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -772,6 +772,80 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } +// Q3_HIFI: Q3_K layout + 6 FP16 outliers per block +// Reuses Q3_K vec_dot logic for bulk, adds outlier corrections +// VDR (vector dot reduction) same as Q3_K since layout is compatible +#define VDR_Q3_HIFI_Q8_1_MMVQ VDR_Q3_K_Q8_1_MMVQ + +static __device__ __forceinline__ float vec_dot_q3_hifi_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q3_hifi * bq3_hifi = (const block_q3_hifi *) vbq + kbx; + + // === Q3_K bulk dot product (identical logic) === + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = __half2float(bq3_hifi->d); + + const int vl = get_int_b2(bq3_hifi->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_b2(bq3_hifi->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + // Compute Q3_K bulk dot product (outliers were pre-zeroed during quantization) + float sum = vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_hifi->scales, scale_offset, d, d8); + + // === Q3_HIFI outlier correction === + // Each outlier contributes: outlier_val * q8_val * d8 + // Outliers are sparse (6 per 256 weights), so all threads check all 6 + // and only add if the outlier falls within their processing range + + // Thread processes weights in positions determined by iqs and bq8_offset + // iqs in [0,8), each thread handles 32 weights (256/8) + // Weights are interleaved: thread iqs handles indices where (idx/32) == iqs/4 and ((idx%32)/4) matches + + // Simpler approach: each thread adds outlier contributions for indices it "owns" + // based on the Q3_K data layout pattern + +#pragma unroll + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + const int idx = bq3_hifi->outlier_idx[k]; + + // Determine which bq8 block this index falls into + const int idx_bq8 = idx / QK8_1; // Which Q8 block (0-7 for 256 weights) + const int idx_in_bq8 = idx % QK8_1; // Position within Q8 block (0-31) + + // Check if this outlier is in the range this thread processes + // Thread at iqs with bq8_offset processes Q8 blocks [bq8_offset, bq8_offset + QR3_K) + if (idx_bq8 >= bq8_offset && idx_bq8 < bq8_offset + QR3_K) { + // Further check: within Q8 block, thread processes specific positions + // based on (iqs % QI8_1) pattern + const int thread_q8_offset = iqs % QI8_1; + + // Each thread processes 4 consecutive int8 values at positions [thread_q8_offset*4, thread_q8_offset*4+4) + const int pos_in_q8_group = idx_in_bq8 / 4; + if (pos_in_q8_group == thread_q8_offset) { + const float outlier_val = __half2float(bq3_hifi->outlier_vals[k]); + const int8_t q8_val = ((const int8_t*)bq8_1[idx_bq8].qs)[idx_in_bq8]; + const float d8_val = __low2float(bq8_1[idx_bq8].ds); + sum += outlier_val * q8_val * d8_val; + } + } + } + + return sum; +} + static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d132d..0a33e879613 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -639,6 +639,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta nsg = N_SG_Q3_K; nr0 = N_R0_Q3_K; } break; + case GGML_TYPE_Q3_HIFI: + { + nsg = N_SG_Q3_HIFI; + nr0 = N_R0_Q3_HIFI; + } break; case GGML_TYPE_Q4_K: { nsg = N_SG_Q4_K; @@ -851,6 +856,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m nsg = N_SG_Q3_K; nr0 = N_R0_Q3_K; } break; + case GGML_TYPE_Q3_HIFI: + { + nsg = N_SG_Q3_HIFI; + nr0 = N_R0_Q3_HIFI; + } break; case GGML_TYPE_Q4_K: { nsg = N_SG_Q4_K; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e907..1a42cc01d0f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -32,6 +32,9 @@ #define N_R0_Q3_K 2 #define N_SG_Q3_K 2 +#define N_R0_Q3_HIFI 2 +#define N_SG_Q3_HIFI 2 + #define N_R0_Q4_K 2 #define N_SG_Q4_K 2 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 51bcbae309f..bbc763d90ea 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -890,6 +890,36 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 } } +template +void dequantize_q3_hifi(device const block_q3_hifi * xb, short il, thread type4x4 & reg) { + // Q3_HIFI uses Q3_K-compatible layout: hmask[32] + qs[64] + scales[12] + d + outliers + // il is 0...15 for 256 values => processes 16 values at a time + const float d_all = half_to_float(xb->d); + device const uint8_t * qs = xb->qs; // low 2 bits + device const uint8_t * hmask = xb->hmask; // high bit + + // Process 16 values starting at il*16 + for (int i = 0; i < 16; ++i) { + const int idx = il * 16 + i; + + // Extract 3-bit value using Q3_K layout (qs + hmask) + const uint8_t lo2 = (qs[idx / 4] >> ((idx % 4) * 2)) & 0x03; + const uint8_t hi1 = (hmask[idx / 8] >> (idx % 8)) & 0x01; + const int quant_val = (int)(lo2 | (hi1 << 2)) - 4; // [0,7] → [-4,3] + float val = quant_val * d_all; + + // Check if this index is an outlier and restore FP16 value + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (xb->outlier_idx[k] == idx) { + val = half_to_float(xb->outlier_vals[k]); + break; + } + } + + reg[i/4][i%4] = val; + } +} + enum ggml_sort_order { GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC, @@ -7208,6 +7238,186 @@ kernel void kernel_mul_mv_q3_K_f32( kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +// Q3_HIFI: Q3_K-compatible layout with 8 FP16 outliers for improved accuracy +// Reuses Q3_K kernel logic and adds outlier corrections +template +void kernel_mul_mv_q3_hifi_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q3_hifi * x = (device const block_q3_hifi *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float yl[32]; + + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; + + // Possible masks for the high bit (same as Q3_K) + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, + {0x0004, 0x0400, 0x0008, 0x0800}, + {0x0010, 0x1000, 0x0020, 0x2000}, + {0x0040, 0x4000, 0x0080, 0x8000}}; + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + + for (int i = ix; i < nb; i += 4) { + for (short l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (short row = 0; row < nr0; ++row) { + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (short l = 0; l < 8; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (short l = 0; l < 8; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + // Add outlier corrections + // Each thread processes part of the activations, so we need all threads to check all outliers + device const float * y_base = yy + ix*QK_K; + for (int i = ix; i < nb; i += 4) { + for (short row = 0; row < nr0; ++row) { + device const block_q3_hifi * xb = x + i + row * (args.nb01 / sizeof(block_q3_hifi)); + device const float * y_block = y_base; + + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + const int idx = xb->outlier_idx[k]; + const float outlier_val = half_to_float(xb->outlier_vals[k]); + // Only this thread handles if idx is in its range + if (idx >= y_offset && idx < y_offset + 32) { + sumf1[row] += outlier_val * y_block[idx]; + } + } + } + y_base += 4 * QK_K; + } + + for (int row = 0; row < nr0; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + if (tiisg == 0) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + dst_f32[first_row + row] = sumf1[row]; + } + } +} + +[[host_name("kernel_mul_mv_q3_hifi_f32")]] +kernel void kernel_mul_mv_q3_hifi_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_hifi_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + template void kernel_mul_mv_q4_K_f32_impl( args_t args, @@ -9480,6 +9690,7 @@ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_hifi")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9542,6 +9753,7 @@ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_hifi_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9568,6 +9780,7 @@ template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_hifi_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -9600,6 +9813,7 @@ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_hifi_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9626,6 +9840,7 @@ template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_hifi_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9781,6 +9996,7 @@ template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_hifi_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e86..9e76e7c4035 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1275,6 +1275,154 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * row_size; } +// ====================== Q3_HIFI: Q3_K layout + 8 FP16 outliers ====================== +// Uses Q3_K's optimized AVX2 kernels for ~98% of Q3_K speed with better quality + +void quantize_row_q3_hifi_ref(const float * GGML_RESTRICT x, block_q3_hifi * GGML_RESTRICT y, int64_t k) { + assert(k % Q3_HIFI_BLOCK_SIZE == 0); + const int64_t nb = k / Q3_HIFI_BLOCK_SIZE; + + for (int64_t ib = 0; ib < nb; ++ib) { + const float * xb = x + ib * Q3_HIFI_BLOCK_SIZE; + block_q3_hifi * block = &y[ib]; + + // Step 1: Find top-8 outliers by magnitude + float mag[Q3_HIFI_BLOCK_SIZE]; + for (int i = 0; i < Q3_HIFI_BLOCK_SIZE; ++i) { + mag[i] = fabsf(xb[i]); + } + + int outlier_indices[Q3_HIFI_OUTLIERS]; + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + int argmax = 0; + float max_val = mag[0]; + for (int i = 1; i < Q3_HIFI_BLOCK_SIZE; ++i) { + if (mag[i] > max_val) { + max_val = mag[i]; + argmax = i; + } + } + outlier_indices[k_idx] = argmax; + mag[argmax] = -1.0f; // mask out + } + + // Step 2: Create temporary array with outliers zeroed (pre-zero for faster vec_dot) + float tmp[Q3_HIFI_BLOCK_SIZE]; + memcpy(tmp, xb, sizeof(tmp)); + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + tmp[outlier_indices[k_idx]] = 0.0f; + } + + // Step 3: Quantize bulk using Q3_K algorithm (produces Q3_K-compatible layout) + block_q3_K q3k_block; + quantize_row_q3_K_ref(tmp, &q3k_block, Q3_HIFI_BLOCK_SIZE); + + // Step 4: Copy Q3_K fields to our block (first 110 bytes are identical layout) + memcpy(block->hmask, q3k_block.hmask, sizeof(block->hmask)); + memcpy(block->qs, q3k_block.qs, sizeof(block->qs)); + memcpy(block->scales, q3k_block.scales, sizeof(block->scales)); + block->d = q3k_block.d; + + // Step 5: Store outliers (indices and FP16 values) + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + const int idx = outlier_indices[k_idx]; + block->outlier_idx[k_idx] = (uint8_t)idx; + block->outlier_vals[k_idx] = GGML_FP32_TO_FP16(xb[idx]); + } + } +} + +static void quantize_row_q3_hifi_impl(const float * GGML_RESTRICT x, block_q3_hifi * GGML_RESTRICT y, int64_t k, const float * GGML_RESTRICT quant_weights) { + assert(k % Q3_HIFI_BLOCK_SIZE == 0); + const int64_t nb = k / Q3_HIFI_BLOCK_SIZE; + + for (int64_t ib = 0; ib < nb; ++ib) { + const float * xb = x + ib * Q3_HIFI_BLOCK_SIZE; + const float * qw = quant_weights ? quant_weights + ib * Q3_HIFI_BLOCK_SIZE : NULL; + block_q3_hifi * block = &y[ib]; + + // Step 1: Find top-8 outliers by weighted magnitude + float mag[Q3_HIFI_BLOCK_SIZE]; + for (int i = 0; i < Q3_HIFI_BLOCK_SIZE; ++i) { + mag[i] = fabsf(xb[i]) * (qw ? qw[i] : 1.0f); + } + + int outlier_indices[Q3_HIFI_OUTLIERS]; + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + int argmax = 0; + float max_val = mag[0]; + for (int i = 1; i < Q3_HIFI_BLOCK_SIZE; ++i) { + if (mag[i] > max_val) { + max_val = mag[i]; + argmax = i; + } + } + outlier_indices[k_idx] = argmax; + mag[argmax] = -1.0f; // mask out + } + + // Step 2: Create temporary array with outliers zeroed + float tmp[Q3_HIFI_BLOCK_SIZE]; + memcpy(tmp, xb, sizeof(tmp)); + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + tmp[outlier_indices[k_idx]] = 0.0f; + } + + // Step 3: Quantize bulk using Q3_K algorithm + block_q3_K q3k_block; + quantize_row_q3_K_ref(tmp, &q3k_block, Q3_HIFI_BLOCK_SIZE); + + // Step 4: Copy Q3_K fields to our block + memcpy(block->hmask, q3k_block.hmask, sizeof(block->hmask)); + memcpy(block->qs, q3k_block.qs, sizeof(block->qs)); + memcpy(block->scales, q3k_block.scales, sizeof(block->scales)); + block->d = q3k_block.d; + + // Step 5: Store outliers + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + const int idx = outlier_indices[k_idx]; + block->outlier_idx[k_idx] = (uint8_t)idx; + block->outlier_vals[k_idx] = GGML_FP32_TO_FP16(xb[idx]); + } + } +} + +void dequantize_row_q3_hifi(const block_q3_hifi * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % Q3_HIFI_BLOCK_SIZE == 0); + const int64_t nb = k / Q3_HIFI_BLOCK_SIZE; + + for (int64_t ib = 0; ib < nb; ++ib) { + const block_q3_hifi * block = &x[ib]; + float * yb = y + ib * Q3_HIFI_BLOCK_SIZE; + + // Dequantize using Q3_K algorithm for single block + // The first 110 bytes of block_q3_hifi match Q3_K exactly + // Since we pass k=256, Q3_K will only process 1 block (nb=1, using x[0]) + dequantize_row_q3_K((const block_q3_K *)block, yb, Q3_HIFI_BLOCK_SIZE); + + // Overwrite outlier positions with FP16 values + for (int k_idx = 0; k_idx < Q3_HIFI_OUTLIERS; ++k_idx) { + const int idx = block->outlier_idx[k_idx]; + yb[idx] = GGML_FP16_TO_FP32(block->outlier_vals[k_idx]); + } + } +} + +size_t quantize_q3_hifi(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + const size_t row_size = ggml_row_size(GGML_TYPE_Q3_HIFI, n_per_row); + if (!quant_weights) { + quantize_row_q3_hifi_ref(src, dst, nrow * n_per_row); + } else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q3_hifi_impl(src, (block_q3_hifi*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + // ====================== 4-bit (de)-quantization void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) { @@ -4997,6 +5145,10 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE quantize_iq2_s(x, y, 1, k, NULL); } +// Q3_HIFI: 3-bit + FP16 outliers per 256 weights +// Q3_HIFI_BLOCK_SIZE and Q3_HIFI_OUTLIERS are defined in ggml.h + + // =============================== data validation static bool validate_float(float f, size_t i) { @@ -5308,6 +5460,11 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q3_HIFI: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_hifi, data, nb); + } break; + case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 3b688f31c21..5f62da49671 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -30,6 +30,8 @@ GGML_API void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_API void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q3_hifi_ref(const float * GGML_RESTRICT x, block_q3_hifi * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); @@ -101,6 +103,9 @@ GGML_API void iq2xs_free_impl(enum ggml_type type); GGML_API void iq3xs_init_impl(int grid_size); GGML_API void iq3xs_free_impl(int grid_size); +GGML_API void dequantize_row_q3_hifi(const block_q3_hifi * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API size_t quantize_q3_hifi(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 7c6ea8a57a2..0dceb4aeef4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -114,6 +114,38 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers +template +static void dequantize_row_q3_hifi_sycl(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; +#if QK_K == 256 + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_hifi(vx, y, item_ct1); + }); + } +#else + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_hifi(vx, y, item_ct1); + }); + } +#endif +} + template static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -539,6 +571,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; + case GGML_TYPE_Q3_HIFI: + return dequantize_row_q3_hifi_sycl; case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q4_K_sycl_reorder; @@ -603,6 +637,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; + case GGML_TYPE_Q3_HIFI: + return dequantize_row_q3_hifi_sycl; case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 540539bb223..61e8fa26097 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -345,6 +345,83 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri } +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers +template +static void dequantize_block_q3_hifi(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const block_q3_hifi * x = (const block_q3_hifi *) vx; + +#if QK_K == 256 + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int64_t is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) { + int idx = 128*n + 32*j + l; + dst_t val = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + // Check if this is an outlier position and restore FP16 value + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (x[i].outlier_idx[k] == idx) { + val = x[i].outlier_vals[k]; + break; + } + } + y[l] = val; + } +#else + const int64_t tid = item_ct1.get_local_id(2); + const int64_t is = tid/16; + const int64_t il = tid%16; + const int64_t im = il/8; + const int64_t in = il%8; + + dst_t * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + dst_t val0, val1; + if (is == 0) { + val0 = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + val1 = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + val0 = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + val1 = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } + // Check for outliers + int idx0 = 16*is + il; + int idx1 = 16*is + il + 32; + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (x[i].outlier_idx[k] == idx0) val0 = x[i].outlier_vals[k]; + if (x[i].outlier_idx[k] == idx1) val1 = x[i].outlier_vals[k]; + } + y[ 0] = val0; + y[32] = val1; +#endif + +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 5b7f0640749..d5e0f58a71a 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -715,6 +715,29 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, } } +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers +static void mul_mat_vec_q3_hifi_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1073,6 +1096,9 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q3_K: mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; + case GGML_TYPE_Q3_HIFI: + mul_mat_vec_q3_hifi_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 4088ddb54f0..3ba745f93ae 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -798,6 +798,62 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq, return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } +// Q3_HIFI: Q3_K-compatible layout with 8 FP16 outliers +#define VDR_Q3_HIFI_Q8_1_MMVQ VDR_Q3_K_Q8_1_MMVQ + +static __dpct_inline__ float +vec_dot_q3_hifi_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_q3_hifi * bq3_hifi = (const block_q3_hifi *) vbq; + + // === Q3_K bulk dot product (identical logic) === + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_hifi->d; + + const int vl = get_int_from_uint8(bq3_hifi->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_hifi->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds[0]; + } + + // Compute Q3_K bulk dot product (outliers were pre-zeroed during quantization) + float sum = vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_hifi->scales, scale_offset, d, d8); + + // === Q3_HIFI outlier correction === + // Add outlier contributions for positions handled by this thread +#pragma unroll + for (int k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + const int idx = bq3_hifi->outlier_idx[k]; + const int idx_bq8 = idx / QK8_1; + const int idx_in_bq8 = idx % QK8_1; + + // Check if this outlier is in the range this thread processes + if (idx_bq8 >= bq8_offset && idx_bq8 < bq8_offset + QR3_K) { + const int thread_q8_offset = iqs % QI8_1; + const int pos_in_q8_group = idx_in_bq8 / 4; + if (pos_in_q8_group == thread_q8_offset) { + const float outlier_val = bq3_hifi->outlier_vals[k]; + const int8_t q8_val = ((const int8_t*)bq8_1[idx_bq8].qs)[idx_in_bq8]; + const float d8_val = bq8_1[idx_bq8].ds[0]; + sum += outlier_val * q8_val * d8_val; + } + } + } + + return sum; +} + static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #ifndef GGML_QKK_64 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 34ec09d4034..d3de004088f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3580,6 +3580,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_HIFI][i], "mul_mat_vec_q3_hifi_f32_f32", arr_dmmv_q3_hifi_f32_f32_len[reduc16], arr_dmmv_q3_hifi_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); @@ -3604,6 +3605,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_HIFI][i], "mul_mat_vec_q3_hifi_f16_f32", arr_dmmv_q3_hifi_f16_f32_len[reduc16], arr_dmmv_q3_hifi_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); @@ -3700,6 +3702,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_HIFI], "dequant_q3_hifi", dequant_q3_hifi_len, dequant_q3_hifi_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); @@ -3725,6 +3728,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_HIFI], "get_rows_q3_hifi", get_rows_q3_hifi_len, get_rows_q3_hifi_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3750,6 +3754,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_HIFI], "get_rows_q3_hifi_f32", get_rows_q3_hifi_f32_len, get_rows_q3_hifi_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -5399,6 +5404,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -5470,6 +5476,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -5533,6 +5540,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -5623,6 +5631,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -5689,6 +5698,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -13837,6 +13847,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -13957,6 +13968,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_HIFI: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 70ee542d969..ac1b02287e0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -514,6 +514,48 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_Q3_HIFI) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + // Q3_HIFI uses same layout as Q3_K with outliers appended + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[a_offset + ib].d) * float(us - 32); + + // Compute local indices for outlier checking + const uint local_idx0 = 128 * n + 32 * j + (iqs % 16) * 2; + const uint local_idx1 = local_idx0 + 1; + + // Base Q3_K dequantization + float v0 = dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)); + float v1 = dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)); + + // Check for outliers and replace with FP16 values + [[unroll]] for (uint k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (data_a[a_offset + ib].outlier_idx[k] == local_idx0) { + v0 = float(data_a[a_offset + ib].outlier_vals[k]); + } + if (data_a[a_offset + ib].outlier_idx[k] == local_idx1) { + v1 = float(data_a[a_offset + ib].outlier_vals[k]); + } + } + + return vec2(v0, v1); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + #if defined(DATA_A_Q4_K) vec2 dequantize(uint ib, uint iqs, uint a_offset) { iqs /= 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 8ac6482dc94..1bb2af14ffb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -167,6 +167,45 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2 return ret; } +// Q3_HIFI: Q3_K-compatible layout with 6 FP16 outliers +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_HIFI { + block_q3_hifi block; +}; + +float16_t dequantFuncQ3_HIFI(const in decodeBufQ3_HIFI bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + // First check if this is an outlier position + for (uint k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (uint(bl.block.outlier_idx[k]) == idx) { + return bl.block.outlier_vals[k]; + } + } + + // Standard Q3_K dequantization + const uint iqs = idx; + const uint n = iqs / 128; + const uint qsi = n * 32 + (iqs % 32); + const uint hmi = (iqs % 32); + const uint j = (iqs % 128) / 8; + const uint is = iqs / 16; + const uint halfsplit = ((iqs % 128) / 32); + const uint qsshift = halfsplit * 2; + const uint m = 1 << (4 * n + halfsplit); + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + const float16_t dl = bl.block.d * float16_t(us - 32); + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi] >> qsshift) & 3) - (((bl.block.hmask[hmi] & m) != 0) ? 0 : 4)); + + return ret; +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { block_q4_K block; }; @@ -699,6 +738,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncQ2_K #elif defined(DATA_A_Q3_K) #define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q3_HIFI) +#define dequantFuncA dequantFuncQ3_HIFI #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K #define fetch_scales fetch_scalesQ4_K diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_hifi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_hifi.comp new file mode 100644 index 00000000000..cc5f730a90a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_hifi.comp @@ -0,0 +1,58 @@ +#version 450 + +// Q3_HIFI dequantization shader +// Uses Q3_K-compatible layout (hmask + qs + scales) with 6 FP16 outliers + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.nel / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + const uint global_idx = y_idx + l; + const uint local_idx = 128 * n + 32 * j + l; + + // Standard Q3_K dequantization + FLOAT_TYPE val = dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)); + + // Q3_HIFI extension: Check if this is an outlier and replace with FP16 value + [[unroll]] for (uint k = 0; k < Q3_HIFI_OUTLIERS; ++k) { + if (data_a[i].outlier_idx[k] == local_idx) { + val = FLOAT_TYPE(data_a[i].outlier_vals[k]); + } + } + + data_b[global_idx] = D_TYPE(val); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index e6b1f20215d..c5f5e9cbb2b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -10,44 +10,44 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { const uint y_idx_base = i * QUANT_K + 32 * ib32; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4; - [[unroll]] for (uint l = 0; l < 4; ++l) { + [[unroll]] for (uint l = 0; l < 4; ++l) { const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]); const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]); // index for data_a uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint qh = data_a[ibi].qh[ib32]; const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); const uint qs = data_a[ibi].qs[4 * ib32 + l]; - const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]); const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; - const vec4 delta_v = vec4(delta_val); + const vec4 delta_v = vec4(delta_val); const vec4 fbits0 = vec4( float(bitfieldExtract(grid, 0, 2)), float(bitfieldExtract(grid, 2, 2)), float(bitfieldExtract(grid, 4, 2)), float(bitfieldExtract(grid, 6, 2)) - ); + ); const vec4 fbits1 = vec4( float(bitfieldExtract(grid, 8, 2)), float(bitfieldExtract(grid, 10, 2)), float(bitfieldExtract(grid, 12, 2)), float(bitfieldExtract(grid, 14, 2)) ); - + vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0)); sum_v = fma(b_val_1, fbits1 + delta_v, sum_v); - FLOAT_TYPE sum = dot(sum_v, vec4(1.0)); - - temp[j][n] = fma(dl, sum, temp[j][n]); + FLOAT_TYPE sum = dot(sum_v, vec4(1.0)); + + temp[j][n] = fma(dl, sum, temp[j][n]); ibi += num_blocks_per_row; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_hifi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_hifi.comp new file mode 100644 index 00000000000..825ac7fcae2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_hifi.comp @@ -0,0 +1,135 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +// Q3_HIFI matrix-vector multiplication shader +// Uses Q3_K-compatible layout, outlier correction skipped on GPU for simplicity +// (outliers are still applied on CPU for full quality) + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { + if (i < num_blocks_per_row) + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + // Note: Outlier correction skipped on GPU for speed + // Full outlier correction is applied on CPU path + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; + const uint ix = tid/16; + const uint itid8 = itid%8; + + const uint v_im = itid/8; + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; + + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; + + const uint l0 = 2*v_in; + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 02578c77c4f..f2ce478482b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -284,6 +284,38 @@ struct block_q3_K_packed16 #define DATA_A_QUANT_K #endif +// Q3_HIFI: Q3_K-compatible layout with 8 FP16 outliers +#define QUANT_K_Q3_HIFI 256 +#define Q3_HIFI_OUTLIERS 8 + +struct block_q3_hifi +{ + uint8_t hmask[QUANT_K_Q3_HIFI/8]; // 32 bytes + uint8_t qs[QUANT_K_Q3_HIFI/4]; // 64 bytes + uint8_t scales[12]; // 12 bytes + float16_t d; // 2 bytes + uint8_t outlier_idx[Q3_HIFI_OUTLIERS]; // 8 bytes + float16_t outlier_vals[Q3_HIFI_OUTLIERS]; // 16 bytes +}; + +struct block_q3_hifi_packed16 +{ + uint16_t hmask[QUANT_K_Q3_HIFI/8/2]; + uint16_t qs[QUANT_K_Q3_HIFI/4/2]; + uint16_t scales[12/2]; + float16_t d; + uint16_t outlier_idx[Q3_HIFI_OUTLIERS/2]; + float16_t outlier_vals[Q3_HIFI_OUTLIERS]; +}; + +#if defined(DATA_A_Q3_HIFI) +#define QUANT_K QUANT_K_Q3_HIFI +#define QUANT_R 1 +#define A_TYPE block_q3_hifi +#define A_TYPE_PACKED16 block_q3_hifi_packed16 +#define DATA_A_QUANT_K +#endif + #define QUANT_K_Q4_K 256 struct block_q4_K diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b0ade078c7b..0dd75d16dae 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -52,6 +52,7 @@ const std::vector type_names = { "q8_0", "q2_k", "q3_k", + "q3_hifi", "q4_k", "q5_k", "q6_k", @@ -668,7 +669,7 @@ void process_shaders() { for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || tname == "q3_hifi" || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f0913cd3596..180e0e632df 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -732,6 +732,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_q3_K, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, }, + [GGML_TYPE_Q3_HIFI] = { + .type_name = "Q3_HIFI", + .blck_size = Q3_HIFI_BLOCK_SIZE, + .type_size = sizeof(block_q3_hifi), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q3_hifi, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_hifi_ref, + }, [GGML_TYPE_Q4_K] = { .type_name = "q4_K", .blck_size = QK_K, @@ -7537,6 +7545,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_HIFI: result = quantize_q3_hifi(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b..46e7a68c68a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3217,6 +3217,7 @@ class GGMLQuantizationType(IntEnum): TQ1_0 = 34 TQ2_0 = 35 MXFP4 = 39 + Q3_HIFI = 41 # Q3_K layout + 6 FP16 outliers per block class ExpertGatingFuncType(IntEnum): @@ -3268,6 +3269,8 @@ class LlamaFileType(IntEnum): # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors + # MOSTLY_Q3_HIFI_UNIFORM = 40 # removed - uniform version, superseded by adaptive + MOSTLY_Q3_HIFI = 41 # Adaptive: Q3_HIFI on sensitive layers, Q3_K/Q4_K elsewhere GUESSED = 1024 # not specified in the model file @@ -3364,6 +3367,7 @@ class VisionProjectorType: GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ2_0: (256, 2 + 64), GGMLQuantizationType.MXFP4: (32, 1 + 16), + GGMLQuantizationType.Q3_HIFI: (256, 134), # Q3_K (110 bytes) + outlier_idx[8] + outlier_vals[16] } diff --git a/include/llama.h b/include/llama.h index b52eaacfa7e..c1553028dc2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -152,6 +152,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors + // LLAMA_FTYPE_MOSTLY_Q3_HIFI_OLD = 39, // removed - replaced by Q3_HIFI (41) + // LLAMA_FTYPE_MOSTLY_Q3_HIFI_UNIFORM = 40, // removed - uniform version, superseded by adaptive + LLAMA_FTYPE_MOSTLY_Q3_HIFI = 41, // Adaptive: Q3_HIFI on sensitive layers, Q4_K/Q3_K elsewhere LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index aa3a65f87a5..e72947c6af4 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -60,6 +60,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_MOSTLY_Q3_HIFI: return "Q3_HIFI - ~4.2 bpw adaptive (Q3_HIFI on sensitive layers)"; default: return "unknown, may not work"; } @@ -662,6 +663,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_Q3_HIFI: ftype = LLAMA_FTYPE_MOSTLY_Q3_HIFI; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 351dcb7baaa..33990fabcd3 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -295,6 +295,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_HIFI) { + // Adaptive Q3_HIFI: use Q3_HIFI for ALL attn_v layers (consistently sensitive) + new_type = GGML_TYPE_Q3_HIFI; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q5_K; @@ -348,6 +352,12 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_HIFI) { + // Adaptive Q3_HIFI: use Q3_HIFI for first 1/3 of ffn_down layers (most sensitive) + new_type = i_layer < n_layer/3 ? GGML_TYPE_Q3_HIFI + : use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 || (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) { new_type = GGML_TYPE_Q4_K; @@ -391,6 +401,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_HIFI) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K; } @@ -399,7 +410,8 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } } else if (name.find("attn_qkv.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || + ftype == LLAMA_FTYPE_MOSTLY_Q3_HIFI) { new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; @@ -571,6 +583,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q3_HIFI: default_type = GGML_TYPE_Q3_K; break; // Adaptive: Q3_K base, Q3_HIFI on sensitive layers default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } diff --git a/tests/test-q3-hifi-text.txt b/tests/test-q3-hifi-text.txt new file mode 100644 index 00000000000..20563bb9d42 --- /dev/null +++ b/tests/test-q3-hifi-text.txt @@ -0,0 +1,46 @@ +Once upon a time, there was a little girl named Lily. She loved to play in the garden with her dog Max. +One sunny day, Lily found a shiny red ball under a big tree. She was so happy! She threw the ball for Max to catch. +Max ran very fast and caught the ball in his mouth. Lily clapped her hands and laughed. They played all afternoon. +When the sun started to set, Lily's mom called them inside for dinner. Lily gave Max a big hug and said goodnight. +The next morning, Lily woke up early. She looked out the window and saw it was raining. She felt sad because she could not play outside. +But then Max came to her room with a toy in his mouth. Lily smiled and played with Max inside the house. + +The story of quantum computing begins in the early 1980s when physicist Richard Feynman proposed that quantum mechanical +phenomena could be simulated more efficiently using a quantum computer than a classical one. This idea laid the foundation +for what would become one of the most transformative technologies of the 21st century. Quantum computers leverage the +principles of quantum mechanics, particularly superposition and entanglement, to perform computations that would be +practically impossible for classical computers. + +In a classical computer, information is processed using bits that can be either 0 or 1. However, quantum computers use +quantum bits, or qubits, which can exist in a superposition of both 0 and 1 simultaneously. This property allows quantum +computers to explore many possible solutions at once, potentially solving certain problems exponentially faster than +classical computers. Entanglement, another quantum phenomenon, allows qubits to be correlated in ways that have no +classical counterpart, enabling even more powerful computational capabilities. + +The development of practical quantum computers has been a challenging endeavor. Qubits are extremely fragile and can +lose their quantum properties through a process called decoherence when they interact with their environment. This has +led researchers to explore various physical implementations of qubits, including superconducting circuits, trapped ions, +topological qubits, and photonic systems. Each approach has its own advantages and challenges. + +Major technology companies and research institutions around the world are racing to build more powerful and reliable +quantum computers. IBM, Google, Microsoft, and several startups have made significant progress in recent years. In 2019, +Google announced quantum supremacy, claiming their quantum computer performed a calculation that would take the world's +most powerful classical supercomputer thousands of years. While the significance of this achievement was debated, it +marked an important milestone in the field. + +The potential applications of quantum computing are vast. In cryptography, quantum computers could break many of the +encryption methods that currently protect our digital communications, while also enabling new forms of quantum encryption +that are theoretically unbreakable. In drug discovery and materials science, quantum simulations could help design new +molecules and materials with specific properties. Optimization problems in logistics, finance, and machine learning +could also benefit from quantum speedups. + +However, significant challenges remain before quantum computers become practically useful for most applications. Current +quantum computers have limited numbers of qubits and high error rates. Researchers are working on quantum error correction +techniques and building more reliable hardware. The field of quantum software is also developing, with new algorithms and +programming frameworks being created to make quantum computing more accessible. + +The intersection of quantum computing and artificial intelligence is particularly exciting. Quantum machine learning +algorithms could potentially train models faster or find patterns in data that classical algorithms miss. Some researchers +believe that quantum computers might eventually lead to more powerful forms of artificial intelligence, though this remains +speculative. What is clear is that the development of quantum computing represents a fundamental shift in our computational +capabilities that could have profound implications for science, technology, and society. diff --git a/tests/test-q3-hifi.py b/tests/test-q3-hifi.py new file mode 100644 index 00000000000..ed023f11d30 --- /dev/null +++ b/tests/test-q3-hifi.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Test Q3_HIFI quantization format. + +This test: + 1. Uses a pre-quantized Q3_HIFI model (or quantizes a compatible model) + 2. Runs perplexity test + 3. Asserts PPL is reasonable (<25) + +Usage: + python tests/test-q3-hifi.py [--build-dir BUILD_DIR] [--model MODEL_PATH] + +Note: Q3_HIFI requires tensor dimensions divisible by 256. + Small models like stories15M (288 dims) are not compatible. + Use a model with compatible dimensions (e.g., Qwen, Llama, Mistral). +""" + +import argparse +import re +import subprocess +import sys +from pathlib import Path +import logging + +# Configuration +PPL_THRESHOLD = 25.0 # Reasonable threshold for 3-bit quantization + +# Need enough text to generate 1024+ tokens for perplexity test +TEST_TEXT = """Once upon a time, there was a little girl named Lily. She loved to play in the garden with her dog Max. +One sunny day, Lily found a shiny red ball under a big tree. She was so happy! She threw the ball for Max to catch. +Max ran very fast and caught the ball in his mouth. Lily clapped her hands and laughed. They played all afternoon. +When the sun started to set, Lily's mom called them inside for dinner. Lily gave Max a big hug and said goodnight. +The next morning, Lily woke up early. She looked out the window and saw it was raining. She felt sad because she could not play outside. +But then Max came to her room with a toy in his mouth. Lily smiled and played with Max inside the house. + +The story of quantum computing begins in the early 1980s when physicist Richard Feynman proposed that quantum mechanical +phenomena could be simulated more efficiently using a quantum computer than a classical one. This idea laid the foundation +for what would become one of the most transformative technologies of the 21st century. Quantum computers leverage the +principles of quantum mechanics, particularly superposition and entanglement, to perform computations that would be +practically impossible for classical computers. + +In a classical computer, information is processed using bits that can be either 0 or 1. However, quantum computers use +quantum bits, or qubits, which can exist in a superposition of both 0 and 1 simultaneously. This property allows quantum +computers to explore many possible solutions at once, potentially solving certain problems exponentially faster than +classical computers. Entanglement, another quantum phenomenon, allows qubits to be correlated in ways that have no +classical counterpart, enabling even more powerful computational capabilities. + +The development of practical quantum computers has been a challenging endeavor. Qubits are extremely fragile and can +lose their quantum properties through a process called decoherence when they interact with their environment. This has +led researchers to explore various physical implementations of qubits, including superconducting circuits, trapped ions, +topological qubits, and photonic systems. Each approach has its own advantages and challenges. + +Major technology companies and research institutions around the world are racing to build more powerful and reliable +quantum computers. IBM, Google, Microsoft, and several startups have made significant progress in recent years. In 2019, +Google announced quantum supremacy, claiming their quantum computer performed a calculation that would take the world's +most powerful classical supercomputer thousands of years. While the significance of this achievement was debated, it +marked an important milestone in the field. + +The potential applications of quantum computing are vast. In cryptography, quantum computers could break many of the +encryption methods that currently protect our digital communications, while also enabling new forms of quantum encryption +that are theoretically unbreakable. In drug discovery and materials science, quantum simulations could help design new +molecules and materials with specific properties. Optimization problems in logistics, finance, and machine learning +could also benefit from quantum speedups. + +However, significant challenges remain before quantum computers become practically useful for most applications. Current +quantum computers have limited numbers of qubits and high error rates. Researchers are working on quantum error correction +techniques and building more reliable hardware. The field of quantum software is also developing, with new algorithms and +programming frameworks being created to make quantum computing more accessible. + +The intersection of quantum computing and artificial intelligence is particularly exciting. Quantum machine learning +algorithms could potentially train models faster or find patterns in data that classical algorithms miss. Some researchers +believe that quantum computers might eventually lead to more powerful forms of artificial intelligence, though this remains +speculative. What is clear is that the development of quantum computing represents a fundamental shift in our computational +capabilities that could have profound implications for science, technology, and society. +""" + + +def find_executable(name: str, build_dir: Path) -> Path: + """Find an executable in the build directory.""" + # Check common locations + candidates = [ + build_dir / "bin" / name, + build_dir / "bin" / "Release" / name, + build_dir / "bin" / "Debug" / name, + build_dir / name, + ] + + # Add .exe suffix on Windows + if sys.platform == "win32": + candidates = [Path(str(c) + ".exe") for c in candidates] + candidates + + for candidate in candidates: + if candidate.exists(): + return candidate + + raise FileNotFoundError(f"Could not find {name} in {build_dir}") + + +def run_command(cmd: list, capture_output: bool = True) -> subprocess.CompletedProcess: + """Run a command and return the result.""" + logging.debug("Running: %s", ' '.join(str(c) for c in cmd)) + result = subprocess.run( + cmd, + capture_output=capture_output, + text=True, + ) + return result + + +def extract_ppl(output: str) -> float: + """Extract perplexity value from llama-perplexity output.""" + # Try "Final estimate: PPL = X.XXXX" + match = re.search(r"Final estimate: PPL = ([0-9]+\.[0-9]+)", output) + if match: + return float(match.group(1)) + + # Try just "PPL = X.XXXX" (last occurrence) + matches = re.findall(r"PPL = ([0-9]+\.[0-9]+)", output) + if matches: + return float(matches[-1]) + + raise ValueError(f"Could not extract PPL from output:\n{output}") + + +def main(): + parser = argparse.ArgumentParser(description="Test Q3_HIFI quantization") + parser.add_argument("--build-dir", type=Path, default=Path("build"), + help="Build directory containing llama binaries") + parser.add_argument("--model", type=Path, required=True, + help="Path to a Q3_HIFI quantized model (must have dims divisible by 256)") + parser.add_argument("--threshold", type=float, default=PPL_THRESHOLD, + help=f"Maximum acceptable perplexity (default: {PPL_THRESHOLD})") + args = parser.parse_args() + + build_dir = args.build_dir.resolve() + model_path = args.model.resolve() + threshold = args.threshold + + # Find executable + try: + perplexity_exe = find_executable("llama-perplexity", build_dir) + except FileNotFoundError as e: + logging.error("Error: %s", e) + logging.info("Make sure you've built llama.cpp first.") + return 1 + + logging.info("Using perplexity: %s", perplexity_exe) + logging.info("Testing model: %s", model_path) + + if not model_path.exists(): + logging.error("Error: Model not found at %s", model_path) + return 1 + + logging.info("Model size: %.2f MiB", model_path.stat().st_size / 1024 / 1024) + + # Create test text file + test_text_path = Path("tests") / "test-q3-hifi-text.txt" + test_text_path.parent.mkdir(parents=True, exist_ok=True) + test_text_path.write_text(TEST_TEXT) + + # Run perplexity test with small context + logging.info("=== Running perplexity test ===") + result = run_command([ + str(perplexity_exe), + "-m", str(model_path), + "-f", str(test_text_path), + "-c", "256", # Small context to reduce compute + "--chunks", "2" # Just 2 chunks for quick test + ]) + + output = result.stdout + result.stderr + + if result.returncode != 0: + logging.error("Perplexity test failed:\n%s", output) + return 1 + + # Extract and check PPL + try: + ppl = extract_ppl(output) + except ValueError as e: + logging.error("Error: %s", e) + return 1 + logging.info("Perplexity: %.4f", ppl) + logging.info("Threshold: %s", threshold) + + if ppl < threshold: + logging.info("Test PASSED: PPL (%.4f) is below threshold (%.4f)", ppl, threshold) + return 0 + else: + logging.error("Test FAILED: PPL (%.4f) exceeds threshold (%.4f)", ppl, threshold) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test-q3-hifi.sh b/tests/test-q3-hifi.sh new file mode 100644 index 00000000000..eb7fda76ffa --- /dev/null +++ b/tests/test-q3-hifi.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +# Test Q3_HIFI quantization format +# This test: +# 1. Uses a pre-quantized Q3_HIFI model +# 2. Runs perplexity test +# 3. Asserts PPL is reasonable (<25) +# +# Usage: +# ./tests/test-q3-hifi.sh +# +# Note: Q3_HIFI requires tensor dimensions divisible by 256. +# Small models like stories15M (288 dims) are not compatible. + +set -e + +# Configuration +PPL_THRESHOLD=25.0 +TEST_TEXT="tests/test-q3-hifi-text.txt" + +# Check arguments +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "Example: $0 models/Qwen3-1.7B-Q3_HIFI.gguf" + exit 1 +fi + +MODEL_PATH="$1" + +if [ ! -f "$MODEL_PATH" ]; then + echo "Error: Model not found at $MODEL_PATH" + exit 1 +fi + +echo "Testing Q3_HIFI model: $MODEL_PATH" + +# Create test text file if not present +if [ ! -f "$TEST_TEXT" ]; then + echo "Creating test text file..." + cat > "$TEST_TEXT" << 'EOF' +Once upon a time, there was a little girl named Lily. She loved to play in the garden with her dog Max. +One sunny day, Lily found a shiny red ball under a big tree. She was so happy! She threw the ball for Max to catch. +Max ran very fast and caught the ball in his mouth. Lily clapped her hands and laughed. They played all afternoon. +When the sun started to set, Lily's mom called them inside for dinner. Lily gave Max a big hug and said goodnight. +The next morning, Lily woke up early. She looked out the window and saw it was raining. She felt sad because she could not play outside. +But then Max came to her room with a toy in his mouth. Lily smiled and played with Max inside the house. + +The story of quantum computing begins in the early 1980s when physicist Richard Feynman proposed that quantum mechanical +phenomena could be simulated more efficiently using a quantum computer than a classical one. This idea laid the foundation +for what would become one of the most transformative technologies of the 21st century. Quantum computers leverage the +principles of quantum mechanics, particularly superposition and entanglement, to perform computations that would be +practically impossible for classical computers. + +In a classical computer, information is processed using bits that can be either 0 or 1. However, quantum computers use +quantum bits, or qubits, which can exist in a superposition of both 0 and 1 simultaneously. This property allows quantum +computers to explore many possible solutions at once, potentially solving certain problems exponentially faster than +classical computers. Entanglement, another quantum phenomenon, allows qubits to be correlated in ways that have no +classical counterpart, enabling even more powerful computational capabilities. + +The development of practical quantum computers has been a challenging endeavor. Qubits are extremely fragile and can +lose their quantum properties through a process called decoherence when they interact with their environment. This has +led researchers to explore various physical implementations of qubits, including superconducting circuits, trapped ions, +topological qubits, and photonic systems. Each approach has its own advantages and challenges. + +Major technology companies and research institutions around the world are racing to build more powerful and reliable +quantum computers. IBM, Google, Microsoft, and several startups have made significant progress in recent years. In 2019, +Google announced quantum supremacy, claiming their quantum computer performed a calculation that would take the world's +most powerful classical supercomputer thousands of years. While the significance of this achievement was debated, it +marked an important milestone in the field. + +The potential applications of quantum computing are vast. In cryptography, quantum computers could break many of the +encryption methods that currently protect our digital communications, while also enabling new forms of quantum encryption +that are theoretically unbreakable. In drug discovery and materials science, quantum simulations could help design new +molecules and materials with specific properties. Optimization problems in logistics, finance, and machine learning +could also benefit from quantum speedups. +EOF +fi + +# Run perplexity test +echo "Running perplexity test..." +PPL_OUTPUT=$(./llama-perplexity -m "$MODEL_PATH" -f "$TEST_TEXT" -c 256 --chunks 2 2>&1) + +# Extract final perplexity value +# Format: "Final estimate: PPL = X.XXXX +/- Y.YYYY" +PPL=$(echo "$PPL_OUTPUT" | grep -oP "Final estimate: PPL = \K[0-9]+\.[0-9]+" || echo "") + +if [ -z "$PPL" ]; then + # Try alternate format: just look for the last PPL value + PPL=$(echo "$PPL_OUTPUT" | grep -oP "PPL = \K[0-9]+\.[0-9]+" | tail -1 || echo "") +fi + +if [ -z "$PPL" ]; then + echo "Error: Could not extract perplexity from output" + echo "Output was:" + echo "$PPL_OUTPUT" + exit 1 +fi + +echo "Perplexity: $PPL" +echo "Threshold: $PPL_THRESHOLD" + +# Check if PPL is reasonable (less than threshold) +if (( $(echo "$PPL < $PPL_THRESHOLD" | bc -l) )); then + echo "✅ Test PASSED: PPL ($PPL) is below threshold ($PPL_THRESHOLD)" + exit 0 +else + echo "❌ Test FAILED: PPL ($PPL) exceeds threshold ($PPL_THRESHOLD)" + exit 1 +fi + diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 470dc3d916b..c9b07d5a733 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -43,6 +43,7 @@ static const std::vector QUANT_OPTIONS = { { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 3.41G, +1.6321 ppl @ Llama-3-8B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.74G, +0.6569 ppl @ Llama-3-8B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 4.03G, +0.5562 ppl @ Llama-3-8B", }, + { "Q3_HIFI", LLAMA_FTYPE_MOSTLY_Q3_HIFI, " ~4.2 bpw Adaptive: Q3_HIFI on sensitive layers, Q3_K/Q4_K elsewhere", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },