diff --git a/blur.html b/blur.html index dc07f71..1f4568b 100644 --- a/blur.html +++ b/blur.html @@ -29,6 +29,10 @@ display: block; margin-bottom: 10px; } + + .inline label { + display: inline-block + } @@ -53,6 +57,15 @@

Renderer Selection

+
+ Shader Type: + + +
diff --git a/blur.js b/blur.js index 96d8e71..c8e14cf 100644 --- a/blur.js +++ b/blur.js @@ -54,8 +54,9 @@ async function initializeBlurRenderer(webGpuDevice) { const zeroCopy = zeroCopyCheckbox.checked; const directOutput = directOutputCheckbox.checked; const zeroCopyTensor = webnnZeroCopyCheckbox.checked; + const useFragment = shaderTypeFrag.checked; const useWebNN = webnnRadio.checked; - appBlurRenderer = await createWebGPUBlurRenderer(webGpuDevice, segmenter, zeroCopy, directOutput, useWebNN, zeroCopyTensor); + appBlurRenderer = await createWebGPUBlurRenderer(webGpuDevice, segmenter, zeroCopy, directOutput, useWebNN, zeroCopyTensor, useFragment); appStatus.innerText = 'Renderer: WebGPU'; console.log('Using WebGPU for blur rendering'); } else { @@ -145,6 +146,7 @@ async function runInWorker(trackProcessor, trackGenerator) { segmenterType: document.querySelector('input[name="segmenter"]:checked').value, zeroCopy: zeroCopyCheckbox ? zeroCopyCheckbox.checked : false, directOutput: directOutputCheckbox ? directOutputCheckbox.checked : false, + useFragment: shaderTypeFrag ? shaderTypeFrag.checked : false, }; // Transfer the readable and writable streams to the worker for zero-copy data handling. @@ -213,6 +215,10 @@ const zeroCopyCheckbox = document.getElementById('zeroCopy'); const zeroCopyLabel = document.getElementById('zeroCopyLabel'); const directOutputCheckbox = document.getElementById('directOutput'); const directOutputLabel = document.getElementById('directOutputLabel'); +const shaderTypeComputeLabel = document.getElementById('shaderTypeComputeLabel'); +const shaderTypeCompute = document.getElementById('shaderTypeCompute'); +const shaderTypeFragLabel = document.getElementById('shaderTypeFragLabel'); +const shaderTypeFrag = document.getElementById('shaderTypeFrag'); const webnnRadio = document.getElementById('segmenter-webnn'); const webnnZeroCopyCheckbox = document.getElementById('webnnZeroCopy'); const webrtcSink = document.getElementById('webrtcSink'); @@ -305,6 +311,10 @@ function updateOptionState() { zeroCopyLabel.style.color = isWebGPU ? '' : '#aaa'; directOutputCheckbox.disabled = !isWebGPU; directOutputLabel.style.color = isWebGPU ? '' : '#aaa'; + shaderTypeCompute.disabled = !isWebGPU; + shaderTypeComputeLabel.style.color = isWebGPU ? '' : '#aaa'; + shaderTypeFrag.disabled = !isWebGPU; + shaderTypeFragLabel.style.color = isWebGPU ? '' : '#aaa'; const useWebRTCSink = webrtcSink.checked; webrtcCodec.style.display = useWebRTCSink ? 'inline-block' : 'none'; @@ -384,7 +394,7 @@ async function initializeApp() { updateOptionState(); // If the app is running, and a core pipeline option changed, restart the pipeline. - const restartNeededOptions = ['renderer', 'useWorker', 'segmenter', 'zeroCopy', 'directOutput']; + const restartNeededOptions = ['renderer', 'useWorker', 'segmenter', 'zeroCopy', 'directOutput', 'shaderType']; if (isRunning && restartNeededOptions.includes(event.target.name)) { console.log(`Restarting pipeline due to change in '${event.target.name}'`); stopVideoProcessing(); diff --git a/blur4/shaders/blend.fragment.wgsl b/blur4/shaders/blend.fragment.wgsl new file mode 100644 index 0000000..aa8853f --- /dev/null +++ b/blur4/shaders/blend.fragment.wgsl @@ -0,0 +1,39 @@ +struct VertexOut { + @builtin(position) pos: vec4f, + @location(0) uv: vec2f, +} + +@vertex +fn vertMain(@builtin(vertex_index) vertexIndex: u32) -> VertexOut { + var pos = array( + vec2f(-1.0, -1.0), vec2f(1.0, -1.0), vec2f(-1.0, 1.0), vec2f(1.0, 1.0), + ); + var uv = pos[vertexIndex] * vec2f(0.5, -0.5) + 0.5; + return VertexOut(vec4f(pos[vertexIndex], 0, 1), uv); +} + +struct ImageSize { + width : i32, + height : i32, + texel_size : vec2, +}; + +@group(0) @binding(0) var input : ${inputTextureType}; +@group(0) @binding(1) var blurred : texture_2d; +@group(0) @binding(2) var mask : texture_2d; +@group(0) @binding(3) var s : sampler; +@group(0) @binding(4) var size : ImageSize; + +const k00 = f32(${k00}); +const kTileSize = ${tileSize}u; + +@fragment +fn fragMain(@location(0) uv: vec2f) -> @location(0) vec4f { + let coord_norm = uv + (0.5 * size.texel_size); + var m = textureSampleLevel(mask, s, coord_norm, 0.0).r; + var c = ${textureSampleCall}; + var b = textureSampleLevel(blurred, s, coord_norm, 0.0); + b = b + (k00 * m) * c; + b = b / b.a; + return mix(b, c, m); +} diff --git a/blur4/shaders/blur.fragment.wgsl b/blur4/shaders/blur.fragment.wgsl new file mode 100644 index 0000000..baa3a40 --- /dev/null +++ b/blur4/shaders/blur.fragment.wgsl @@ -0,0 +1,78 @@ +struct VertexOut { + @builtin(position) pos: vec4f, + @location(0) uv: vec2f, +} + +@vertex +fn vertMain(@builtin(vertex_index) vertexIndex: u32) -> VertexOut { + var pos = array( + vec2f(-1.0, -1.0), vec2f(1.0, -1.0), vec2f(-1.0, 1.0), vec2f(1.0, 1.0), + ); + var uv = pos[vertexIndex] * vec2f(0.5, -0.5) + 0.5; + return VertexOut(vec4f(pos[vertexIndex], 0, 1), uv); +} + +struct ImageSize { + width : i32, + height : i32, + texel_size : vec2, +}; + +@group(0) @binding(0) var input : ${inputTextureType}; +@group(0) @binding(1) var mask : texture_2d; +@group(0) @binding(2) var s : sampler; +@group(0) @binding(3) var size : ImageSize; + +const kRadius = ${radius}u; +const kTileSize = ${tileSize}u; + +fn blur(sample_coordinate : vec2, dir : vec2, pass_no : i32) -> vec4f { + let sample_coordinate_norm = + (vec2(sample_coordinate) + vec2(0.5)) * size.texel_size; + + var kKernel = array(${kernelInitializer}); + + let alpha = 1.0 - textureSampleLevel(mask, s, sample_coordinate_norm, 0.0).r; + var color = ${textureSampleCall}; + var w = kKernel[0]; + if (pass_no == 0) { + w = w * alpha; + } + color = color * w; + + let step = dir * alpha; + var offset = step; + + var coord : vec2; + for (var i = 1u; i <= kRadius; i=i+1u) { + coord = sample_coordinate_norm + offset; + w = kKernel[i]; + if (pass_no == 0) { + w = w * (1.0 - textureSampleLevel(mask, s, coord, 0.0).r); + } + color = color + (w * ${loopTextureSampleCall}); + + coord = sample_coordinate_norm - offset; + w = kKernel[i]; + if (pass_no == 0) { + w = w * (1.0 - textureSampleLevel(mask, s, coord, 0.0).r); + } + color = color + (w * ${loopTextureSampleCall}); + + offset = offset + step; + } + + return color; +} + +@fragment +fn main_horizontal(@builtin(position) pos: vec4f) -> @location(0) vec4f { + let dir = vec2(size.texel_size.x, 0.0); + return blur(vec2i(pos.xy), dir, 0); +} + +@fragment +fn main_vertical(@builtin(position) pos: vec4f) -> @location(0) vec4f { + let dir = vec2(0.0, size.texel_size.y); + return blur(vec2i(pos.xy), dir, 1); +} diff --git a/blur4/shaders/downscale.fragment.wgsl b/blur4/shaders/downscale.fragment.wgsl new file mode 100644 index 0000000..5b09097 --- /dev/null +++ b/blur4/shaders/downscale.fragment.wgsl @@ -0,0 +1,21 @@ +struct VertexOut { + @builtin(position) pos: vec4f, + @location(0) uv: vec2f, +} + +@vertex +fn vertMain(@builtin(vertex_index) vertexIndex: u32) -> VertexOut { + var pos = array( + vec2f(-1.0, -1.0), vec2f(1.0, -1.0), vec2f(-1.0, 1.0), vec2f(1.0, 1.0), + ); + var uv = pos[vertexIndex] * vec2f(0.5, -0.5) + 0.5; + return VertexOut(vec4f(pos[vertexIndex], 0, 1), uv); +} + +@group(0) @binding(0) var inputTexture: ${inputTextureType}; +@group(0) @binding(1) var textureSampler: sampler; + +@fragment +fn fragMain(@location(0) uv: vec2f) -> @location(0) vec4f { + return textureSampleBaseClampToEdge(inputTexture, textureSampler, uv) + vec4f(0, 1, 0, 0); +} diff --git a/webgpu-blur.js b/webgpu-blur.js index e8b4cfc..f133766 100644 --- a/webgpu-blur.js +++ b/webgpu-blur.js @@ -45,10 +45,11 @@ function adjustSizeByResolution(resolution, width, height) { } export class WebGPUBlur { - constructor(device, zeroCopy, directOutput) { + constructor(device, zeroCopy, directOutput, useFragment) { this.device = device; this.zeroCopy = zeroCopy; this.directOutput = directOutput; + this.useFragment = useFragment; this.pipelines = {}; this.sampler = device.createSampler({ magFilter: 'linear', @@ -70,27 +71,85 @@ export class WebGPUBlur { const kernel = this.calculateKernel(this.radius); const kernelInitializer = kernel.join(', '); - const blurHorizontalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, true); - const blurHorizontalModule = this.device.createShaderModule({ code: blurHorizontalShader }); - this.pipelines.horizontal = await this.device.createComputePipelineAsync({ - layout: 'auto', - compute: { module: blurHorizontalModule, entryPoint: 'main_horizontal' }, - }); + if (this.useFragment) { + const format = this.directOutput ? navigator.gpu.getPreferredCanvasFormat() : 'rgba8unorm'; - const blurVerticalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, false); - const blurVerticalModule = this.device.createShaderModule({ code: blurVerticalShader }); - this.pipelines.vertical = await this.device.createComputePipelineAsync({ - layout: 'auto', - compute: { module: blurVerticalModule, entryPoint: 'main_vertical' }, - }); + const blurHorizontalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, true); + const blurHorizontalModule = this.device.createShaderModule({ code: blurHorizontalShader }); + this.pipelines.horizontal = await this.device.createRenderPipelineAsync({ + label: 'blurHorizontal', + layout: 'auto', + vertex: { + module: blurHorizontalModule, + }, + primitive: { + topology: 'triangle-strip' + }, + fragment: { + module: blurHorizontalModule, + entryPoint: 'main_horizontal', + targets: [{ format }] + }, + }); - const k00 = kernel[0] * kernel[0]; - const blendShader = await this.getBlendShader(k00, this.tileSize); - const blendModule = this.device.createShaderModule({ code: blendShader }); - this.pipelines.blend = await this.device.createComputePipelineAsync({ - layout: 'auto', - compute: { module: blendModule, entryPoint: 'main' }, - }); + const blurVerticalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, false); + const blurVerticalModule = this.device.createShaderModule({ code: blurVerticalShader }); + this.pipelines.vertical = await this.device.createRenderPipelineAsync({ + label: 'blurVertical', + layout: 'auto', + vertex: { + module: blurVerticalModule, + }, + primitive: { + topology: 'triangle-strip' + }, + fragment: { + module: blurVerticalModule, + entryPoint: 'main_vertical', + targets: [{ format }] + }, + }); + + const k00 = kernel[0] * kernel[0]; + const blendShader = await this.getBlendShader(k00, this.tileSize); + const blendModule = this.device.createShaderModule({ label: 'blend', code: blendShader }); + this.pipelines.blend = await this.device.createRenderPipelineAsync({ + label: 'blend', + layout: 'auto', + vertex: { + module: blendModule, + }, + primitive: { + topology: 'triangle-strip' + }, + fragment: { + module: blendModule, + targets: [{ format }] + }, + }); + } else { + const blurHorizontalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, true); + const blurHorizontalModule = this.device.createShaderModule({ code: blurHorizontalShader }); + this.pipelines.horizontal = await this.device.createComputePipelineAsync({ + layout: 'auto', + compute: { module: blurHorizontalModule, entryPoint: 'main_horizontal' }, + }); + + const blurVerticalShader = await this.getBlurShader(this.radius, this.tileSize, kernel.length, kernelInitializer, false); + const blurVerticalModule = this.device.createShaderModule({ code: blurVerticalShader }); + this.pipelines.vertical = await this.device.createComputePipelineAsync({ + layout: 'auto', + compute: { module: blurVerticalModule, entryPoint: 'main_vertical' }, + }); + + const k00 = kernel[0] * kernel[0]; + const blendShader = await this.getBlendShader(k00, this.tileSize); + const blendModule = this.device.createShaderModule({ code: blendShader }); + this.pipelines.blend = await this.device.createComputePipelineAsync({ + layout: 'auto', + compute: { module: blendModule, entryPoint: 'main' }, + }); + } } calculateKernel(radius) { @@ -106,7 +165,8 @@ export class WebGPUBlur { } getBlurShader(radius, tileSize, kernelSize, kernelInitializer, isHorizontal) { - return fetch('blur4/shaders/blur.wgsl').then(res => res.text()).then(code => code.replace(/\${(\w+)}/g, (...groups) => ({ + const url = this.useFragment ? 'blur4/shaders/blur.fragment.wgsl' : 'blur4/shaders/blur.wgsl'; + return fetch(url).then(res => res.text()).then(code => code.replace(/\${(\w+)}/g, (...groups) => ({ inputTextureType: (isHorizontal && this.zeroCopy) ? 'texture_external' : 'texture_2d', outputFormat: this.directOutput ? navigator.gpu.getPreferredCanvasFormat() : 'rgba8unorm', radius, @@ -123,7 +183,8 @@ export class WebGPUBlur { } getBlendShader(k00, tileSize) { - return fetch('blur4/shaders/blend.wgsl').then(res => res.text()).then(code => code.replace(/\${(\w+)}/g, (...groups) => ({ + const url = this.useFragment ? 'blur4/shaders/blend.fragment.wgsl' : 'blur4/shaders/blend.wgsl'; + return fetch(url).then(res => res.text()).then(code => code.replace(/\${(\w+)}/g, (...groups) => ({ inputTextureType: this.zeroCopy ? 'texture_external' : 'texture_2d', outputFormat: this.directOutput ? navigator.gpu.getPreferredCanvasFormat() : 'rgba8unorm', k00, @@ -158,54 +219,127 @@ export class WebGPUBlur { device.queue.writeBuffer(blurSizeBuffer, 0, blurSizeData); device.queue.writeBuffer(blurSizeBuffer, 8, blurTexelSizeData); - const horizontalTexture = getOrCreateTexture(device, this.resourceCache, 'horizontal', [blurWidth, blurHeight], this.directOutput, GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING); - const blurredTexture = getOrCreateTexture(device, this.resourceCache, 'blurred', [blurWidth, blurHeight], this.directOutput, GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING); + const textureUsage = (this.useFragment ? GPUTextureUsage.RENDER_ATTACHMENT : GPUTextureUsage.STORAGE_BINDING) | GPUTextureUsage.TEXTURE_BINDING + const horizontalTexture = getOrCreateTexture(device, this.resourceCache, 'horizontal', [blurWidth, blurHeight], this.directOutput, textureUsage); + const blurredTexture = getOrCreateTexture(device, this.resourceCache, 'blurred', [blurWidth, blurHeight], this.directOutput, textureUsage); - const passEncoder = commandEncoder.beginComputePass(); + if (this.useFragment) { + let passEncoder = commandEncoder.beginRenderPass({ + colorAttachments: [{ + view: horizontalTexture.createView(), + loadOp: 'clear', + storeOp: 'store', + }] + }); - const horizontalBindGroup = device.createBindGroup({ - layout: this.pipelines.horizontal.getBindGroupLayout(0), - entries: [ - { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, - { binding: 1, resource: maskTexture.createView() }, - { binding: 2, resource: horizontalTexture.createView() }, - { binding: 3, resource: this.sampler }, - { binding: 4, resource: { buffer: blurSizeBuffer } }, - ], - }); - passEncoder.setPipeline(this.pipelines.horizontal); - passEncoder.setBindGroup(0, horizontalBindGroup); - passEncoder.dispatchWorkgroups(Math.ceil(blurWidth / this.tileSize), Math.ceil(blurHeight / this.tileSize)); - - const verticalBindGroup = device.createBindGroup({ - layout: this.pipelines.vertical.getBindGroupLayout(0), - entries: [ - { binding: 0, resource: horizontalTexture.createView() }, - { binding: 1, resource: maskTexture.createView() }, - { binding: 2, resource: blurredTexture.createView() }, - { binding: 3, resource: this.sampler }, - { binding: 4, resource: { buffer: blurSizeBuffer } }, - ], - }); - passEncoder.setPipeline(this.pipelines.vertical); - passEncoder.setBindGroup(0, verticalBindGroup); - passEncoder.dispatchWorkgroups(Math.ceil(blurWidth / this.tileSize), Math.ceil(blurHeight / this.tileSize)); - - const blendBindGroup = device.createBindGroup({ - layout: this.pipelines.blend.getBindGroupLayout(0), - entries: [ - { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, - { binding: 1, resource: blurredTexture.createView() }, - { binding: 2, resource: maskTexture.createView() }, - { binding: 3, resource: outputTexture.createView() }, - { binding: 4, resource: this.sampler }, - { binding: 5, resource: { buffer: imageSizeBuffer } }, - ], - }); - passEncoder.setPipeline(this.pipelines.blend); - passEncoder.setBindGroup(0, blendBindGroup); - passEncoder.dispatchWorkgroups(Math.ceil(width / this.tileSize), Math.ceil(height / this.tileSize)); + const horizontalBindGroup = device.createBindGroup({ + layout: this.pipelines.horizontal.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, + { binding: 1, resource: maskTexture.createView() }, + { binding: 2, resource: this.sampler }, + { binding: 3, resource: { buffer: blurSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.horizontal); + passEncoder.setBindGroup(0, horizontalBindGroup); + passEncoder.draw(4); + + passEncoder.end(); - passEncoder.end(); + passEncoder = commandEncoder.beginRenderPass({ + colorAttachments: [{ + view: blurredTexture.createView(), + loadOp: 'clear', + storeOp: 'store', + }] + }); + + const verticalBindGroup = device.createBindGroup({ + layout: this.pipelines.vertical.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: horizontalTexture.createView() }, + { binding: 1, resource: maskTexture.createView() }, + { binding: 2, resource: this.sampler }, + { binding: 3, resource: { buffer: blurSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.vertical); + passEncoder.setBindGroup(0, verticalBindGroup); + passEncoder.draw(4); + + passEncoder.end(); + + passEncoder = commandEncoder.beginRenderPass({ + colorAttachments: [{ + view: outputTexture.createView(), + loadOp: 'clear', + storeOp: 'store', + }] + }); + + const blendBindGroup = device.createBindGroup({ + layout: this.pipelines.blend.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, + { binding: 1, resource: blurredTexture.createView() }, + { binding: 2, resource: maskTexture.createView() }, + { binding: 3, resource: this.sampler }, + { binding: 4, resource: { buffer: imageSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.blend); + passEncoder.setBindGroup(0, blendBindGroup); + passEncoder.draw(6); + + passEncoder.end(); + } else { + const passEncoder = commandEncoder.beginComputePass(); + + const horizontalBindGroup = device.createBindGroup({ + layout: this.pipelines.horizontal.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, + { binding: 1, resource: maskTexture.createView() }, + { binding: 2, resource: horizontalTexture.createView() }, + { binding: 3, resource: this.sampler }, + { binding: 4, resource: { buffer: blurSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.horizontal); + passEncoder.setBindGroup(0, horizontalBindGroup); + passEncoder.dispatchWorkgroups(Math.ceil(blurWidth / this.tileSize), Math.ceil(blurHeight / this.tileSize)); + + const verticalBindGroup = device.createBindGroup({ + layout: this.pipelines.vertical.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: horizontalTexture.createView() }, + { binding: 1, resource: maskTexture.createView() }, + { binding: 2, resource: blurredTexture.createView() }, + { binding: 3, resource: this.sampler }, + { binding: 4, resource: { buffer: blurSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.vertical); + passEncoder.setBindGroup(0, verticalBindGroup); + passEncoder.dispatchWorkgroups(Math.ceil(blurWidth / this.tileSize), Math.ceil(blurHeight / this.tileSize)); + + const blendBindGroup = device.createBindGroup({ + layout: this.pipelines.blend.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? inputTexture : inputTexture.createView() }, + { binding: 1, resource: blurredTexture.createView() }, + { binding: 2, resource: maskTexture.createView() }, + { binding: 3, resource: outputTexture.createView() }, + { binding: 4, resource: this.sampler }, + { binding: 5, resource: { buffer: imageSizeBuffer } }, + ], + }); + passEncoder.setPipeline(this.pipelines.blend); + passEncoder.setBindGroup(0, blendBindGroup); + passEncoder.dispatchWorkgroups(Math.ceil(width / this.tileSize), Math.ceil(height / this.tileSize)); + + passEncoder.end(); + } } } diff --git a/webgpu-renderer.js b/webgpu-renderer.js index d7db1ae..5b530ca 100644 --- a/webgpu-renderer.js +++ b/webgpu-renderer.js @@ -5,13 +5,14 @@ import { WebGPUBlur } from './webgpu-blur.js'; class WebGPURenderer { - constructor(device, segmenter, blurrer, { zeroCopy, directOutput }, {useWebNN, zeroCopyTensor}) { + constructor(device, segmenter, blurrer, { zeroCopy, directOutput, useFragment }, {useWebNN, zeroCopyTensor}) { console.log("createWebGPUBlurRenderer", { zeroCopy, directOutput }); this.device = device; this.segmenter = segmenter; this.blurrer = blurrer; this.zeroCopy = zeroCopy; this.directOutput = directOutput; + this.useFragment = useFragment; this.useWebNN = useWebNN; this.zeroCopyTensor = zeroCopyTensor; @@ -33,7 +34,9 @@ class WebGPURenderer { this.segmentationHeight = 144; this.downscaledImageData = new ImageData(this.segmentationWidth, this.segmentationHeight); let downscaleShaderUrl; - if (this.useWebNN && this.zeroCopyTensor) { + if (this.useFragment) { + downscaleShaderUrl = 'blur4/shaders/downscale.fragment.wgsl'; + } else if (this.useWebNN && this.zeroCopyTensor) { downscaleShaderUrl = 'blur4/shaders/downscale-and-convert-to-rgb16float.wgsl'; } else { downscaleShaderUrl = 'blur4/shaders/downscale-and-convert-to-rgba8unorm.wgsl'; @@ -44,10 +47,30 @@ class WebGPURenderer { inputTextureType: zeroCopy ? "texture_external" : "texture_2d", }[groups[1]])) })); - this.downscalePipeline = this.downscaleModule.then(module => this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module, entryPoint: 'main' }, - })); + + if (this.useFragment) { + this.downscalePipeline = this.downscaleModule.then(module => this.device.createRenderPipeline({ + label: 'downscale', + layout: 'auto', + vertex: { + module: module, + }, + primitive: { + topology: 'triangle-strip' + }, + fragment: { + module: module, + targets: [{ + format: 'rgba8unorm' + }] + } + })); + } else { + this.downscalePipeline = this.downscaleModule.then(module => this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module, entryPoint: 'main' }, + })); + } this.downscaleSampler = this.device.createSampler({ magFilter: 'linear', minFilter: 'linear', @@ -78,7 +101,7 @@ class WebGPURenderer { const cacheKey = `${key}_${width}x${height}_${format}_${usage}`; let texture = this.resourceCache[cacheKey]; - if (!texture || texture.width !== width || texture.height !== height) { + if (!texture || texture.width !== width || texture.height !== height || texture.usage !== usage) { console.log("Creating new texture", cacheKey, "with format", format, "and usage", usage); if (texture) { console.log("Destroying old texture", cacheKey); @@ -91,7 +114,7 @@ class WebGPURenderer { } async getOutputRendererFragmentShader(device, width, height) { - if (!this.outputRendererFragmentShader || this.lastDim !== [width, height]) { + if (!this.outputRendererFragmentShader || this.lastDim[0] != width || this.lastDim[1] != height) { this.outputRendererFragmentShader = await this.device.createShaderModule({ code: await fetch('blur4/shaders/render.fragment.wgsl').then(res => res.text()) .then(code => code.replace(/\${(\w+)}/g, (...groups) => ({ width, height }[groups[1]]))), @@ -110,25 +133,49 @@ class WebGPURenderer { } else { destTexture = this.getOrCreateTexture('downscaleDest', { size: [this.segmentationWidth, this.segmentationHeight, 1], - usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.COPY_DST | GPUTextureUsage.COPY_SRC | GPUTextureUsage.TEXTURE_BINDING, + usage: (this.useFragment ? GPUTextureUsage.RENDER_ATTACHMENT : GPUTextureUsage.STORAGE_BINDING) | GPUTextureUsage.COPY_DST | GPUTextureUsage.COPY_SRC | GPUTextureUsage.TEXTURE_BINDING, format: 'rgba8unorm', }); } - const downscaleBindGroup = this.device.createBindGroup({ - layout: (await this.downscalePipeline).getBindGroupLayout(0), - entries: [ - { binding: 0, resource: this.zeroCopy ? sourceTexture : sourceTexture.createView() }, - { binding: 1, resource: this.downscaleSampler }, - { binding: 2, resource: useInterop ? destTexture : destTexture.createView() }, - ], - }); const commandEncoder = this.device.createCommandEncoder(); - const computePass = commandEncoder.beginComputePass(); - computePass.setPipeline(await this.downscalePipeline); - computePass.setBindGroup(0, downscaleBindGroup); - computePass.dispatchWorkgroups(Math.ceil(this.segmentationWidth / 8), Math.ceil(this.segmentationHeight / 8)); - computePass.end(); + + if (this.useFragment) { + const downscaleBindGroup = this.device.createBindGroup({ + layout: (await this.downscalePipeline).getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? sourceTexture : sourceTexture.createView() }, + { binding: 1, resource: this.downscaleSampler }, + ], + }); + + const renderPass = commandEncoder.beginRenderPass({ + colorAttachments: [{ + view: destTexture.createView(), + loadOp: 'clear', + storeOp: 'store' + }] + }); + renderPass.setPipeline(await this.downscalePipeline); + renderPass.setBindGroup(0, downscaleBindGroup); + renderPass.draw(4); + renderPass.end(); + } else { + const downscaleBindGroup = this.device.createBindGroup({ + layout: (await this.downscalePipeline).getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.zeroCopy ? sourceTexture : sourceTexture.createView() }, + { binding: 1, resource: this.downscaleSampler }, + { binding: 2, resource: useInterop ? destTexture : destTexture.createView() }, + ], + }); + + const computePass = commandEncoder.beginComputePass(); + computePass.setPipeline(await this.downscalePipeline); + computePass.setBindGroup(0, downscaleBindGroup); + computePass.dispatchWorkgroups(Math.ceil(this.segmentationWidth / 8), Math.ceil(this.segmentationHeight / 8)); + computePass.end(); + } if (useInterop) { this.device.queue.submit([commandEncoder.finish()]); @@ -231,7 +278,7 @@ class WebGPURenderer { const canvasTexture = this.context.getCurrentTexture(); const outputTexture = this.getOrCreateTexture('outputTexture', { size: [width, height, 1], - usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.COPY_SRC | GPUTextureUsage.TEXTURE_BINDING + usage: (this.useFragment ? GPUTextureUsage.RENDER_ATTACHMENT : GPUTextureUsage.STORAGE_BINDING) | GPUTextureUsage.COPY_SRC | GPUTextureUsage.TEXTURE_BINDING }); const commandEncoder = this.device.createCommandEncoder(); @@ -299,12 +346,20 @@ export async function getWebGPUDevice() { // Ensure we're compatible with directOutput console.log("Adapter features:"); console.log([...adapter.features]); - const requiredFeatures = ['bgra8unorm-storage', 'shader-f16', 'texture-formats-tier1']; + const requiredFeatures = ['bgra8unorm-storage', 'shader-f16']; for (const feature of requiredFeatures) { if (!adapter.features.has(feature)) { console.log(`${feature} is not supported`); } } + // Optional features that we would nevertheless like to have. + const desiredFeatures = ['texture-formats-tier1']; + for (const feature of desiredFeatures) { + if (adapter.features.has(feature)) { + requiredFeatures.push(feature); + } + } + const device = await adapter.requestDevice({ requiredFeatures }); if (!device) { console.error('WebGPU adapter does not support the required features:', requiredFeatures); @@ -314,9 +369,9 @@ export async function getWebGPUDevice() { } // WebGPU blur renderer -export async function createWebGPUBlurRenderer(device, segmenter, zeroCopy, directOutput, useWebNN, zeroCopyTensor) { - const blurrer = new WebGPUBlur(device, zeroCopy, directOutput); +export async function createWebGPUBlurRenderer(device, segmenter, zeroCopy, directOutput, useWebNN, zeroCopyTensor, useFragment) { + const blurrer = new WebGPUBlur(device, zeroCopy, directOutput, useFragment); await blurrer.init(); - return new WebGPURenderer(device, segmenter, blurrer, { zeroCopy, directOutput }, {useWebNN, zeroCopyTensor}); + return new WebGPURenderer(device, segmenter, blurrer, { zeroCopy, directOutput, useFragment }, {useWebNN, zeroCopyTensor}); }