Skip to content

Train Config

Bases: BaseModel

Training hyperparameters and auto-batching configuration.

Notes
  • auto_batch_target_effective is interpreted as the per-device effective batch size target, i.e. the number of images seen by a single process in one optimizer step after accounting for grad_accum_steps. In multi-GPU / multi-node runs the global effective batch size is therefore:

    global_effective_batch = auto_batch_target_effective * devices * num_nodes

This avoids silently changing behavior when scaling from single-GPU to multi-GPU training.

Source code in src/rfdetr/config.py
class TrainConfig(BaseModel):
    """Training hyperparameters and auto-batching configuration.

    Notes:
        * ``auto_batch_target_effective`` is interpreted as the **per-device**
          effective batch size target, i.e. the number of images seen by a
          single process in one optimizer step after accounting for
          ``grad_accum_steps``. In multi-GPU / multi-node runs the global
          effective batch size is therefore:

            ``global_effective_batch = auto_batch_target_effective * devices * num_nodes``

          This avoids silently changing behavior when scaling from single-GPU
          to multi-GPU training.
    """

    lr: float = 1e-4
    lr_encoder: float = 1.5e-4
    batch_size: int | Literal["auto"] = 4
    grad_accum_steps: int = 4
    auto_batch_target_effective: int = 16  # per-device effective batch size target (before devices * num_nodes)
    # Auto-batch probe: worst-case assumptions when batch_size="auto".
    auto_batch_max_targets_per_image: int = 100
    auto_batch_ema_headroom: float = 0.7  # scale safe batch by this when use_ema=True (EMA uses extra memory)
    epochs: int = 100
    resume: Optional[str] = None
    ema_decay: float = 0.993
    ema_tau: int = 100
    lr_drop: int = 100
    checkpoint_interval: int = Field(default=10, ge=1)
    warmup_epochs: float = 0.0
    lr_vit_layer_decay: float = 0.8
    lr_component_decay: float = 0.7
    drop_path: float = 0.0
    group_detr: int = 13
    ia_bce_loss: bool = True
    cls_loss_coef: float = 1.0
    num_select: int = 300
    dataset_file: Literal["coco", "o365", "roboflow", "yolo"] = "roboflow"
    square_resize_div_64: bool = True
    dataset_dir: str
    output_dir: str = "output"
    multi_scale: bool = True
    expanded_scales: bool = True
    do_random_resize_via_padding: bool = False
    use_ema: bool = True
    ema_update_interval: int = 1
    num_workers: int = 2
    weight_decay: float = 1e-4
    early_stopping: bool = False
    early_stopping_patience: int = 10
    early_stopping_min_delta: float = 0.001
    early_stopping_use_ema: bool = False
    progress_bar: Optional[Literal["tqdm", "rich"]] = None  # Progress bar style: "rich", "tqdm", or None to disable.
    tensorboard: bool = True
    wandb: bool = False
    mlflow: bool = False
    clearml: bool = False  # Not yet implemented — reserved for future use.
    project: Optional[str] = None
    run: Optional[str] = None
    class_names: List[str] = None
    run_test: bool = False
    segmentation_head: bool = False
    eval_max_dets: int = 500
    eval_interval: int = 1
    log_per_class_metrics: bool = True
    aug_config: Optional[Dict[str, Any]] = None

    @model_validator(mode="after")
    def _warn_deprecated_train_config_fields(self) -> "TrainConfig":
        """Emit DeprecationWarning for fields whose ownership is moving to ModelConfig.

        The following fields are duplicated between ``ModelConfig`` and ``TrainConfig``
        but ``ModelConfig`` is the authoritative source (Item #3, v1.7).  Setting them
        on ``TrainConfig`` is deprecated.  The fields will be removed in v1.9.

        - ``group_detr``: query group count is an architecture decision → ``ModelConfig``
        - ``ia_bce_loss``: loss type is tied to architecture family → ``ModelConfig``
        - ``segmentation_head``: architecture flag → ``ModelConfig``
        - ``num_select``: postprocessor count is an architecture decision → ``ModelConfig``
        """
        _deprecated = ("group_detr", "ia_bce_loss", "segmentation_head", "num_select")
        for field in _deprecated:
            if field in self.model_fields_set:
                # stacklevel=2 points into Pydantic internals; unavoidable with
                # @model_validator(mode="after") in Pydantic v2.
                warnings.warn(
                    f"TrainConfig.{field} is deprecated and will be removed in v1.9. "
                    f"Set {field} on ModelConfig instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )
        return self

    @field_validator("progress_bar", mode="before")
    @classmethod
    def _coerce_legacy_progress_bar(cls, value: Any) -> Any:
        """Normalize legacy boolean progress_bar values to the new string/None representation.

        This preserves compatibility with older configs where ``progress_bar`` was a bool.
        """
        if isinstance(value, bool):
            return "tqdm" if value else None
        return value

    # Promoted from populate_args() — PTL migration (T4-2).
    # device is intentionally absent: PTL auto-detects accelerator via Trainer(accelerator="auto").
    accelerator: str = "auto"
    clip_max_norm: float = 0.1
    seed: Optional[int] = None
    sync_bn: bool = False
    # strategy maps to PTL Trainer(strategy=...). Common values: "auto", "ddp",
    # "ddp_spawn", "fsdp", "deepspeed". Invalid values surface as PTL errors.
    strategy: str = "auto"
    devices: Union[int, str] = 1
    # num_nodes maps to PTL Trainer(num_nodes=...) for multi-machine training.
    # Single-machine DDP users should leave this at 1 (the default).
    num_nodes: int = 1
    fp16_eval: bool = False
    lr_scheduler: Literal["step", "cosine"] = "step"
    lr_min_factor: float = 0.0
    dont_save_weights: bool = False
    # PTL runtime/perf tuning knobs.
    train_log_sync_dist: bool = False
    train_log_on_step: bool = False
    compute_val_loss: bool = True
    compute_test_loss: bool = True
    pin_memory: Optional[bool] = None
    persistent_workers: Optional[bool] = None
    prefetch_factor: Optional[int] = None

    @field_validator("batch_size", mode="after")
    @classmethod
    def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
        """Validate batch_size is a positive integer or the literal 'auto'."""
        if v == "auto":
            return v
        if v < 1:
            raise ValueError("batch_size must be >= 1, or 'auto'.")
        return v

    @field_validator(
        "grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
    )
    @classmethod
    def validate_positive_train_steps(cls, v: int) -> int:
        """Validate accumulation, target-effective batch, and max targets are >= 1."""
        if v < 1:
            raise ValueError(
                "grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
            )
        return v

    @field_validator("auto_batch_ema_headroom", mode="after")
    @classmethod
    def validate_ema_headroom(cls, v: float) -> float:
        """Validate auto_batch_ema_headroom is in (0, 1]."""
        if not (0 < v <= 1.0):
            raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
        return v

    @field_validator("ema_update_interval", "eval_interval", mode="after")
    @classmethod
    def validate_positive_intervals(cls, v: int) -> int:
        """Validate interval fields are >= 1."""
        if v < 1:
            raise ValueError("Interval fields must be >= 1.")
        return v

    @field_validator("prefetch_factor", mode="after")
    @classmethod
    def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
        """Validate prefetch_factor is None or >= 1."""
        if v is not None and v < 1:
            raise ValueError("prefetch_factor must be >= 1 when provided.")
        return v

    @field_validator("dataset_dir", "output_dir", mode="after")
    @classmethod
    def expand_paths(cls, v: str) -> str:
        """
        Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
        (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
        """
        if v is None:
            return v
        return os.path.realpath(os.path.expanduser(v))

Functions

expand_paths(v) classmethod

Expand user paths (e.g., '~' or paths with separators) but leave simple filenames (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
    """
    Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
    (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
    """
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(v))

validate_batch_size(v) classmethod

Validate batch_size is a positive integer or the literal 'auto'.

Source code in src/rfdetr/config.py
@field_validator("batch_size", mode="after")
@classmethod
def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
    """Validate batch_size is a positive integer or the literal 'auto'."""
    if v == "auto":
        return v
    if v < 1:
        raise ValueError("batch_size must be >= 1, or 'auto'.")
    return v

validate_ema_headroom(v) classmethod

Validate auto_batch_ema_headroom is in (0, 1].

Source code in src/rfdetr/config.py
@field_validator("auto_batch_ema_headroom", mode="after")
@classmethod
def validate_ema_headroom(cls, v: float) -> float:
    """Validate auto_batch_ema_headroom is in (0, 1]."""
    if not (0 < v <= 1.0):
        raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
    return v

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_positive_train_steps(v) classmethod

Validate accumulation, target-effective batch, and max targets are >= 1.

Source code in src/rfdetr/config.py
@field_validator(
    "grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
)
@classmethod
def validate_positive_train_steps(cls, v: int) -> int:
    """Validate accumulation, target-effective batch, and max targets are >= 1."""
    if v < 1:
        raise ValueError(
            "grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
        )
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v