Skip to content

Segmentation Train Config

Bases: TrainConfig

Source code in src/rfdetr/config.py
class SegmentationTrainConfig(TrainConfig):
    num_select: Optional[int] = None
    mask_point_sample_ratio: int = 16
    mask_ce_loss_coef: float = 5.0
    mask_dice_loss_coef: float = 5.0
    cls_loss_coef: float = 5.0
    segmentation_head: bool = True

Functions

expand_paths(v) classmethod

Expand user paths (e.g., '~' or paths with separators) but leave simple filenames (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.

Source code in src/rfdetr/config.py
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
    """
    Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
    (like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
    """
    if v is None:
        return v
    return os.path.realpath(os.path.expanduser(v))

validate_positive_intervals(v) classmethod

Validate interval fields are >= 1.

Source code in src/rfdetr/config.py
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
    """Validate interval fields are >= 1."""
    if v < 1:
        raise ValueError("Interval fields must be >= 1.")
    return v

validate_prefetch_factor(v) classmethod

Validate prefetch_factor is None or >= 1.

Source code in src/rfdetr/config.py
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
    """Validate prefetch_factor is None or >= 1."""
    if v is not None and v < 1:
        raise ValueError("prefetch_factor must be >= 1 when provided.")
    return v