BRAI AI RMBG #

RMBG v2.0 is our new state-of-the-art background removal model significantly improves RMBG v1.4. The model is designed to effectively separate foreground from background in a range of categories and image types. This model has been trained on a carefully selected dataset, which includes: general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale. The accuracy, efficiency, and versatility currently rival leading source-available models. It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount.
RMBG-2.0 Background Removal op AMD GPU (DirectML + ONNX) #
Belangrijke licentie informatie #
Het RMBG-2.0 model van BRIA valt onder een gated licentie op HuggingFace.
Dit betekent:
✔ Persoonlijk / research gebruik toegestaan
✔ Model downloaden via HuggingFace account toegestaan
❌ Model weights opnieuw distribueren niet toegestaan
❌ Commercieel gebruik vereist een licentie
Omdat een ONNX export nog steeds de originele model-weights bevat, mag het ONNX model zelf niet worden gedeeld.
Wat wel gedeeld mag worden:
- export scripts
- build instructies
- inference code
Introductie #
In deze handleiding laten we zien hoe je het BRIA RMBG-2.0 achtergrond-verwijderingsmodel kunt exporteren naar ONNX en vervolgens kunt draaien op een AMD GPU via DirectML.
De standaard PyTorch implementatie van RMBG draait meestal op CPU of CUDA.
Met deze methode kan het model ook efficiënt draaien op AMD GPU’s onder Windows.
Door RMBG-2.0 te exporteren naar ONNX en DirectML te gebruiken kan het model efficiënt draaien op AMD GPU’s.
Belangrijkste voordelen:
- GPU acceleratie zonder CUDA
- veel snellere inference
- framework-onafhankelijke runtime
Dit maakt het model geschikt voor:
- lokale tools
- batch processing
- realtime toepassingen
In onze tests op een AMD Radeon RX 7900 XTX ging de inference tijd van:
| Runtime | Tijd |
|---|---|
| PyTorch CPU | ~14 seconden |
| ONNX + DirectML | ~1.2 seconden |
Dit betekent een 10-12× versnelling.
Benodigdheden #
Hardware
- AMD GPU (bijvoorbeeld RX 6000 / RX 7000 serie)
- Windows 10 / Windows 11
Software
- Python 3.11 of 3.13
- ONNX Runtime DirectML
- PyTorch 2.0.1 (alleen voor export)
Python installatie #
Download Python vanaf:
https://www.python.org/downloads/windows
Gebruik de Windows x64 installer.
Tijdens installatie aanvinken: Add Python to PATH
Controleer daarna:
python --version
Python libraries installeren #
Voor de ONNX runtime pipeline:
python -m pip install onnxruntime-directml numpy pillow opencv-python onnx onnxsim
Voor het bouwen van het model:
python -m pip install torch==2.0.1 torchvision==0.15.2 transformers==4.38
RMBG model downloaden #
Download het model van HuggingFace:
https://huggingface.co/briaai/RMBG-2.0
Plaats het model in bijvoorbeeld: ./RMBG
ONNX export script #
RMBG bevat operators die niet standaard naar ONNX exporteren (zoals <strong>deform_conv2d</strong>).
Daarom moet het model licht aangepast worden.
Het export script doet:
- laden van RMBG
- patchen van deformable convolution
- ONNX export
- graph simplification
Voorbeeld export:
python build_onnx.py
inhoud build_onnx.py:
import os
os.environ["TRANSFORMERS_NO_META_DEVICE"] = "1"
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
def fake_deform_conv2d(
input,
offset,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
mask=None
):
# fallback naar normale conv
return F.conv2d(input, weight, bias, stride, padding, dilation)
torchvision.ops.deform_conv2d = fake_deform_conv2d
torch.set_grad_enabled(False)
device = torch.device("cpu")
import torch.onnx
from transformers import AutoModelForImageSegmentation
from onnxsim import simplify
import onnx
MODEL_PATH = "./RMBG"
ONNX_FILE = "rmbg.onnx"
ONNX_SIM_FILE = "rmbg_sim.onnx"
print("Loading RMBG model...")
model = AutoModelForImageSegmentation.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
local_files_only=True
)
model = model.to(device)
model.eval()
# ------------------------------------------------
# Replace DeformConv2d recursively
# ------------------------------------------------
def replace_deform_conv(module):
for name, child in list(module.named_children()):
if isinstance(child, torchvision.ops.DeformConv2d):
print("Replacing DeformConv2d:", name)
new_conv = nn.Conv2d(
child.in_channels,
child.out_channels,
child.kernel_size,
child.stride,
child.padding,
bias=True
)
setattr(module, name, new_conv)
else:
replace_deform_conv(child)
replace_deform_conv(model)
# ------------------------------------------------
# Check if any deformconv remains
# ------------------------------------------------
count = 0
for m in model.modules():
if isinstance(m, torchvision.ops.DeformConv2d):
count += 1
print("Remaining deform conv layers:", count)
# ------------------------------------------------
# Wrapper so ONNX gets a tensor
# ------------------------------------------------
class RMBGWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
out = self.model(x)
if isinstance(out, (list, tuple)):
return out[0]
if hasattr(out, "predictions"):
return out.predictions
return out
model = RMBGWrapper(model)
# ------------------------------------------------
# Dummy input
# ------------------------------------------------
dummy = torch.randn(1, 3, 1024, 1024)
# ------------------------------------------------
# Test forward
# ------------------------------------------------
print("Testing forward pass...")
with torch.no_grad():
out = model(dummy)
print("Output shape:", out.shape)
# ------------------------------------------------
# Export ONNX
# ------------------------------------------------
print("Exporting ONNX...")
with torch.no_grad():
torch.onnx.export(
model,
dummy,
ONNX_FILE,
input_names=["input"],
output_names=["mask"],
opset_version=17,
do_constant_folding=True,
dynamic_axes={
"input": {0: "batch", 2: "height", 3: "width"},
"mask": {0: "batch", 2: "height", 3: "width"}
}
)
print("ONNX export complete:", ONNX_FILE)
# ------------------------------------------------
# Simplify ONNX
# ------------------------------------------------
print("Loading ONNX...")
model_onnx = onnx.load(ONNX_FILE)
print("Simplifying ONNX...")
model_simp, check = simplify(
model_onnx,
dynamic_input_shape=True
)
if check:
onnx.save(model_simp, ONNX_SIM_FILE)
print("Simplified ONNX saved:", ONNX_SIM_FILE)
else:
print("Simplification failed")
print("Done.")
Dit genereert twee bestanden:
rmbg.onnx
rmbg_sim.onnx
RAW vs SIMPLIFIED model #
| Model | Betekenis |
|---|---|
| rmbg.onnx | RAW export uit PyTorch |
| rmbg_sim.onnx | geoptimaliseerde / opgeschoonde graph |
rmbg.onnx (RAW) #
Dit is wat torch.onnx.export() direct produceert.
Kenmerken:
- bevat vaak onnodige nodes
- tussenstappen blijven bestaan
- shape-ops worden niet samengevoegd
- graph is vaak rommelig
Dus bijvoorbeeld:
Conv → Add → Reshape → Identity → Mul → Reshape → Add
Dit kan eigenlijk simpeler.
rmbg_sim.onnx (simplified) #
De _sim versie is gemaakt met:
onnxsim
Die doet:
- constant folding
- node fusion
- shape propagation
- onnodige nodes verwijderen
Dus het wordt bijvoorbeeld:
Conv → Add → Mul
Gebruik daarom altijd:
rmbg_sim.onnx
Wat er technisch mis ging in de originele export #
Het RMBG model gebruikt: torchvision::deform_conv2d
Die operator bestaat niet in standaard ONNX.
Jij hebt dus:
- de operator vervangen
- het model exporteerbaar gemaakt
- een DirectML-compatibele graph gekregen
Dat is precies waarom deze ONNX werkt en veel andere RMBG exports niet.
De ONNX die jij hebt gemaakt is beter omdat hij DirectML-compatible is, waardoor hij op je GPU (RX 7900 XTX) kan draaien in plaats van alleen op de CPU.
Framework-onafhankelijk
ONNX kan draaien met:
- DirectML (AMD / Windows)
- CUDA
- TensorRT
- OpenVINO
- CPU
Je zit dus niet meer vast aan PyTorch.
Deployment-vriendelijk
ONNX runtime is veel makkelijker voor:
- servers
- API’s
- C++
- embedded
- batch inference
Wat we gefixt hebben #
RMBG bevatte deze operator: torchvision::deform_conv2d
Die DirectML en ONNX niet ondersteunen.
Daarom moest je:
- DeformConv patchen
- ONNX export mogelijk maken
- Graph vereenvoudigen
Zonder die stap zou dit gebeuren: UnsupportedOperatorError
oftewel: model draait helemaal niet.
Dus dat probleem hebben we opgelost.
Samenvatting
Voordelen van jouw ONNX build:
- GPU versnelling (DirectML)
De originele PyTorch RMBG gebruikt operators (zoalsdeform_conv2d) die niet goed exporteren.
Jij hebt dat opgelost zodat het model DirectML kan draaien.
Waarom het bestand groter werd #
Je zag:
rmbg.onnx ~878 MB
rmbg_sim.onnx ~930 MB
Dat kan gebeuren omdat de simplifier:
- sommige runtime berekeningen vooraf berekent
- en die als constante tensors opslaat
Dus:
compute tijdens runtime
↓
compute vooraf + opslaan
Daardoor groeit het model soms.
DirectML inference #
ONNX Runtime kan DirectML gebruiken voor GPU acceleratie.
Voorbeeld Python code:
import onnxruntime as ort
session = ort.InferenceSession(
"rmbg_sim.onnx",
providers=[("DmlExecutionProvider", {"device_id":0}), "CPUExecutionProvider"]
)
De providers worden automatisch gekozen:
GPU → DirectML
fallback → CPU
GPU selectie #
In sommige systemen kiest DirectML de verkeerde GPU (bijvoorbeeld de iGPU).
Dit kan opgelost worden door:
device_id = 0
of via Windows Graphics Settings:
Settings
System
Display
Graphics
En daar Python instellen op:
High performance GPU
Image preprocessing #
Het model verwacht input in formaat:
1 x 3 x 1024 x 1024
Stappen:
- afbeelding laden
- resize naar 1024×1024
- normaliseren (0-1)
- HWC → CHW
- batch dimension toevoegen
Mask post-processing #
De ONNX output is meestal:
(1,1,32,32)
Dit moet worden:

