RF-DETR 1.6.0: From Quick Start to Full Control¶
RF-DETR is a real-time object detection model that combines the accuracy of transformer-based detectors with inference speeds suitable for production. In 1.6.0, the training stack is rebuilt on PyTorch Lightning — giving you composable building blocks you can adopt incrementally, without changing your existing code.
| Building block | Role |
|---|---|
RFDETRModelModule |
LightningModule — model, loss, optimizer, scheduler |
RFDETRDataModule |
LightningDataModule — datasets and dataloaders |
build_trainer() |
Factory that assembles a Trainer with all RF-DETR callbacks |
Key design principle: start simple, then pick up building blocks without losing your trained weights.
- Phase 1 —
model.train()one-liner (EPOCHS_PHASE_1epochs) - Phase 2 — swap in the PTL components and continue for
EPOCHS_PHASE_2more epochs from the same checkpoint, same output folder — no conversion required - End — full training curve, single-image inference, and batch inference with
trainer.predict()
1. Install RF-DETR 1.6.0¶
rfdetr[train,loggers] pulls in PyTorch Lightning, torchmetrics, and the full
callback stack. roboflow downloads the demo dataset.
!pip install -q rfdetr[train,loggers]==1.6.0 roboflow
2. Config¶
All notebook-level knobs in one place. Adjust EPOCHS_PHASE_* and BATCH_SIZE
to match your hardware — every downstream cell reads from these variables.
num_workers is set to os.cpu_count() inside a Jupyter/Colab kernel where
process forking is safe, and to 0 when running as a plain Python script.
On macOS and Windows, spawn-based multiprocessing would otherwise re-import
this module as __main__ and retrigger training.
import os
from pathlib import Path
DATASET_DIR = os.environ.get("DATASET_DIR", "")
OUTPUT_DIR = "output"
EPOCHS_PHASE_1 = 20
EPOCHS_PHASE_2 = 10
BATCH_SIZE = 12
THRESHOLD = 0.3
os.makedirs(OUTPUT_DIR, exist_ok=True)
try:
from IPython import get_ipython
_in_notebook = get_ipython() is not None
except Exception:
_in_notebook = False
# Outside a notebook kernel, macOS/Windows spawn-based multiprocessing will
# re-import this script as __main__, triggering training again.
# Use 0 workers in that case; inside a kernel the usual forking rules apply safely.
num_workers = (os.cpu_count() or 0) if _in_notebook else 0
3. Dataset¶
Aquarium Combined
— 638 images across 7 classes (fish, jellyfish, penguin, puffin,
shark, starfish, stingray). It is small enough to complete a demo run
in a few minutes yet diverse enough to produce meaningful detection results.
Set ROBOFLOW_API_KEY as a Colab secret (Secrets panel, key icon) or as an
environment variable before running this cell. The dataset is downloaded in
COCO format, which RF-DETR reads natively.
import json
from roboflow import Roboflow
try:
from google.colab import userdata # type: ignore[import]
API_KEY = userdata.get("ROBOFLOW_API_KEY")
except Exception:
API_KEY = os.environ["ROBOFLOW_API_KEY"]
rf = Roboflow(api_key=API_KEY)
dataset = rf.workspace("brad-dwyer").project("aquarium-combined").version(1).download("coco", location="datasets")
DATASET_DIR = dataset.location
with open(Path(DATASET_DIR) / "train" / "_annotations.coco.json") as f:
_ann = json.load(f)
CLASS_NAMES = [c["name"] for c in sorted(_ann["categories"], key=lambda c: c["id"])]
NUM_CLASSES = len(CLASS_NAMES)
print(f"Dataset : {DATASET_DIR}")
print(f"Classes : {NUM_CLASSES} — {CLASS_NAMES}")
with open(Path(DATASET_DIR) / "valid" / "_annotations.coco.json") as f:
_val_ann = json.load(f)
val_images_dir = Path(DATASET_DIR) / "valid"
val_image_files = [img["file_name"] for img in _val_ann["images"]]
4. Phase 1 — model.train() one-liner¶
The same high-level API that has been in RF-DETR since v1.0 — nothing here changes from previous releases. If you have existing training scripts, they keep working without modification.
pretrain_weights="rf-detr-medium.pth" downloads COCO-pretrained backbone
weights automatically on first run and caches them locally. use_ema=True
maintains an exponential moving average of the weights to stabilise validation
metrics. run_test=False skips the final test-set evaluation to keep Phase 1
fast; Phase 2 turns it back on.
After this cell completes, OUTPUT_DIR/checkpoint_best_total.pth holds the
best weights seen so far — the starting point for Phase 2.
from rfdetr import RFDETRMedium
model = RFDETRMedium(num_classes=NUM_CLASSES, pretrain_weights="rf-detr-medium.pth")
model.train(
dataset_dir=DATASET_DIR,
epochs=EPOCHS_PHASE_1,
batch_size=BATCH_SIZE,
grad_accum_steps=4,
lr=1e-4,
num_workers=num_workers,
output_dir=OUTPUT_DIR,
use_ema=True,
run_test=False,
progress_bar="rich",
tensorboard=True,
seed=42,
)
5. Phase 2 — PTL building blocks¶
Pick up the three PTL components and call trainer.fit() pointing at the Phase 1
checkpoint. No weight conversion is needed — RFDETRModelModule.on_load_checkpoint
detects the .pth format and remaps keys automatically.
A lower learning rate (5e-5) is used because the model is already partially
converged. epochs=EPOCHS_PHASE_1 + EPOCHS_PHASE_2 sets the absolute epoch
ceiling; because the loaded checkpoint records the last completed epoch, PTL
runs exactly EPOCHS_PHASE_2 additional epochs before stopping. The same
OUTPUT_DIR is reused so checkpoints and metrics all land in one place.
import pandas as pd
from rfdetr import RFDETRDataModule, RFDETRModelModule, build_trainer
from rfdetr.config import RFDETRMediumConfig, TrainConfig
# Read Phase 1 metrics before Phase 2 overwrites the CSV.
df1 = pd.read_csv(f"{OUTPUT_DIR}/metrics.csv")
model_config = RFDETRMediumConfig(
num_classes=NUM_CLASSES,
pretrain_weights="rf-detr-medium.pth",
)
# epochs = EPOCHS_1 + EPOCHS_2 so PTL (which resumes the epoch counter from the
# checkpoint) runs exactly EPOCHS_2 additional epochs before reaching max_epochs.
train_config = TrainConfig(
dataset_dir=DATASET_DIR,
epochs=EPOCHS_PHASE_1 + EPOCHS_PHASE_2,
batch_size=BATCH_SIZE,
grad_accum_steps=4,
lr=5e-5,
num_workers=num_workers,
output_dir=OUTPUT_DIR,
use_ema=True,
run_test=True,
progress_bar="tqdm",
tensorboard=True,
seed=42,
)
module = RFDETRModelModule(model_config=model_config, train_config=train_config)
datamodule = RFDETRDataModule(model_config=model_config, train_config=train_config)
trainer = build_trainer(train_config, model_config)
# Resume directly from the Phase 1 .pth — no conversion needed.
trainer.fit(module, datamodule, ckpt_path=f"{OUTPUT_DIR}/checkpoint_best_total.pth")
6. Training curve¶
Phase 1 and Phase 2 each emit their own metrics.csv (Phase 2 overwrites Phase 1's
file when it starts). We captured the Phase 1 copy before fitting so we can
concatenate both DataFrames and plot a single continuous curve across all epochs.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import supervision as sv
from IPython.display import Image as IPyImage
from IPython.display import display
from PIL import Image
from rfdetr.visualize.training import plot_metrics
df2 = pd.read_csv(f"{OUTPUT_DIR}/metrics.csv")
combined_csv = f"{OUTPUT_DIR}/metrics_combined.csv"
pd.concat([df1, df2], ignore_index=True).to_csv(combined_csv, index=False)
print(f"Combined CSV: {combined_csv} ({len(df1) + len(df2)} rows)")
plot_path = plot_metrics(combined_csv)
display(IPyImage(plot_path))
print(f"Saved: {plot_path}")
7. Single-image inference — model.predict()¶
RFDETRMedium can be instantiated directly from a checkpoint path for inference
— no training config needed. model.predict() accepts a PIL Image, runs
preprocessing, the forward pass, and postprocessing internally, and returns a
supervision.Detections object that is ready to annotate and display.
%matplotlib inline
model = RFDETRMedium(pretrain_weights=f"{OUTPUT_DIR}/checkpoint_best_total.pth", num_classes=NUM_CLASSES)
image = Image.open(val_images_dir / val_image_files[0])
detections = model.predict(image, threshold=THRESHOLD)
annotated = sv.BoxAnnotator().annotate(image.copy(), detections)
annotated = sv.LabelAnnotator().annotate(annotated, detections, labels=[CLASS_NAMES[c] for c in detections.class_id])
plt.figure(figsize=(10, 7))
plt.imshow(np.array(annotated))
plt.axis("off")
plt.show()
print(f"Detected {len(detections)} object(s)")
8. Batch inference — trainer.predict()¶
Instead of calling model.predict() one image at a time, trainer.predict()
streams the entire validation set through the model in batches and collects
all results — useful for dataset-level evaluation or offline export pipelines.
What happens under the hood:
- PTL calls
datamodule.setup("predict")— this builds_dataset_valif it does not exist yet. Becausetrainer.fit()already ran above, the dataset is already in memory andsetupis a no-op. - PTL calls
datamodule.predict_dataloader()— this returns the validation dataset wrapped in aSequentialSampler(no shuffle, no augmentation), identical toval_dataloader. - For each batch,
RFDETRModelModule.predict_step()runs a forward pass undertorch.no_grad()and returns a list of{"scores", "labels", "boxes"}dicts — one dict per image in the batch. trainer.predict()collects all batch results into aList[List[dict]](outer = batches, inner = images).
Flatten, apply a confidence threshold, and wrap in sv.Detections.
%matplotlib inline
import itertools
# Returns List[List[dict]] — outer: batches, inner: one dict per image
# Each dict has keys: "scores" (N,), "labels" (N,), "boxes" (N, 4) — all tensors
all_preds = trainer.predict(module, datamodule)
# Flatten the batch dimension → one result dict per validation image
flat_preds = [img_result for batch in all_preds for img_result in batch]
print(f"Ran predict on {len(flat_preds)} validation images")
# Build sv.Detections from raw tensors and visualise the first four images
annotated_images = []
for img_file, result in itertools.islice(zip(val_image_files, flat_preds), 4):
keep = result["scores"] > THRESHOLD
detections = sv.Detections(
xyxy=result["boxes"][keep].cpu().float().numpy(),
confidence=result["scores"][keep].cpu().float().numpy(),
class_id=result["labels"][keep].cpu().long().numpy(),
)
image = Image.open(val_images_dir / img_file)
annotated = sv.BoxAnnotator().annotate(image.copy(), detections)
annotated = sv.LabelAnnotator().annotate(
annotated, detections, labels=[CLASS_NAMES[c] for c in detections.class_id]
)
annotated_images.append(np.array(annotated))
print(f" {img_file}: {len(detections)} detection(s)")
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for ax, img in zip(axes.flat, annotated_images):
ax.imshow(img)
ax.axis("off")
plt.tight_layout()
plt.show()
9. Next steps¶
You have now seen the complete 1.6.0 stack — from a one-liner model.train()
through composable PTL components to batch inference. From here:
- PyTorch Lightning training docs — custom callbacks, multi-GPU, mixed precision
- Advanced training options — augmentations, EMA, learning rate schedules
- Logger integrations (ClearML, MLflow, W&B) — experiment tracking
- Export your model — ONNX, TensorRT, CoreML