[python] Training script for convolutional image classifier

Viewer

copydownloadembedprintName: Training script for convolutional image classifier
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision import datasets, transforms
  5. from tqdm import tqdm
  6. import time
  7. import matplotlib.pyplot as plt
  8. import pandas as pd
  9.  
  10. # Define the hyperparameters
  11. batch_size = 64
  12. num_epochs = 150
  13. input_size = (224, 224)
  14. num_classes = 2
  15. learning_rate = 0.01
  16. weight_decay = 0.01
  17.  
  18. # Start the training timer
  19. training_start_time = time.time()
  20.  
  21. # Define the paths to the training and validation data
  22. path_to_train_data = "D:/Workzone/Datasets/bestphoto/train"
  23. path_to_validation_data = "D:/Workzone/Datasets/bestphoto/validation"
  24.  
  25. # Define the device to use for training
  26. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  27.  
  28. # Define a transform to preprocess the data
  29. transform = transforms.Compose([
  30.     transforms.ColorJitter(brightness=.5, hue=.3),
  31.     transforms.RandomHorizontalFlip(),
  32.     transforms.Resize(input_size),
  33.     transforms.ToTensor(),
  34.     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  35.  
  36. # Load the training data
  37. train_dataset = datasets.ImageFolder(root=path_to_train_data, transform=transform)
  38.  
  39. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  40. # Load the validation data
  41. val_dataset = datasets.ImageFolder(root=path_to_validation_data, transform=transform)
  42. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  43.  
  44. # Define the CNN classifier
  45. class CNNClassifier(nn.Module):
  46.     def __init__(self, input_size, num_classes):
  47.         super(CNNClassifier, self).__init__()
  48.         self.input_size = input_size
  49.         self.num_classes = num_classes
  50.         self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1)
  51.         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  52.         self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
  53.         self.fc1 = nn.Linear(50176, 128)
  54.         self.fc2 = nn.Linear(128, num_classes)
  55.         self.softmax = nn.Softmax(dim=1)
  56.         
  57.     def forward(self, x):     
  58.         x = self.conv1(x)
  59.         x = self.pool(x)
  60.         x = self.conv2(x)   
  61.         x = self.pool(x)
  62.         x = torch.flatten(x, 1)
  63.         x = self.fc1(x)
  64.         x = self.fc2(x)
  65.         x = self.softmax(x)
  66.         return x
  67.  
  68. # Create an instance of the CNN classifier
  69. model = CNNClassifier(input_size=input_size, num_classes=num_classes).to(device)
  70.  
  71. # Define the loss function and optimizer
  72. criterion = nn.NLLLoss()
  73. optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  74.  
  75. # Initialize empty lists for the training and validation losses
  76. train_losses = []
  77. val_losses = []
  78.  
  79. # Define the training function
  80. def train():
  81.     # Train the model
  82.     for epoch in range(num_epochs):
  83.         start = time.time()
  84.  
  85.         epoch_loss = 0.0
  86.         
  87.         # Train the model on the training data
  88.         model.train()
  89.         for inputs, labels in tqdm(train_dataloader):
  90.             # Move the data to the correct device
  91.             inputs = inputs.to(device)
  92.             labels = labels.to(device)
  93.  
  94.             # Clear the gradients
  95.             optimizer.zero_grad()
  96.             
  97.             # Forward pass
  98.             log_probs = model(inputs)
  99.             
  100.             # Compute the loss
  101.             loss = criterion(log_probs, labels)
  102.             epoch_loss += loss.item()
  103.             
  104.             # Backward pass
  105.             loss.backward()
  106.             
  107.             # Update the weights
  108.             optimizer.step()
  109.             
  110.         # Append the epoch loss to the training losses list
  111.         train_losses.append(epoch_loss / len(train_dataloader))
  112.         
  113.         # Evaluate the model on the validation data
  114.         model.eval()
  115.         with torch.no_grad():
  116.             correct = 0
  117.             total = 0   
  118.             val_loss = 0
  119.             for inputs, labels in tqdm(val_dataloader):
  120.                 # Move the data to the correct device
  121.                 inputs = inputs.to(device)
  122.                 labels = labels.to(device)
  123.  
  124.                 # Forward pass
  125.                 log_probs = model(inputs)
  126.  
  127.                 # Compute the loss
  128.                 loss = criterion(log_probs, labels)
  129.                 val_loss += loss.item()
  130.  
  131.                 _, predicted = torch.max(log_probs, 1)
  132.  
  133.                 # Update the correct and total count
  134.                 total += labels.size(0)
  135.                 correct += (predicted == labels).sum().item()
  136.                 
  137.             # Append the average validation loss to the validation losses list
  138.             val_losses.append(val_loss / len(val_dataloader))
  139.  
  140.         # Calculate the accuracy
  141.         accuracy = correct / total
  142.  
  143.         # Print the epoch loss and time elapsed
  144.         print(f'Epoch {epoch+1} | Loss: {epoch_loss / len(train_dataloader):.4f} | Val Loss: {val_loss / len(val_dataloader):.4f} | Accuracy: {accuracy:.4f}" | Time: {time.time() - start:.2f}s')
  145.  
  146.         # Save a copy of the model every 10 epochs
  147.         if (epoch+1) % 10 == 0:
  148.             # Get the current time in a struct_time object
  149.             now = time.gmtime()
  150.  
  151.             # Format the time stamp as a string and save the model
  152.             time_stamp = time.strftime("%Y_%m_%d_%H_%M_%S", now)
  153.             torch.save(model.state_dict(), f"model_{epoch+1}_{time_stamp}.pt")
  154.             
  155.             # Plot the learning curve
  156.             save_learning_curve_data(train_losses, val_losses, epoch+1)
  157.  
  158. # Define the function to plot the learning curve
  159. def save_learning_curve_data(train_losses, val_losses, epoch):
  160.     # Create a dataframe from the lists of losses
  161.     data = {'Epoch'range(1, epoch+1), 'Training Loss': train_losses, 'Validation Loss': val_losses}
  162.     df = pd.DataFrame(data)
  163.     
  164.     # Get the current time
  165.     now = time.gmtime()
  166.     time_stamp = time.strftime("%Y_%m_%d_%H_%M_%S", now)
  167.     
  168.     # Save the dataframe as a CSV file
  169.     df.to_csv(f'learning_curve_epoch_{epoch}_{time_stamp}.csv', index=False)
  170.  
  171.  
  172. def save_model():
  173.     # Get the current time in a struct_time object
  174.     now = time.gmtime()
  175.  
  176.     # Format the time stamp as a string
  177.     time_stamp = time.strftime("%Y_%m_%d_%H_%M_%S", now)
  178.  
  179.     # Generate a file name for the saved model
  180.     model_name = f"model_{time_stamp}.pt"
  181.  
  182.     # Save the model
  183.     torch.save(model.state_dict(), model_name)
  184.     print(f"Model saved as {model_name}")
  185.  
  186.  
  187. if __name__ == "__main__":
  188.     train()
  189.     save_model()
  190.     save_learning_curve_data(train_losses, val_losses, num_epochs)
  191.     print(f"Total training time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - training_start_time))}")
  192.  

Editor

You can edit this paste and save as new:


File Description
  • Training script for convolutional image classifier
  • Paste Code
  • 07 Jan-2023
  • 6.65 Kb
You can Share it: