Training API Reference¶
This page documents the training primitives that power RF-DETR. For a narrative guide with runnable examples, see Custom Training API.
RFDETRModule¶
Bases: LightningModule
LightningModule wrapping the RF-DETR model and training loop.
Migrates Model.__init__, train_one_epoch, evaluate, and
optimizer setup from main.py / engine.py into PTL lifecycle hooks.
Coexists with the existing code until Chapter 4 removes the legacy path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModelConfig
|
Architecture configuration. |
required |
|
TrainConfig
|
Training hyperparameter configuration. |
required |
Functions¶
__init__(model_config, train_config)
¶
on_fit_start()
¶
Seed RNGs at fit start when TrainConfig.seed is set.
This avoids hidden global side-effects in build_trainer while still
preserving deterministic training behaviour for actual fit runs.
on_train_batch_start(batch, batch_idx)
¶
Apply optional multi-scale resize to the incoming batch.
Modifications to batch (in-place on NestedTensor) are visible
in training_step because they share the same object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Index of the current batch within the epoch. |
required |
transfer_batch_to_device(batch, device, dataloader_idx)
¶
Override PTL's default to handle NestedTensor device transfer.
PTL's default iterates tuple elements and calls .to(device); that
works for plain tensors but NestedTensor must be moved explicitly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
device
|
Target device. |
required |
|
int
|
Index of the dataloader providing this batch. |
required |
Returns:
| Type | Description |
|---|---|
Tuple
|
Batch with all tensors on |
training_step(batch, batch_idx)
¶
Compute loss for one training step and log metrics.
PTL handles gradient accumulation (accumulate_grad_batches), AMP
(precision), and gradient clipping (gradient_clip_val) — no
manual GradScaler or loss scaling here. The loss is divided by
trainer.accumulate_grad_batches so that the accumulated gradient
magnitude matches the legacy engine (which scales each sub-batch by
1/grad_accum_steps before calling backward()).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the epoch. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar loss tensor. |
validation_step(batch, batch_idx)
¶
Run forward pass and postprocess for one validation step.
Returns raw results and targets so COCOEvalCallback can accumulate
them across the epoch via on_validation_batch_end.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the validation epoch. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict with |
test_step(batch, batch_idx)
¶
Run forward pass and postprocess for one test step.
Mirrors :meth:validation_step so COCOEvalCallback can accumulate
results via on_test_batch_end when trainer.test() is called (e.g.
from :class:~rfdetr.training.callbacks.BestModelCallback at end of training).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the test epoch. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict with |
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Run inference on a preprocessed batch and return postprocessed results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index. |
required |
|
int
|
Index of the predict dataloader. |
0
|
Returns:
| Type | Description |
|---|---|
Any
|
Postprocessed detection results from |
configure_optimizers()
¶
Build AdamW optimizer with layer-wise LR decay and LambdaLR scheduler.
Uses trainer.estimated_stepping_batches for total step count so
cosine annealing covers the full training run regardless of dataset
size or accumulation settings.
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
PTL optimizer config dict with optimizer and step-interval scheduler. |
clip_gradients(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)
¶
Override PTL gradient clipping to support fused AdamW.
PTL's AMP precision plugin refuses to clip gradients when the optimizer
declares it handles unscaling internally (fused=True). When fused is
active we are on BF16 (no GradScaler) so clip_grad_norm_ is
correct. For the non-fused path (FP16 + GradScaler or FP32) we
delegate to super() to preserve scaler-aware unscaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Optimizer
|
The current optimizer. |
required |
|
Optional[float]
|
Maximum gradient norm. |
None
|
|
Optional[str]
|
Clipping algorithm; forwarded to super() for the non-fused path. |
None
|
on_load_checkpoint(checkpoint)
¶
Auto-detect and normalise legacy .pth checkpoints at load time.
PTL calls this hook before applying checkpoint["state_dict"] to
the module. Two legacy formats are handled:
-
Raw legacy format — a
*.pthfile loaded directly byTrainer(e.g. viackpt_path=). Recognised by the presence of"model"without"state_dict". The state dict is rewritten in-place with the"model."prefix so PTL can apply it normally. -
Converted format — a file produced by :func:
~rfdetr.training.checkpoint.convert_legacy_checkpointthat already has"state_dict"but also carries"legacy_ema_state_dict". The EMA weights are stashed onself._pending_legacy_ema_statefor optional restoration by :class:~rfdetr.training.callbacks.ema.RFDETREMACallback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict[str, Any]
|
Checkpoint dict passed in by PTL (mutated in-place). |
required |
RFDETRDataModule¶
Bases: LightningDataModule
LightningDataModule wrapping RF-DETR dataset construction and data loading.
Migrates Model.train() dataset construction and DataLoader setup from
main.py into PTL lifecycle hooks. Coexists with the existing code until
Chapter 4 removes the legacy path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModelConfig
|
Architecture configuration (used for resolution, patch_size, etc.). |
required |
|
TrainConfig
|
Training hyperparameter configuration (used for dataset params). |
required |
Attributes¶
class_names
property
¶
Class names from the training or validation dataset annotation file.
Reads category names from the first available COCO-style dataset.
Returns None if no dataset has been set up yet or the dataset
does not expose COCO-style category information.
Returns:
| Type | Description |
|---|---|
Optional[List[str]]
|
Sorted list of class name strings, or |
Functions¶
__init__(model_config, train_config)
¶
setup(stage)
¶
Build datasets for the requested stage.
PTL calls this on every process before the corresponding
dataloader method. Datasets are built lazily — a dataset is
only constructed once even if setup is called multiple times.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
PTL stage identifier — one of |
required |
train_dataloader()
¶
Return the training DataLoader.
Uses a replacement sampler when the dataset is too small to fill
_MIN_TRAIN_BATCHES effective batches (matching legacy behaviour in
main.py). Otherwise uses a BatchSampler with
drop_last=True to avoid incomplete batches.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the training dataset. |
val_dataloader()
¶
Return the validation DataLoader.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the validation dataset with sequential sampling. |
test_dataloader()
¶
Return the test DataLoader.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the test dataset with sequential sampling. |
build_trainer¶
Assemble a PTL Trainer with the full RF-DETR callback and logger stack.
Resolves training precision from model_config.amp and device capability,
guards EMA against sharded strategies, wires conditional loggers, and applies
promoted training knobs (gradient clipping, sync_batchnorm, strategy).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
TrainConfig
|
Training hyperparameter configuration. |
required |
|
ModelConfig
|
Architecture configuration (used for precision and segmentation). |
required |
|
str | None
|
PTL accelerator string (e.g. |
None
|
|
Any
|
Extra keyword arguments forwarded verbatim to
Any key present in both |
{}
|
Returns:
| Type | Description |
|---|---|
Trainer
|
A configured |
Callbacks¶
RFDETREMACallback¶
Bases: Callback
Exponential Moving Average with optional tau-based warm-up.
Drop-in replacement for rfdetr.util.utils.ModelEma implemented as a
plain Lightning callback around :class:torch.optim.swa_utils.AveragedModel.
The _avg_fn reproduces the exact same formula as ModelEma
(1-indexed updates counter, optional tau warm-up).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Base EMA decay factor. Corresponds to |
0.993
|
|
int
|
Warm-up time constant (in optimizer steps). When > 0 the
effective decay ramps from 0 towards decay following
|
100
|
|
bool
|
Whether buffers are averaged in addition to parameters. |
True
|
|
int
|
Update EMA every N optimizer steps. |
1
|
BestModelCallback¶
Bases: ModelCheckpoint
Track best validation mAP and save best checkpoints during training.
Extends :class:pytorch_lightning.callbacks.ModelCheckpoint to save
stripped {model, args, epoch} .pth files (instead of full .ckpt
files) and to track a separate EMA checkpoint in parallel.
At the end of training the overall winner (regular vs EMA, strict > for
EMA) is copied to checkpoint_best_total.pth and optimizer/scheduler
state is stripped via :func:rfdetr.util.misc.strip_checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Directory where checkpoint files are written. |
required |
|
str
|
Metric key for the regular model mAP. |
'val/mAP_50_95'
|
|
Optional[str]
|
Metric key for the EMA model mAP. |
None
|
|
bool
|
If |
True
|
RFDETREarlyStopping¶
Bases: EarlyStopping
Early stopping callback monitoring validation mAP for RF-DETR.
Extends :class:pytorch_lightning.callbacks.EarlyStopping with dual-metric
monitoring: by default it monitors max(regular_mAP, ema_mAP) (legacy
behaviour); set use_ema=True to monitor the EMA metric exclusively.
The effective metric is injected into trainer.callback_metrics under a
synthetic key before delegating to the parent's stopping logic, so all parent
features are available for free: state_dict/load_state_dict for
checkpoint resumption, NaN/inf guard via check_finite, and
stopping_threshold/divergence_threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Number of epochs with no improvement before stopping. |
10
|
|
float
|
Minimum mAP improvement to reset the patience counter. |
0.001
|
|
bool
|
When |
False
|
|
str
|
Metric key for the regular model mAP. |
'val/mAP_50_95'
|
|
str
|
Metric key for the EMA model mAP. |
'val/ema_mAP_50_95'
|
|
bool
|
If |
True
|
DropPathCallback¶
Bases: Callback
Applies per-step drop-path and dropout rate schedules to the model.
Computes the full schedule array in on_train_start using
:func:rfdetr.util.drop_scheduler.drop_scheduler, then indexes into it
on every training batch to update the model's stochastic-depth and
dropout rates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Peak drop-path rate. |
0.0
|
|
float
|
Peak dropout rate. |
0.0
|
|
int
|
Epoch boundary for early / late modes. |
0
|
|
Literal['standard', 'early', 'late']
|
Schedule mode forwarded to |
'standard'
|
|
Literal['constant', 'linear']
|
Schedule shape forwarded to |
'constant'
|
|
int
|
Passed to |
12
|
COCOEvalCallback¶
Bases: Callback
Validation callback that computes mAP (via torchmetrics) and macro-F1.
Accumulates predictions and targets across validation batches, then at epoch end computes:
val/mAP_50_95,val/mAP_50,val/mAP_75,val/mARusingtorchmetrics.detection.MeanAveragePrecision.- Per-class
val/AP/<name>when class names are available. val/F1,val/precision,val/recallfrom a confidence-threshold sweep over compact per-class matching data (DDP-safe).
For segmentation models (segmentation=True) additional metrics
val/segm_mAP_50_95 and val/segm_mAP_50 are logged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Maximum detections per image passed to
|
500
|
|
bool
|
When |
False
|
|
int
|
Run validation metrics every N epochs. Test metrics are
always computed when |
1
|
|
bool
|
When |
True
|
RFDETRCli¶
RFDETRCli is the command-line entry point for RF-DETR. It wraps
RFDETRModule and RFDETRDataModule under a single rfdetr command and
auto-generates four subcommands from the PyTorch Lightning CLI machinery:
rfdetr fit --config configs/rfdetr_base.yaml
rfdetr validate --ckpt_path output/best.ckpt
rfdetr test --ckpt_path output/best.ckpt
rfdetr predict --ckpt_path output/best.ckpt
Both model_config and train_config are specified once; RFDETRCli
automatically links them to the datamodule so you do not need to repeat the
same arguments under --data.*.
Bases: LightningCLI
LightningCLI subclass for RF-DETR training and evaluation.
Wires RFDETRModule and RFDETRDataModule under a unified CLI,
with argument linking that shares model_config and train_config
between module and datamodule so the user only specifies them once.
Auto-generated subcommands: fit, validate, test, predict.