Introduction
Running machine learning models in the browser has traditionally been limited by CPU performance and WebGL's constraints. WebGPU changes this by providing direct access to the GPU's compute capabilities. With WebGPU, you can run large language models, image classifiers, and other ML workloads at near-native performance—entirely in the browser, with no server required.
This guide covers running ML models with WebGPU, using ONNX Runtime Web and Transformers.js, and building real-time AI applications.
WebGPU for ML: Why It Matters
WebGL vs WebGPU for ML
| Feature | WebGL | WebGPU |
|---|---|---|
| Compute shaders | No (graphics only) | Yes |
| Memory model | Texture-based | Buffer-based |
| Data types | Limited (float32 mainly) | float32, float16, int32, int8 |
| Pipeline control | Limited | Full control |
| Performance | Good for graphics | Excellent for compute |
| Tensor operations | Workarounds needed | Native compute shaders |
Key Advantages for ML
- Compute shaders: Run arbitrary computations on the GPU
- Shared memory: Efficient data sharing between CPU and GPU
- Tensor core support: Leverage hardware-accelerated matrix operations
- Workgroup memory: Cache frequently accessed data on the GPU
ONNX Runtime Web
ONNX Runtime Web runs ONNX models using WebGPU as the execution provider.
Setup
npm install onnxruntime-webRunning a Model
import * as ort from 'onnxruntime-web';
async function runModel() {
// Set WebGPU as the execution provider
ort.env.wasm.numThreads = navigator.hardwareConcurrency;
const session = await ort.InferenceSession.create('./model.onnx', {
executionProviders: ['webgpu'],
});
// Create input tensor
const inputData = new Float32Array(1 * 3 * 224 * 224); // NCHW format
const inputTensor = new ort.Tensor('float32', inputData, [1, 3, 224, 224]);
// Run inference
const results = await session.run({ input: inputTensor });
const output = results.output.data;
console.log('Predictions:', output);
}Image Classification Example
async function classifyImage(imageElement) {
const session = await ort.InferenceSession.create('./mobilenet.onnx', {
executionProviders: ['webgpu'],
});
// Preprocess image
const canvas = document.createElement('canvas');
canvas.width = 224;
canvas.height = 224;
const ctx = canvas.getContext('2d');
ctx.drawImage(imageElement, 0, 0, 224, 224);
const imageData = ctx.getImageData(0, 0, 224, 224);
const { data } = imageData;
// Convert to NCHW format with normalization
const input = new Float32Array(1 * 3 * 224 * 224);
for (let i = 0; i < 224 * 224; i++) {
input[i] = (data[i * 4] / 255 - 0.485) / 0.229; // R
input[i + 224 * 224] = (data[i * 4 + 1] / 255 - 0.456) / 0.224; // G
input[i + 2 * 224 * 224] = (data[i * 4 + 2] / 255 - 0.406) / 0.225; // B
}
const inputTensor = new ort.Tensor('float32', input, [1, 3, 224, 224]);
const results = await session.run({ input: inputTensor });
// Get top predictions
const predictions = Array.from(results.output.data)
.map((score, index) => ({ score, index }))
.sort((a, b) => b.score - a.score)
.slice(0, 5);
return predictions;
}Transformers.js
Hugging Face's Transformers.js provides a high-level API for running transformer models in the browser.
Setup
npm install @huggingface/transformersText Generation
import { pipeline } from '@huggingface/transformers';
async function generateText() {
const generator = await pipeline(
'text-generation',
'Xenova/distilgpt2',
{ device: 'webgpu' }
);
const result = await generator('The future of AI is', {
max_new_tokens: 50,
temperature: 0.7,
});
console.log(result[0].generated_text);
}Image Classification
import { pipeline } from '@huggingface/transformers';
async function classifyImage(imageUrl) {
const classifier = await pipeline(
'image-classification',
'Xenova/vit-base-patch16-224',
{ device: 'webgpu' }
);
const result = await classifier(imageUrl);
console.log(result);
// [{ label: 'Egyptian cat', score: 0.95 }, ...]
}Text Embeddings
import { pipeline } from '@huggingface/transformers';
async function getEmbeddings(texts) {
const extractor = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2',
{ device: 'webgpu' }
);
const embeddings = await extractor(texts, {
pooling: 'mean',
normalize: true,
});
return embeddings;
}Custom Compute Shaders
For custom ML operations, write WGSL compute shaders directly.
Matrix Multiplication
// matmul.wgsl
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let row = id.x;
let col = id.y;
let M = 128u;
let N = 128u;
let K = 128u;
var sum = 0.0;
for (var k = 0u; k < K; k++) {
sum += a[row * K + k] * b[k * N + col];
}
result[row * N + col] = sum;
}async function matrixMultiply(device, a, b, M, N, K) {
const shaderModule = device.createShaderModule({
code: await (await fetch('./matmul.wgsl')).text(),
});
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: { module: shaderModule, entryPoint: 'main' },
});
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: aBuffer } },
{ binding: 1, resource: { buffer: bBuffer } },
{ binding: 2, resource: { buffer: resultBuffer } },
],
});
const commandEncoder = device.createCommandEncoder();
const pass = commandEncoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(M / 16, N / 16);
pass.end();
device.queue.submit([commandEncoder.finish()]);
}Performance Optimization
Batch Processing
async function batchInference(session, images) {
const batchSize = 8;
const results = [];
for (let i = 0; i < images.length; i += batchSize) {
const batch = images.slice(i, i + batchSize);
const batchTensor = createBatchTensor(batch);
const batchResults = await session.run({ input: batchTensor });
results.push(...batchResults.output.data);
}
return results;
}Model Quantization
Use quantized models (int8, fp16) for faster inference and smaller downloads.
// Use quantized model
const session = await ort.InferenceSession.create('./model_quantized.onnx', {
executionProviders: ['webgpu'],
graphOptimizationLevel: 'all',
});Best Practices
- Check WebGPU support: Not all browsers support WebGPU yet
- Use quantized models: int8/fp16 models are faster and smaller
- Warm up the GPU: First inference is slower due to shader compilation
- Batch inputs: Process multiple inputs for better GPU utilization
- Show progress: ML inference can take time—show loading states
Common Pitfalls
| Pitfall | Impact | Solution |
|---|---|---|
| No WebGPU support check | App crashes | Use feature detection |
| Large model downloads | Slow initial load | Use quantized models |
| No warm-up | First inference is slow | Run a dummy inference first |
| Memory leaks | GPU memory exhaustion | Dispose tensors after use |
Framework Comparison for Web ML
| Framework | Model Format | WebGPU Support | Ease of Use | Performance |
|---|---|---|---|---|
| ONNX Runtime Web | ONNX | Native EP | Medium | Excellent |
| Transformers.js | HF/ONNX | Automatic | High | Very Good |
| TensorFlow.js | TF SavedModel | Backend | High | Good |
| MediaPipe | Custom | WebGPU delegate | High | Very Good |
| WebNN (upcoming) | Multiple | Native backend | Medium | Excellent |
Model Architecture Considerations
Not all machine learning models are equally suited for WebGPU inference. Models with many small matrix multiplications benefit less from GPU acceleration than models with large matrix operations. Convolutional neural networks with large feature maps are ideal candidates because the convolution operation is highly parallelizable. Transformer models with large attention matrices also benefit significantly from GPU acceleration. However, models with many sequential operations or complex control flow may not see as much improvement because the GPU excels at parallel work, not sequential processing. Profile your model on the target hardware to determine whether WebGPU acceleration provides meaningful speedup compared to optimized CPU inference.
ONNX Runtime Web
ONNX Runtime Web provides a production-ready runtime for running ONNX models in the browser using WebGPU. It handles model loading, input preprocessing, GPU buffer management, and inference execution. ONNX Runtime Web automatically selects the best execution provider available on the device, falling back to WASM or CPU if WebGPU is not supported. Use the ort.InferenceSession.create method to load a model and the session.run method to perform inference. ONNX Runtime Web supports a wide range of ONNX operators and handles quantized models efficiently, making it the recommended approach for deploying machine learning models on the web.
Privacy and Security Benefits
Running machine learning models on the client using WebGPU provides significant privacy and security benefits. User data never leaves the device, eliminating the risk of data breaches during transmission or storage on the server. This is particularly important for applications processing sensitive data like medical images, personal documents, or biometric data. Client-side inference also reduces server costs because the computational work is distributed across user devices. However, be aware that model weights are downloaded to the client and could potentially be extracted, so this approach is not suitable for proprietary models that need to remain confidential.
Quantization and Model Optimization
Quantization reduces model size and improves inference speed by using lower precision data types. Convert model weights from float32 to float16 for a two times reduction in memory usage with minimal accuracy loss. Use int8 quantization for even greater compression, though this requires calibration data to maintain accuracy. WebGPU supports float16 natively through the f16 WGSL type, making half-precision inference efficient. Implement dynamic quantization that selects precision based on the hardware capabilities of the device. Combine quantization with model pruning to remove unnecessary weights and reduce the computational cost of inference. These optimizations are essential for deploying machine learning models on mobile devices with limited memory and processing power.
Real-Time Inference Applications
WebGPU enables real-time machine learning inference for interactive applications. Implement real-time object detection using a webcam feed processed through a YOLO or SSD model running on the GPU. Build a real-time style transfer application that applies artistic styles to video frames as they are captured. Create a hand gesture recognition system that processes hand landmarks detected by a pose estimation model. Implement real-time speech recognition by running audio through a Whisper model accelerated by WebGPU. These applications require careful optimization to maintain interactive frame rates, including batching inference requests, using efficient model architectures, and pre-processing data on the GPU.
Framework Support
Several frameworks support WebGPU-based machine learning inference. TensorFlow.js includes a WebGPU backend that accelerates tensor operations using compute shaders. ONNX Runtime Web provides a WebGPU execution provider for running ONNX models. Transformers.js supports running Hugging Face models in the browser with WebGPU acceleration. The WebNN API, currently in development, will provide a standardized neural network API that can use WebGPU as a backend. Evaluate these frameworks based on your model format, performance requirements, and target browsers. The ecosystem is rapidly evolving, so stay updated on new releases and features that improve WebGPU-based machine learning on the web.
Handling Different Hardware
WebGPU performance varies significantly across hardware. High-end desktop GPUs can run large models in milliseconds, while mobile GPUs may struggle with the same models. Implement hardware detection to select appropriate model sizes and precision levels for each device. Use the GPUAdapter info property to query the device capabilities including maximum buffer sizes, texture dimensions, and compute shader limits. Provide fallback paths for devices that do not support specific features like float16 or large compute workgroups. Test your application on a range of hardware including desktop GPUs, integrated graphics, and mobile devices to ensure a good experience for all users.
Model Conversion Pipeline
Preparing models for WebGPU inference requires a conversion pipeline. Start with a model trained in PyTorch, TensorFlow, or another framework. Export the model to ONNX format, which provides a standardized representation of the model architecture and weights. Use ONNX Simplifier to optimize the model graph by removing redundant nodes and fusing operations. Quantize the model to reduce size and improve inference speed. Test the converted model against the original to verify that accuracy is preserved. Automate this pipeline using scripts so that model updates can be deployed to the web application quickly and reliably.
Browser Support and Feature Detection
WebGPU is supported in Chrome 113+, Edge 113+, and Firefox Nightly (behind a flag). Safari has experimental support in Technology Preview. Before deploying WebGPU-based ML applications, implement robust feature detection:
async function checkWebGPUSupport() {
if (!navigator.gpu) {
return { supported: false, reason: 'WebGPU not available in this browser' };
}
try {
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
return { supported: false, reason: 'No suitable GPU adapter found' };
}
const features = [...adapter.features];
const limits = adapter.limits;
return {
supported: true,
features,
maxBufferSize: limits.maxBufferSize,
maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize,
maxComputeWorkgroupsPerDimension: limits.maxComputeWorkgroupsPerDimension,
};
} catch (error) {
return { supported: false, reason: error.message };
}
}For applications that must work across all browsers, implement a graceful degradation strategy. Use WebGPU as the primary backend with WASM as the fallback for browsers without WebGPU support. ONNX Runtime Web handles this automatically when you specify multiple execution providers:
const session = await ort.InferenceSession.create('./model.onnx', {
executionProviders: ['webgpu', 'wasm'],
});The runtime will attempt WebGPU first and fall back to WASM if WebGPU is unavailable. This ensures your application works everywhere while taking advantage of GPU acceleration where available.
Building a Complete ML Pipeline
A production ML pipeline in the browser involves more than just running inference. You need to handle model loading, preprocessing, postprocessing, and result visualization. Here is a complete pipeline for real-time image classification:
class MLPipeline {
constructor(modelUrl, labelsUrl) {
this.modelUrl = modelUrl;
this.labelsUrl = labelsUrl;
this.session = null;
this.labels = null;
}
async initialize() {
const [sessionResponse, labelsResponse] = await Promise.all([
fetch(this.modelUrl),
fetch(this.labelsUrl),
]);
const modelBuffer = await sessionResponse.arrayBuffer();
this.labels = await labelsResponse.json();
this.session = await ort.InferenceSession.create(modelBuffer, {
executionProviders: ['webgpu', 'wasm'],
});
// Warm up the GPU with a dummy inference
const dummyInput = new ort.Tensor('float32', new Float32Array(1 * 3 * 224 * 224), [1, 3, 224, 224]);
await this.session.run({ input: dummyInput });
}
async classify(imageData) {
const tensor = this.preprocess(imageData);
const results = await this.session.run({ input: tensor });
return this.postprocess(results);
}
preprocess(imageData) {
const { data, width, height } = imageData;
const float32Data = new Float32Array(3 * width * height);
for (let i = 0; i < width * height; i++) {
float32Data[i] = (data[i * 4] / 255 - 0.485) / 0.229;
float32Data[i + width * height] = (data[i * 4 + 1] / 255 - 0.456) / 0.224;
float32Data[i + 2 * width * height] = (data[i * 4 + 2] / 255 - 0.406) / 0.225;
}
return new ort.Tensor('float32', float32Data, [1, 3, height, width]);
}
postprocess(results) {
const output = results.output.data;
const softmax = this.softmax(output);
return softmax
.map((score, index) => ({ label: this.labels[index], score }))
.sort((a, b) => b.score - a.score)
.slice(0, 5);
}
softmax(arr) {
const max = Math.max(...arr);
const exps = arr.map(x => Math.exp(x - max));
const sumExps = exps.reduce((a, b) => a + b, 0);
return exps.map(x => x / sumExps);
}
}
// Usage
const pipeline = new MLPipeline('./mobilenet.onnx', './labels.json');
await pipeline.initialize();
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const predictions = await pipeline.classify(imageData);
console.log('Top prediction:', predictions[0].label, predictions[0].score);GPU Memory Management
WebGPU requires explicit memory management for GPU buffers. Failing to destroy buffers after use leads to GPU memory leaks that can crash the application:
class GPUBufferManager {
constructor(device) {
this.device = device;
this.buffers = new Set();
}
createBuffer(descriptor) {
const buffer = this.device.createBuffer(descriptor);
this.buffers.add(buffer);
return buffer;
}
destroyBuffer(buffer) {
buffer.destroy();
this.buffers.delete(buffer);
}
destroyAll() {
for (const buffer of this.buffers) {
buffer.destroy();
}
this.buffers.clear();
}
}For long-running applications that process multiple images or video frames, allocate a pool of reusable buffers instead of creating new ones for each inference. This reduces memory allocation overhead and prevents GPU memory fragmentation.
WebNN API: The Emerging Standard
The Web Neural Network API (WebNN) is a W3C specification that provides a hardware-agnostic interface for neural network inference. Unlike WebGPU, which requires writing compute shaders, WebNN provides high-level operations like conv2d, matmul, and softmax that the browser delegates to the most efficient hardware accelerator available:
async function runWithWebNN() {
const context = await navigator.ml.createContext();
const builder = new MLGraphBuilder(context);
// Define model graph
const input = builder.input('input', { dataType: 'float32', dimensions: [1, 3, 224, 224] });
// Conv2d layer
const convWeight = builder.constant(convData);
const conv = builder.conv2d(input, convWeight, {
padding: [1, 1, 1, 1],
strides: [1, 1],
});
// ReLU activation
const relu = builder.relu(conv);
// Build and compute
const graph = await builder.build({ output: relu });
const results = await context.compute(graph, {
input: new Float32Array(1 * 3 * 224 * 224),
});
return results.output.data;
}WebNN is available in Chrome 119+ behind the #web-machine-learning-neural-network flag. When it reaches general availability, it will provide the most efficient path for ML inference on the web by directly leveraging hardware accelerators like NPUs, GPUs, and SIMD-capable CPUs.
Benchmarking WebGPU ML Performance
Benchmarking ML inference on the web requires measuring multiple metrics beyond raw inference time. Track first inference latency (including shader compilation), steady-state inference throughput, GPU memory usage, and power consumption:
class MLBenchmark {
async run(session, inputTensor, iterations = 100) {
// First inference (includes shader compilation)
const start = performance.now();
await session.run({ input: inputTensor });
const firstInference = performance.now() - start;
// Steady-state benchmarks
const times = [];
for (let i = 0; i < iterations; i++) {
const iterStart = performance.now();
await session.run({ input: inputTensor });
times.push(performance.now() - iterStart);
}
times.sort((a, b) => a - b);
return {
firstInference: firstInference.toFixed(2),
median: times[Math.floor(times.length / 2)].toFixed(2),
p95: times[Math.floor(times.length * 0.95)].toFixed(2),
p99: times[Math.floor(times.length * 0.99)].toFixed(2),
throughput: (1000 / (times.reduce((a, b) => a + b) / times.length)).toFixed(1),
};
}
}Typical performance gains from WebGPU over WASM range from 3x to 30x depending on the model architecture. Models with large matrix operations see the greatest improvement, while models with complex control flow see modest gains.
Deploying ML Models to Production
Production deployment of WebGPU ML applications requires careful attention to model serving, caching, and progressive loading. Use a service worker to cache model files for offline support and faster subsequent loads. Implement progressive model loading that starts with a smaller, less accurate model and upgrades when the full model is downloaded:
async function loadModelProgressive() {
// Load small model first for immediate feedback
const smallSession = await ort.InferenceSession.create('./model-small.onnx', {
executionProviders: ['webgpu', 'wasm'],
});
// Start loading full model in background
const fullModelPromise = ort.InferenceSession.create('./model-full.onnx', {
executionProviders: ['webgpu', 'wasm'],
});
// Use small model until full model is ready
let activeSession = smallSession;
fullModelPromise.then(session => {
activeSession = session;
console.log('Upgraded to full model');
});
return {
classify: async (input) => activeSession.run({ input }),
};
}Use Content Delivery Networks (CDNs) to serve model files from edge locations close to users. Enable gzip or brotli compression on model files to reduce download sizes by 30-50%. Set appropriate cache headers so browsers cache model files locally after the first download.
Conclusion
WebGPU enables running machine learning models at near-native performance in the browser. With ONNX Runtime Web and Transformers.js, you can run image classifiers, text generators, and other ML models entirely client-side—no server required.
Key takeaways:
- WebGPU compute shaders enable efficient ML inference on the GPU
- ONNX Runtime Web runs ONNX models with WebGPU acceleration
- Transformers.js provides a high-level API for transformer models
- Quantized models reduce download size and improve performance
- Always check WebGPU support and provide fallbacks
- GPU memory management is critical for long-running applications