rnd-v1
This commit is contained in:
40
bgtopo_poc/train_yolo.py
Normal file
40
bgtopo_poc/train_yolo.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_yolo(data_yaml: str | Path, model: str = "yolov8s.pt", imgsz: int = 1024, epochs: int = 80, batch: int = 4, device: str = "0"):
|
||||
"""Train YOLO on the generated weak-label dataset.
|
||||
|
||||
This function imports ultralytics lazily so the rest of the PoC works without GPU dependencies.
|
||||
Review/correct the weak labels before treating this model as useful.
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
|
||||
yolo = YOLO(model)
|
||||
LOG.info("Starting YOLO training: model=%s data=%s imgsz=%d epochs=%d batch=%d device=%s", model, data_yaml, imgsz, epochs, batch, device)
|
||||
return yolo.train(
|
||||
data=str(data_yaml),
|
||||
imgsz=imgsz,
|
||||
epochs=epochs,
|
||||
batch=batch,
|
||||
device=device,
|
||||
workers=4,
|
||||
cache=False,
|
||||
patience=20,
|
||||
project="runs/bgtopo_bluebox",
|
||||
name=f"{Path(data_yaml).parent.name}_{Path(model).stem}",
|
||||
hsv_h=0.005,
|
||||
hsv_s=0.20,
|
||||
hsv_v=0.18,
|
||||
degrees=0.0,
|
||||
translate=0.05,
|
||||
scale=0.20,
|
||||
fliplr=0.5,
|
||||
flipud=0.5,
|
||||
mosaic=0.25,
|
||||
close_mosaic=15,
|
||||
)
|
||||
Reference in New Issue
Block a user