diff --git a/.gitignore b/.gitignore
index 4e975f791450428326ec18a07dee02dbcea3e6bc..4bcf49ad9b405ec59dbe51a68c932150c622b456 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,11 @@
 /simple-needles-2-class
 /data
-*.pth
\ No newline at end of file
+*.pth
+*.JPG
+*.jpg
+*.jpeg
+*.JPEG
+*.png
+*.PNG
+.directory
+*.pyc
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/simple_cnn.py b/models/simple_cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e268454f7eba9271c0b9ab886f41ceaa2ec9a48
--- /dev/null
+++ b/models/simple_cnn.py
@@ -0,0 +1,40 @@
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+#TODO: activation masks after each conv layer
+
+class SimpleCNN(nn.Module):
+    def __init__(self):
+        super(SimpleCNN, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
+        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
+        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
+        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.fc1 = nn.Linear(in_features=64 * 32 * 32, out_features=500)
+        self.fc2 = nn.Linear(in_features=500, out_features=2)
+
+    def forward(self, x, return_activations=False):
+        activations = {}
+
+        x = F.relu(self.conv1(x))
+        activations['conv1'] = x
+        x = self.pool1(x)
+
+        x = F.relu(self.conv2(x))
+        activations['conv2'] = x
+        x = self.pool2(x)
+
+        x = F.relu(self.conv3(x))
+        activations['conv3'] = x
+        x = self.pool3(x)
+
+        x = x.view(-1, 64 * 32 * 32)  # Flatten
+        x = F.relu(self.fc1(x))
+
+        if return_activations:
+            return self.fc2(x), activations
+        else:
+            return self.fc2(x)
\ No newline at end of file
diff --git a/notebooks/run.ipynb b/notebooks/run.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..709d82cff5441273c179920d46f936095fd29ac8
--- /dev/null
+++ b/notebooks/run.ipynb
@@ -0,0 +1,18 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "name": "python"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb
index b44acc5705d9529fb6ef8bd28b674a16c464205a..a8ec2e128751917cfdf2d87bb73575a1973dbb8a 100644
--- a/notebooks/train_model.ipynb
+++ b/notebooks/train_model.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -36,9 +36,29 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    },
+    {
+     "ename": "NameError",
+     "evalue": "name 'data_path' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[1], line 30\u001b[0m\n\u001b[1;32m     23\u001b[0m transform \u001b[38;5;241m=\u001b[39m transforms\u001b[38;5;241m.\u001b[39mCompose([\n\u001b[1;32m     24\u001b[0m     transforms\u001b[38;5;241m.\u001b[39mResize((\u001b[38;5;241m256\u001b[39m, \u001b[38;5;241m256\u001b[39m)),\n\u001b[1;32m     25\u001b[0m     transforms\u001b[38;5;241m.\u001b[39mToTensor(),\n\u001b[1;32m     26\u001b[0m     transforms\u001b[38;5;241m.\u001b[39mNormalize((\u001b[38;5;241m0.5\u001b[39m,), (\u001b[38;5;241m0.5\u001b[39m,))\n\u001b[1;32m     27\u001b[0m ])\n\u001b[1;32m     29\u001b[0m \u001b[38;5;66;03m# Create the train_dataset and train_loader as before\u001b[39;00m\n\u001b[0;32m---> 30\u001b[0m train_dataset \u001b[38;5;241m=\u001b[39m datasets\u001b[38;5;241m.\u001b[39mImageFolder(root\u001b[38;5;241m=\u001b[39m\u001b[43mdata_path\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/train\u001b[39m\u001b[38;5;124m'\u001b[39m, transform\u001b[38;5;241m=\u001b[39mtransform)\n\u001b[1;32m     31\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(train_dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m     33\u001b[0m \u001b[38;5;66;03m# Get some random training images\u001b[39;00m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'data_path' is not defined"
+     ]
+    }
+   ],
    "source": [
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
@@ -82,9 +102,21 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "ename": "ModuleNotFoundError",
+     "evalue": "No module named 'models'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[2], line 11\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msys\u001b[39;00m\n\u001b[1;32m      9\u001b[0m sys\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/home/nibio/mutable-outside-world/code/ml-department-workshop/models\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodels\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msimple_cnn\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SimpleCNN \n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m# Create an instance of the model\u001b[39;00m\n\u001b[1;32m     14\u001b[0m model \u001b[38;5;241m=\u001b[39m SimpleCNN()\n",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'models'"
+     ]
+    }
+   ],
    "source": [
     "# create a simple CNN model\n",
     "import torch\n",
@@ -92,28 +124,11 @@
     "import torch.nn.functional as F\n",
     "import torch.optim as optim\n",
     "\n",
+    "# add path to models folder to python path\n",
+    "import sys\n",
+    "sys.path.append('models')\n",
     "\n",
-    "class SimpleCNN(nn.Module):\n",
-    "    def __init__(self):\n",
-    "        super(SimpleCNN, self).__init__()\n",
-    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)\n",
-    "        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
-    "        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)\n",
-    "        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
-    "        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)\n",
-    "        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
-    "        self.fc1 = nn.Linear(in_features=64 * 32 * 32, out_features=500)\n",
-    "        self.fc2 = nn.Linear(in_features=500, out_features=2)\n",
-    "\n",
-    "    def forward(self, x):\n",
-    "        x = self.pool1(F.relu(self.conv1(x)))  # 16 x 128 x 128\n",
-    "        x = self.pool2(F.relu(self.conv2(x)))  # 32 x 64 x 64\n",
-    "        x = self.pool3(F.relu(self.conv3(x)))  # 64 x 32 x 32\n",
-    "        x = x.view(-1, 64 * 32 * 32)  # Flatten\n",
-    "        x = F.relu(self.fc1(x))\n",
-    "        x = self.fc2(x)\n",
-    "        return x\n",
-    "    \n",
+    "from models.simple_cnn import SimpleCNN \n",
     "\n",
     "# Create an instance of the model\n",
     "model = SimpleCNN()\n",
@@ -151,14 +166,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Accuracy of the network on the test images: 77 %\n"
+      "Accuracy of the network on the test images: 100 %\n"
      ]
     }
    ],
diff --git a/pipeline/__init__.py b/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/pipeline/data_loader.py b/pipeline/data_loader.py
index 5bc70a2ae9e0d3538f82d5b1ea290ace0ed86047..d46908083e91d6b03c1cf0a790093f60be5be352 100644
--- a/pipeline/data_loader.py
+++ b/pipeline/data_loader.py
@@ -1,2 +1,23 @@
-import torch
-# import dataset from torch 
+from torchvision import transforms, datasets
+from torch.utils.data import DataLoader
+
+def create_data_loaders(data_path, batch_size=8, num_workers=4):
+    # Define a transform to apply to each image
+    transform = transforms.Compose([
+        transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),  # Use faster resize algorithm
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                             std=[0.229, 0.224, 0.225])
+    ])
+
+    # Create a dataset for each set: train, validation, and test
+    train_dataset = datasets.ImageFolder(root=data_path + '/train', transform=transform)
+    val_dataset = datasets.ImageFolder(root=data_path + '/val', transform=transform)
+    test_dataset = datasets.ImageFolder(root=data_path + '/test', transform=transform)
+
+    # Create a DataLoader for each set
+    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+    return train_loader, val_loader, test_loader
diff --git a/pipeline/run.py b/pipeline/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8b3aa7be63e9bf8b8f8910111d2575a10bc8454
--- /dev/null
+++ b/pipeline/run.py
@@ -0,0 +1,238 @@
+# The main run script is pipeline/run.py. 
+# This script will run the entire pipeline. 
+# It will first run the prepare_data step, then the train_model step, and finally the evaluate_model step. 
+# The prepare_data step will run the clean_file_names.py and prepare_train_val_test.py scripts. 
+# The train_model step will run the train_model.py script. 
+# The evaluate_model step will run the evaluate_model.py script.
+
+
+############################## this section prepares the data for training ##############################
+
+from prepare_data.clean_file_names import clean_file_names
+
+RAW_DATA_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/ml-department-workshop-dataset/simple-needles-2-class"
+
+# Clean file and directory names
+clean_file_names(RAW_DATA_PATH)
+
+
+from prepare_data.prepare_train_val_test import PrepareTrainValTest
+DATA_IN_PATH = RAW_DATA_PATH
+DATA_OUT_PATH = "/home/nibio/mutable-outside-world/code/ml-department-workshop/datasets/data_splited"
+
+# Create train, validation, and test data sets
+prepare_data = PrepareTrainValTest(DATA_IN_PATH, DATA_OUT_PATH)
+
+prepare_data.prepare_train_val_test()
+
+############################## this section creates the instance of the data readers ##############################
+from pipeline.data_loader import create_data_loaders
+DATA_PATH = DATA_OUT_PATH
+BATCH_SIZE = 8
+NUM_WORKERS = 4
+
+# Create data loaders
+train_loader, val_loader, test_loader = create_data_loaders(DATA_PATH, BATCH_SIZE, NUM_WORKERS)
+
+############################## visualize sample data ##############################################
+from visualization.show_sample_images import show_sample_images
+
+# Show sample images
+show_sample_images(DATA_PATH, 'output_image.png')
+
+
+############################## this section trains the model ##############################
+TRAIN = False
+
+if TRAIN:
+    # Import necessary packages for training
+    import torch
+    import torch.nn as nn
+    import torch.nn.functional as F
+    import torch.optim as optim
+
+    # Import the model
+    from models.simple_cnn import SimpleCNN 
+
+    # Create an instance of the model
+    model = SimpleCNN()
+
+    # Define the loss function and optimizer
+    criterion = nn.CrossEntropyLoss()
+
+    # Use Adam optimizer
+    optimizer = optim.Adam(model.parameters(), lr=0.001)
+
+    # Train the model
+    num_epochs = 5
+    for epoch in range(num_epochs):
+        running_loss = 0.0
+        for i, data in enumerate(train_loader):
+            # Get the inputs
+            inputs, labels = data
+
+            # Zero the parameter gradients
+            optimizer.zero_grad()
+
+            # Forward + backward + optimize
+            outputs = model(inputs)
+            loss = criterion(outputs, labels)
+            loss.backward()
+            optimizer.step()
+
+            # Print statistics
+            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, loss))
+
+
+    # save the model
+    torch.save(model.state_dict(), 'simple_cnn.pth')
+
+
+############################## this section evaluates the model ##############################
+# load the model
+import torch
+
+from models.simple_cnn import SimpleCNN
+
+model = SimpleCNN()
+model.load_state_dict(torch.load('simple_cnn.pth'))
+
+# run the model on the test set and print the accuracy
+correct = 0
+total = 0
+
+with torch.no_grad():
+    for data in test_loader:
+        images, labels = data
+        outputs = model(images)
+        _, predicted = torch.max(outputs.data, dim=1)
+        total += labels.size(0)
+        correct += (predicted == labels).sum().item()
+
+print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
+
+
+############################## this section plots the confusion matrix ##############################
+import matplotlib.pyplot as plt
+import numpy as np
+import seaborn as sns
+from sklearn.metrics import confusion_matrix
+
+# Get the predictions for the test data
+y_pred = []
+y_true = []
+
+with torch.no_grad():
+    for data in test_loader:
+        images, labels = data
+        outputs = model(images)
+        _, predicted = torch.max(outputs.data, dim=1)
+        y_pred += predicted.tolist()
+        y_true += labels.tolist()
+
+# Get the confusion matrix
+cm = confusion_matrix(y_true, y_pred)
+
+# Plot the confusion matrix
+plt.figure(figsize=(10, 10))
+sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
+plt.xlabel('Predicted label')
+plt.ylabel('True label')
+# save the confusion matrix
+plt.savefig('confusion_matrix.png')
+
+############################## this section plots the ROC curve ##############################
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.metrics import roc_curve, auc
+
+# Get the predictions for the test data
+y_pred = []
+y_true = []
+
+with torch.no_grad():
+    for data in test_loader:
+        images, labels = data
+        outputs = model(images)
+        _, predicted = torch.max(outputs.data, dim=1)
+        y_pred += predicted.tolist()
+        y_true += labels.tolist()
+
+# Get the ROC curve
+fpr, tpr, _ = roc_curve(y_true, y_pred)
+roc_auc = auc(fpr, tpr)
+
+# Plot the ROC curve
+plt.figure(figsize=(10, 10))
+plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
+plt.plot([0, 1], [0, 1], 'k--')  # Add a diagonal line for reference
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC Curve')
+plt.legend(loc="lower right")
+# save the ROC curve
+plt.savefig('roc_curve.png')
+
+
+############################## this section computes precison, recall and F1-score ##############################
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.metrics import precision_recall_fscore_support
+
+# Get the predictions for the test data
+y_pred = []
+y_true = []
+
+with torch.no_grad():
+    for data in test_loader:
+        images, labels = data
+        outputs = model(images)
+        _, predicted = torch.max(outputs.data, dim=1)
+        y_pred += predicted.tolist()
+        y_true += labels.tolist()
+
+# Get the precision, recall, and F1-score
+precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred)
+
+# Plot the precision, recall, and F1-score as a bar plot
+plt.figure(figsize=(10, 10))
+x = np.arange(len(precision))
+width = 0.2
+plt.bar(x, precision, width, label='Precision')
+plt.bar(x + width, recall, width, label='Recall')
+plt.bar(x + 2 * width, f1_score, width, label='F1-score')
+plt.xlabel('Class')
+plt.ylabel('Metric')
+plt.title('Precision, Recall, and F1-score')
+plt.xticks(x + width, range(len(precision)))
+plt.legend()
+# save the precision, recall, and F1-score
+plt.savefig('precision_recall_f1_score.png')
+
+############################## this section shows the activations after each layer of the model ############################
+
+import matplotlib.pyplot as plt
+import os
+
+def save_activations(activations, save_dir):
+    for name, act in activations.items():
+        num_features = act.size(1)
+        for i in range(num_features):
+            plt.figure()
+            plt.imshow(act[0, i].detach().numpy(), cmap='hot')
+            plt.axis('off')
+            
+            # Save each channel's activation with a proper file name
+            filename = f"{save_dir}/{name}_channel_{i}.png"
+            plt.savefig(filename, bbox_inches='tight', pad_inches=0)
+            plt.close()  # Close the plot to free up memory
+
+
+# Assuming 'images' is a batch of images
+# And 'model' is an instance of SimpleCNN
+# create a folder to save the activations
+save_dir = 'activations'
+os.makedirs(save_dir, exist_ok=True)
+
+outputs, activations = model(images, return_activations=True)
+save_activations(activations, save_dir)
diff --git a/prepare_data/clean_file_names.py b/prepare_data/clean_file_names.py
index 42096689990569bd5552c9c6900498d5c184c121..684a88a15fd34ecf4f098d07dc933beff59668c8 100644
--- a/prepare_data/clean_file_names.py
+++ b/prepare_data/clean_file_names.py
@@ -5,7 +5,7 @@ import re
 
 def clean_file_names(path):
     """
-    Clean file names in a directory. This function will replace all spaces with underscores,
+    Clean file names in a directory and its subfolders. This function will replace all spaces with underscores,
     replace all dashes with underscores, and change all file names to lowercase. If there are
     numbers in brackets, they will be replaced with an underscore and the number.
 
@@ -20,20 +20,21 @@ def clean_file_names(path):
 
     """
 
-    for filename in os.listdir(path):
-        if filename.lower().endswith((".png", ".jpg")):
-            # replace all spaces with underscores
-            new_filename = re.sub(r"\s+", "_", filename)
-            # replace all dashes with underscores
-            new_filename = re.sub(r"-", "_", new_filename)
-            # if there are numbers in bruckets, change to underscore number
-            new_filename = re.sub(r"\(\d+\)", lambda x: "_" + x.group()[1:-1], new_filename)
-            print(new_filename)
-            # rename file to new filename and change to lowercase
-            os.rename(
-                os.path.join(path, filename), os.path.join(path, new_filename.lower())
-            )
+    for root, dirs, files in os.walk(path):
+        for filename in files:
+            if filename.lower().endswith((".png", ".jpg")):
+                # replace all spaces with underscores
+                new_filename = re.sub(r"\s+", "_", filename)
+                # replace all dashes with underscores
+                new_filename = re.sub(r"-", "_", new_filename)
+                # if there are numbers in brackets, change to underscore number
+                new_filename = re.sub(r"\(\d+\)", lambda x: "_" + x.group()[1:-1], new_filename)
+                # print(new_filename.lower())
+
+                # rename file to new filename and change to lowercase
+                os.rename(
+                    os.path.join(root, filename), os.path.join(root, new_filename.lower())
+                )
 
 if __name__ == "__main__":
-
     clean_file_names(sys.argv[1])
\ No newline at end of file
diff --git a/prepare_data/preparare_train_val_test.py b/prepare_data/prepare_train_val_test.py
similarity index 100%
rename from prepare_data/preparare_train_val_test.py
rename to prepare_data/prepare_train_val_test.py
diff --git a/visualization/__init__.py b/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/visualization/show_sample_images.py b/visualization/show_sample_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..7975ef1d329762ebd215838b7add21d63a38bd4f
--- /dev/null
+++ b/visualization/show_sample_images.py
@@ -0,0 +1,42 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torchvision
+from torchvision import transforms, datasets
+from torch.utils.data import DataLoader
+import torch
+
+def show_sample_images(data_path, save_path=None):
+    # Function to show an image with labels
+    def imshow(img, labels, classes, save_path=None):
+        img = img / 2 + 0.5  # unnormalize
+        npimg = img.numpy()
+        plt.imshow(np.transpose(npimg, (1, 2, 0)))
+        # Display labels below the image
+        plt.xticks([])  # Remove x-axis ticks
+        plt.yticks([])  # Remove y-axis ticks
+        plt.xlabel(' - '.join('%5s' % classes[label] for label in labels), fontsize=10)
+        if save_path:
+            plt.savefig(save_path)
+        else:
+            plt.show()
+
+    # Define transformations
+    transform = transforms.Compose([
+        transforms.Resize((256, 256)),
+        transforms.ToTensor(),
+        transforms.Normalize((0.5,), (0.5,))
+    ])
+
+    # Create the train_dataset and train_loader
+    train_dataset = datasets.ImageFolder(root=data_path + '/train', transform=transform)
+    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
+
+    # Get some random training images
+    dataiter = iter(train_loader)
+    images, labels = next(dataiter)
+
+    # Show images with labels and optionally save the image
+    imshow(torchvision.utils.make_grid(images), labels, train_dataset.classes, save_path)
+
+# Example usage:
+# show_sample_images('/path/to/your/data', 'output_image.png')