Skip to content

Training API Reference

This page documents the training primitives that power RF-DETR. For a narrative guide with runnable examples, see Custom Training API.

RFDETRModule

Bases: LightningModule

LightningModule wrapping the RF-DETR model and training loop.

Migrates Model.__init__, train_one_epoch, evaluate, and optimizer setup from main.py / engine.py into PTL lifecycle hooks. Coexists with the existing code until Chapter 4 removes the legacy path.

Parameters:

Name Type Description Default

model_config

ModelConfig

Architecture configuration.

required

train_config

TrainConfig

Training hyperparameter configuration.

required

Functions

__init__(model_config, train_config)

on_fit_start()

Seed RNGs at fit start when TrainConfig.seed is set.

This avoids hidden global side-effects in build_trainer while still preserving deterministic training behaviour for actual fit runs.

on_train_batch_start(batch, batch_idx)

Apply optional multi-scale resize to the incoming batch.

Modifications to batch (in-place on NestedTensor) are visible in training_step because they share the same object.

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

batch_idx

int

Index of the current batch within the epoch.

required

transfer_batch_to_device(batch, device, dataloader_idx)

Override PTL's default to handle NestedTensor device transfer.

PTL's default iterates tuple elements and calls .to(device); that works for plain tensors but NestedTensor must be moved explicitly.

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

device

device

Target device.

required

dataloader_idx

int

Index of the dataloader providing this batch.

required

Returns:

Type Description
Tuple

Batch with all tensors on device.

training_step(batch, batch_idx)

Compute loss for one training step and log metrics.

PTL handles gradient accumulation (accumulate_grad_batches), AMP (precision), and gradient clipping (gradient_clip_val) — no manual GradScaler or loss scaling here. The loss is divided by trainer.accumulate_grad_batches so that the accumulated gradient magnitude matches the legacy engine (which scales each sub-batch by 1/grad_accum_steps before calling backward()).

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

batch_idx

int

Batch index within the epoch.

required

Returns:

Type Description
Tensor

Scalar loss tensor.

validation_step(batch, batch_idx)

Run forward pass and postprocess for one validation step.

Returns raw results and targets so COCOEvalCallback can accumulate them across the epoch via on_validation_batch_end.

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

batch_idx

int

Batch index within the validation epoch.

required

Returns:

Type Description
Dict[str, Any]

Dict with results (postprocessed predictions) and targets.

test_step(batch, batch_idx)

Run forward pass and postprocess for one test step.

Mirrors :meth:validation_step so COCOEvalCallback can accumulate results via on_test_batch_end when trainer.test() is called (e.g. from :class:~rfdetr.training.callbacks.BestModelCallback at end of training).

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

batch_idx

int

Batch index within the test epoch.

required

Returns:

Type Description
Dict[str, Any]

Dict with results (postprocessed predictions) and targets.

predict_step(batch, batch_idx, dataloader_idx=0)

Run inference on a preprocessed batch and return postprocessed results.

Parameters:

Name Type Description Default

batch

Tuple

Tuple of (NestedTensor samples, list of target dicts).

required

batch_idx

int

Batch index.

required

dataloader_idx

int

Index of the predict dataloader.

0

Returns:

Type Description
Any

Postprocessed detection results from PostProcess.

configure_optimizers()

Build AdamW optimizer with layer-wise LR decay and LambdaLR scheduler.

Uses trainer.estimated_stepping_batches for total step count so cosine annealing covers the full training run regardless of dataset size or accumulation settings.

Returns:

Type Description
Dict[str, Any]

PTL optimizer config dict with optimizer and step-interval scheduler.

clip_gradients(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)

Override PTL gradient clipping to support fused AdamW.

PTL's AMP precision plugin refuses to clip gradients when the optimizer declares it handles unscaling internally (fused=True). When fused is active we are on BF16 (no GradScaler) so clip_grad_norm_ is correct. For the non-fused path (FP16 + GradScaler or FP32) we delegate to super() to preserve scaler-aware unscaling.

Parameters:

Name Type Description Default

optimizer

Optimizer

The current optimizer.

required

gradient_clip_val

Optional[float]

Maximum gradient norm.

None

gradient_clip_algorithm

Optional[str]

Clipping algorithm; forwarded to super() for the non-fused path.

None

on_load_checkpoint(checkpoint)

Auto-detect and normalise legacy .pth checkpoints at load time.

PTL calls this hook before applying checkpoint["state_dict"] to the module. Two legacy formats are handled:

  1. Raw legacy format — a *.pth file loaded directly by Trainer (e.g. via ckpt_path=). Recognised by the presence of "model" without "state_dict". The state dict is rewritten in-place with the "model." prefix so PTL can apply it normally.

  2. Converted format — a file produced by :func:~rfdetr.training.checkpoint.convert_legacy_checkpoint that already has "state_dict" but also carries "legacy_ema_state_dict". The EMA weights are stashed on self._pending_legacy_ema_state for optional restoration by :class:~rfdetr.training.callbacks.ema.RFDETREMACallback.

Parameters:

Name Type Description Default

checkpoint

dict[str, Any]

Checkpoint dict passed in by PTL (mutated in-place).

required

reinitialize_detection_head(num_classes)

Reinitialize the detection head for a new class count.

Parameters:

Name Type Description Default

num_classes

int

New number of classes (excluding background).

required

RFDETRDataModule

Bases: LightningDataModule

LightningDataModule wrapping RF-DETR dataset construction and data loading.

Migrates Model.train() dataset construction and DataLoader setup from main.py into PTL lifecycle hooks. Coexists with the existing code until Chapter 4 removes the legacy path.

Parameters:

Name Type Description Default

model_config

ModelConfig

Architecture configuration (used for resolution, patch_size, etc.).

required

train_config

TrainConfig

Training hyperparameter configuration (used for dataset params).

required

Attributes

class_names property

Class names from the training or validation dataset annotation file.

Reads category names from the first available COCO-style dataset. Returns None if no dataset has been set up yet or the dataset does not expose COCO-style category information.

Returns:

Type Description
Optional[List[str]]

Sorted list of class name strings, or None.

Functions

__init__(model_config, train_config)

setup(stage)

Build datasets for the requested stage.

PTL calls this on every process before the corresponding dataloader method. Datasets are built lazily — a dataset is only constructed once even if setup is called multiple times.

Parameters:

Name Type Description Default

stage

str

PTL stage identifier — one of "fit", "validate", "test", or "predict".

required

train_dataloader()

Return the training DataLoader.

Uses a replacement sampler when the dataset is too small to fill _MIN_TRAIN_BATCHES effective batches (matching legacy behaviour in main.py). Otherwise uses a BatchSampler with drop_last=True to avoid incomplete batches.

Returns:

Type Description
DataLoader

DataLoader for the training dataset.

val_dataloader()

Return the validation DataLoader.

Returns:

Type Description
DataLoader

DataLoader for the validation dataset with sequential sampling.

test_dataloader()

Return the test DataLoader.

Returns:

Type Description
DataLoader

DataLoader for the test dataset with sequential sampling.


build_trainer

Assemble a PTL Trainer with the full RF-DETR callback and logger stack.

Resolves training precision from model_config.amp and device capability, guards EMA against sharded strategies, wires conditional loggers, and applies promoted training knobs (gradient clipping, sync_batchnorm, strategy).

Parameters:

Name Type Description Default

train_config

TrainConfig

Training hyperparameter configuration.

required

model_config

ModelConfig

Architecture configuration (used for precision and segmentation).

required

accelerator

str | None

PTL accelerator string (e.g. "auto", "cpu", "gpu"). Defaults to None which reads from train_config.accelerator (itself defaulting to "auto"). Pass "cpu" to override auto-detection (e.g. when the caller explicitly requests CPU training via device="cpu").

None

**trainer_kwargs

Any

Extra keyword arguments forwarded verbatim to pytorch_lightning.Trainer. Use this to pass PTL-native flags that are not exposed through TrainConfig, for example::

build_trainer(tc, mc, fast_dev_run=2)

Any key present in both trainer_kwargs and the built config dict will be overridden by the value in trainer_kwargs.

{}

Returns:

Type Description
Trainer

A configured pytorch_lightning.Trainer instance.


Callbacks

RFDETREMACallback

Bases: Callback

Exponential Moving Average with optional tau-based warm-up.

Drop-in replacement for rfdetr.util.utils.ModelEma implemented as a plain Lightning callback around :class:torch.optim.swa_utils.AveragedModel. The _avg_fn reproduces the exact same formula as ModelEma (1-indexed updates counter, optional tau warm-up).

Parameters:

Name Type Description Default

decay

float

Base EMA decay factor. Corresponds to TrainConfig.ema_decay.

0.993

tau

int

Warm-up time constant (in optimizer steps). When > 0 the effective decay ramps from 0 towards decay following decay * (1 - exp(-updates / tau)). Corresponds to TrainConfig.ema_tau.

100

use_buffers

bool

Whether buffers are averaged in addition to parameters.

True

update_interval_steps

int

Update EMA every N optimizer steps.

1

Functions

__init__(decay=0.993, tau=100, use_buffers=True, update_interval_steps=1)

BestModelCallback

Bases: ModelCheckpoint

Track best validation mAP and save best checkpoints during training.

Extends :class:pytorch_lightning.callbacks.ModelCheckpoint to save stripped {model, args, epoch} .pth files (instead of full .ckpt files) and to track a separate EMA checkpoint in parallel.

At the end of training the overall winner (regular vs EMA, strict > for EMA) is copied to checkpoint_best_total.pth and optimizer/scheduler state is stripped via :func:rfdetr.util.misc.strip_checkpoint.

Parameters:

Name Type Description Default

output_dir

str

Directory where checkpoint files are written.

required

monitor_regular

str

Metric key for the regular model mAP.

'val/mAP_50_95'

monitor_ema

Optional[str]

Metric key for the EMA model mAP. None disables EMA tracking.

None

run_test

bool

If True, run trainer.test() on the best model at the end of training.

True

Functions

__init__(output_dir, monitor_regular='val/mAP_50_95', monitor_ema=None, run_test=True)

RFDETREarlyStopping

Bases: EarlyStopping

Early stopping callback monitoring validation mAP for RF-DETR.

Extends :class:pytorch_lightning.callbacks.EarlyStopping with dual-metric monitoring: by default it monitors max(regular_mAP, ema_mAP) (legacy behaviour); set use_ema=True to monitor the EMA metric exclusively.

The effective metric is injected into trainer.callback_metrics under a synthetic key before delegating to the parent's stopping logic, so all parent features are available for free: state_dict/load_state_dict for checkpoint resumption, NaN/inf guard via check_finite, and stopping_threshold/divergence_threshold.

Parameters:

Name Type Description Default

patience

int

Number of epochs with no improvement before stopping.

10

min_delta

float

Minimum mAP improvement to reset the patience counter.

0.001

use_ema

bool

When True and both regular and EMA metrics are available, monitor only the EMA metric. When False, monitor max(regular, ema).

False

monitor_regular

str

Metric key for the regular model mAP.

'val/mAP_50_95'

monitor_ema

str

Metric key for the EMA model mAP.

'val/ema_mAP_50_95'

verbose

bool

If True, log early stopping status each epoch.

True

Functions

__init__(patience=10, min_delta=0.001, use_ema=False, monitor_regular='val/mAP_50_95', monitor_ema='val/ema_mAP_50_95', verbose=True)

DropPathCallback

Bases: Callback

Applies per-step drop-path and dropout rate schedules to the model.

Computes the full schedule array in on_train_start using :func:rfdetr.util.drop_scheduler.drop_scheduler, then indexes into it on every training batch to update the model's stochastic-depth and dropout rates.

Parameters:

Name Type Description Default

drop_path

float

Peak drop-path rate. 0.0 disables the schedule.

0.0

dropout

float

Peak dropout rate. 0.0 disables the schedule.

0.0

cutoff_epoch

int

Epoch boundary for early / late modes.

0

mode

Literal['standard', 'early', 'late']

Schedule mode forwarded to drop_scheduler.

'standard'

schedule

Literal['constant', 'linear']

Schedule shape forwarded to drop_scheduler.

'constant'

vit_encoder_num_layers

int

Passed to model.update_drop_path so the model can distribute rates across ViT encoder layers.

12

Functions

__init__(drop_path=0.0, dropout=0.0, cutoff_epoch=0, mode='standard', schedule='constant', vit_encoder_num_layers=12)

COCOEvalCallback

Bases: Callback

Validation callback that computes mAP (via torchmetrics) and macro-F1.

Accumulates predictions and targets across validation batches, then at epoch end computes:

  • val/mAP_50_95, val/mAP_50, val/mAP_75, val/mAR using torchmetrics.detection.MeanAveragePrecision.
  • Per-class val/AP/<name> when class names are available.
  • val/F1, val/precision, val/recall from a confidence-threshold sweep over compact per-class matching data (DDP-safe).

For segmentation models (segmentation=True) additional metrics val/segm_mAP_50_95 and val/segm_mAP_50 are logged.

Parameters:

Name Type Description Default

max_dets

int

Maximum detections per image passed to MeanAveragePrecision. Defaults to 500.

500

segmentation

bool

When True, evaluate both bbox and segm IoU using backend="faster_coco_eval". Defaults to False.

False

eval_interval

int

Run validation metrics every N epochs. Test metrics are always computed when trainer.test() is called.

1

log_per_class_metrics

bool

When False, skip per-class AP logging/table.

True

Functions

__init__(max_dets=500, segmentation=False, eval_interval=1, log_per_class_metrics=True, in_notebook=None)


RFDETRCli

RFDETRCli is the command-line entry point for RF-DETR. It wraps RFDETRModule and RFDETRDataModule under a single rfdetr command and auto-generates four subcommands from the PyTorch Lightning CLI machinery:

rfdetr fit      --config configs/rfdetr_base.yaml
rfdetr validate --ckpt_path output/best.ckpt
rfdetr test     --ckpt_path output/best.ckpt
rfdetr predict  --ckpt_path output/best.ckpt

Both model_config and train_config are specified once; RFDETRCli automatically links them to the datamodule so you do not need to repeat the same arguments under --data.*.

Bases: LightningCLI

LightningCLI subclass for RF-DETR training and evaluation.

Wires RFDETRModule and RFDETRDataModule under a unified CLI, with argument linking that shares model_config and train_config between module and datamodule so the user only specifies them once.

Auto-generated subcommands: fit, validate, test, predict.