Training API Reference¶
This page documents the training primitives that power RF-DETR. For a narrative guide with runnable examples, see Custom Training API.
RFDETRModelModule¶
Bases: LightningModule
LightningModule wrapping the RF-DETR model and training loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModelConfig
|
Architecture configuration. |
required |
|
TrainConfig
|
Training hyperparameter configuration. |
required |
Functions¶
__init__(model_config, train_config)
¶
on_fit_start()
¶
Seed RNGs at fit start when TrainConfig.seed is set.
This avoids hidden global side-effects in build_trainer while still preserving deterministic training
behaviour for actual fit runs.
on_train_batch_start(batch, batch_idx)
¶
Apply optional multi-scale resize to the incoming batch.
Modifications to batch (in-place on NestedTensor) are visible in training_step because they share
the same object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Index of the current batch within the epoch. |
required |
training_step(batch, batch_idx)
¶
Compute loss for one training step and log metrics.
PTL handles AMP (precision) without a manual GradScaler. Keypoint models perform manual optimization so
box-count loss normalization is based on the full accumulated effective batch rather than each microbatch
independently; detection and segmentation models keep Lightning's automatic optimization path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the epoch. |
required |
Returns:
| Type | Description |
|---|---|
Tensor | dict[str, Any]
|
Scalar loss tensor by default. When |
Tensor | dict[str, Any]
|
returns a Lightning-compatible dict containing |
Tensor | dict[str, Any]
|
detached postprocessed predictions for train mAP logging. |
validation_step(batch, batch_idx)
¶
Run forward pass and postprocess for one validation step.
Returns raw results and targets so COCOEvalCallback can accumulate them across the epoch via
on_validation_batch_end.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the validation epoch. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict with |
test_step(batch, batch_idx)
¶
Run forward pass and postprocess for one test step.
Mirrors :meth:validation_step so COCOEvalCallback can accumulate results via on_test_batch_end when
trainer.test() is called (e.g. from :class:~rfdetr.training.callbacks.BestModelCallback at end of
training).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index within the test epoch. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict with |
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Run inference on a preprocessed batch and return postprocessed results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tuple
|
Tuple of (NestedTensor samples, list of target dicts). |
required |
|
int
|
Batch index. |
required |
|
int
|
Index of the predict dataloader. |
0
|
Returns:
| Type | Description |
|---|---|
Any
|
Postprocessed detection results from |
configure_optimizers()
¶
Build AdamW optimizer with layer-wise LR decay and LambdaLR scheduler.
Uses trainer.estimated_stepping_batches for total step count so cosine annealing covers the full training
run regardless of dataset size or accumulation settings.
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
PTL optimizer config dict with optimizer and step-interval scheduler. |
clip_gradients(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)
¶
Override PTL gradient clipping to support fused AdamW.
PTL's AMP precision plugin refuses to clip gradients when the optimizer declares it handles unscaling internally
(fused=True). When fused is active we are on BF16 (no GradScaler) so clip_grad_norm_ is correct. For the
non-fused path (FP16 + GradScaler or FP32) we delegate to super() to preserve scaler-aware unscaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Optimizer
|
The current optimizer. |
required |
|
Optional[float]
|
Maximum gradient norm. |
None
|
|
Optional[str]
|
Clipping algorithm; forwarded to super() for the non-fused path. |
None
|
on_load_checkpoint(checkpoint)
¶
Auto-detect legacy formats and reconcile PE shapes at checkpoint load time.
PTL calls this hook before applying checkpoint["state_dict"] to the module. Three normalisation steps are
applied in order:
-
Raw legacy format — a
*.pthfile loaded directly byTrainer(e.g. viackpt_path=). Recognised by the presence of"model"without"state_dict". The state dict is rewritten in-place with the"model."prefix so PTL can apply it normally. -
Positional-embedding interpolation — when the checkpoint was saved at a different image resolution than the current model, the DINOv2
position_embeddingstensor shape will mismatch. :func:~rfdetr.models.weights.interpolate_position_embeddingsis called to bicubic-resize the PE tomodel_config.positional_encoding_sizebefore PTL applies the state dict. Regression fix for :issue:998. -
Converted format — a file produced by :func:
~rfdetr.training.checkpoint.convert_legacy_checkpointthat already has"state_dict"but also carries"legacy_ema_state_dict". The EMA weights are stashed onself._pending_legacy_ema_statefor optional restoration by :class:~rfdetr.training.callbacks.ema.RFDETREMACallback.
Note
This hook only fires on Trainer(ckpt_path=...) resume paths. Fresh-train bootstrap from a
pretrain_weights checkpoint runs through :func:~rfdetr.models.weights.load_pretrain_weights during
__init__ instead — that helper performs its own PTL .ckpt normalisation (state_dict → model
key, _orig_mod strip) and PE interpolation, so the two code paths intentionally do not share state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict[str, Any]
|
Checkpoint dict passed in by PTL (mutated in-place). |
required |
RFDETRDataModule¶
Bases: LightningDataModule
LightningDataModule wrapping RF-DETR dataset construction and data loading.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModelConfig
|
Architecture configuration (used for resolution, patch_size, etc.). |
required |
|
TrainConfig
|
Training hyperparameter configuration (used for dataset params). |
required |
Attributes¶
class_names
property
¶
Class names from the training or validation dataset annotation file.
Reads category names from the first available COCO-style dataset. Returns None if no dataset has been set up
yet or the dataset does not expose COCO-style category information.
Returns:
| Type | Description |
|---|---|
Optional[List[str]]
|
Sorted list of class name strings, or |
Functions¶
__init__(model_config, train_config)
¶
setup(stage)
¶
Build datasets for the requested stage.
PTL calls this on every process before the corresponding dataloader method. Datasets are built lazily — a
dataset is only constructed once even if setup is called multiple times.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
PTL stage identifier — one of |
required |
train_dataloader()
¶
Return the training DataLoader.
Uses a replacement sampler when the dataset is too small to fill _MIN_TRAIN_BATCHES effective batches
(matching legacy behaviour in main.py). Otherwise wraps the dataset with :class:GradAccumAlignedDataset
to ensure its length is an exact multiple of effective_batch_size * world_size (workaround for
https://github.com/Lightning-AI/pytorch-lightning/issues/19987) and then uses shuffle=True, drop_last=True
so that PTL can auto-inject DistributedSampler in DDP mode.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the training dataset. |
val_dataloader()
¶
Return the validation DataLoader.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the validation dataset with sequential sampling. |
test_dataloader()
¶
Return the test DataLoader.
Returns:
| Type | Description |
|---|---|
DataLoader
|
DataLoader for the test dataset with sequential sampling. |
build_trainer¶
Assemble a PTL Trainer with the full RF-DETR callback and logger stack.
Resolves training precision from model_config.amp and device capability, guards EMA against sharded strategies,
wires conditional loggers, and applies promoted training knobs (sync_batchnorm, strategy).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
TrainConfig
|
Training hyperparameter configuration. |
required |
|
ModelConfig
|
Architecture configuration. Used for precision resolution
( |
required |
|
str | None
|
PTL accelerator string (e.g. |
None
|
|
Any
|
Extra keyword arguments forwarded to Most keys present in both |
{}
|
Returns:
| Type | Description |
|---|---|
Trainer
|
A configured |
Callbacks¶
RFDETREMACallback¶
Bases: Callback
Exponential Moving Average with optional tau-based warm-up.
Drop-in replacement for rfdetr.util.utils.ModelEma implemented as a plain Lightning callback around
:class:torch.optim.swa_utils.AveragedModel. The _avg_fn reproduces the exact same formula as ModelEma
(1-indexed updates counter, optional tau warm-up).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Base EMA decay factor. Corresponds to |
0.993
|
|
int
|
Warm-up time constant (in optimizer steps). When > 0 the
effective decay ramps from 0 towards decay following |
100
|
|
bool
|
Whether buffers are averaged in addition to parameters. |
True
|
|
int
|
Update EMA every N optimizer steps. |
1
|
BestModelCallback¶
Bases: ModelCheckpoint
Track best validation mAP and save best checkpoints during training.
Extends :class:pytorch_lightning.callbacks.ModelCheckpoint to save stripped {model, args, epoch} .pth
files (instead of full .ckpt files) and to track a separate EMA checkpoint in parallel.
At the end of training the overall winner (regular vs EMA, strict > for EMA) is copied to
checkpoint_best_total.pth and optimizer/scheduler state is stripped via
:func:rfdetr.util.misc.strip_checkpoint.
Checkpoints are only updated on validation epochs where the monitor metric is actually logged. On non-eval epochs
(when eval_interval > 1 causes COCO evaluation to be skipped) the callback is a no-op.
state_dict() and load_state_dict() are overridden to persist _best_ema in the Lightning callback state,
ensuring that trainer.fit(ckpt_path=...) resumes EMA high-water-mark tracking from the correct value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Directory where checkpoint files are written. |
required |
|
str
|
Metric key for the regular model mAP. |
'val/mAP_50_95'
|
|
str | None
|
Metric key for the EMA model mAP. |
None
|
|
bool
|
If |
True
|
|
int
|
Ignore the first N epochs (0..N-1) when tracking
best regular and EMA checkpoints. Useful when fine-tuning from |
0
|
Examples:
Skip the first 3 epochs so pretrained weights do not dominate:
RFDETREarlyStopping¶
Bases: EarlyStopping
Early stopping callback monitoring validation mAP for RF-DETR.
Extends :class:pytorch_lightning.callbacks.EarlyStopping with dual-metric
monitoring: by default it monitors max(regular_mAP, ema_mAP) (legacy
behaviour); set use_ema=True to monitor the EMA metric exclusively.
The effective metric is injected into trainer.callback_metrics under a synthetic key before delegating to the
parent's stopping logic, so all parent features are available for free: state_dict/load_state_dict for
checkpoint resumption, NaN/inf guard via check_finite, and stopping_threshold/divergence_threshold.
Early stopping evaluates only on validation epochs where the monitored metrics are logged; non-eval epochs
(eval_interval > 1) are skipped automatically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Number of epochs with no improvement before stopping. |
10
|
|
float
|
Minimum mAP improvement to reset the patience counter. |
0.001
|
|
bool
|
When |
False
|
|
str
|
Metric key for the regular model mAP. |
'val/mAP_50_95'
|
|
str
|
Metric key for the EMA model mAP. |
'val/ema_mAP_50_95'
|
|
bool
|
If |
True
|
|
int
|
Ignore the first N epochs (0..N-1) when evaluating
patience and best-score baselines. Set this when fine-tuning from |
0
|
Examples:
Fine-tuning from pretrained weights — skip first 3 epochs:
DropPathCallback¶
Bases: Callback
Applies per-step drop-path and dropout rate schedules to the model.
Computes the full schedule array in on_train_start using :func:rfdetr.util.drop_scheduler.drop_scheduler, then
indexes into it on every training batch to update the model's stochastic-depth and dropout rates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Peak drop-path rate. |
0.0
|
|
float
|
Peak dropout rate. |
0.0
|
|
int
|
Epoch boundary for early / late modes. |
0
|
|
Literal['standard', 'early', 'late']
|
Schedule mode forwarded to |
'standard'
|
|
Literal['constant', 'linear']
|
Schedule shape forwarded to |
'constant'
|
|
int
|
Passed to |
12
|
COCOEvalCallback¶
Bases: Callback
Validation callback that computes mAP (via torchmetrics) and macro-F1.
Accumulates predictions and targets across validation batches, then at epoch end computes:
val/mAP_50_95,val/mAP_50,val/mAP_75,val/mARusingtorchmetrics.detection.MeanAveragePrecision.- Per-class
val/AP/<name>when class names are available. val/F1,val/precision,val/recallfrom a confidence-threshold sweep over compact per-class matching data (DDP-safe).
For segmentation models (segmentation=True) additional metrics val/segm_mAP_50_95 and val/segm_mAP_50
are logged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Maximum detections per image passed to
|
DEFAULT_KEYPOINT_MAX_DETS
|
|
bool
|
When |
False
|
|
int
|
Run validation metrics every N epochs. Test metrics are
always computed when |
1
|
|
bool
|
When |
True
|
RFDETRCli¶
RFDETRCli is the command-line entry point for RF-DETR. It wraps
RFDETRModelModule and RFDETRDataModule under a single rfdetr command and
auto-generates four subcommands from the PyTorch Lightning CLI machinery:
rfdetr fit --config configs/rfdetr_base.yaml
rfdetr validate --ckpt_path output/best.ckpt
rfdetr test --ckpt_path output/best.ckpt
rfdetr predict --ckpt_path output/best.ckpt
Both model_config and train_config are specified once; RFDETRCli
automatically links them to the datamodule so you do not need to repeat the
same arguments under --data.*.
Bases: LightningCLI
LightningCLI subclass for RF-DETR training and evaluation.
Wires RFDETRModelModule and RFDETRDataModule under a unified CLI, with argument linking that shares
model_config and train_config between module and datamodule so the user only specifies them once.
Auto-generated subcommands: fit, validate, test, predict.