Advertisement
ngnhtrg

Untitled

Oct 2nd, 2023
1,332
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.33 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as transforms
  4. import torchvision.datasets as datasets
  5. from torch.utils.data import DataLoader
  6. import torch.optim as optim
  7.  
  8. # Define your training and validation datasets and data loaders here
  9. # Example: Replace 'YourTrainDataset' and 'YourValDataset' with your actual datasets
  10. # train_dataset = YourTrainDataset(transform=transforms.Compose([transforms.Resize(224), transforms.ToTensor()]))
  11. # val_dataset = YourValDataset(transform=transforms.Compose([transforms.Resize(224), transforms.ToTensor()]))
  12.  
  13. # Define the number of classes for digit prediction (10 classes)
  14. num_classes_digits = 10
  15.  
  16. # Create an instance of the custom model
  17. model = CustomModel(num_classes_digits)
  18.  
  19. # Define loss function and optimizer
  20. criterion = nn.CrossEntropyLoss()
  21. optimizer = optim.Adam(model.parameters(), lr=0.001)
  22.  
  23. # Set the number of training epochs
  24. num_epochs = 10
  25.  
  26. # Training loop
  27. for epoch in range(num_epochs):
  28.     model.train()  # Set the model in training mode
  29.     for batch_data, batch_labels in train_loader:
  30.         optimizer.zero_grad()  # Zero out gradients
  31.         digit_logits = model(batch_data)  # Forward pass for digit prediction
  32.         loss = model.loss(digit_logits, batch_labels)  # Compute loss
  33.         loss.backward()  # Backpropagation
  34.         optimizer.step()  # Update model parameters
  35.  
  36.     # Validation
  37.     model.eval()  # Set the model in evaluation mode
  38.     val_loss = 0.0
  39.     correct_predictions = 0
  40.     total_samples = 0
  41.  
  42.     with torch.no_grad():
  43.         for val_data, val_labels in val_loader:
  44.             val_digit_logits = model(val_data)
  45.             val_loss += model.loss(val_digit_logits, val_labels).item()
  46.  
  47.             _, val_predicted_digits = torch.max(val_digit_logits, 2)
  48.             correct_predictions += torch.sum(val_predicted_digits == val_labels).item()
  49.             total_samples += val_labels.size(0)
  50.  
  51.     # Calculate validation loss and accuracy
  52.     avg_val_loss = val_loss / len(val_loader)
  53.     val_accuracy = correct_predictions / total_samples
  54.  
  55.     # Log and print training and validation progress
  56.     print(f'Epoch [{epoch+1}/{num_epochs}] - Train Loss: {loss:.4f} - Val Loss: {avg_val_loss:.4f} - Val Accuracy: {val_accuracy:.2%}')
  57.  
  58. # Save the trained model
  59. torch.save(model.state_dict(), 'custom_model_weights.pth')
  60.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement