Categorieën bekijken

Achtergrond verwijderen met BRA AI-RMBG

BRAI AI RMBG #

RMBG v2.0 is our new state-of-the-art background removal model significantly improves RMBG v1.4. The model is designed to effectively separate foreground from background in a range of categories and image types. This model has been trained on a carefully selected dataset, which includes: general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale. The accuracy, efficiency, and versatility currently rival leading source-available models. It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount.

RMBG-2.0 Background Removal op AMD GPU (DirectML + ONNX) #

Belangrijke licentie informatie #

Het RMBG-2.0 model van BRIA valt onder een gated licentie op HuggingFace.

Dit betekent:

✔ Persoonlijk / research gebruik toegestaan
✔ Model downloaden via HuggingFace account toegestaan
❌ Model weights opnieuw distribueren niet toegestaan
❌ Commercieel gebruik vereist een licentie

Omdat een ONNX export nog steeds de originele model-weights bevat, mag het ONNX model zelf niet worden gedeeld.

Wat wel gedeeld mag worden:

  • export scripts
  • build instructies
  • inference code

Introductie #

In deze handleiding laten we zien hoe je het BRIA RMBG-2.0 achtergrond-verwijderingsmodel kunt exporteren naar ONNX en vervolgens kunt draaien op een AMD GPU via DirectML.

De standaard PyTorch implementatie van RMBG draait meestal op CPU of CUDA.
Met deze methode kan het model ook efficiënt draaien op AMD GPU’s onder Windows.

Door RMBG-2.0 te exporteren naar ONNX en DirectML te gebruiken kan het model efficiënt draaien op AMD GPU’s.

Belangrijkste voordelen:

  • GPU acceleratie zonder CUDA
  • veel snellere inference
  • framework-onafhankelijke runtime

Dit maakt het model geschikt voor:

  • lokale tools
  • batch processing
  • realtime toepassingen

In onze tests op een AMD Radeon RX 7900 XTX ging de inference tijd van:

RuntimeTijd
PyTorch CPU~14 seconden
ONNX + DirectML~1.2 seconden

Dit betekent een 10-12× versnelling.

Benodigdheden #

Hardware

  • AMD GPU (bijvoorbeeld RX 6000 / RX 7000 serie)
  • Windows 10 / Windows 11


Software

  • Python 3.11 of 3.13
  • ONNX Runtime DirectML
  • PyTorch 2.0.1 (alleen voor export)

Python installatie #

Download Python vanaf:

https://www.python.org/downloads/windows


Gebruik de Windows x64 installer.

Tijdens installatie aanvinken: Add Python to PATH

Controleer daarna:

python --version

Python libraries installeren #

Voor de ONNX runtime pipeline:

python -m pip install onnxruntime-directml numpy pillow opencv-python onnx onnxsim

Voor het bouwen van het model:

python -m pip install torch==2.0.1 torchvision==0.15.2 transformers==4.38

RMBG model downloaden #

Download het model van HuggingFace:

https://huggingface.co/briaai/RMBG-2.0

Plaats het model in bijvoorbeeld: ./RMBG

ONNX export script #

RMBG bevat operators die niet standaard naar ONNX exporteren (zoals <strong>deform_conv2d</strong>).

Daarom moet het model licht aangepast worden.

Het export script doet:

  • laden van RMBG
  • patchen van deformable convolution
  • ONNX export
  • graph simplification

Voorbeeld export:

python build_onnx.py

inhoud build_onnx.py:

import os
os.environ["TRANSFORMERS_NO_META_DEVICE"] = "1"

import torch
import torch.nn as nn
import torchvision


import torch.nn.functional as F

def fake_deform_conv2d(
    input,
    offset,
    weight,
    bias=None,
    stride=1,
    padding=0,
    dilation=1,
    mask=None
):
    # fallback naar normale conv
    return F.conv2d(input, weight, bias, stride, padding, dilation)

torchvision.ops.deform_conv2d = fake_deform_conv2d



torch.set_grad_enabled(False)
device = torch.device("cpu")

import torch.onnx
from transformers import AutoModelForImageSegmentation
from onnxsim import simplify
import onnx


MODEL_PATH = "./RMBG"

ONNX_FILE = "rmbg.onnx"
ONNX_SIM_FILE = "rmbg_sim.onnx"


print("Loading RMBG model...")

model = AutoModelForImageSegmentation.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    local_files_only=True
)

model = model.to(device)
model.eval()


# ------------------------------------------------
# Replace DeformConv2d recursively
# ------------------------------------------------

def replace_deform_conv(module):

    for name, child in list(module.named_children()):

        if isinstance(child, torchvision.ops.DeformConv2d):

            print("Replacing DeformConv2d:", name)

            new_conv = nn.Conv2d(
                child.in_channels,
                child.out_channels,
                child.kernel_size,
                child.stride,
                child.padding,
                bias=True
            )

            setattr(module, name, new_conv)

        else:
            replace_deform_conv(child)


replace_deform_conv(model)


# ------------------------------------------------
# Check if any deformconv remains
# ------------------------------------------------

count = 0
for m in model.modules():
    if isinstance(m, torchvision.ops.DeformConv2d):
        count += 1

print("Remaining deform conv layers:", count)


# ------------------------------------------------
# Wrapper so ONNX gets a tensor
# ------------------------------------------------

class RMBGWrapper(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):

        out = self.model(x)

        if isinstance(out, (list, tuple)):
            return out[0]

        if hasattr(out, "predictions"):
            return out.predictions

        return out


model = RMBGWrapper(model)


# ------------------------------------------------
# Dummy input
# ------------------------------------------------

dummy = torch.randn(1, 3, 1024, 1024)


# ------------------------------------------------
# Test forward
# ------------------------------------------------

print("Testing forward pass...")

with torch.no_grad():
    out = model(dummy)

print("Output shape:", out.shape)


# ------------------------------------------------
# Export ONNX
# ------------------------------------------------

print("Exporting ONNX...")

with torch.no_grad():

    torch.onnx.export(
        model,
        dummy,
        ONNX_FILE,
        input_names=["input"],
        output_names=["mask"],
        opset_version=17,
        do_constant_folding=True,
        dynamic_axes={
            "input": {0: "batch", 2: "height", 3: "width"},
            "mask": {0: "batch", 2: "height", 3: "width"}
        }
    )

print("ONNX export complete:", ONNX_FILE)


# ------------------------------------------------
# Simplify ONNX
# ------------------------------------------------

print("Loading ONNX...")

model_onnx = onnx.load(ONNX_FILE)

print("Simplifying ONNX...")

model_simp, check = simplify(
    model_onnx,
    dynamic_input_shape=True
)

if check:

    onnx.save(model_simp, ONNX_SIM_FILE)

    print("Simplified ONNX saved:", ONNX_SIM_FILE)

else:

    print("Simplification failed")


print("Done.")

Dit genereert twee bestanden:

rmbg.onnx
rmbg_sim.onnx

RAW vs SIMPLIFIED model #

ModelBetekenis
rmbg.onnxRAW export uit PyTorch
rmbg_sim.onnxgeoptimaliseerde / opgeschoonde graph

rmbg.onnx (RAW) #

Dit is wat torch.onnx.export() direct produceert.


Kenmerken:

  • bevat vaak onnodige nodes
  • tussenstappen blijven bestaan
  • shape-ops worden niet samengevoegd
  • graph is vaak rommelig

Dus bijvoorbeeld:

Conv → Add → Reshape → Identity → Mul → Reshape → Add

Dit kan eigenlijk simpeler.

rmbg_sim.onnx (simplified) #

De _sim versie is gemaakt met:

onnxsim

Die doet:

  • constant folding
  • node fusion
  • shape propagation
  • onnodige nodes verwijderen

Dus het wordt bijvoorbeeld:

Conv → Add → Mul

Gebruik daarom altijd:

rmbg_sim.onnx

Wat er technisch mis ging in de originele export #

Het RMBG model gebruikt: torchvision::deform_conv2d

Die operator bestaat niet in standaard ONNX.

Jij hebt dus:

  • de operator vervangen
  • het model exporteerbaar gemaakt
  • een DirectML-compatibele graph gekregen

Dat is precies waarom deze ONNX werkt en veel andere RMBG exports niet.

De ONNX die jij hebt gemaakt is beter omdat hij DirectML-compatible is, waardoor hij op je GPU (RX 7900 XTX) kan draaien in plaats van alleen op de CPU.

Framework-onafhankelijk

    ONNX kan draaien met:

    • DirectML (AMD / Windows)
    • CUDA
    • TensorRT
    • OpenVINO
    • CPU

    Je zit dus niet meer vast aan PyTorch.

    Deployment-vriendelijk

    ONNX runtime is veel makkelijker voor:

    • servers
    • API’s
    • C++
    • embedded
    • batch inference

    Wat we gefixt hebben #

    RMBG bevatte deze operator: torchvision::deform_conv2d

    Die DirectML en ONNX niet ondersteunen.

    Daarom moest je:

    • DeformConv patchen
    • ONNX export mogelijk maken
    • Graph vereenvoudigen

    Zonder die stap zou dit gebeuren: UnsupportedOperatorError

    oftewel: model draait helemaal niet.

    Dus dat probleem hebben we opgelost.

    Samenvatting

    Voordelen van jouw ONNX build:

    1. GPU versnelling (DirectML)
      De originele PyTorch RMBG gebruikt operators (zoals deform_conv2d) die niet goed exporteren.
      Jij hebt dat opgelost zodat het model DirectML kan draaien.

    Waarom het bestand groter werd #

    Je zag:

    rmbg.onnx     ~878 MB
    rmbg_sim.onnx ~930 MB


    Dat kan gebeuren omdat de simplifier:

    • sommige runtime berekeningen vooraf berekent
    • en die als constante tensors opslaat


    Dus:

    compute tijdens runtime

    compute vooraf + opslaan


    Daardoor groeit het model soms.

    DirectML inference #

    ONNX Runtime kan DirectML gebruiken voor GPU acceleratie.

    Voorbeeld Python code:

    import onnxruntime as ort
    
    session = ort.InferenceSession(
        "rmbg_sim.onnx",
        providers=[("DmlExecutionProvider", {"device_id":0}), "CPUExecutionProvider"]
    )

    De providers worden automatisch gekozen:

    GPU → DirectML
    fallback → CPU

    GPU selectie #

    In sommige systemen kiest DirectML de verkeerde GPU (bijvoorbeeld de iGPU).

    Dit kan opgelost worden door:

    device_id = 0

    of via Windows Graphics Settings:

    Settings
    System
    Display
    Graphics

    En daar Python instellen op:

    High performance GPU

    Image preprocessing #

    Het model verwacht input in formaat:

    1 x 3 x 1024 x 1024

    Stappen:

    1. afbeelding laden
    2. resize naar 1024×1024
    3. normaliseren (0-1)
    4. HWC → CHW
    5. batch dimension toevoegen

    Mask post-processing #

    De ONNX output is meestal:

    (1,1,32,32)

    Dit moet worden: