class TrainConfig(BaseModel):
"""Training hyperparameters and auto-batching configuration.
Notes:
* ``auto_batch_target_effective`` is interpreted as the **per-device**
effective batch size target, i.e. the number of images seen by a
single process in one optimizer step after accounting for
``grad_accum_steps``. In multi-GPU / multi-node runs the global
effective batch size is therefore:
``global_effective_batch = auto_batch_target_effective * devices * num_nodes``
This avoids silently changing behavior when scaling from single-GPU
to multi-GPU training.
"""
lr: float = 1e-4
lr_encoder: float = 1.5e-4
batch_size: int | Literal["auto"] = 4
grad_accum_steps: int = 4
auto_batch_target_effective: int = 16 # per-device effective batch size target (before devices * num_nodes)
# Auto-batch probe: worst-case assumptions when batch_size="auto".
auto_batch_max_targets_per_image: int = 100
auto_batch_ema_headroom: float = 0.7 # scale safe batch by this when use_ema=True (EMA uses extra memory)
epochs: int = 100
resume: Optional[str] = None
ema_decay: float = 0.993
ema_tau: int = 100
lr_drop: int = 100
checkpoint_interval: int = Field(default=10, ge=1)
warmup_epochs: float = 0.0
lr_vit_layer_decay: float = 0.8
lr_component_decay: float = 0.7
drop_path: float = 0.0
group_detr: int = 13
ia_bce_loss: bool = True
cls_loss_coef: float = 1.0
num_select: int = 300
dataset_file: Literal["coco", "o365", "roboflow", "yolo"] = "roboflow"
square_resize_div_64: bool = True
dataset_dir: str
output_dir: str = "output"
multi_scale: bool = True
expanded_scales: bool = True
do_random_resize_via_padding: bool = False
use_ema: bool = True
ema_update_interval: int = 1
num_workers: int = 2
weight_decay: float = 1e-4
early_stopping: bool = False
early_stopping_patience: int = 10
early_stopping_min_delta: float = 0.001
early_stopping_use_ema: bool = False
progress_bar: Optional[Literal["tqdm", "rich"]] = None # Progress bar style: "rich", "tqdm", or None to disable.
tensorboard: bool = True
wandb: bool = False
mlflow: bool = False
clearml: bool = False # Not yet implemented — reserved for future use.
project: Optional[str] = None
run: Optional[str] = None
class_names: List[str] = None
run_test: bool = False
segmentation_head: bool = False
eval_max_dets: int = 500
eval_interval: int = 1
log_per_class_metrics: bool = True
aug_config: Optional[Dict[str, Any]] = None
@model_validator(mode="after")
def _warn_deprecated_train_config_fields(self) -> "TrainConfig":
"""Emit DeprecationWarning for fields whose ownership is moving to ModelConfig.
The following fields are duplicated between ``ModelConfig`` and ``TrainConfig``
but ``ModelConfig`` is the authoritative source (Item #3, v1.7). Setting them
on ``TrainConfig`` is deprecated. The fields will be removed in v1.9.
- ``group_detr``: query group count is an architecture decision → ``ModelConfig``
- ``ia_bce_loss``: loss type is tied to architecture family → ``ModelConfig``
- ``segmentation_head``: architecture flag → ``ModelConfig``
- ``num_select``: postprocessor count is an architecture decision → ``ModelConfig``
"""
_deprecated = ("group_detr", "ia_bce_loss", "segmentation_head", "num_select")
for field in _deprecated:
if field in self.model_fields_set:
# stacklevel=2 points into Pydantic internals; unavoidable with
# @model_validator(mode="after") in Pydantic v2.
warnings.warn(
f"TrainConfig.{field} is deprecated and will be removed in v1.9. "
f"Set {field} on ModelConfig instead.",
DeprecationWarning,
stacklevel=2,
)
return self
@field_validator("progress_bar", mode="before")
@classmethod
def _coerce_legacy_progress_bar(cls, value: Any) -> Any:
"""Normalize legacy boolean progress_bar values to the new string/None representation.
This preserves compatibility with older configs where ``progress_bar`` was a bool.
"""
if isinstance(value, bool):
return "tqdm" if value else None
return value
# Promoted from populate_args() — PTL migration (T4-2).
# device is intentionally absent: PTL auto-detects accelerator via Trainer(accelerator="auto").
accelerator: str = "auto"
clip_max_norm: float = 0.1
seed: Optional[int] = None
sync_bn: bool = False
# strategy maps to PTL Trainer(strategy=...). Common values: "auto", "ddp",
# "ddp_spawn", "fsdp", "deepspeed". Invalid values surface as PTL errors.
strategy: str = "auto"
devices: Union[int, str] = 1
# num_nodes maps to PTL Trainer(num_nodes=...) for multi-machine training.
# Single-machine DDP users should leave this at 1 (the default).
num_nodes: int = 1
fp16_eval: bool = False
lr_scheduler: Literal["step", "cosine"] = "step"
lr_min_factor: float = 0.0
dont_save_weights: bool = False
# PTL runtime/perf tuning knobs.
train_log_sync_dist: bool = False
train_log_on_step: bool = False
compute_val_loss: bool = True
compute_test_loss: bool = True
pin_memory: Optional[bool] = None
persistent_workers: Optional[bool] = None
prefetch_factor: Optional[int] = None
@field_validator("batch_size", mode="after")
@classmethod
def validate_batch_size(cls, v: int | Literal["auto"]) -> int | Literal["auto"]:
"""Validate batch_size is a positive integer or the literal 'auto'."""
if v == "auto":
return v
if v < 1:
raise ValueError("batch_size must be >= 1, or 'auto'.")
return v
@field_validator(
"grad_accum_steps", "auto_batch_target_effective", "auto_batch_max_targets_per_image", mode="after"
)
@classmethod
def validate_positive_train_steps(cls, v: int) -> int:
"""Validate accumulation, target-effective batch, and max targets are >= 1."""
if v < 1:
raise ValueError(
"grad_accum_steps, auto_batch_target_effective, and auto_batch_max_targets_per_image must be >= 1."
)
return v
@field_validator("auto_batch_ema_headroom", mode="after")
@classmethod
def validate_ema_headroom(cls, v: float) -> float:
"""Validate auto_batch_ema_headroom is in (0, 1]."""
if not (0 < v <= 1.0):
raise ValueError("auto_batch_ema_headroom must be in (0, 1].")
return v
@field_validator("ema_update_interval", "eval_interval", mode="after")
@classmethod
def validate_positive_intervals(cls, v: int) -> int:
"""Validate interval fields are >= 1."""
if v < 1:
raise ValueError("Interval fields must be >= 1.")
return v
@field_validator("prefetch_factor", mode="after")
@classmethod
def validate_prefetch_factor(cls, v: Optional[int]) -> Optional[int]:
"""Validate prefetch_factor is None or >= 1."""
if v is not None and v < 1:
raise ValueError("prefetch_factor must be >= 1 when provided.")
return v
@field_validator("dataset_dir", "output_dir", mode="after")
@classmethod
def expand_paths(cls, v: str) -> str:
"""
Expand user paths (e.g., '~' or paths with separators) but leave simple filenames
(like 'rf-detr-base.pth') unchanged so they can match hosted model keys.
"""
if v is None:
return v
return os.path.realpath(os.path.expanduser(v))