Skip to content
Snippets Groups Projects
Commit 1882966b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

small updates for better training

parent da906910
No related branches found
No related tags found
No related merge requests found
......@@ -81,6 +81,6 @@ Install the packet to be used for getting the data : `!pip install gdown`.
Get the data with the following command: `!gdown https://drive.google.com/uc?id=1D6z3UbCoBOhOs8lhasgm-ap58-uPdDY-`
The basis for the tutorial is the contest of the `run.py` script in the folder of the cloned repository. You can gradulaly copy the commands from there and modify them.
The basis for the tutorial is the conrent of the `run.py` script in the folder of the cloned repository. You can gradulaly copy the commands from there, modify and run them.
......@@ -2,10 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F
#TODO: activation masks after each conv layer
class SimpleCNN(nn.Module):
def __init__(self):
def __init__(self, num_classes=3):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
......@@ -14,7 +12,7 @@ class SimpleCNN(nn.Module):
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=64 * 32 * 32, out_features=500)
self.fc2 = nn.Linear(in_features=500, out_features=3)
self.fc2 = nn.Linear(in_features=500, out_features=num_classes)
def forward(self, x, return_activations=False):
activations = {}
......
This diff is collapsed.
......@@ -16,7 +16,14 @@ def create_data_loaders(data_path, batch_size=8, num_workers=4):
test_dataset = datasets.ImageFolder(root=data_path + '/test', transform=transform)
# Create a DataLoader for each set
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
......
############################## this section prepares the data for training ##############################
############################## this section prepares the data for training ###############################
import shutil
import os
from prepare_data.clean_file_names import clean_file_names
# RAW_DATA_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/ml-department-workshop-dataset/simple-needles-2-class"
RAW_DATA_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/ml-department-workshop-dataset/simple-needles-3-class"
# RAW_DATA_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/ml-department-workshop-dataset/simple-needles-3-class"
RAW_DATA_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/ml-department-workshop-dataset/simple-needles-4-class"
# Clean file and directory names
......@@ -12,8 +15,13 @@ clean_file_names(RAW_DATA_PATH)
from prepare_data.prepare_train_val_test import PrepareTrainValTest
DATA_IN_PATH = RAW_DATA_PATH
DATA_OUT_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/datasets/data_splited"
# check if DATA_OUT_PATH exists and delete it if it does
if os.path.exists(DATA_OUT_PATH):
shutil.rmtree(DATA_OUT_PATH)
# Create train, validation, and test data sets
prepare_data = PrepareTrainValTest(DATA_IN_PATH, DATA_OUT_PATH)
......@@ -23,7 +31,7 @@ prepare_data.prepare_train_val_test()
from pipeline.data_loader import create_data_loaders
DATA_PATH = DATA_OUT_PATH
BATCH_SIZE = 8
NUM_WORKERS = 4
NUM_WORKERS = 8
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(DATA_PATH, BATCH_SIZE, NUM_WORKERS)
......@@ -40,6 +48,8 @@ TRAIN = True
if TRAIN:
# Import necessary packages for training
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -49,16 +59,24 @@ if TRAIN:
from models.simple_cnn import SimpleCNN
# Create an instance of the model
model = SimpleCNN()
model = SimpleCNN(num_classes=len(train_loader.dataset.classes))
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
# # use focal loss
# from focal_loss import FocalLoss
# criterion = FocalLoss(alpha=0.25, gamma=2.0)
# Use Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_acc = []
val_acc = []
# Train the model
num_epochs = 5
num_epochs = 15
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
......@@ -68,19 +86,69 @@ if TRAIN:
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Forward + loss + backward + optimize
outputs = model(inputs) # forward
loss = criterion(outputs, labels) # loss
loss.backward() # backward
optimizer.step() # optimize
# Print statistics
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, loss))
print('Train accuracy: %.2f%%' % (100 * (labels == outputs.argmax(dim=1)).sum().item() / len(labels)))
train_acc.append(100 * (labels == outputs.argmax(dim=1)).sum().item() / len(labels))
# Validation step
model.eval() # Set the model to evaluation mode
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = val_data
val_outputs = model(val_inputs)
val_loss += criterion(val_outputs, val_labels).item()
_, val_predicted = torch.max(val_outputs.data, dim=1)
val_total += val_labels.size(0)
val_correct += (val_predicted == val_labels).sum().item()
val_accuracy = 100 * val_correct / val_total
val_acc.append(val_accuracy)
val_loss /= len(val_loader)
print('Validation: loss = %.3f, Validation accuracy = %.2f%%' % (val_loss, val_accuracy))
model.train() # Set the model back to training mode
# save the model
torch.save(model.state_dict(), 'simple_cnn.pth')
# print the train and validation accuracy
# reset the figure
plt.clf()
plt.plot(train_acc, label='Train accuracy')
plt.plot(val_acc, label='Validation accuracy')
# plot also smooth curves
x = np.arange(len(train_acc))
y = np.array(train_acc)
z = np.polyfit(x, y, 3)
p_train = np.poly1d(z)
x = np.arange(len(val_acc))
y = np.array(val_acc)
z = np.polyfit(x, y, 3)
p_val = np.poly1d(z)
plt.plot(x, p_train(x), "r--", label='Train accuracy smooth')
plt.plot(x, p_val(x), "g--", label='Validation accuracy smooth')
plt.xlabel('Epoch')
plt.xticks(np.arange(0, num_epochs, step=1))
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('train_val_accuracy.png')
############################## this section evaluates the model ##############################
# load the model
......@@ -88,7 +156,7 @@ import torch
from models.simple_cnn import SimpleCNN
model = SimpleCNN()
model = SimpleCNN(num_classes=len(train_loader.dataset.classes))
model.load_state_dict(torch.load('simple_cnn.pth'))
# run the model on the test set and print the accuracy
......@@ -129,7 +197,7 @@ cm = confusion_matrix(y_true, y_pred)
# Plot the confusion matrix
plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') # 'Blues', 'Greens', 'Greys', 'Purples', 'Reds', etc.
plt.xlabel('Predicted label')
plt.ylabel('True label')
# save the confusion matrix
......@@ -180,7 +248,7 @@ def save_activations(activations, save_dir):
num_features = act.size(1)
for i in range(num_features):
plt.figure()
plt.imshow(act[0, i].detach().numpy(), cmap='hot')
plt.imshow(act[0, i].detach().numpy(), cmap='gray')
plt.axis('off')
# Save each channel's activation with a proper file name
......
import os
import argparse
import shutil
import numpy as np
......@@ -98,9 +99,7 @@ class PrepareTrainValTest:
if __name__ == "__main__":
# use argparse to get command line arguments
import argparse
# parse command-line arguments
parser = argparse.ArgumentParser(
description="Prepare train, validation, and test data sets from raw data."
)
......
......@@ -8,9 +8,14 @@ import torch
def show_sample_images(data_path, save_path=None):
# Function to show an image with labels
def imshow(img, labels, classes, save_path=None):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# Denormalize image
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = img.numpy().transpose((1, 2, 0)) # Convert from tensor image
img = std * img + mean
img = np.clip(img, 0, 1) # Clip values to be in the range [0, 1]
plt.imshow(img)
# Display labels below the image
plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks
......@@ -20,6 +25,7 @@ def show_sample_images(data_path, save_path=None):
else:
plt.show()
# Define transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment