Skip to content

Segmentation Train Config

Bases: TrainConfig

Source code in src/rfdetr/config.py
class SegmentationTrainConfig(TrainConfig):
    num_select: Optional[int] = None
    mask_point_sample_ratio: int = 16
    mask_ce_loss_coef: float = 5.0
    mask_dice_loss_coef: float = 5.0
    cls_loss_coef: float = 5.0
    segmentation_head: bool = True

Functions

expand_paths(v) classmethod

Expand user paths (e.g., '~' or paths with separators) but leave simple filenames (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
    """
    Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
    (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
    """
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(v))

validate_batch_size(v) classmethod

Validate batch_size is a positive integer or the literal 'auto'.

Source code in src/rfdetr/config.py
@field_validator("batch_size", mode="after")
@classmethod
def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
    """Validate batch_size is a positive integer or the literal 'auto'."""
    if v == "auto":
        return v
    if v < 1:
        raise ValueError("batch_size must be >= 1, or 'auto'.")
    return v

validate_ema_headroom(v) classmethod

Validate auto_batch_ema_headroom is in (0, 1].

Source code in src/rfdetr/config.py
@field_validator("auto_batch_ema_headroom", mode="after")
@classmethod
def validate_ema_headroom(cls, v: float) -> float:
    """Validate auto_batch_ema_headroom is in (0, 1]."""
    if not (0 < v <= 1.0):
        raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
    return v

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_positive_train_steps(v) classmethod

Validate accumulation, target-effective batch, and max targets are >= 1.

Source code in src/rfdetr/config.py
@field_validator(
    "grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
)
@classmethod
def validate_positive_train_steps(cls, v: int) -> int:
    """Validate accumulation, target-effective batch, and max targets are >= 1."""
    if v < 1:
        raise ValueError(
            "grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
        )
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v