import tensorflow as tffrom tensorflow.keras.utils import to_categoricalfrom tensorflow.keras.applications import ResNet50, MobileNetV2from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocessfrom tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocessfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense, Dropout, Flatten, GlobalAveragePooling2Dimport matplotlib.pyplot as pltimport tensorflow_datasets as tfds# Load Oxford Flowers 102 dataset(dataset_train, dataset_test), dataset_info = tfds.load('oxford_flowers102', split=['train', 'test'], as_supervised=True, with_info=True)# Dataset preprocessingdef preprocess(image, label, model_type='resnet'): image = tf.image.resize(image, (224, 224)) # Resize for input to modelsif model_type =='resnet': image = resnet_preprocess(image) # Preprocess for ResNetelif model_type =='mobilenet': image = mobilenet_preprocess(image) # Preprocess for MobileNetreturn image, label# Apply preprocessingbatch_size =32train_dataset = dataset_train.map(lambda x, y: preprocess(x, y, model_type='resnet')).batch(batch_size).prefetch(tf.data.AUTOTUNE)test_dataset = dataset_test.map(lambda x, y: preprocess(x, y, model_type='resnet')).batch(batch_size).prefetch(tf.data.AUTOTUNE)# Display a few samplesfig, axes = plt.subplots(1, 5, figsize=(15, 5))for i, (image, label) inenumerate(dataset_train.take(5)): axes[i].imshow(image) axes[i].set_title(f"Label: {label}") axes[i].axis('off')plt.tight_layout()plt.show()
Downloading and preparing dataset 328.90 MiB (download: 328.90 MiB, generated: 331.34 MiB, total: 660.25 MiB) to /root/tensorflow_datasets/oxford_flowers102/2.1.1...
Dataset oxford_flowers102 downloaded and prepared to /root/tensorflow_datasets/oxford_flowers102/2.1.1. Subsequent calls will reuse this data.
from tensorflow.keras.optimizers import Adam# Load ResNet50 pre-trained on ImageNetresnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))# Freeze base layersresnet_base.trainable =False# Add custom layersresnet_model = Sequential([ resnet_base, GlobalAveragePooling2D(), # Reduces feature maps to a single vector Dense(256, activation='relu'), Dropout(0.5), Dense(102, activation='softmax') # 102 flower categories])# Compile the modelresnet_model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])# Train the modelresnet_history = resnet_model.fit(train_dataset, epochs=10, validation_data=test_dataset)
import numpy as npimport matplotlib.pyplot as plt# Get class names from the dataset infoclass_names = dataset_info.features['label'].names# Select a batch of images and labels from the test datasettest_batch =next(iter(test_dataset))test_images, true_labels = test_batch# Get predictions for the batchpredictions = resnet_model.predict(test_images)predicted_labels = np.argmax(predictions, axis=1)# Display the first 5 images in the batchfig, axes = plt.subplots(1, 5, figsize=(20, 5))for i, ax inenumerate(axes):# Display the image ax.imshow(test_images[i].numpy().astype('uint8'))# Set the title with true and predicted labels true_label = class_names[true_labels[i]] predicted_label = class_names[predicted_labels[i]] ax.set_title(f"True: {true_label}\nPredicted: {predicted_label}", fontsize=12) ax.axis('off') # Hide axes for better visualizationplt.tight_layout()plt.show()