Custom Training API¶
The high-level RFDETR.train() method is the quickest path to fine-tuning, but the underlying training primitives are fully public and are the recommended path for any customisation: custom callbacks, alternative loggers, mixed-precision overrides, multi-GPU strategies, or integration with external training frameworks.
Quickstart vs. customisation
If you want to start training with minimal code, use model.train() — it sets up and runs the full PTL stack automatically. Come here when you need to take direct control over any part of that stack.
How RFDETR.train() relates to PTL¶
When you call model.train(...), three things happen internally:
from rfdetr.training import RFDETRModule, RFDETRDataModule, build_trainer
module = RFDETRModule(model_config, train_config)
datamodule = RFDETRDataModule(model_config, train_config)
trainer = build_trainer(train_config, model_config)
trainer.fit(module, datamodule, ckpt_path=train_config.resume or None)
Each of these objects is a standard PTL class. You can construct them directly, modify them, swap out callbacks, or replace the trainer entirely.
RFDETRModule¶
RFDETRModule is a pytorch_lightning.LightningModule. It owns the model weights, the criterion, the postprocessor, and the optimizer/scheduler configuration.
from rfdetr.config import (
RFDETRMediumConfig,
TrainConfig,
) # config classes live in rfdetr.config, not the top-level rfdetr namespace
from rfdetr.training import RFDETRModule
model_config = RFDETRMediumConfig(num_classes=10)
train_config = TrainConfig(
dataset_dir="path/to/dataset",
epochs=50,
batch_size=4,
grad_accum_steps=4,
lr=1e-4,
output_dir="output",
)
module = RFDETRModule(model_config, train_config)
Lifecycle hooks¶
| Hook | Behaviour |
|---|---|
on_fit_start |
Seeds RNGs when train_config.seed is set. |
on_train_batch_start |
Applies multi-scale random resize when train_config.multi_scale=True. |
transfer_batch_to_device |
Moves NestedTensor batches to the target device. |
training_step |
Computes loss, divides by accumulate_grad_batches, and logs train/loss and per-term losses. |
validation_step |
Runs forward pass and postprocessing; returns {results, targets} for COCOEvalCallback. |
test_step |
Same as validation_step, logs under test/. |
predict_step |
Runs inference-only forward pass and returns postprocessed detections. |
configure_optimizers |
Builds AdamW with layer-wise LR decay and a LambdaLR scheduler (cosine or step). |
on_load_checkpoint |
Auto-converts legacy .pth checkpoints to PTL format. |
Accessing the underlying model¶
The raw nn.Module is module.model. After training completes, RFDETR.train() syncs it back onto self.model.model so predict() and export() continue to work.
RFDETRDataModule¶
RFDETRDataModule is a pytorch_lightning.LightningDataModule. It builds train/val/test datasets and wraps them in DataLoader objects.
from rfdetr.training import RFDETRDataModule
datamodule = RFDETRDataModule(model_config, train_config)
Stages¶
| Stage | Datasets built |
|---|---|
"fit" |
train + val |
"validate" |
val only |
"test" |
test (or val for COCO-format datasets) |
The setup(stage) method is lazy — each split is built at most once, even if called multiple times.
class_names property¶
Returns sorted category names from the COCO annotation file of the first available split, or None if the dataset has not been set up yet.
build_trainer¶
build_trainer assembles a pytorch_lightning.Trainer with the full RF-DETR callback and logger stack. All TrainConfig fields are wired automatically.
What build_trainer configures¶
| Concern | Source |
|---|---|
| Max epochs | train_config.epochs |
| Gradient accumulation | train_config.grad_accum_steps |
| Gradient clipping | train_config.clip_max_norm (default 0.1) |
| Mixed precision | Resolved from model_config.amp and device capability (bf16-mixed on Ampere+, 16-mixed otherwise) |
| Accelerator | train_config.accelerator (default "auto") |
| Strategy | Pass strategy= as a **trainer_kwarg to build_trainer. TrainConfig has no strategy field — setting it on TrainConfig will raise a ValueError. |
| Sync batch norm | train_config.sync_bn |
| Progress bar | train_config.progress_bar |
| Loggers | CSVLogger always; TensorBoard, WandB, MLflow when their train_config flags are True |
| Callbacks | RFDETREMACallback, DropPathCallback, COCOEvalCallback, BestModelCallback, RFDETREarlyStopping (conditional) |
Overriding PTL Trainer kwargs¶
Pass any keyword argument accepted by pytorch_lightning.Trainer via **trainer_kwargs. These override the built configuration:
trainer = build_trainer(
train_config,
model_config,
fast_dev_run=2, # run 2 batches per epoch for a smoke test
accumulate_grad_batches=8, # override TrainConfig.grad_accum_steps
log_every_n_steps=10,
)
Running the training loop¶
Full training run¶
from rfdetr.config import (
RFDETRMediumConfig,
TrainConfig,
) # config classes live in rfdetr.config, not the top-level rfdetr namespace
from rfdetr.training import RFDETRModule, RFDETRDataModule, build_trainer
model_config = RFDETRMediumConfig(num_classes=10)
train_config = TrainConfig(
dataset_dir="path/to/dataset",
epochs=100,
batch_size=4,
grad_accum_steps=4,
lr=1e-4,
output_dir="output",
)
module = RFDETRModule(model_config, train_config)
datamodule = RFDETRDataModule(model_config, train_config)
trainer = build_trainer(train_config, model_config)
trainer.fit(module, datamodule)
Resume from checkpoint¶
Pass the checkpoint path to trainer.fit via ckpt_path. The path can be a PTL .ckpt file or a legacy RF-DETR .pth file — RFDETRModule.on_load_checkpoint converts either format automatically.
trainer.fit(module, datamodule, ckpt_path="output/last.ckpt")
# or a legacy checkpoint:
trainer.fit(module, datamodule, ckpt_path="output/checkpoint.pth")
If you need to persist a converted checkpoint on disk (for example to inspect it, share it, or use it outside of PTL), convert it explicitly before passing it to trainer.fit:
from rfdetr.training import convert_legacy_checkpoint
convert_legacy_checkpoint("old_checkpoint.pth", "new_checkpoint.ckpt")
trainer.fit(module, datamodule, ckpt_path="new_checkpoint.ckpt")
convert_legacy_checkpoint reads a pre-PTL .pth file produced by the legacy engine.py training loop and writes a PTL-compatible .ckpt file. Use it when migrating saved checkpoints to the PTL format rather than relying on on-the-fly conversion at load time.
Validation only¶
Runs one full validation pass and logs val/mAP_50_95, val/mAP_50, val/F1, and per-class AP metrics to all active loggers.
Inference with the data pipeline¶
Calls module.predict_step on every batch and returns a list of postprocessed detection results. Pass any DataLoader instance — datamodule.val_dataloader(), datamodule.test_dataloader(), or a custom loader — as the dataloaders argument. This is useful for offline evaluation or generating submission files.
predict_dataloader not implemented
RFDETRDataModule does not define a predict_dataloader() method, so trainer.predict(module, datamodule) will raise an error. Always pass a dataloader explicitly via the dataloaders= argument.
Multi-GPU training¶
build_trainer configures PyTorch Lightning's Trainer directly, so all PTL strategies work out of the box.
Data Parallel (DDP) — recommended¶
Set train_config.accelerator = "auto" and pass strategy="ddp" to build_trainer, then launch with torchrun:
devices must be overridden for multi-GPU runs
build_trainer defaults to devices=1. To use all available GPUs, pass devices="auto" (or an explicit count) as a **trainer_kwarg:
Without this override, torchrun will spawn multiple processes but each process will only see one device, defeating the purpose of the multi-GPU launch.
where train.py contains:
from rfdetr.config import (
RFDETRMediumConfig,
TrainConfig,
) # config classes live in rfdetr.config, not the top-level rfdetr namespace
from rfdetr.training import RFDETRModule, RFDETRDataModule, build_trainer
model_config = RFDETRMediumConfig(num_classes=10)
train_config = TrainConfig(
dataset_dir="path/to/dataset",
epochs=100,
batch_size=4, # per-GPU batch size
grad_accum_steps=1, # reduce when using more GPUs
output_dir="output",
sync_bn=True, # sync batch norms across GPUs
)
module = RFDETRModule(model_config, train_config)
datamodule = RFDETRDataModule(model_config, train_config)
trainer = build_trainer(train_config, model_config, strategy="ddp", devices="auto")
trainer.fit(module, datamodule)
EMA is not compatible with FSDP or DeepSpeed
build_trainer automatically disables RFDETREMACallback when strategy contains "fsdp" or "deepspeed", and emits a UserWarning. Use strategy="ddp" or strategy="auto" to keep EMA active.
Effective batch size¶
Maintain an effective batch size of 16 regardless of GPU count:
| GPUs | batch_size |
grad_accum_steps |
Effective |
|---|---|---|---|
| 1 | 4 | 4 | 16 |
| 2 | 4 | 2 | 16 |
| 4 | 4 | 1 | 16 |
| 8 | 2 | 1 | 16 |
Custom callbacks¶
build_trainer builds the default callback stack. To add your own callbacks alongside the built-in ones, pass them via trainer_kwargs:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelSummary
from rfdetr.training import build_trainer
extra_callbacks = [
LearningRateMonitor(logging_interval="step"),
ModelSummary(max_depth=3),
]
trainer = build_trainer(
train_config,
model_config,
callbacks=extra_callbacks, # replaces the default callback list entirely
)
Replacing vs. extending callbacks
Passing callbacks= to build_trainer via trainer_kwargs replaces the entire default callback list built inside build_trainer (EMA, COCO eval, best-model checkpointing, etc.). To extend rather than replace, build the extra callbacks separately and merge them after calling build_trainer:
Built-in callbacks¶
| Class | Purpose | Enabled when |
|---|---|---|
RFDETREMACallback |
Maintains an EMA copy of model weights | train_config.use_ema=True and strategy is not sharded |
DropPathCallback |
Anneals drop-path rate over training | train_config.drop_path > 0 |
COCOEvalCallback |
Computes mAP and F1 after each validation epoch | Always |
BestModelCallback |
Saves checkpoint_best_regular.pth, checkpoint_best_ema.pth, checkpoint_best_total.pth |
Always |
RFDETREarlyStopping |
Stops training when validation mAP stops improving | train_config.early_stopping=True |
Custom loggers¶
build_trainer adds loggers based on TrainConfig flags. To attach a logger not supported by TrainConfig (for example a custom Neptune or Comet logger), build it yourself and pass it alongside the defaults:
from pytorch_lightning.loggers import NeptuneLogger # hypothetical
from rfdetr.training import build_trainer
trainer = build_trainer(train_config, model_config)
trainer.loggers.append(NeptuneLogger(project="my-workspace/rf-detr"))
trainer.fit(module, datamodule)
All logged keys (train/loss, val/mAP_50_95, val/F1, val/ema_mAP_50_95, etc.) are written to every active logger in the list.
Logged metrics reference¶
| Key | When logged | Description |
|---|---|---|
train/loss |
Every step / epoch | Total weighted training loss |
train/<term> |
Every step / epoch | Individual loss terms (e.g. train/loss_bbox) |
val/loss |
Each epoch | Validation loss (if train_config.compute_val_loss=True) |
val/mAP_50_95 |
Each eval epoch | COCO box mAP@[.50:.05:.95] |
val/mAP_50 |
Each eval epoch | COCO box mAP@.50 |
val/mAP_75 |
Each eval epoch | COCO box mAP@.75 |
val/mAR |
Each eval epoch | COCO mean average recall |
val/ema_mAP_50_95 |
Each eval epoch | EMA-model mAP@[.50:.05:.95] (if EMA active) |
val/F1 |
Each eval epoch | Macro F1 at best confidence threshold |
val/precision |
Each eval epoch | Precision at best F1 threshold |
val/recall |
Each eval epoch | Recall at best F1 threshold |
val/AP/<class> |
Each eval epoch | Per-class AP (if log_per_class_metrics=True) |
val/segm_mAP_50_95 |
Each eval epoch | Segmentation mAP (segmentation models only) |
val/segm_mAP_50 |
Each eval epoch | Segmentation mAP@.50 (segmentation models only) |
test/* |
After trainer.test() |
Mirror of val/* keys |
See also¶
- RFDETR.train() — high-level API — the one-liner training path
- Training parameters — all
TrainConfigfields - Training loggers — TensorBoard, WandB, MLflow setup
- Advanced training — checkpointing, early stopping, memory optimisation
- PTL primitives API reference — full docstring reference