Skip to content

Train an RF-DETR Model

Key Takeaways

  • Train detection, segmentation, or keypoint preview models with a single model.train(dataset_dir=...) call
  • Detection and segmentation support COCO JSON and YOLO dataset formats with automatic detection
  • Keypoint preview training supports COCO keypoint JSON; YOLO pose training is not supported yet
  • Fine-tune from COCO-pretrained checkpoints (Nano to 2XLarge) for fastest convergence
  • Built on PyTorch Lightning — use the high-level API or access PTL primitives directly for full control
  • EMA weights, early stopping, and best-model checkpointing are included by default

You can train RF-DETR object detection and segmentation models on a custom dataset using the rfdetr Python package, or in the cloud using Roboflow.

This guide describes how to train both an object detection and segmentation RF-DETR model.

Training paths

RF-DETR provides two training paths:

Path When to use
RFDETR.train() (this page) Quickstart, fine-tuning with standard options, Colab notebooks. One call sets up and runs everything.
Custom Training API Custom callbacks, alternative loggers, multi-GPU strategies, integration with external frameworks, or any other customisation of the training loop.

Both paths run the same underlying PyTorch Lightning stack. RFDETR.train() constructs RFDETRModelModule, RFDETRDataModule, and a Trainer internally; the Lightning API page shows how to do the same thing explicitly so you can modify each component.

Quick Start

RF-DETR supports training on datasets in both COCO and YOLO formats. The format is automatically detected based on the structure of your dataset directory.

from rfdetr import RFDETRMedium

model = RFDETRMedium()

model.train(
    dataset_dir="<DATASET_PATH>",
    epochs=100,
    batch_size=4,
    grad_accum_steps=4,
    lr=1e-4,
    output_dir="<OUTPUT_PATH>",
)
from rfdetr import RFDETRSegMedium

model = RFDETRSegMedium()

model.train(
    dataset_dir="<DATASET_PATH>",
    epochs=100,
    batch_size=4,
    grad_accum_steps=4,
    lr=1e-4,
    output_dir="<OUTPUT_PATH>",
)
from rfdetr import RFDETRKeypointPreview

model = RFDETRKeypointPreview()

model.train(
    dataset_dir="<COCO_KEYPOINT_DATASET_PATH>",
    epochs=50,
    batch_size=2,
    grad_accum_steps=8,
    lr=1e-5,
    output_dir="<OUTPUT_PATH>",
)

Different GPUs have different VRAM capacities, so adjust batch_size and grad_accum_steps to maintain a total batch size of 16. For example, on a powerful GPU like the A100, use batch_size=16 and grad_accum_steps=1; on smaller GPUs like the T4, use batch_size=4 and grad_accum_steps=4. This gradient accumulation strategy helps train effectively even with limited memory.

For object detection, the RF-DETR-B checkpoint is used by default. To get started quickly with training an object detection model, please refer to our fine-tuning Google Colab notebook.

Keypoint preview custom datasets

The pretrained keypoint preview checkpoint predicts 17 COCO person keypoints. Fine-tuned keypoint preview models can use the keypoint schema from your own COCO keypoint dataset, so the output keypoint count is not limited to 17.

Use COCO keypoint JSON for custom keypoint training. Roboflow COCO exports are supported when split annotations are named train/_annotations.coco.json, valid/_annotations.coco.json, and optionally test/_annotations.coco.json. YOLO pose/keypoint training is not supported yet.

The keypoint fine-tuning demo infers the class names and keypoint schema from the training annotation file, then passes those values into RFDETRKeypointPreview and model.train():

from pathlib import Path

from rfdetr import RFDETRKeypointPreview
from rfdetr.datasets._keypoint_schema import infer_coco_keypoint_schema

DATASET_DIR = Path("/path/to/coco-keypoint-dataset")
schema = infer_coco_keypoint_schema(DATASET_DIR / "train" / "_annotations.coco.json")

model = RFDETRKeypointPreview(
    num_classes=len(schema.class_names),
    num_keypoints_per_class=schema.num_keypoints_per_class,
)

model.train(
    dataset_file="roboflow",
    dataset_dir=str(DATASET_DIR),
    class_names=schema.class_names,
    keypoint_oks_sigmas=schema.keypoint_oks_sigmas,
    epochs=50,
    batch_size=8,
    grad_accum_steps=2,
    lr=2e-5,
    lr_encoder=2e-5,
    output_dir="output/keypoint_custom",
    use_ema=False,
    run_test=False,
)

Set keypoint_flip_pairs if horizontal flips should swap left/right keypoints for your schema.

Dataset Format

RF-DETR automatically detects whether your dataset is in COCO or YOLO format. Simply pass your dataset directory to the train() method and the appropriate data loader will be used.

Format Detection Method Learn More
COCO Looks for train/_annotations.coco.json COCO Format Guide
YOLO Looks for data.yaml + train/images/ YOLO Format Guide

For keypoint preview training, use COCO keypoint JSON. YOLO pose/keypoint datasets are rejected with a clear error instead of being treated as detection-only labels.

Roboflow allows you to create object detection datasets from scratch and export them in either COCO JSON or YOLO format for training. You can also explore Roboflow Universe to find pre-labeled datasets for a range of use cases.

Learn more about dataset formats

Training Configuration

RF-DETR provides many configuration options to customize your training run. See the complete reference for all available parameters.

View all training parameters

Advanced Topics

Learn more about advanced training

Custom Training API

RF-DETR's training stack is built on PyTorch Lightning. The RFDETR.train() call above constructs and runs PTL primitives internally. Use them directly when you need custom callbacks, non-default loggers, multi-GPU strategies, or full control over the training loop.

Custom Training API guide

Training Loggers

Track your experiments with popular logging platforms:

Learn more about training loggers

Result Checkpoints

During training, multiple model checkpoints are saved to the output directory:

  • checkpoint.pth – the most recent checkpoint, saved at the end of the latest epoch.

  • checkpoint_<number>.pth – periodic checkpoints saved every N epochs (default is every 10).

  • checkpoint_best_ema.pth – best checkpoint based on validation score, using the EMA (Exponential Moving Average) weights. EMA weights are a smoothed version of the model's parameters across training steps, often yielding better generalization.

  • checkpoint_best_regular.pth – best checkpoint based on validation score, using the raw (non-EMA) model weights.

  • checkpoint_best_total.pth – final checkpoint selected for inference and benchmarking. It contains only the model weights (no optimizer state or scheduler) and is chosen as the better of the EMA and non-EMA models based on validation performance.

For detection and segmentation models, the validation score is box mAP (val/mAP_50_95). For keypoint preview models, best-checkpoint selection uses COCO keypoint AP (val/keypoint_map_50_95) and checkpoints persist the model keypoint schema so RFDETR.from_checkpoint() can reconstruct the same label/keypoint slots.

Checkpoint file sizes

Checkpoint sizes vary based on what they contain:

  • Training checkpoints (e.g. checkpoint.pth, checkpoint_<number>.pth) include model weights, optimizer state, scheduler state, and training metadata. Use these to resume training.

  • Evaluation checkpoints (e.g. checkpoint_best_ema.pth, checkpoint_best_regular.pth) store only the model weights — either EMA or raw — and are used to track the best-performing models. These may come from different epochs depending on which version achieved the highest validation score.

  • Stripped checkpoint (e.g. checkpoint_best_total.pth) contains only the final model weights and is optimized for inference and deployment.

Load and Run Fine-Tuned Model

from rfdetr import RFDETRMedium

model = RFDETRMedium(pretrain_weights="<CHECKPOINT_PATH>")

detections = model.predict("<IMAGE_PATH>")
from rfdetr import RFDETRSegMedium

model = RFDETRSegMedium(pretrain_weights="<CHECKPOINT_PATH>")

detections = model.predict("<IMAGE_PATH>")

Next Steps

After training your model, you can: