Building an Image Classification Model with Transfer Learning in TensorFlow

Kshitij Kutumbe
3 min readDec 3, 2023

--

Introduction

In this blog, we’ll dive into an exciting Machine Learning project: building an image classification model using transfer learning in TensorFlow. This project is not only interesting but also highly educational, providing insights into cutting-edge techniques in AI. Our goal is to classify images from a publicly available dataset, leveraging the power of pre-trained models to achieve high accuracy with minimal effort. This project is perfect for those who want to explore the depths of machine learning and TensorFlow.

Why Image Classification?

Image classification is a foundational task in computer vision, with applications ranging from facial recognition to medical imaging analysis. It’s an excellent entry point into the world of AI and a fantastic way to learn about neural networks.

Why Transfer Learning?

Transfer learning involves taking a pre-trained neural network and adapting it to a new, but similar, task. It’s incredibly efficient, reducing the need for large datasets and extensive computing resources.

Project Overview

We’ll use the TensorFlow library to implement our image classifier. Our dataset will be the CIFAR-10 dataset, a well-known set of 60,000 32x32 color images in 10 classes, with 6,000 images per class.

Tools and Libraries

  • Python: Our primary programming language.
  • TensorFlow: An open-source machine learning library.
  • Keras: A high-level neural networks API, running on top of TensorFlow.

Step 1: Environment Setup

Ensure you have Python installed on your system. Then, install TensorFlow:

pip install tensorflow

Step 2: Import Necessary Libraries

In your Python environment, start by importing the necessary libraries:

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import matplotlib.pyplot as plt

Step 3: Load and Preprocess the Dataset

The CIFAR-10 dataset is conveniently included in Keras:

(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

Step 4: Visualize the Data

It’s always a good idea to understand the data you’re working with:

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i][0]])
plt.show()

Step 5: Using a Pre-trained Model

We’ll use MobileNetV2, a lightweight, efficient model for mobile vision applications:

base_model = MobileNetV2(input_shape=(32, 32, 3), include_top=False, weights='imagenet')
base_model.trainable = False # Freeze the base model

# Add our own classifier on top
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(10, activation='softmax')
])

Step 6: Compile and Train the Model

Now, we compile and train our model:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

Step 7: Evaluate the Model

After training, we evaluate the model’s performance:

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Test accuracy: {test_acc}")

Step 8: Plotting Performance Metrics

Visualize the training process:

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='validation accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()

Conclusion

Congratulations! You’ve just built an image classification model using transfer learning in TensorFlow. This project highlights the power of leveraging pre-trained models to create efficient, effective AI solutions.

Stay tuned for more exciting machine learning projects and deep dives into AI concepts. Happy coding!

--

--

Kshitij Kutumbe
Kshitij Kutumbe

Written by Kshitij Kutumbe

Data Scientist | NLP | GenAI | RAG | AI agents | Knowledge Graph | Neo4j kshitijkutumbe@gmail.com www.linkedin.com/in/kshitijkutumbe/