Training with PyTorch Lightning¶
RF-DETR is a real-time object detection model that combines the accuracy of transformer-based detectors with inference speeds suitable for production. The training stack is built on PyTorch Lightning — giving you composable building blocks you can adopt incrementally, without changing your existing code.
| Building block | Role |
|---|---|
RFDETRModelModule |
LightningModule — model, loss, optimizer, scheduler |
RFDETRDataModule |
LightningDataModule — datasets and dataloaders |
build_trainer() |
Factory that assembles a Trainer with all RF-DETR callbacks |
Key design principle: start simple, then pick up building blocks without losing your trained weights.
- Phase 1 —
model.train()one-liner (EPOCHS_PHASE_1epochs) - Phase 2 — swap in the PTL components and continue for
EPOCHS_PHASE_2more epochs from the same checkpoint, same output folder — no conversion required - End — full training curve, single-image inference, and batch inference with
trainer.predict()
1. Install RF-DETR 1.6.0¶
rfdetr[train,loggers] pulls in PyTorch Lightning, torchmetrics, and the full
callback stack. roboflow downloads the demo dataset.
!pip install -q rfdetr[train,loggers]==1.6.0 roboflow
2. Config¶
All notebook-level knobs in one place. Adjust EPOCHS_PHASE_* and BATCH_SIZE
to match your hardware — every downstream cell reads from these variables.
num_workers is set to os.cpu_count() inside a Jupyter/Colab kernel where
process forking is safe, and to 0 when running as a plain Python script.
On macOS and Windows, spawn-based multiprocessing would otherwise re-import
this module as __main__ and retrigger training.
import os
from pathlib import Path
DATASET_DIR = os.environ.get("DATASET_DIR", "")
OUTPUT_DIR = "output"
EPOCHS_PHASE_1 = 20
EPOCHS_PHASE_2 = 10
BATCH_SIZE = 12
THRESHOLD = 0.3
os.makedirs(OUTPUT_DIR, exist_ok=True)
try:
from IPython import get_ipython
_in_notebook = get_ipython() is not None
except Exception:
_in_notebook = False
# Outside a notebook kernel, macOS/Windows spawn-based multiprocessing will
# re-import this script as __main__, triggering training again.
# Use 0 workers in that case; inside a kernel the usual forking rules apply safely.
num_workers = (os.cpu_count() or 0) if _in_notebook else 0
3. Dataset¶
Aquarium Combined
— 638 images across 7 classes (fish, jellyfish, penguin, puffin,
shark, starfish, stingray). It is small enough to complete a demo run
in a few minutes yet diverse enough to produce meaningful detection results.
Set ROBOFLOW_API_KEY as a Colab secret (Secrets panel, key icon) or as an
environment variable before running this cell. The dataset is downloaded in
COCO format, which RF-DETR reads natively.
import json
from roboflow import Roboflow
try:
from google.colab import userdata # type: ignore[import]
API_KEY = userdata.get("ROBOFLOW_API_KEY")
except Exception:
API_KEY = os.environ["ROBOFLOW_API_KEY"]
rf = Roboflow(api_key=API_KEY)
dataset = rf.workspace("brad-dwyer").project("aquarium-combined").version(1).download("coco", location="datasets")
DATASET_DIR = dataset.location
with open(Path(DATASET_DIR) / "train" / "_annotations.coco.json") as f:
_ann = json.load(f)
CLASS_NAMES = [c["name"] for c in sorted(_ann["categories"], key=lambda c: c["id"])]
NUM_CLASSES = len(CLASS_NAMES)
print(f"Dataset : {DATASET_DIR}")
print(f"Classes : {NUM_CLASSES} — {CLASS_NAMES}")
with open(Path(DATASET_DIR) / "valid" / "_annotations.coco.json") as f:
_val_ann = json.load(f)
val_images_dir = Path(DATASET_DIR) / "valid"
val_image_files = [img["file_name"] for img in _val_ann["images"]]
4. Phase 1 — model.train() one-liner¶
The same high-level API that has been in RF-DETR since v1.0 — nothing here changes from previous releases. If you have existing training scripts, they keep working without modification.
pretrain_weights="rf-detr-medium.pth" downloads COCO-pretrained backbone
weights automatically on first run and caches them locally. use_ema=True
maintains an exponential moving average of the weights to stabilise validation
metrics. run_test=False skips the final test-set evaluation to keep Phase 1
fast; Phase 2 turns it back on.
After this cell completes, OUTPUT_DIR/checkpoint_best_total.pth holds the
best weights seen so far — the starting point for Phase 2.
from rfdetr import RFDETRMedium
model = RFDETRMedium(num_classes=NUM_CLASSES, pretrain_weights="rf-detr-medium.pth")
model.train(
dataset_dir=DATASET_DIR,
epochs=EPOCHS_PHASE_1,
batch_size=BATCH_SIZE,
grad_accum_steps=4,
lr=1e-4,
num_workers=num_workers,
output_dir=OUTPUT_DIR,
use_ema=True,
run_test=False,
progress_bar="rich",
tensorboard=True,
seed=42,
)
5. Phase 2 — PTL building blocks¶
Pick up the three PTL components and call trainer.fit() pointing at the Phase 1
checkpoint. No weight conversion is needed — RFDETRModelModule.on_load_checkpoint
detects the .pth format and remaps keys automatically.
A lower learning rate (5e-5) is used because the model is already partially
converged. epochs=EPOCHS_PHASE_1 + EPOCHS_PHASE_2 sets the absolute epoch
ceiling; because the loaded checkpoint records the last completed epoch, PTL
runs exactly EPOCHS_PHASE_2 additional epochs before stopping. The same
OUTPUT_DIR is reused so checkpoints and metrics all land in one place.
import pandas as pd
from rfdetr import RFDETRDataModule, RFDETRModelModule, build_trainer
from rfdetr.config import RFDETRMediumConfig, TrainConfig
# Read Phase 1 metrics before Phase 2 overwrites the CSV.
df1 = pd.read_csv(f"{OUTPUT_DIR}/metrics.csv")
model_config = RFDETRMediumConfig(
num_classes=NUM_CLASSES,
pretrain_weights="rf-detr-medium.pth",
)
# epochs = EPOCHS_1 + EPOCHS_2 so PTL (which resumes the epoch counter from the
# checkpoint) runs exactly EPOCHS_2 additional epochs before reaching max_epochs.
train_config = TrainConfig(
dataset_dir=DATASET_DIR,
epochs=EPOCHS_PHASE_1 + EPOCHS_PHASE_2,
batch_size=BATCH_SIZE,
grad_accum_steps=4,
lr=5e-5,
num_workers=num_workers,
output_dir=OUTPUT_DIR,
use_ema=True,
run_test=True,
progress_bar="tqdm",
tensorboard=True,
seed=42,
)
module = RFDETRModelModule(model_config=model_config, train_config=train_config)
datamodule = RFDETRDataModule(model_config=model_config, train_config=train_config)
trainer = build_trainer(train_config, model_config)
# Resume directly from the Phase 1 .pth — no conversion needed.
trainer.fit(module, datamodule, ckpt_path=f"{OUTPUT_DIR}/checkpoint_best_total.pth")
6. Training curve¶
Phase 1 and Phase 2 each emit their own metrics.csv (Phase 2 overwrites Phase 1's
file when it starts). We captured the Phase 1 copy before fitting so we can
concatenate both DataFrames and plot a single continuous curve across all epochs.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import supervision as sv
from IPython.display import display
from PIL import Image
from rfdetr.visualize.training import plot_metrics
df2 = pd.read_csv(f"{OUTPUT_DIR}/metrics.csv")
combined_csv = f"{OUTPUT_DIR}/metrics_combined.csv"
pd.concat([df1, df2], ignore_index=True).to_csv(combined_csv, index=False)
print(f"Combined CSV: {combined_csv} ({len(df1) + len(df2)} rows)")
fig = plot_metrics(combined_csv)
display(fig)
print(f"Saved: {OUTPUT_DIR}/metrics_plot.png")
7. Single-image inference — model.predict()¶
RFDETRMedium can be instantiated directly from a checkpoint path for inference
— no training config needed. model.predict() accepts a PIL Image, runs
preprocessing, the forward pass, and postprocessing internally, and returns a
supervision.Detections object that is ready to annotate and display.
%matplotlib inline
model = RFDETRMedium(pretrain_weights=f"{OUTPUT_DIR}/checkpoint_best_total.pth", num_classes=NUM_CLASSES)
image = Image.open(val_images_dir / val_image_files[0])
detections = model.predict(image, threshold=THRESHOLD)
annotated = sv.BoxAnnotator().annotate(image.copy(), detections)
annotated = sv.LabelAnnotator().annotate(annotated, detections, labels=[CLASS_NAMES[c] for c in detections.class_id])
plt.figure(figsize=(10, 7))
plt.imshow(np.array(annotated))
plt.axis("off")
plt.show()
print(f"Detected {len(detections)} object(s)")
8. Batch inference — trainer.predict()¶
Instead of calling model.predict() one image at a time, trainer.predict()
streams the entire validation set through the model in batches and collects
all results — useful for dataset-level evaluation or offline export pipelines.
What happens under the hood:
- PTL calls
datamodule.setup("predict")— this builds_dataset_valif it does not exist yet. Becausetrainer.fit()already ran above, the dataset is already in memory andsetupis a no-op. - PTL calls
datamodule.predict_dataloader()— this returns the validation dataset wrapped in aSequentialSampler(no shuffle, no augmentation), identical toval_dataloader. - For each batch,
RFDETRModelModule.predict_step()runs a forward pass undertorch.no_grad()and returns a list of{"scores", "labels", "boxes"}dicts — one dict per image in the batch. trainer.predict()collects all batch results into aList[List[dict]](outer = batches, inner = images).
Flatten, apply a confidence threshold, and wrap in sv.Detections.
%matplotlib inline
import itertools
# Returns List[List[dict]] — outer: batches, inner: one dict per image
# Each dict has keys: "scores" (N,), "labels" (N,), "boxes" (N, 4) — all tensors
all_preds = trainer.predict(module, datamodule)
# Flatten the batch dimension → one result dict per validation image
flat_preds = [img_result for batch in all_preds for img_result in batch]
print(f"Ran predict on {len(flat_preds)} validation images")
# Build sv.Detections from raw tensors and visualise the first four images
annotated_images = []
for img_file, result in itertools.islice(zip(val_image_files, flat_preds), 4):
keep = result["scores"] > THRESHOLD
detections = sv.Detections(
xyxy=result["boxes"][keep].cpu().float().numpy(),
confidence=result["scores"][keep].cpu().float().numpy(),
class_id=result["labels"][keep].cpu().long().numpy(),
)
image = Image.open(val_images_dir / img_file)
annotated = sv.BoxAnnotator().annotate(image.copy(), detections)
annotated = sv.LabelAnnotator().annotate(
annotated, detections, labels=[CLASS_NAMES[c] for c in detections.class_id]
)
annotated_images.append(np.array(annotated))
print(f" {img_file}: {len(detections)} detection(s)")
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for ax, img in zip(axes.flat, annotated_images):
ax.imshow(img)
ax.axis("off")
plt.tight_layout()
plt.show()
9. Next steps¶
You have now seen the complete 1.6.0 stack — from a one-liner model.train()
through composable PTL components to batch inference. From here:
- PyTorch Lightning training docs — custom callbacks, multi-GPU, mixed precision
- Advanced training options — augmentations, EMA, learning rate schedules
- Logger integrations (ClearML, MLflow, W&B) — experiment tracking
- Export your model — ONNX, TensorRT, CoreML