Vassa007

few-shot learning

Feb 14th, 2025 (edited)
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.32 KB | None | 0 0
  1. import os
  2. import torch
  3. import numpy as np
  4. import cv2
  5. from tqdm import tqdm
  6. from torchvision import models, transforms
  7. from sklearn.neighbors import KNeighborsClassifier
  8.  
  9. # Cek apakah GPU tersedia dan kompatibel dengan CUDA 12.2
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. print(f"🚀 Menggunakan device: {device}")
  12.  
  13. # Path dataset
  14. DATASET_PATH = "assets/datasets"
  15. TRAIN_PATH = os.path.join(DATASET_PATH, "train")
  16. TEST_PATH = os.path.join(DATASET_PATH, "test")
  17.  
  18. # Load Pretrained Model (EfficientNet-B0 sebagai Feature Extractor)
  19. efficientnet = models.efficientnet_b0(weights="IMAGENET1K_V1")
  20. efficientnet = torch.nn.Sequential(*list(efficientnet.children())[:-1])  # Hapus layer FC
  21. efficientnet = efficientnet.to(device).eval()
  22.  
  23. # Transformasi untuk preprocessing gambar
  24. transform = transforms.Compose([
  25.     transforms.ToPILImage(),
  26.     transforms.Resize((224, 224)),
  27.     transforms.ToTensor(),
  28.     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  29. ])
  30.  
  31. def extract_features(image_paths):
  32.     """Ekstrak fitur dari batch gambar menggunakan EfficientNet-B0."""
  33.     batch_images = []
  34.     valid_paths = []
  35.  
  36.     for image_path in image_paths:
  37.         image = cv2.imread(image_path)
  38.         if image is not None:
  39.             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  40.             image = transform(image)
  41.             batch_images.append(image)
  42.             valid_paths.append(image_path)
  43.  
  44.     if len(batch_images) == 0:
  45.         return None, valid_paths
  46.  
  47.     batch_tensor = torch.stack(batch_images).to(device)  # Konversi ke tensor batch
  48.     with torch.no_grad():
  49.         features = efficientnet(batch_tensor).squeeze().cpu().numpy()
  50.  
  51.     return features, valid_paths
  52.  
  53. def load_dataset(dataset_path, batch_size=32):
  54.     """Muat dataset dan ekstrak fitur dari setiap gambar dalam batch."""
  55.     X, y = [], []
  56.     image_paths = []
  57.     labels = []
  58.  
  59.     for person_id in os.listdir(dataset_path):
  60.         person_folder = os.path.join(dataset_path, person_id)
  61.         if not os.path.isdir(person_folder):
  62.             continue
  63.  
  64.         for file in os.listdir(person_folder):
  65.             if file.endswith(('.jpg', '.jpeg', '.png')):
  66.                 img_path = os.path.join(person_folder, file)
  67.                 image_paths.append(img_path)
  68.                 labels.append(person_id)
  69.  
  70.     # Proses ekstraksi fitur dalam batch
  71.     for i in tqdm(range(0, len(image_paths), batch_size), desc="📂 Ekstraksi Fitur", unit="batch"):
  72.         batch_paths = image_paths[i:i+batch_size]
  73.         features, valid_paths = extract_features(batch_paths)
  74.         if features is not None:
  75.             X.extend(features)
  76.             y.extend(labels[i:i+len(valid_paths)])
  77.  
  78.     return np.array(X), np.array(y)
  79.  
  80. # Load Data Train & Test dengan batch processing
  81. print("📂 Memuat dataset train...")
  82. X_train, y_train = load_dataset(TRAIN_PATH, batch_size=64)
  83.  
  84. print("📂 Memuat dataset test...")
  85. X_test, y_test = load_dataset(TEST_PATH, batch_size=64)
  86.  
  87. # Inisialisasi dan Latih KNN Classifier
  88. print("🚀 Melatih model Few-Shot Learning dengan KNN...")
  89. knn = KNeighborsClassifier(n_neighbors=3, metric="euclidean")
  90. knn.fit(X_train, y_train)
  91.  
  92. # Evaluasi Model
  93. y_pred = knn.predict(X_test)
  94. accuracy = np.mean(y_pred == y_test)
  95.  
  96. print(f"✅ Few-Shot Learning Accuracy: {accuracy:.2%}")
  97.  
Add Comment
Please, Sign In to add comment