Skip to content

Keypoint Train Config

Bases: TrainConfig

Training configuration for keypoint detection models.

Extends :class:TrainConfig with keypoint-specific loss coefficients and metric-smoothing defaults tuned for the NLL-Cholesky keypoint head, which produces noisy per-epoch OKS metrics during early fine-tuning.

Attributes:

Name Type Description
cls_loss_coef float

Classification loss weight.

keypoint_l1_loss_coef float

L1 regression loss weight for keypoint coordinates.

keypoint_findable_loss_coef float

Loss weight for the keypoint visibility head.

keypoint_visible_loss_coef float

Loss weight for the keypoint visibility score.

keypoint_nll_loss_coef float

NLL-Cholesky loss weight. Reduced from 1.0 to 0.5 to dampen OKS@75 oscillation caused by precision-coupling in the Cholesky parameterisation.

smooth_alpha float

EMA smoothing factor for :class:BestModelCallback metric comparison. Overrides the :class:TrainConfig default of 0.0 (disabled) to 0.5, which balances responsiveness and noise suppression for noisy keypoint mAP curves.

skip_best_epochs int

Number of epochs to skip before checkpoint selection begins. Overrides the :class:TrainConfig default of 0 to 10 because val/keypoint_map_50_95 under the NLL-Cholesky loss is noisy in early fine-tuning and can lock checkpoint selection to a transient peak.

Source code in src/rfdetr/config.py
class KeypointTrainConfig(TrainConfig):
    """Training configuration for keypoint detection models.

    Extends :class:`TrainConfig` with keypoint-specific loss coefficients and
    metric-smoothing defaults tuned for the NLL-Cholesky keypoint head, which
    produces noisy per-epoch OKS metrics during early fine-tuning.

    Attributes:
        cls_loss_coef: Classification loss weight.
        keypoint_l1_loss_coef: L1 regression loss weight for keypoint coordinates.
        keypoint_findable_loss_coef: Loss weight for the keypoint visibility head.
        keypoint_visible_loss_coef: Loss weight for the keypoint visibility score.
        keypoint_nll_loss_coef: NLL-Cholesky loss weight. Reduced from 1.0 to 0.5
            to dampen OKS@75 oscillation caused by precision-coupling in the
            Cholesky parameterisation.
        smooth_alpha: EMA smoothing factor for :class:`BestModelCallback` metric
            comparison. Overrides the :class:`TrainConfig` default of ``0.0``
            (disabled) to ``0.5``, which balances responsiveness and noise
            suppression for noisy keypoint mAP curves.
        skip_best_epochs: Number of epochs to skip before checkpoint selection begins.
            Overrides the :class:`TrainConfig` default of ``0`` to ``10`` because
            ``val/keypoint_map_50_95`` under the NLL-Cholesky loss is noisy in early
            fine-tuning and can lock checkpoint selection to a transient peak.
    """

    cls_loss_coef: float = 2.0  # TODO: verify empirically before final release; ported as-is from internal recipe.
    keypoint_l1_loss_coef: float = 1
    keypoint_findable_loss_coef: float = 1
    keypoint_visible_loss_coef: float = 1
    keypoint_nll_loss_coef: float = 0.5
    smooth_alpha: float = 0.5
    skip_best_epochs: int = Field(default=10, ge=0)

    @model_validator(mode="after")
    def _warn_keypoint_flip_pairs_not_yet_implemented(self) -> "KeypointTrainConfig":
        """Emit a warning when keypoint_flip_pairs is set before the feature ships."""
        if self.keypoint_flip_pairs:
            warnings.warn(
                "keypoint_flip_pairs is accepted but not yet implemented and will be ignored. "
                "Flip pair swapping (swapping left/right joint indices after a horizontal flip) "
                "is planned for a future release. Training will proceed without semantic joint swapping.",
                UserWarning,
                stacklevel=2,
            )
        return self

Functions

expand_paths(v) classmethod

Expand and normalize dataset/output directory paths via os.fspathexpanduserrealpath.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="before")
@classmethod
def expand_paths(cls, v: PathLikeStr | None) -> str | None:
    """Expand and normalize dataset/output directory paths via ``os.fspath`` → ``expanduser`` → ``realpath``."""
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(os.fspath(v)))

validate_batch_size(v) classmethod

Validate batch_size is a positive integer or the literal 'auto'.

Source code in src/rfdetr/config.py
@field_validator("batch_size", mode="after")
@classmethod
def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
    """Validate batch_size is a positive integer or the literal 'auto'."""
    if v == "auto":
        return v
    if v < 1:
        raise ValueError("batch_size must be >= 1, or 'auto'.")
    return v

validate_ema_headroom(v) classmethod

Validate auto_batch_ema_headroom is in (0, 1].

Source code in src/rfdetr/config.py
@field_validator("auto_batch_ema_headroom", mode="after")
@classmethod
def validate_ema_headroom(cls, v: float) -> float:
    """Validate auto_batch_ema_headroom is in (0, 1]."""
    if not (0 < v <= 1.0):
        raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
    return v

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_positive_train_steps(v) classmethod

Validate accumulation, target-effective batch, and max targets are >= 1.

Source code in src/rfdetr/config.py
@field_validator(
    "grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
)
@classmethod
def validate_positive_train_steps(cls, v: int) -> int:
    """Validate accumulation, target-effective batch, and max targets are >= 1."""
    if v < 1:
        raise ValueError(
            "grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
        )
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v

validate_smooth_alpha(v) classmethod

Validate smooth_alpha is in [0.0, 1.0).

Source code in src/rfdetr/config.py
@field_validator("smooth_alpha", mode="after")
@classmethod
def validate_smooth_alpha(cls, v: float) -> float:
    """Validate smooth_alpha is in [0.0, 1.0)."""
    if not (0.0 <= v < 1.0):
        raise ValueError("smooth_alpha must be in [0.0, 1.0).")
    return v