Skip to content

Train Config

Bases: BaseModel

Source code in src/rfdetr/config.py
class TrainConfig(BaseModel):
    lr: float = 1e-4
    lr_encoder: float = 1.5e-4
    batch_size: int = 4
    grad_accum_steps: int = 4
    epochs: int = 100
    resume: Optional[str] = None
    ema_decay: float = 0.993
    ema_tau: int = 100
    lr_drop: int = 100
    checkpoint_interval: int = 10
    warmup_epochs: float = 0.0
    lr_vit_layer_decay: float = 0.8
    lr_component_decay: float = 0.7
    drop_path: float = 0.0
    group_detr: int = 13
    ia_bce_loss: bool = True
    cls_loss_coef: float = 1.0
    num_select: int = 300
    dataset_file: Literal["coco", "o365", "roboflow", "yolo"] = "roboflow"
    square_resize_div_64: bool = True
    dataset_dir: str
    output_dir: str = "output"
    multi_scale: bool = True
    expanded_scales: bool = True
    do_random_resize_via_padding: bool = False
    use_ema: bool = True
    ema_update_interval: int = 1
    num_workers: int = 2
    weight_decay: float = 1e-4
    early_stopping: bool = False
    early_stopping_patience: int = 10
    early_stopping_min_delta: float = 0.001
    early_stopping_use_ema: bool = False
    progress_bar: bool = False  # Enable tqdm progress bars during training and evaluation epochs.
    tensorboard: bool = True
    wandb: bool = False
    mlflow: bool = False
    clearml: bool = False
    project: Optional[str] = None
    run: Optional[str] = None
    class_names: List[str] = None
    run_test: bool = False
    segmentation_head: bool = False
    eval_max_dets: int = 500
    eval_interval: int = 1
    log_per_class_metrics: bool = True
    aug_config: Optional[Dict[str, Any]] = None
    # Promoted from populate_args() — PTL migration (T4-2).
    # device is intentionally absent: PTL auto-detects accelerator via Trainer(accelerator="auto").
    accelerator: str = "auto"
    clip_max_norm: float = 0.1
    seed: Optional[int] = None
    sync_bn: bool = False
    fp16_eval: bool = False
    lr_scheduler: Literal["step", "cosine"] = "step"
    lr_min_factor: float = 0.0
    dont_save_weights: bool = False
    # PTL runtime/perf tuning knobs.
    train_log_sync_dist: bool = False
    train_log_on_step: bool = False
    compute_val_loss: bool = True
    compute_test_loss: bool = True
    pin_memory: Optional[bool] = None
    persistent_workers: Optional[bool] = None
    prefetch_factor: Optional[int] = None

    @field_validator("ema_update_interval", "eval_interval", mode="after")
    @classmethod
    def validate_positive_intervals(cls, v: int) -> int:
        """Validate interval fields are >= 1."""
        if v < 1:
            raise ValueError("Interval fields must be >= 1.")
        return v

    @field_validator("prefetch_factor", mode="after")
    @classmethod
    def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
        """Validate prefetch_factor is None or >= 1."""
        if v is not None and v < 1:
            raise ValueError("prefetch_factor must be >= 1 when provided.")
        return v

    @field_validator("dataset_dir", "output_dir", mode="after")
    @classmethod
    def expand_paths(cls, v: str) -> str:
        """
        Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
        (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
        """
        if v is None:
            return v
        return os.path.realpath(os.path.expanduser(v))

Functions

expand_paths(v) classmethod

Expand user paths (e.g., '~' or paths with separators) but leave simple filenames (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
    """
    Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
    (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
    """
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(v))

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v