Skip to content

Segmentation Train Config

Bases: TrainConfig

Source code in src/rfdetr/config.py
class SegmentationTrainConfig(TrainConfig):
    num_select: Optional[int] = None
    mask_point_sample_ratio: int = 16
    mask_ce_loss_coef: float = 5.0
    mask_dice_loss_coef: float = 5.0
    cls_loss_coef: float = 5.0
    segmentation_head: bool = True

    @model_validator(mode="after")
    def warn_deprecated_num_select(self) -> "SegmentationTrainConfig":
        """Warn when callers explicitly set the deprecated train-time ``num_select`` field."""
        if "num_select" in self.model_fields_set and self.num_select is not None:
            warnings.warn(
                "TrainConfig.num_select is deprecated and ignored by "
                "PTL/inference; set ModelConfig.num_select instead.",
                DeprecationWarning,
                stacklevel=2,
            )
        return self

Functions

expand_paths(v) classmethod

Expand user paths (e.g., '~' or paths with separators) but leave simple filenames (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
    """
    Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
    (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
    """
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(v))

validate_batch_size(v) classmethod

Validate batch_size is a positive integer or the literal 'auto'.

Source code in src/rfdetr/config.py
@field_validator("batch_size", mode="after")
@classmethod
def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
    """Validate batch_size is a positive integer or the literal 'auto'."""
    if v == "auto":
        return v
    if v < 1:
        raise ValueError("batch_size must be >= 1, or 'auto'.")
    return v

validate_ema_headroom(v) classmethod

Validate auto_batch_ema_headroom is in (0, 1].

Source code in src/rfdetr/config.py
@field_validator("auto_batch_ema_headroom", mode="after")
@classmethod
def validate_ema_headroom(cls, v: float) -> float:
    """Validate auto_batch_ema_headroom is in (0, 1]."""
    if not (0 < v <= 1.0):
        raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
    return v

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_positive_train_steps(v) classmethod

Validate accumulation, target-effective batch, and max targets are >= 1.

Source code in src/rfdetr/config.py
@field_validator(
    "grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
)
@classmethod
def validate_positive_train_steps(cls, v: int) -> int:
    """Validate accumulation, target-effective batch, and max targets are >= 1."""
    if v < 1:
        raise ValueError(
            "grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
        )
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v

warn_deprecated_num_select()

Warn when callers explicitly set the deprecated train-time num_select field.

Source code in src/rfdetr/config.py
@model_validator(mode="after")
def warn_deprecated_num_select(self) -> "SegmentationTrainConfig":
    """Warn when callers explicitly set the deprecated train-time ``num_select`` field."""
    if "num_select" in self.model_fields_set and self.num_select is not None:
        warnings.warn(
            "TrainConfig.num_select is deprecated and ignored by "
            "PTL/inference; set ModelConfig.num_select instead.",
            DeprecationWarning,
            stacklevel=2,
        )
    return self