|
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 2199 images belonging to 2 classes.\n",
- "Found 549 images belonging to 2 classes.\n",
- "Training samples: 2199\n",
- "Validation samples: 549\n"
- ]
- }
- ],
- "source": [
- "import tensorflow as tf\n",
- "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
- "\n",
- "dataset_dir = 'C:/Users/vsavelev/GITHUB/DS_projet/jan24_cds_mushrooms/data'\n",
- "\n",
- "# Create ImageDataGenerator with validation split\n",
- "datagen = ImageDataGenerator(rescale=1.0/255, validation_split=0.2)\n",
- "\n",
- "train_generator = datagen.flow_from_directory(\n",
- " dataset_dir,\n",
- " target_size=(224, 224),\n",
- " batch_size=32,\n",
- " class_mode='categorical',\n",
- " subset='training' # Set as training data\n",
- ")\n",
- "\n",
- "validation_generator = datagen.flow_from_directory(\n",
- " dataset_dir,\n",
- " target_size=(224, 224),\n",
- " batch_size=32,\n",
- " class_mode='categorical',\n",
- " subset='validation' # Set as validation data\n",
- ")\n",
- "\n",
- "print(f'Training samples: {train_generator.samples}')\n",
- "print(f'Validation samples: {validation_generator.samples}')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"sequential\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ resnet50 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Functional</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ global_average_pooling2d │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
- "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePooling2D</span>) │ │ │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "</pre>\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ resnet50 (\u001b[38;5;33mFunctional\u001b[0m) │ ? │ \u001b[38;5;34m23,587,712\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ global_average_pooling2d │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
- "│ (\u001b[38;5;33mGlobalAveragePooling2D\u001b[0m) │ │ │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> (89.98 MB)\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m23,587,712\u001b[0m (89.98 MB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> (89.98 MB)\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m23,587,712\u001b[0m (89.98 MB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "from tensorflow.keras.applications import ResNet50\n",
- "from tensorflow.keras import layers, models\n",
- "\n",
- "# Load and Configure the Pre-trained ResNet50 Model\n",
- "base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
- "\"\"\"\n",
- "weights='imagenet': Loads the pre-trained weights from the ImageNet dataset.\n",
- "include_top=False: Excludes the top fully-connected layers of the ResNet50 model, enabling you to add your own custom layers.\n",
- "input_shape=(224, 224, 3): Specifies the input shape of the images (224x224 pixels, with 3 color channels - RGB).\n",
- "\"\"\"\n",
- "\n",
- "# Freeze the base model (to freeze the pre-trained layers)\n",
- "base_model.trainable = False\n",
- "\n",
- "# Add custom layers on top of the base model\n",
- "model = models.Sequential([ #allows to stack layers linearly\n",
- " base_model,\n",
- " layers.GlobalAveragePooling2D(),\n",
- " layers.Dense(1024, activation='relu'),\n",
- " layers.Dropout(0.5),\n",
- " layers.Dense(train_generator.num_classes, activation='softmax')\n",
- "])\n",
- "\n",
- "\"\"\"\n",
- "GlobalAveragePooling2D(): Reduces each feature map to a single number by taking the average, \n",
- "which helps to reduce the size of the model and prevent overfitting.\n",
- "Dense(1024, activation='relu'): Adds a fully connected layer with 1024 units and ReLU activation function.\n",
- "Dropout(0.5): Adds a dropout layer with a 50% dropout rate to prevent overfitting by randomly setting half of the input units \n",
- "to 0 at each update during training.\n",
- "Dense(train_generator.num_classes, activation='softmax'): \n",
- "Adds the final output layer with units equal to the number of classes in your dataset, using the softmax activation function for multi-class classification.\n",
- "\"\"\"\n",
- "model.compile(optimizer=tf.keras.optimizers.Adam(),\n",
- " loss='categorical_crossentropy',\n",
- " metrics=['accuracy'])\n",
- "\n",
- "\"\"\"\n",
- "optimizer=tf.keras.optimizers.Adam(): Uses the Adam optimizer, which is an adaptive learning rate optimization algorithm.\n",
- "loss='categorical_crossentropy': Uses categorical cross-entropy as the loss function, suitable for multi-class classification.\n",
- "metrics=['accuracy']: Tracks accuracy as the metric to evaluate the model's performance during training and testing.\n",
- "\"\"\"\n",
- "model.summary()\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/10\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "c:\\Users\\vsavelev\\AppData\\Local\\anaconda3\\Lib\\site-packages\\keras\\src\\trainers\\data_adapters\\py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
- " self._warn_if_super_not_called()\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m135s\u001b[0m 2s/step - accuracy: 0.8657 - loss: 0.5402 - val_accuracy: 0.8860 - val_loss: 0.3811\n",
- "Epoch 2/10\n",
- "\u001b[1m 1/68\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1:36\u001b[0m 1s/step - accuracy: 0.8125 - loss: 0.5545"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "c:\\Users\\vsavelev\\AppData\\Local\\anaconda3\\Lib\\contextlib.py:158: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.\n",
- " self.gen.throw(typ, value, traceback)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 7ms/step - accuracy: 0.8125 - loss: 0.5545 - val_accuracy: 1.0000 - val_loss: 0.2601\n",
- "Epoch 3/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m147s\u001b[0m 2s/step - accuracy: 0.8812 - loss: 0.3830 - val_accuracy: 0.8860 - val_loss: 0.3604\n",
- "Epoch 4/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9375 - loss: 0.2276 - val_accuracy: 1.0000 - val_loss: 0.1036\n",
- "Epoch 5/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m186s\u001b[0m 3s/step - accuracy: 0.8776 - loss: 0.3826 - val_accuracy: 0.8860 - val_loss: 0.3568\n",
- "Epoch 6/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 9ms/step - accuracy: 0.9688 - loss: 0.1941 - val_accuracy: 1.0000 - val_loss: 0.1264\n",
- "Epoch 7/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m199s\u001b[0m 3s/step - accuracy: 0.8844 - loss: 0.3665 - val_accuracy: 0.8879 - val_loss: 0.3532\n",
- "Epoch 8/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 10ms/step - accuracy: 0.9062 - loss: 0.2978 - val_accuracy: 0.8000 - val_loss: 0.5479\n",
- "Epoch 9/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m190s\u001b[0m 3s/step - accuracy: 0.8929 - loss: 0.3426 - val_accuracy: 0.8879 - val_loss: 0.3531\n",
- "Epoch 10/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 10ms/step - accuracy: 0.8125 - loss: 0.4996 - val_accuracy: 0.8000 - val_loss: 0.5142\n"
- ]
- }
- ],
- "source": [
- "history = model.fit(\n",
- " train_generator,\n",
- " steps_per_epoch=train_generator.samples // train_generator.batch_size,\n",
- " validation_data=validation_generator,\n",
- " validation_steps=validation_generator.samples // validation_generator.batch_size,\n",
- " epochs=10\n",
- ")\n",
- "\n",
- "#This specifies the number of complete passes through the training dataset. Here, the model will train for 10 epochs."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m330s\u001b[0m 4s/step - accuracy: 0.8855 - loss: 0.4046 - val_accuracy: 0.8879 - val_loss: 0.3791\n",
- "Epoch 2/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 10ms/step - accuracy: 0.8750 - loss: 0.3539 - val_accuracy: 0.8000 - val_loss: 0.5808\n",
- "Epoch 3/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m289s\u001b[0m 4s/step - accuracy: 0.8960 - loss: 0.2870 - val_accuracy: 0.8860 - val_loss: 0.3903\n",
- "Epoch 4/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 7ms/step - accuracy: 0.9688 - loss: 0.1397 - val_accuracy: 1.0000 - val_loss: 0.2601\n",
- "Epoch 5/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m283s\u001b[0m 4s/step - accuracy: 0.9162 - loss: 0.2052 - val_accuracy: 0.8787 - val_loss: 0.4098\n",
- "Epoch 6/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - accuracy: 0.8438 - loss: 0.3068 - val_accuracy: 0.6000 - val_loss: 0.7204\n",
- "Epoch 7/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m290s\u001b[0m 4s/step - accuracy: 0.9429 - loss: 0.1480 - val_accuracy: 0.8382 - val_loss: 0.4598\n",
- "Epoch 8/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 6ms/step - accuracy: 0.9062 - loss: 0.1594 - val_accuracy: 1.0000 - val_loss: 0.2145\n",
- "Epoch 9/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m299s\u001b[0m 4s/step - accuracy: 0.9658 - loss: 0.1009 - val_accuracy: 0.8548 - val_loss: 0.4892\n",
- "Epoch 10/10\n",
- "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 6ms/step - accuracy: 1.0000 - loss: 0.0651 - val_accuracy: 0.8000 - val_loss: 0.9206\n"
- ]
- }
- ],
- "source": [
- "# Unfreeze some layers\n",
- "base_model.trainable = True\n",
- "fine_tune_at = 100 # fine-tune from this layer onwards\n",
- "\n",
- "for layer in base_model.layers[:fine_tune_at]:\n",
- " layer.trainable = False\n",
- "\n",
- "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),\n",
- " loss='categorical_crossentropy', #The categorical cross-entropy loss function is used because this is a multi-class classification problem\n",
- " metrics=['accuracy'])\n",
- "\n",
- "\"\"\"\n",
- "The Adam optimizer is used with a very small learning rate (1e-5). Fine-tuning typically \n",
- "uses a smaller learning rate to prevent large updates to the weights, which could potentially destroy the learned features in the pre-trained model.\n",
- "\"\"\"\n",
- "history_fine = model.fit(\n",
- " train_generator,\n",
- " steps_per_epoch=train_generator.samples // train_generator.batch_size,\n",
- " validation_data=validation_generator,\n",
- " validation_steps=validation_generator.samples // validation_generator.batch_size,\n",
- " epochs=10\n",
- ")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m36s\u001b[0m 2s/step - accuracy: 0.8559 - loss: 0.4854\n",
- "Validation loss: 0.4950728416442871\n",
- "Validation accuracy: 0.8561019897460938\n"
- ]
- }
- ],
- "source": [
- "loss, accuracy = model.evaluate(validation_generator)\n",
- "print(f'Validation loss: {loss}')\n",
- "print(f'Validation accuracy: {accuracy}')\n"
- ]
- }
- ],
- "metadata": {
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|