Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torchvision.transforms as transforms
- import torchvision.datasets as datasets
- from torch.utils.data import DataLoader
- import torch.optim as optim
- # Define your training and validation datasets and data loaders here
- # Example: Replace 'YourTrainDataset' and 'YourValDataset' with your actual datasets
- # train_dataset = YourTrainDataset(transform=transforms.Compose([transforms.Resize(224), transforms.ToTensor()]))
- # val_dataset = YourValDataset(transform=transforms.Compose([transforms.Resize(224), transforms.ToTensor()]))
- # Define the number of classes for digit prediction (10 classes)
- num_classes_digits = 10
- # Create an instance of the custom model
- model = CustomModel(num_classes_digits)
- # Define loss function and optimizer
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=0.001)
- # Set the number of training epochs
- num_epochs = 10
- # Training loop
- for epoch in range(num_epochs):
- model.train() # Set the model in training mode
- for batch_data, batch_labels in train_loader:
- optimizer.zero_grad() # Zero out gradients
- digit_logits = model(batch_data) # Forward pass for digit prediction
- loss = model.loss(digit_logits, batch_labels) # Compute loss
- loss.backward() # Backpropagation
- optimizer.step() # Update model parameters
- # Validation
- model.eval() # Set the model in evaluation mode
- val_loss = 0.0
- correct_predictions = 0
- total_samples = 0
- with torch.no_grad():
- for val_data, val_labels in val_loader:
- val_digit_logits = model(val_data)
- val_loss += model.loss(val_digit_logits, val_labels).item()
- _, val_predicted_digits = torch.max(val_digit_logits, 2)
- correct_predictions += torch.sum(val_predicted_digits == val_labels).item()
- total_samples += val_labels.size(0)
- # Calculate validation loss and accuracy
- avg_val_loss = val_loss / len(val_loader)
- val_accuracy = correct_predictions / total_samples
- # Log and print training and validation progress
- print(f'Epoch [{epoch+1}/{num_epochs}] - Train Loss: {loss:.4f} - Val Loss: {avg_val_loss:.4f} - Val Accuracy: {val_accuracy:.2%}')
- # Save the trained model
- torch.save(model.state_dict(), 'custom_model_weights.pth')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement