Skip to content

RF-DETR

The base RF-DETR class implements the core methods for training RF-DETR models, running inference on the models, optimising models, and uploading trained models for deployment.

Source code in src/rfdetr/detr.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
class RFDETR:
    """
    The base RF-DETR class implements the core methods for training RF-DETR models,
    running inference on the models, optimising models, and uploading trained
    models for deployment.
    """

    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    size = None

    def __init__(self, **kwargs):
        self.model_config = self.get_model_config(**kwargs)
        self.maybe_download_pretrain_weights()
        self.model = self.get_model(self.model_config)
        self.callbacks = defaultdict(list)

        self.model.inference_model = None
        self._is_optimized_for_inference = False
        self._has_warned_about_not_being_optimized_for_inference = False
        self._optimized_has_been_compiled = False
        self._optimized_batch_size = None
        self._optimized_resolution = None
        self._optimized_dtype = None

    def maybe_download_pretrain_weights(self):
        """
        Download pre-trained weights if they are not already downloaded.
        """
        pretrain_weights = self.model_config.pretrain_weights
        if pretrain_weights is None:
            return
        download_pretrain_weights(pretrain_weights)

    def get_model_config(self, **kwargs):
        """
        Retrieve the configuration parameters used by the model.
        """
        return ModelConfig(**kwargs)

    def train(self, **kwargs):
        """Train an RF-DETR model via the PyTorch Lightning stack.

        All keyword arguments are forwarded to :meth:`get_train_config` to build
        a :class:`~rfdetr.config.TrainConfig`.  Several legacy kwargs are absorbed
        so existing call-sites do not break:

        * ``device`` — mapped to ``TrainConfig.accelerator``; ``"cpu"`` becomes
          ``accelerator="cpu"``, all others default to ``"auto"``.
        * ``callbacks`` — if the dict contains any non-empty lists a
          :class:`DeprecationWarning` is emitted; the dict is then discarded.
          Use PTL :class:`~pytorch_lightning.Callback` objects passed via
          :func:`~rfdetr.training.build_trainer` instead.
        * ``start_epoch`` — emits :class:`DeprecationWarning` and is dropped.
        * ``do_benchmark`` — emits :class:`DeprecationWarning` and is dropped.

        After training completes the underlying ``nn.Module`` is synced back
        onto ``self.model.model`` so that :meth:`predict` and :meth:`export`
        continue to work without reloading the checkpoint.
        """
        from rfdetr.training import RFDETRDataModule, RFDETRModule, build_trainer

        # Absorb legacy `callbacks` dict — warn if non-empty, then discard.
        callbacks_dict = kwargs.pop("callbacks", None)
        if callbacks_dict and any(callbacks_dict.values()):
            warnings.warn(
                "Custom callbacks dict is not forwarded to PTL. Use PTL Callback objects instead.",
                DeprecationWarning,
                stacklevel=2,
            )

        # Absorb legacy `device` kwarg.  When the caller explicitly requests CPU
        # (e.g. in tests or CPU-only environments), honour it by forwarding it as
        # the PTL accelerator.  All other device strings (cuda, mps) are ignored
        # so PTL can auto-select the best available device.
        _device = kwargs.pop("device", None)
        _accelerator = "cpu" if _device == "cpu" else None

        # Absorb legacy `start_epoch` — PTL resumes automatically via ckpt_path.
        if "start_epoch" in kwargs:
            warnings.warn(
                "`start_epoch` is deprecated and ignored; PTL resumes automatically via `resume`.",
                DeprecationWarning,
                stacklevel=2,
            )
            kwargs.pop("start_epoch")

        # Pop `do_benchmark`; benchmarking via `.train()` is deprecated.
        run_benchmark = bool(kwargs.pop("do_benchmark", False))
        if run_benchmark:
            warnings.warn(
                "`do_benchmark` in `.train()` is deprecated; use `rfdetr benchmark`.",
                DeprecationWarning,
                stacklevel=2,
            )

        config = self.get_train_config(**kwargs)
        module = RFDETRModule(self.model_config, config)
        datamodule = RFDETRDataModule(self.model_config, config)
        trainer = build_trainer(config, self.model_config, accelerator=_accelerator)
        trainer.fit(module, datamodule, ckpt_path=config.resume or None)

        # Sync the trained weights back so predict() / export() see the updated model.
        self.model.model = module.model

    def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
        self.remove_optimized_model()

        self.model.inference_model = deepcopy(self.model.model)
        self.model.inference_model.eval()
        self.model.inference_model.export()

        self._optimized_resolution = self.model.resolution
        self._is_optimized_for_inference = True

        self.model.inference_model = self.model.inference_model.to(dtype=dtype)
        self._optimized_dtype = dtype

        if compile:
            self.model.inference_model = torch.jit.trace(
                self.model.inference_model,
                torch.randn(
                    batch_size, 3, self.model.resolution, self.model.resolution, device=self.model.device, dtype=dtype
                ),
            )
            self._optimized_has_been_compiled = True
            self._optimized_batch_size = batch_size

    def remove_optimized_model(self):
        self.model.inference_model = None
        self._is_optimized_for_inference = False
        self._optimized_has_been_compiled = False
        self._optimized_batch_size = None
        self._optimized_resolution = None
        self._optimized_half = False

    def export(
        self,
        output_dir: str = "output",
        infer_dir: str = None,
        simplify: bool = False,
        backbone_only: bool = False,
        opset_version: int = 17,
        verbose: bool = True,
        force: bool = False,
        shape: tuple = None,
        batch_size: int = 1,
        **kwargs,
    ) -> None:
        """Export the trained model to ONNX format.

        See the `ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>`_
        for more information.

        Args:
            output_dir: Directory to write the ONNX file to.
            infer_dir: Optional directory of sample images for dynamic-axes inference.
            simplify: Whether to run onnx-simplifier on the exported graph.
            backbone_only: Export only the backbone (feature extractor).
            opset_version: ONNX opset version to target.
            verbose: Print export progress information.
            force: Force re-export even if output already exists.
            shape: ``(height, width)`` tuple; defaults to square at model resolution.
            batch_size: Static batch size to bake into the ONNX graph.
            **kwargs: Additional keyword arguments forwarded to export_onnx.
        """
        logger.info("Exporting model to ONNX format")
        try:
            from rfdetr.export.main import export_onnx, make_infer_image, onnx_simplify
        except ImportError:
            logger.error(
                "It seems some dependencies for ONNX export are missing."
                " Please run `pip install rfdetr[onnx]` and try again."
            )
            raise

        device = self.model.device
        model = deepcopy(self.model.model.to("cpu"))
        model.to(device)

        os.makedirs(output_dir, exist_ok=True)
        output_dir_path = Path(output_dir)
        if shape is None:
            shape = (self.model.resolution, self.model.resolution)
        else:
            if shape[0] % 14 != 0 or shape[1] % 14 != 0:
                raise ValueError("Shape must be divisible by 14")

        input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
        input_names = ["input"]
        if backbone_only:
            output_names = ["features"]
        elif self.model_config.segmentation_head:
            output_names = ["dets", "labels", "masks"]
        else:
            output_names = ["dets", "labels"]

        dynamic_axes = None
        model.eval()
        with torch.no_grad():
            if backbone_only:
                features = model(input_tensors)
                logger.debug(f"PyTorch inference output shape: {features.shape}")
            elif self.model_config.segmentation_head:
                outputs = model(input_tensors)
                dets = outputs["pred_boxes"]
                labels = outputs["pred_logits"]
                masks = outputs["pred_masks"]
                if isinstance(masks, torch.Tensor):
                    logger.debug(
                        f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}, "
                        f"Masks: {masks.shape}"
                    )
                else:
                    logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
            else:
                outputs = model(input_tensors)
                dets = outputs["pred_boxes"]
                labels = outputs["pred_logits"]
                logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")

        model.cpu()
        input_tensors = input_tensors.cpu()

        output_file = export_onnx(
            output_dir=str(output_dir_path),
            model=model,
            input_names=input_names,
            input_tensors=input_tensors,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            backbone_only=backbone_only,
            verbose=verbose,
            opset_version=opset_version,
        )

        logger.info(f"Successfully exported ONNX model to: {output_file}")

        if simplify:
            sim_output_file = onnx_simplify(
                onnx_dir=output_file, input_names=input_names, input_tensors=input_tensors, force=force
            )
            logger.info(f"Successfully simplified ONNX model to: {sim_output_file}")

        logger.info("ONNX export completed successfully")
        self.model.model = self.model.model.to(device)

    @staticmethod
    def _load_classes(dataset_dir: str) -> List[str]:
        """Load class names from a COCO or YOLO dataset directory."""
        if is_valid_coco_dataset(dataset_dir):
            coco_path = os.path.join(dataset_dir, "train", "_annotations.coco.json")
            with open(coco_path, "r") as f:
                anns = json.load(f)
            categories = sorted(anns["categories"], key=lambda category: category.get("id", float("inf")))

            # Catch possible placeholders for no supercategory
            placeholders = {"", "none", "null", None}

            # If no meaningful supercategory exists anywhere, treat as flat dataset
            has_any_sc = any(c.get("supercategory", "none") not in placeholders for c in categories)
            if not has_any_sc:
                return [c["name"] for c in categories]

            # Mixed/Hierarchical: keep only categories that are not parents of other categories.
            # Both leaves (with a real supercategory) and standalone top-level nodes (supercategory is a
            # placeholder) satisfy this condition — neither appears as another category's supercategory.
            parents = {c.get("supercategory") for c in categories if c.get("supercategory", "none") not in placeholders}
            has_children = {c["name"] for c in categories if c["name"] in parents}

            class_names = [c["name"] for c in categories if c["name"] not in has_children]
            # Safety fallback for pathological inputs
            return class_names or [c["name"] for c in categories]

        # list all YAML files in the folder
        if is_valid_yolo_dataset(dataset_dir):
            yaml_paths = glob.glob(os.path.join(dataset_dir, "*.yaml")) + glob.glob(os.path.join(dataset_dir, "*.yml"))
            # any YAML file starting with data e.g. data.yaml, dataset.yaml
            yaml_data_files = [yp for yp in yaml_paths if os.path.basename(yp).startswith("data")]
            yaml_path = yaml_data_files[0]
            with open(yaml_path, "r") as f:
                data = yaml.safe_load(f)
            if "names" in data:
                if isinstance(data["names"], dict):
                    return [data["names"][i] for i in sorted(data["names"].keys())]
                return data["names"]
            else:
                raise ValueError(f"Found {yaml_path} but it does not contain 'names' field.")
        raise FileNotFoundError(
            f"Could not find class names in {dataset_dir}."
            " Checked for COCO (train/_annotations.coco.json) and YOLO (data.yaml, data.yml) styles."
        )

    def get_train_config(self, **kwargs):
        """
        Retrieve the configuration parameters that will be used for training.
        """
        return TrainConfig(**kwargs)

    def get_model(self, config: ModelConfig) -> "_ModelContext":
        """Retrieve a model context from the provided architecture configuration.

        Args:
            config: Architecture configuration.

        Returns:
            _ModelContext with model, postprocess, device, resolution, args,
            and class_names attributes.
        """
        return _build_model_context(config)

    # Get class_names from the model
    @property
    def class_names(self):
        """
        Retrieve the class names supported by the loaded model.

        Returns:
            dict: A dictionary mapping class IDs to class names. The keys are integers starting from
        """
        if hasattr(self.model, "class_names") and self.model.class_names:
            return {i + 1: name for i, name in enumerate(self.model.class_names)}

        return COCO_CLASSES

    def predict(
        self,
        images: Union[
            str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]
        ],
        threshold: float = 0.5,
        **kwargs,
    ) -> Union[sv.Detections, List[sv.Detections]]:
        """Performs object detection on the input images and returns bounding box
        predictions.

        This method accepts a single image or a list of images in various formats
        (file path, image url, PIL Image, NumPy array, or torch.Tensor). The images should be in
        RGB channel order. If a torch.Tensor is provided, it must already be normalized
        to values in the [0, 1] range and have the shape (C, H, W).

        Args:
            images:
                A single image or a list of images to process. Images can be provided
                as file paths, PIL Images, NumPy arrays, or torch.Tensors.
            threshold:
                The minimum confidence score needed to consider a detected bounding box valid.
            **kwargs:
                Additional keyword arguments.

        Returns:
            A single or multiple Detections objects, each containing bounding box
            coordinates, confidence scores, and class IDs.
        """
        import supervision as sv

        if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
            logger.warning(
                "Model is not optimized for inference. Latency may be higher than expected."
                " You can optimize the model for inference by calling model.optimize_for_inference()."
            )
            self._has_warned_about_not_being_optimized_for_inference = True

            self.model.model.eval()

        if not isinstance(images, list):
            images = [images]

        orig_sizes = []
        processed_images = []

        for img in images:
            if isinstance(img, str):
                if img.startswith("http"):
                    img = requests.get(img, stream=True).raw
                img = Image.open(img)

            if not isinstance(img, torch.Tensor):
                img = F.to_tensor(img)

            if (img > 1).any():
                raise ValueError(
                    "Image has pixel values above 1. Please ensure the image is normalized (scaled to [0, 1])."
                )
            if img.shape[0] != 3:
                raise ValueError(f"Invalid image shape. Expected 3 channels (RGB), but got {img.shape[0]} channels.")
            img_tensor = img

            h, w = img_tensor.shape[1:]
            orig_sizes.append((h, w))

            img_tensor = img_tensor.to(self.model.device)
            img_tensor = F.normalize(img_tensor, self.means, self.stds)
            img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))

            processed_images.append(img_tensor)

        batch_tensor = torch.stack(processed_images)

        if self._is_optimized_for_inference:
            if self._optimized_resolution != batch_tensor.shape[2]:
                # this could happen if someone manually changes self.model.resolution after optimizing the model
                raise ValueError(
                    f"Resolution mismatch. "
                    f"Model was optimized for resolution {self._optimized_resolution}, "
                    f"but got {batch_tensor.shape[2]}."
                    " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
                )
            if self._optimized_has_been_compiled:
                if self._optimized_batch_size != batch_tensor.shape[0]:
                    raise ValueError(
                        f"Batch size mismatch. "
                        f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
                        f"but got {batch_tensor.shape[0]}."
                        " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
                        " Alternatively, you can recompile the optimized model for a different batch size"
                        " by calling model.optimize_for_inference(batch_size=<new_batch_size>)."
                    )

        with torch.no_grad():
            if self._is_optimized_for_inference:
                predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
            else:
                predictions = self.model.model(batch_tensor)
            if isinstance(predictions, tuple):
                return_predictions = {
                    "pred_logits": predictions[1],
                    "pred_boxes": predictions[0],
                }
                if len(predictions) == 3:
                    return_predictions["pred_masks"] = predictions[2]
                predictions = return_predictions
            target_sizes = torch.tensor(orig_sizes, device=self.model.device)
            results = self.model.postprocess(predictions, target_sizes=target_sizes)

        detections_list = []
        for result in results:
            scores = result["scores"]
            labels = result["labels"]
            boxes = result["boxes"]

            keep = scores > threshold
            scores = scores[keep]
            labels = labels[keep]
            boxes = boxes[keep]

            if "masks" in result:
                masks = result["masks"]
                masks = masks[keep]

                detections = sv.Detections(
                    xyxy=boxes.float().cpu().numpy(),
                    confidence=scores.float().cpu().numpy(),
                    class_id=labels.cpu().numpy(),
                    mask=masks.squeeze(1).cpu().numpy(),
                )
            else:
                detections = sv.Detections(
                    xyxy=boxes.float().cpu().numpy(),
                    confidence=scores.float().cpu().numpy(),
                    class_id=labels.cpu().numpy(),
                )

            detections_list.append(detections)

        return detections_list if len(detections_list) > 1 else detections_list[0]

    def deploy_to_roboflow(
        self,
        workspace: str,
        project_id: str,
        version: str,
        api_key: Optional[str] = None,
        size: Optional[str] = None,
    ) -> None:
        """
        Deploy the trained RF-DETR model to Roboflow.

        Deploying with Roboflow will create a Serverless API to which you can make requests.

        You can also download weights into a Roboflow Inference deployment for use in
        Roboflow Workflows and on-device deployment.

        Args:
            workspace: The name of the Roboflow workspace to deploy to.
            project_id: The project ID to which the model will be deployed.
            version: The project version to which the model will be deployed.
            api_key: Your Roboflow API key. If not provided,
                it will be read from the environment variable `ROBOFLOW_API_KEY`.
            size: The size of the model to deploy. If not provided,
                it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).

        Raises:
            ValueError: If the `api_key` is not provided and not found in the
                environment variable `ROBOFLOW_API_KEY`, or if the `size` is
                not set for custom architectures.
        """
        import shutil

        from roboflow import Roboflow

        if api_key is None:
            api_key = os.getenv("ROBOFLOW_API_KEY")
            if api_key is None:
                raise ValueError("Set api_key=<KEY> in deploy_to_roboflow or export ROBOFLOW_API_KEY=<KEY>")

        rf = Roboflow(api_key=api_key)
        workspace = rf.workspace(workspace)

        if self.size is None and size is None:
            raise ValueError("Must set size for custom architectures")

        size = self.size or size
        tmp_out_dir = ".roboflow_temp_upload"
        os.makedirs(tmp_out_dir, exist_ok=True)
        outpath = os.path.join(tmp_out_dir, "weights.pt")
        torch.save({"model": self.model.model.state_dict(), "args": self.model.args}, outpath)
        project = workspace.project(project_id)
        version = project.version(version)
        version.deploy(model_type=size, model_path=tmp_out_dir, filename="weights.pt")
        shutil.rmtree(tmp_out_dir)

Attributes

class_names property

Retrieve the class names supported by the loaded model.

Returns:

Name Type Description
dict

A dictionary mapping class IDs to class names. The keys are integers starting from

Functions

deploy_to_roboflow(workspace, project_id, version, api_key=None, size=None)

Deploy the trained RF-DETR model to Roboflow.

Deploying with Roboflow will create a Serverless API to which you can make requests.

You can also download weights into a Roboflow Inference deployment for use in Roboflow Workflows and on-device deployment.

Parameters:

Name Type Description Default

workspace

str

The name of the Roboflow workspace to deploy to.

required

project_id

str

The project ID to which the model will be deployed.

required

version

str

The project version to which the model will be deployed.

required

api_key

Optional[str]

Your Roboflow API key. If not provided, it will be read from the environment variable ROBOFLOW_API_KEY.

None

size

Optional[str]

The size of the model to deploy. If not provided, it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).

None

Raises:

Type Description
ValueError

If the api_key is not provided and not found in the environment variable ROBOFLOW_API_KEY, or if the size is not set for custom architectures.

Source code in src/rfdetr/detr.py
def deploy_to_roboflow(
    self,
    workspace: str,
    project_id: str,
    version: str,
    api_key: Optional[str] = None,
    size: Optional[str] = None,
) -> None:
    """
    Deploy the trained RF-DETR model to Roboflow.

    Deploying with Roboflow will create a Serverless API to which you can make requests.

    You can also download weights into a Roboflow Inference deployment for use in
    Roboflow Workflows and on-device deployment.

    Args:
        workspace: The name of the Roboflow workspace to deploy to.
        project_id: The project ID to which the model will be deployed.
        version: The project version to which the model will be deployed.
        api_key: Your Roboflow API key. If not provided,
            it will be read from the environment variable `ROBOFLOW_API_KEY`.
        size: The size of the model to deploy. If not provided,
            it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).

    Raises:
        ValueError: If the `api_key` is not provided and not found in the
            environment variable `ROBOFLOW_API_KEY`, or if the `size` is
            not set for custom architectures.
    """
    import shutil

    from roboflow import Roboflow

    if api_key is None:
        api_key = os.getenv("ROBOFLOW_API_KEY")
        if api_key is None:
            raise ValueError("Set api_key=<KEY> in deploy_to_roboflow or export ROBOFLOW_API_KEY=<KEY>")

    rf = Roboflow(api_key=api_key)
    workspace = rf.workspace(workspace)

    if self.size is None and size is None:
        raise ValueError("Must set size for custom architectures")

    size = self.size or size
    tmp_out_dir = ".roboflow_temp_upload"
    os.makedirs(tmp_out_dir, exist_ok=True)
    outpath = os.path.join(tmp_out_dir, "weights.pt")
    torch.save({"model": self.model.model.state_dict(), "args": self.model.args}, outpath)
    project = workspace.project(project_id)
    version = project.version(version)
    version.deploy(model_type=size, model_path=tmp_out_dir, filename="weights.pt")
    shutil.rmtree(tmp_out_dir)

export(output_dir='output', infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs)

Export the trained model to ONNX format.

See the ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>_ for more information.

Parameters:

Name Type Description Default

output_dir

str

Directory to write the ONNX file to.

'output'

infer_dir

str

Optional directory of sample images for dynamic-axes inference.

None

simplify

bool

Whether to run onnx-simplifier on the exported graph.

False

backbone_only

bool

Export only the backbone (feature extractor).

False

opset_version

int

ONNX opset version to target.

17

verbose

bool

Print export progress information.

True

force

bool

Force re-export even if output already exists.

False

shape

tuple

(height, width) tuple; defaults to square at model resolution.

None

batch_size

int

Static batch size to bake into the ONNX graph.

1

**kwargs

Additional keyword arguments forwarded to export_onnx.

{}
Source code in src/rfdetr/detr.py
def export(
    self,
    output_dir: str = "output",
    infer_dir: str = None,
    simplify: bool = False,
    backbone_only: bool = False,
    opset_version: int = 17,
    verbose: bool = True,
    force: bool = False,
    shape: tuple = None,
    batch_size: int = 1,
    **kwargs,
) -> None:
    """Export the trained model to ONNX format.

    See the `ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>`_
    for more information.

    Args:
        output_dir: Directory to write the ONNX file to.
        infer_dir: Optional directory of sample images for dynamic-axes inference.
        simplify: Whether to run onnx-simplifier on the exported graph.
        backbone_only: Export only the backbone (feature extractor).
        opset_version: ONNX opset version to target.
        verbose: Print export progress information.
        force: Force re-export even if output already exists.
        shape: ``(height, width)`` tuple; defaults to square at model resolution.
        batch_size: Static batch size to bake into the ONNX graph.
        **kwargs: Additional keyword arguments forwarded to export_onnx.
    """
    logger.info("Exporting model to ONNX format")
    try:
        from rfdetr.export.main import export_onnx, make_infer_image, onnx_simplify
    except ImportError:
        logger.error(
            "It seems some dependencies for ONNX export are missing."
            " Please run `pip install rfdetr[onnx]` and try again."
        )
        raise

    device = self.model.device
    model = deepcopy(self.model.model.to("cpu"))
    model.to(device)

    os.makedirs(output_dir, exist_ok=True)
    output_dir_path = Path(output_dir)
    if shape is None:
        shape = (self.model.resolution, self.model.resolution)
    else:
        if shape[0] % 14 != 0 or shape[1] % 14 != 0:
            raise ValueError("Shape must be divisible by 14")

    input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
    input_names = ["input"]
    if backbone_only:
        output_names = ["features"]
    elif self.model_config.segmentation_head:
        output_names = ["dets", "labels", "masks"]
    else:
        output_names = ["dets", "labels"]

    dynamic_axes = None
    model.eval()
    with torch.no_grad():
        if backbone_only:
            features = model(input_tensors)
            logger.debug(f"PyTorch inference output shape: {features.shape}")
        elif self.model_config.segmentation_head:
            outputs = model(input_tensors)
            dets = outputs["pred_boxes"]
            labels = outputs["pred_logits"]
            masks = outputs["pred_masks"]
            if isinstance(masks, torch.Tensor):
                logger.debug(
                    f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}, "
                    f"Masks: {masks.shape}"
                )
            else:
                logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
        else:
            outputs = model(input_tensors)
            dets = outputs["pred_boxes"]
            labels = outputs["pred_logits"]
            logger.debug(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")

    model.cpu()
    input_tensors = input_tensors.cpu()

    output_file = export_onnx(
        output_dir=str(output_dir_path),
        model=model,
        input_names=input_names,
        input_tensors=input_tensors,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        backbone_only=backbone_only,
        verbose=verbose,
        opset_version=opset_version,
    )

    logger.info(f"Successfully exported ONNX model to: {output_file}")

    if simplify:
        sim_output_file = onnx_simplify(
            onnx_dir=output_file, input_names=input_names, input_tensors=input_tensors, force=force
        )
        logger.info(f"Successfully simplified ONNX model to: {sim_output_file}")

    logger.info("ONNX export completed successfully")
    self.model.model = self.model.model.to(device)

get_model(config)

Retrieve a model context from the provided architecture configuration.

Parameters:

Name Type Description Default

config

ModelConfig

Architecture configuration.

required

Returns:

Type Description
'_ModelContext'

_ModelContext with model, postprocess, device, resolution, args,

'_ModelContext'

and class_names attributes.

Source code in src/rfdetr/detr.py
def get_model(self, config: ModelConfig) -> "_ModelContext":
    """Retrieve a model context from the provided architecture configuration.

    Args:
        config: Architecture configuration.

    Returns:
        _ModelContext with model, postprocess, device, resolution, args,
        and class_names attributes.
    """
    return _build_model_context(config)

get_model_config(**kwargs)

Retrieve the configuration parameters used by the model.

Source code in src/rfdetr/detr.py
def get_model_config(self, **kwargs):
    """
    Retrieve the configuration parameters used by the model.
    """
    return ModelConfig(**kwargs)

get_train_config(**kwargs)

Retrieve the configuration parameters that will be used for training.

Source code in src/rfdetr/detr.py
def get_train_config(self, **kwargs):
    """
    Retrieve the configuration parameters that will be used for training.
    """
    return TrainConfig(**kwargs)

maybe_download_pretrain_weights()

Download pre-trained weights if they are not already downloaded.

Source code in src/rfdetr/detr.py
def maybe_download_pretrain_weights(self):
    """
    Download pre-trained weights if they are not already downloaded.
    """
    pretrain_weights = self.model_config.pretrain_weights
    if pretrain_weights is None:
        return
    download_pretrain_weights(pretrain_weights)

predict(images, threshold=0.5, **kwargs)

Performs object detection on the input images and returns bounding box predictions.

This method accepts a single image or a list of images in various formats (file path, image url, PIL Image, NumPy array, or torch.Tensor). The images should be in RGB channel order. If a torch.Tensor is provided, it must already be normalized to values in the [0, 1] range and have the shape (C, H, W).

Parameters:

Name Type Description Default

images

Union[str, Image, ndarray, Tensor, List[Union[str, ndarray, Image, Tensor]]]

A single image or a list of images to process. Images can be provided as file paths, PIL Images, NumPy arrays, or torch.Tensors.

required

threshold

float

The minimum confidence score needed to consider a detected bounding box valid.

0.5

**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[Detections, List[Detections]]

A single or multiple Detections objects, each containing bounding box

Union[Detections, List[Detections]]

coordinates, confidence scores, and class IDs.

Source code in src/rfdetr/detr.py
def predict(
    self,
    images: Union[
        str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]
    ],
    threshold: float = 0.5,
    **kwargs,
) -> Union[sv.Detections, List[sv.Detections]]:
    """Performs object detection on the input images and returns bounding box
    predictions.

    This method accepts a single image or a list of images in various formats
    (file path, image url, PIL Image, NumPy array, or torch.Tensor). The images should be in
    RGB channel order. If a torch.Tensor is provided, it must already be normalized
    to values in the [0, 1] range and have the shape (C, H, W).

    Args:
        images:
            A single image or a list of images to process. Images can be provided
            as file paths, PIL Images, NumPy arrays, or torch.Tensors.
        threshold:
            The minimum confidence score needed to consider a detected bounding box valid.
        **kwargs:
            Additional keyword arguments.

    Returns:
        A single or multiple Detections objects, each containing bounding box
        coordinates, confidence scores, and class IDs.
    """
    import supervision as sv

    if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
        logger.warning(
            "Model is not optimized for inference. Latency may be higher than expected."
            " You can optimize the model for inference by calling model.optimize_for_inference()."
        )
        self._has_warned_about_not_being_optimized_for_inference = True

        self.model.model.eval()

    if not isinstance(images, list):
        images = [images]

    orig_sizes = []
    processed_images = []

    for img in images:
        if isinstance(img, str):
            if img.startswith("http"):
                img = requests.get(img, stream=True).raw
            img = Image.open(img)

        if not isinstance(img, torch.Tensor):
            img = F.to_tensor(img)

        if (img > 1).any():
            raise ValueError(
                "Image has pixel values above 1. Please ensure the image is normalized (scaled to [0, 1])."
            )
        if img.shape[0] != 3:
            raise ValueError(f"Invalid image shape. Expected 3 channels (RGB), but got {img.shape[0]} channels.")
        img_tensor = img

        h, w = img_tensor.shape[1:]
        orig_sizes.append((h, w))

        img_tensor = img_tensor.to(self.model.device)
        img_tensor = F.normalize(img_tensor, self.means, self.stds)
        img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))

        processed_images.append(img_tensor)

    batch_tensor = torch.stack(processed_images)

    if self._is_optimized_for_inference:
        if self._optimized_resolution != batch_tensor.shape[2]:
            # this could happen if someone manually changes self.model.resolution after optimizing the model
            raise ValueError(
                f"Resolution mismatch. "
                f"Model was optimized for resolution {self._optimized_resolution}, "
                f"but got {batch_tensor.shape[2]}."
                " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
            )
        if self._optimized_has_been_compiled:
            if self._optimized_batch_size != batch_tensor.shape[0]:
                raise ValueError(
                    f"Batch size mismatch. "
                    f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
                    f"but got {batch_tensor.shape[0]}."
                    " You can explicitly remove the optimized model by calling model.remove_optimized_model()."
                    " Alternatively, you can recompile the optimized model for a different batch size"
                    " by calling model.optimize_for_inference(batch_size=<new_batch_size>)."
                )

    with torch.no_grad():
        if self._is_optimized_for_inference:
            predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
        else:
            predictions = self.model.model(batch_tensor)
        if isinstance(predictions, tuple):
            return_predictions = {
                "pred_logits": predictions[1],
                "pred_boxes": predictions[0],
            }
            if len(predictions) == 3:
                return_predictions["pred_masks"] = predictions[2]
            predictions = return_predictions
        target_sizes = torch.tensor(orig_sizes, device=self.model.device)
        results = self.model.postprocess(predictions, target_sizes=target_sizes)

    detections_list = []
    for result in results:
        scores = result["scores"]
        labels = result["labels"]
        boxes = result["boxes"]

        keep = scores > threshold
        scores = scores[keep]
        labels = labels[keep]
        boxes = boxes[keep]

        if "masks" in result:
            masks = result["masks"]
            masks = masks[keep]

            detections = sv.Detections(
                xyxy=boxes.float().cpu().numpy(),
                confidence=scores.float().cpu().numpy(),
                class_id=labels.cpu().numpy(),
                mask=masks.squeeze(1).cpu().numpy(),
            )
        else:
            detections = sv.Detections(
                xyxy=boxes.float().cpu().numpy(),
                confidence=scores.float().cpu().numpy(),
                class_id=labels.cpu().numpy(),
            )

        detections_list.append(detections)

    return detections_list if len(detections_list) > 1 else detections_list[0]

train(**kwargs)

Train an RF-DETR model via the PyTorch Lightning stack.

All keyword arguments are forwarded to :meth:get_train_config to build a :class:~rfdetr.config.TrainConfig. Several legacy kwargs are absorbed so existing call-sites do not break:

  • device — mapped to TrainConfig.accelerator; "cpu" becomes accelerator="cpu", all others default to "auto".
  • callbacks — if the dict contains any non-empty lists a :class:DeprecationWarning is emitted; the dict is then discarded. Use PTL :class:~pytorch_lightning.Callback objects passed via :func:~rfdetr.training.build_trainer instead.
  • start_epoch — emits :class:DeprecationWarning and is dropped.
  • do_benchmark — emits :class:DeprecationWarning and is dropped.

After training completes the underlying nn.Module is synced back onto self.model.model so that :meth:predict and :meth:export continue to work without reloading the checkpoint.

Source code in src/rfdetr/detr.py
def train(self, **kwargs):
    """Train an RF-DETR model via the PyTorch Lightning stack.

    All keyword arguments are forwarded to :meth:`get_train_config` to build
    a :class:`~rfdetr.config.TrainConfig`.  Several legacy kwargs are absorbed
    so existing call-sites do not break:

    * ``device`` — mapped to ``TrainConfig.accelerator``; ``"cpu"`` becomes
      ``accelerator="cpu"``, all others default to ``"auto"``.
    * ``callbacks`` — if the dict contains any non-empty lists a
      :class:`DeprecationWarning` is emitted; the dict is then discarded.
      Use PTL :class:`~pytorch_lightning.Callback` objects passed via
      :func:`~rfdetr.training.build_trainer` instead.
    * ``start_epoch`` — emits :class:`DeprecationWarning` and is dropped.
    * ``do_benchmark`` — emits :class:`DeprecationWarning` and is dropped.

    After training completes the underlying ``nn.Module`` is synced back
    onto ``self.model.model`` so that :meth:`predict` and :meth:`export`
    continue to work without reloading the checkpoint.
    """
    from rfdetr.training import RFDETRDataModule, RFDETRModule, build_trainer

    # Absorb legacy `callbacks` dict — warn if non-empty, then discard.
    callbacks_dict = kwargs.pop("callbacks", None)
    if callbacks_dict and any(callbacks_dict.values()):
        warnings.warn(
            "Custom callbacks dict is not forwarded to PTL. Use PTL Callback objects instead.",
            DeprecationWarning,
            stacklevel=2,
        )

    # Absorb legacy `device` kwarg.  When the caller explicitly requests CPU
    # (e.g. in tests or CPU-only environments), honour it by forwarding it as
    # the PTL accelerator.  All other device strings (cuda, mps) are ignored
    # so PTL can auto-select the best available device.
    _device = kwargs.pop("device", None)
    _accelerator = "cpu" if _device == "cpu" else None

    # Absorb legacy `start_epoch` — PTL resumes automatically via ckpt_path.
    if "start_epoch" in kwargs:
        warnings.warn(
            "`start_epoch` is deprecated and ignored; PTL resumes automatically via `resume`.",
            DeprecationWarning,
            stacklevel=2,
        )
        kwargs.pop("start_epoch")

    # Pop `do_benchmark`; benchmarking via `.train()` is deprecated.
    run_benchmark = bool(kwargs.pop("do_benchmark", False))
    if run_benchmark:
        warnings.warn(
            "`do_benchmark` in `.train()` is deprecated; use `rfdetr benchmark`.",
            DeprecationWarning,
            stacklevel=2,
        )

    config = self.get_train_config(**kwargs)
    module = RFDETRModule(self.model_config, config)
    datamodule = RFDETRDataModule(self.model_config, config)
    trainer = build_trainer(config, self.model_config, accelerator=_accelerator)
    trainer.fit(module, datamodule, ckpt_path=config.resume or None)

    # Sync the trained weights back so predict() / export() see the updated model.
    self.model.model = module.model