41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def train_yolo(data_yaml: str | Path, model: str = "yolov8s.pt", imgsz: int = 1024, epochs: int = 80, batch: int = 4, device: str = "0"):
|
|
"""Train YOLO on the generated weak-label dataset.
|
|
|
|
This function imports ultralytics lazily so the rest of the PoC works without GPU dependencies.
|
|
Review/correct the weak labels before treating this model as useful.
|
|
"""
|
|
from ultralytics import YOLO
|
|
|
|
yolo = YOLO(model)
|
|
LOG.info("Starting YOLO training: model=%s data=%s imgsz=%d epochs=%d batch=%d device=%s", model, data_yaml, imgsz, epochs, batch, device)
|
|
return yolo.train(
|
|
data=str(data_yaml),
|
|
imgsz=imgsz,
|
|
epochs=epochs,
|
|
batch=batch,
|
|
device=device,
|
|
workers=4,
|
|
cache=False,
|
|
patience=20,
|
|
project="runs/bgtopo_bluebox",
|
|
name=f"{Path(data_yaml).parent.name}_{Path(model).stem}",
|
|
hsv_h=0.005,
|
|
hsv_s=0.20,
|
|
hsv_v=0.18,
|
|
degrees=0.0,
|
|
translate=0.05,
|
|
scale=0.20,
|
|
fliplr=0.5,
|
|
flipud=0.5,
|
|
mosaic=0.25,
|
|
close_mosaic=15,
|
|
)
|