Vassa007

training yolo

Feb 14th, 2025 (edited)
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.72 KB | None | 0 0
  1. import torch
  2. from ultralytics import YOLO
  3. import os
  4. import glob
  5. import datetime
  6.  
  7. def get_latest_model():
  8.     """Mencari model terbaru di folder hasil training."""
  9.     model_paths = glob.glob("runs/detect/train*/weights/best.pt")
  10.     if not model_paths:
  11.         print("❌ Tidak ada model ditemukan! Pastikan training sudah dilakukan.")
  12.         return None
  13.     latest_model = max(model_paths, key=os.path.getctime)
  14.     print(f"✅ Model terbaru ditemukan: {latest_model}")
  15.     return latest_model
  16.  
  17. def train_models_few_shot():
  18.     """Training YOLO dengan dataset kecil (Few-Shot Learning)."""
  19.     dataset_path = "assets/datasets/dataset.yaml"
  20.     model_path = "yolov8n.pt"  # Bisa diganti dengan yolov8s.pt untuk akurasi lebih baik
  21.  
  22.     # Cek apakah model YOLO tersedia
  23.     if not os.path.exists(model_path):
  24.         print(f"❌ Model {model_path} tidak ditemukan! Harap unduh model terlebih dahulu.")
  25.         return
  26.  
  27.     # Buat direktori training unik berdasarkan timestamp
  28.     timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  29.     save_dir = f"runs/detect/train_{timestamp}"
  30.  
  31.     os.makedirs(save_dir, exist_ok=True)
  32.  
  33.     print("📂 Memuat model YOLO untuk Few-Shot Learning...")
  34.     model = YOLO(model_path)
  35.  
  36.     print("🚀 Memulai training YOLO Few-Shot Learning...")
  37.  
  38.     # Training dengan strategi Few-Shot Learning
  39.     model.train(
  40.         data=dataset_path,
  41.         epochs=50,  # Kurangi epoch untuk menghindari overfitting
  42.         imgsz=512,  # Resolusi gambar
  43.         batch=16,   # Ukuran batch kecil untuk dataset kecil
  44.         device="cuda" if torch.cuda.is_available() else "cpu",  # Gunakan GPU jika tersedia
  45.         workers=0,  # Worker kecil karena dataset kecil
  46.         project="runs/detect",
  47.         name=f"train_{timestamp}",
  48.         amp=True,  # Mixed Precision Training (menghemat memori dan mempercepat training)
  49.         exist_ok=True,
  50.         save_period=5,  # Simpan checkpoint lebih sering (setiap 5 epoch)
  51.         freeze=10,  # Freeze 10 layer pertama agar hanya head layer yang belajar
  52.         patience=10,  # Early stopping jika tidak ada peningkatan dalam 10 epoch
  53.         augment=True,  # Aktifkan augmentasi agar model lebih kuat
  54.     )
  55.  
  56.     print(f"✅ Training Few-Shot Learning selesai! Model disimpan di {save_dir}/weights/best.pt")
  57.  
  58. def evaluate_latest_model():
  59.     """Evaluasi model terbaru secara otomatis."""
  60.     latest_model = get_latest_model()
  61.     if not latest_model:
  62.         return
  63.  
  64.     print("📊 Evaluasi model Few-Shot Learning...")
  65.     model = YOLO(latest_model)
  66.     model.val(save_json=True)  # Simpan hasil evaluasi dalam format JSON
  67.     print("✅ Evaluasi selesai!")
  68.  
  69. # Eksekusi training Few-Shot Learning
  70. train_models_few_shot()
  71. evaluate_latest_model()
  72.  
Add Comment
Please, Sign In to add comment