Skip to content

RF-DETR Seg Large

Bases: RFDETR

Source code in src/rfdetr/detr.py
class RFDETRSegLarge(RFDETR):
    size = "rfdetr-seg-large"

    def get_model_config(self, **kwargs):
        return RFDETRSegLargeConfig(**kwargs)

    def get_train_config(self, **kwargs):
        return SegmentationTrainConfig(**kwargs)

Attributes

class_names property

Retrieve the class names supported by the loaded model.

Returns:

Name Type Description
dict

A dictionary mapping class IDs to class names. The keys are integers starting from

Functions

deploy_to_roboflow(workspace, project_id, version, api_key=None, size=None)

Deploy the trained RF-DETR model to Roboflow.

Deploying with Roboflow will create a Serverless API to which you can make requests.

You can also download weights into a Roboflow Inference deployment for use in Roboflow Workflows and on-device deployment.

Parameters:

Name Type Description Default

workspace

str

The name of the Roboflow workspace to deploy to.

required

project_id

str

The project ID to which the model will be deployed.

required

version

str

The project version to which the model will be deployed.

required

api_key

Optional[str]

Your Roboflow API key. If not provided, it will be read from the environment variable ROBOFLOW_API_KEY.

None

size

Optional[str]

The size of the model to deploy. If not provided, it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).

None

Raises:

Type Description
ValueError

If the api_key is not provided and not found in the environment variable ROBOFLOW_API_KEY, or if the size is not set for custom architectures.

Source code in src/rfdetr/detr.py
def deploy_to_roboflow(
    self,
    workspace: str,
    project_id: str,
    version: str,
    api_key: Optional[str] = None,
    size: Optional[str] = None,
) -> None:
    """
    Deploy the trained RF-DETR model to Roboflow.

    Deploying with Roboflow will create a Serverless API to which you can make requests.

    You can also download weights into a Roboflow Inference deployment for use in
    Roboflow Workflows and on-device deployment.

    Args:
        workspace: The name of the Roboflow workspace to deploy to.
        project_id: The project ID to which the model will be deployed.
        version: The project version to which the model will be deployed.
        api_key: Your Roboflow API key. If not provided,
            it will be read from the environment variable `ROBOFLOW_API_KEY`.
        size: The size of the model to deploy. If not provided,
            it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).

    Raises:
        ValueError: If the `api_key` is not provided and not found in the
            environment variable `ROBOFLOW_API_KEY`, or if the `size` is
            not set for custom architectures.
    """
    import shutil

    from roboflow import Roboflow

    if api_key is None:
        api_key = os.getenv("ROBOFLOW_API_KEY")
        if api_key is None:
            raise ValueError("Set api_key=<KEY> in deploy_to_roboflow or export ROBOFLOW_API_KEY=<KEY>")

    rf = Roboflow(api_key=api_key)
    workspace = rf.workspace(workspace)

    if self.size is None and size is None:
        raise ValueError("Must set size for custom architectures")

    size = self.size or size
    tmp_out_dir = ".roboflow_temp_upload"
    os.makedirs(tmp_out_dir, exist_ok=True)
    outpath = os.path.join(tmp_out_dir, "weights.pt")
    torch.save({"model": self.model.model.state_dict(), "args": self.model.args}, outpath)
    project = workspace.project(project_id)
    version = project.version(version)
    version.deploy(model_type=size, model_path=tmp_out_dir, filename="weights.pt")
    shutil.rmtree(tmp_out_dir)

export(output_dir='output', infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs)

Export the trained model to ONNX format.

See the ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>_ for more information.

Parameters:

Name Type Description Default

output_dir

str

Directory to write the ONNX file to.

'output'

infer_dir

str

Optional directory of sample images for dynamic-axes inference.

None

simplify

bool

Whether to run onnx-simplifier on the exported graph.

False

backbone_only

bool

Export only the backbone (feature extractor).

False

opset_version

int

ONNX opset version to target.

17

verbose

bool

Print export progress information.

True

force

bool

Force re-export even if output already exists.

False

shape

tuple

(height, width) tuple; defaults to square at model resolution.

None

batch_size

int

Static batch size to bake into the ONNX graph.

1

**kwargs

Additional keyword arguments forwarded to export_onnx.

{}
Source code in src/rfdetr/detr.py
def export(
    self,
    output_dir: str = "output",
    infer_dir: str = None,
    simplify: bool = False,
    backbone_only: bool = False,
    opset_version: int = 17,
    verbose: bool = True,
    force: bool = False,
    shape: tuple = None,
    batch_size: int = 1,
    **kwargs,
) -> None:
    """Export the trained model to ONNX format.

    See the `ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>`_
    for more information.

    Args:
        output_dir: Directory to write the ONNX file to.
        infer_dir: Optional directory of sample images for dynamic-axes inference.
        simplify: Whether to run onnx-simplifier on the exported graph.
        backbone_only: Export only the backbone (feature extractor).
        opset_version: ONNX opset version to target.
        verbose: Print export progress information.
        force: Force re-export even if output already exists.
        shape: ``(height, width)`` tuple; defaults to square at model resolution.
        batch_size: Static batch size to bake into the ONNX graph.
        **kwargs: Additional keyword arguments forwarded to export_onnx.
    """
    logger.info("Exporting model to ONNX format")
    try:
        from rfdetr.export.main import export_onnx, make_infer_image, onnx_simplify
    except ImportError:
        logger.error(
            "It seems some dependencies for ONNX export are missing."
            " Please run `pip install rfdetr[onnx]` and try again."
        )
        raise

    device = self.model.device
    model = deepcopy(self.model.model.to("cpu"))
    model.to(device)

    os.makedirs(output_dir, exist_ok=True)
    output_dir_path = Path(output_dir)
    if shape is None:
        shape = (self.model.resolution, self.model.resolution)
    else:
        if shape[0] % 14 != 0 or shape[1] % 14 != 0:
            raise ValueError("Shape must be divisible by 14")

    input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
    input_names = ["input"]
    if backbone_only:
        output_names = ["features"]
    elif self.model_config.segmentation_head:
        output_names = ["dets", "labels", "masks"]
    else:
        output_names = ["dets", "labels"]

    dynamic_axes = None
    model.eval()
    with torch.no_grad():
        if backbone_only:
            features = model(input_tensors)
            logger.debug(f"PyTorch inference output shape: {features.shape}")
        elif self.model_config.segmentation_head:
            outputs = model(input_tensors)
            dets = outputs["pred_boxes"]
            labels = outputs["pred_logits"]
            masks = outputs["pred_masks"]
            if isinstance(masks, torch.Tensor):
                logger.debug(
                    f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}, "
                    f"Masks: {masks.shape}"
                )
            else:
                logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
        else:
            outputs = model(input_tensors)
            dets = outputs["pred_boxes"]
            labels = outputs["pred_logits"]
            logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")

    model.cpu()
    input_tensors = input_tensors.cpu()

    output_file = export_onnx(
        output_dir=str(output_dir_path),
        model=model,
        input_names=input_names,
        input_tensors=input_tensors,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        backbone_only=backbone_only,
        verbose=verbose,
        opset_version=opset_version,
    )

    logger.info(f"Successfully exported ONNX model to: {output_file}")

    if simplify:
        sim_output_file = onnx_simplify(
            onnx_dir=output_file, input_names=input_names, input_tensors=input_tensors, force=force
        )
        logger.info(f"Successfully simplified ONNX model to: {sim_output_file}")

    logger.info("ONNX export completed successfully")
    self.model.model = self.model.model.to(device)

get_model(config)

Retrieve a model context from the provided architecture configuration.

Parameters:

Name Type Description Default

config

ModelConfig

Architecture configuration.

required

Returns:

Type Description
'_ModelContext'

_ModelContext with model, postprocess, device, resolution, args,

'_ModelContext'

and class_names attributes.

Source code in src/rfdetr/detr.py
def get_model(self, config: ModelConfig) -> "_ModelContext":
    """Retrieve a model context from the provided architecture configuration.

    Args:
        config: Architecture configuration.

    Returns:
        _ModelContext with model, postprocess, device, resolution, args,
        and class_names attributes.
    """
    return _build_model_context(config)

maybe_download_pretrain_weights()

Download pre-trained weights if they are not already downloaded.

Source code in src/rfdetr/detr.py
def maybe_download_pretrain_weights(self):
    """
    Download pre-trained weights if they are not already downloaded.
    """
    pretrain_weights = self.model_config.pretrain_weights
    if pretrain_weights is None:
        return
    download_pretrain_weights(pretrain_weights)

predict(images, threshold=0.5, **kwargs)

Performs object detection on the input images and returns bounding box predictions.

This method accepts a single image or a list of images in various formats (file path, image url, PIL Image, NumPy array, or torch.Tensor). The images should be in RGB channel order. If a torch.Tensor is provided, it must already be normalized to values in the [0, 1] range and have the shape (C, H, W).

Parameters:

Name Type Description Default

images

Union[str, Image, ndarray, Tensor, List[Union[str, ndarray, Image, Tensor]]]

A single image or a list of images to process. Images can be provided as file paths, PIL Images, NumPy arrays, or torch.Tensors.

required

threshold

float

The minimum confidence score needed to consider a detected bounding box valid.

0.5

**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[Detections, List[Detections]]

A single or multiple Detections objects, each containing bounding box

Union[Detections, List[Detections]]

coordinates, confidence scores, and class IDs.

Source code in src/rfdetr/detr.py
def predict(
    self,
    images: Union[
        str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]
    ],
    threshold: float = 0.5,
    **kwargs,
) -> Union[sv.Detections, List[sv.Detections]]:
    """Performs object detection on the input images and returns bounding box
    predictions.

    This method accepts a single image or a list of images in various formats
    (file path, image url, PIL Image, NumPy array, or torch.Tensor). The images should be in
    RGB channel order. If a torch.Tensor is provided, it must already be normalized
    to values in the [0, 1] range and have the shape (C, H, W).

    Args:
        images:
            A single image or a list of images to process. Images can be provided
            as file paths, PIL Images, NumPy arrays, or torch.Tensors.
        threshold:
            The minimum confidence score needed to consider a detected bounding box valid.
        **kwargs:
            Additional keyword arguments.

    Returns:
        A single or multiple Detections objects, each containing bounding box
        coordinates, confidence scores, and class IDs.
    """
    import supervision as sv

    if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
        logger.warning(
            "Model is not optimized for inference. Latency may be higher than expected."
            " You can optimize the model for inference by calling model.optimize_for_inference()."
        )
        self._has_warned_about_not_being_optimized_for_inference = True

        self.model.model.eval()

    if not isinstance(images, list):
        images = [images]

    orig_sizes = []
    processed_images = []

    for img in images:
        if isinstance(img, str):
            if img.startswith("http"):
                img = requests.get(img, stream=True).raw
            img = Image.open(img)

        if not isinstance(img, torch.Tensor):
            img = F.to_tensor(img)

        if (img > 1).any():
            raise ValueError(
                "Image has pixel values above 1. Please ensure the image is normalized (scaled to [0, 1])."
            )
        if img.shape[0] != 3:
            raise ValueError(f"Invalid image shape. Expected 3 channels (RGB), but got {img.shape[0]} channels.")
        img_tensor = img

        h, w = img_tensor.shape[1:]
        orig_sizes.append((h, w))

        img_tensor = img_tensor.to(self.model.device)
        img_tensor = F.normalize(img_tensor, self.means, self.stds)
        img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))

        processed_images.append(img_tensor)

    batch_tensor = torch.stack(processed_images)

    if self._is_optimized_for_inference:
        if self._optimized_resolution != batch_tensor.shape[2]:
            # this could happen if someone manually changes self.model.resolution after optimizing the model
            raise ValueError(
                f"Resolution mismatch. "
                f"Model was optimized for resolution {self._optimized_resolution}, "
                f"but got {batch_tensor.shape[2]}."
                " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
            )
        if self._optimized_has_been_compiled:
            if self._optimized_batch_size != batch_tensor.shape[0]:
                raise ValueError(
                    f"Batch size mismatch. "
                    f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
                    f"but got {batch_tensor.shape[0]}."
                    " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
                    " Alternatively, you can recompile the optimized model for a different batch size"
                    " by calling model.optimize_for_inference(batch_size=<new_batch_size>)."
                )

    with torch.no_grad():
        if self._is_optimized_for_inference:
            predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
        else:
            predictions = self.model.model(batch_tensor)
        if isinstance(predictions, tuple):
            return_predictions = {
                "pred_logits": predictions[1],
                "pred_boxes": predictions[0],
            }
            if len(predictions) == 3:
                return_predictions["pred_masks"] = predictions[2]
            predictions = return_predictions
        target_sizes = torch.tensor(orig_sizes, device=self.model.device)
        results = self.model.postprocess(predictions, target_sizes=target_sizes)

    detections_list = []
    for result in results:
        scores = result["scores"]
        labels = result["labels"]
        boxes = result["boxes"]

        keep = scores > threshold
        scores = scores[keep]
        labels = labels[keep]
        boxes = boxes[keep]

        if "masks" in result:
            masks = result["masks"]
            masks = masks[keep]

            detections = sv.Detections(
                xyxy=boxes.float().cpu().numpy(),
                confidence=scores.float().cpu().numpy(),
                class_id=labels.cpu().numpy(),
                mask=masks.squeeze(1).cpu().numpy(),
            )
        else:
            detections = sv.Detections(
                xyxy=boxes.float().cpu().numpy(),
                confidence=scores.float().cpu().numpy(),
                class_id=labels.cpu().numpy(),
            )

        detections_list.append(detections)

    return detections_list if len(detections_list) > 1 else detections_list[0]

train(**kwargs)

Train an RF-DETR model via the PyTorch Lightning stack.

All keyword arguments are forwarded to :meth:get_train_config to build a :class:~rfdetr.config.TrainConfig. Several legacy kwargs are absorbed so existing call-sites do not break:

  • device — mapped to TrainConfig.accelerator; "cpu" becomes accelerator="cpu", all others default to "auto".
  • callbacks — if the dict contains any non-empty lists a :class:DeprecationWarning is emitted; the dict is then discarded. Use PTL :class:~pytorch_lightning.Callback objects passed via :func:~rfdetr.training.build_trainer instead.
  • start_epoch — emits :class:DeprecationWarning and is dropped.
  • do_benchmark — emits :class:DeprecationWarning and is dropped.

After training completes the underlying nn.Module is synced back onto self.model.model so that :meth:predict and :meth:export continue to work without reloading the checkpoint.

Source code in src/rfdetr/detr.py
def train(self, **kwargs):
    """Train an RF-DETR model via the PyTorch Lightning stack.

    All keyword arguments are forwarded to :meth:`get_train_config` to build
    a :class:`~rfdetr.config.TrainConfig`.  Several legacy kwargs are absorbed
    so existing call-sites do not break:

    * ``device`` — mapped to ``TrainConfig.accelerator``; ``"cpu"`` becomes
      ``accelerator="cpu"``, all others default to ``"auto"``.
    * ``callbacks`` — if the dict contains any non-empty lists a
      :class:`DeprecationWarning` is emitted; the dict is then discarded.
      Use PTL :class:`~pytorch_lightning.Callback` objects passed via
      :func:`~rfdetr.training.build_trainer` instead.
    * ``start_epoch`` — emits :class:`DeprecationWarning` and is dropped.
    * ``do_benchmark`` — emits :class:`DeprecationWarning` and is dropped.

    After training completes the underlying ``nn.Module`` is synced back
    onto ``self.model.model`` so that :meth:`predict` and :meth:`export`
    continue to work without reloading the checkpoint.
    """
    from rfdetr.training import RFDETRDataModule, RFDETRModule, build_trainer

    # Absorb legacy `callbacks` dict — warn if non-empty, then discard.
    callbacks_dict = kwargs.pop("callbacks", None)
    if callbacks_dict and any(callbacks_dict.values()):
        warnings.warn(
            "Custom callbacks dict is not forwarded to PTL. Use PTL Callback objects instead.",
            DeprecationWarning,
            stacklevel=2,
        )

    # Absorb legacy `device` kwarg.  When the caller explicitly requests CPU
    # (e.g. in tests or CPU-only environments), honour it by forwarding it as
    # the PTL accelerator.  All other device strings (cuda, mps) are ignored
    # so PTL can auto-select the best available device.
    _device = kwargs.pop("device", None)
    _accelerator = "cpu" if _device == "cpu" else None

    # Absorb legacy `start_epoch` — PTL resumes automatically via ckpt_path.
    if "start_epoch" in kwargs:
        warnings.warn(
            "`start_epoch` is deprecated and ignored; PTL resumes automatically via `resume`.",
            DeprecationWarning,
            stacklevel=2,
        )
        kwargs.pop("start_epoch")

    # Pop `do_benchmark`; benchmarking via `.train()` is deprecated.
    run_benchmark = bool(kwargs.pop("do_benchmark", False))
    if run_benchmark:
        warnings.warn(
            "`do_benchmark` in `.train()` is deprecated; use `rfdetr benchmark`.",
            DeprecationWarning,
            stacklevel=2,
        )

    config = self.get_train_config(**kwargs)
    module = RFDETRModule(self.model_config, config)
    datamodule = RFDETRDataModule(self.model_config, config)
    trainer = build_trainer(config, self.model_config, accelerator=_accelerator)
    trainer.fit(module, datamodule, ckpt_path=config.resume or None)

    # Sync the trained weights back so predict() / export() see the updated model.
    self.model.model = module.model