You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

404 lines
25 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "Found 2199 images belonging to 2 classes.\n",
  13. "Found 549 images belonging to 2 classes.\n",
  14. "Training samples: 2199\n",
  15. "Validation samples: 549\n"
  16. ]
  17. }
  18. ],
  19. "source": [
  20. "import tensorflow as tf\n",
  21. "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
  22. "\n",
  23. "dataset_dir = 'C:/Users/vsavelev/GITHUB/DS_projet/jan24_cds_mushrooms/data'\n",
  24. "\n",
  25. "# Create ImageDataGenerator with validation split\n",
  26. "datagen = ImageDataGenerator(rescale=1.0/255, validation_split=0.2)\n",
  27. "\n",
  28. "train_generator = datagen.flow_from_directory(\n",
  29. " dataset_dir,\n",
  30. " target_size=(224, 224),\n",
  31. " batch_size=32,\n",
  32. " class_mode='categorical',\n",
  33. " subset='training' # Set as training data\n",
  34. ")\n",
  35. "\n",
  36. "validation_generator = datagen.flow_from_directory(\n",
  37. " dataset_dir,\n",
  38. " target_size=(224, 224),\n",
  39. " batch_size=32,\n",
  40. " class_mode='categorical',\n",
  41. " subset='validation' # Set as validation data\n",
  42. ")\n",
  43. "\n",
  44. "print(f'Training samples: {train_generator.samples}')\n",
  45. "print(f'Validation samples: {validation_generator.samples}')"
  46. ]
  47. },
  48. {
  49. "cell_type": "code",
  50. "execution_count": 2,
  51. "metadata": {},
  52. "outputs": [
  53. {
  54. "data": {
  55. "text/html": [
  56. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
  57. "</pre>\n"
  58. ],
  59. "text/plain": [
  60. "\u001b[1mModel: \"sequential\"\u001b[0m\n"
  61. ]
  62. },
  63. "metadata": {},
  64. "output_type": "display_data"
  65. },
  66. {
  67. "data": {
  68. "text/html": [
  69. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
  70. "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
  71. "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
  72. "│ resnet50 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Functional</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> │\n",
  73. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  74. "│ global_average_pooling2d │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
  75. "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePooling2D</span>) │ │ │\n",
  76. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  77. "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
  78. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  79. "│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
  80. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  81. "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ ? │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (unbuilt) │\n",
  82. "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
  83. "</pre>\n"
  84. ],
  85. "text/plain": [
  86. "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
  87. "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
  88. "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
  89. "│ resnet50 (\u001b[38;5;33mFunctional\u001b[0m) │ ? │ \u001b[38;5;34m23,587,712\u001b[0m │\n",
  90. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  91. "│ global_average_pooling2d │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
  92. "│ (\u001b[38;5;33mGlobalAveragePooling2D\u001b[0m) │ │ │\n",
  93. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  94. "│ dense (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
  95. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  96. "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
  97. "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
  98. "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n",
  99. "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
  100. ]
  101. },
  102. "metadata": {},
  103. "output_type": "display_data"
  104. },
  105. {
  106. "data": {
  107. "text/html": [
  108. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> (89.98 MB)\n",
  109. "</pre>\n"
  110. ],
  111. "text/plain": [
  112. "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m23,587,712\u001b[0m (89.98 MB)\n"
  113. ]
  114. },
  115. "metadata": {},
  116. "output_type": "display_data"
  117. },
  118. {
  119. "data": {
  120. "text/html": [
  121. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
  122. "</pre>\n"
  123. ],
  124. "text/plain": [
  125. "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
  126. ]
  127. },
  128. "metadata": {},
  129. "output_type": "display_data"
  130. },
  131. {
  132. "data": {
  133. "text/html": [
  134. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">23,587,712</span> (89.98 MB)\n",
  135. "</pre>\n"
  136. ],
  137. "text/plain": [
  138. "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m23,587,712\u001b[0m (89.98 MB)\n"
  139. ]
  140. },
  141. "metadata": {},
  142. "output_type": "display_data"
  143. }
  144. ],
  145. "source": [
  146. "from tensorflow.keras.applications import ResNet50\n",
  147. "from tensorflow.keras import layers, models\n",
  148. "\n",
  149. "# Load and Configure the Pre-trained ResNet50 Model\n",
  150. "base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
  151. "\"\"\"\n",
  152. "weights='imagenet': Loads the pre-trained weights from the ImageNet dataset.\n",
  153. "include_top=False: Excludes the top fully-connected layers of the ResNet50 model, enabling you to add your own custom layers.\n",
  154. "input_shape=(224, 224, 3): Specifies the input shape of the images (224x224 pixels, with 3 color channels - RGB).\n",
  155. "\"\"\"\n",
  156. "\n",
  157. "# Freeze the base model (to freeze the pre-trained layers)\n",
  158. "base_model.trainable = False\n",
  159. "\n",
  160. "# Add custom layers on top of the base model\n",
  161. "model = models.Sequential([ #allows to stack layers linearly\n",
  162. " base_model,\n",
  163. " layers.GlobalAveragePooling2D(),\n",
  164. " layers.Dense(1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)), # L2 regularization\n",
  165. " layers.Dropout(0.5),\n",
  166. " layers.Dense(train_generator.num_classes, activation='softmax')\n",
  167. "])\n",
  168. "\n",
  169. "\"\"\"\n",
  170. "GlobalAveragePooling2D(): Reduces each feature map to a single number by taking the average, \n",
  171. "which helps to reduce the size of the model and prevent overfitting.\n",
  172. "Dense(1024, activation='relu'): Adds a fully connected layer with 1024 units and ReLU activation function.\n",
  173. "Dropout(0.5): Adds a dropout layer with a 50% dropout rate to prevent overfitting by randomly setting half of the input units \n",
  174. "to 0 at each update during training.\n",
  175. "Dense(train_generator.num_classes, activation='softmax'): \n",
  176. "Adds the final output layer with units equal to the number of classes in your dataset, using the softmax activation function for multi-class classification.\n",
  177. "\"\"\"\n",
  178. "model.compile(optimizer=tf.keras.optimizers.Adam(),\n",
  179. " loss='categorical_crossentropy',\n",
  180. " metrics=['accuracy'])\n",
  181. "\n",
  182. "\"\"\"\n",
  183. "optimizer=tf.keras.optimizers.Adam(): Uses the Adam optimizer, which is an adaptive learning rate optimization algorithm.\n",
  184. "loss='categorical_crossentropy': Uses categorical cross-entropy as the loss function, suitable for multi-class classification.\n",
  185. "metrics=['accuracy']: Tracks accuracy as the metric to evaluate the model's performance during training and testing.\n",
  186. "\"\"\"\n",
  187. "model.summary()\n",
  188. "\n"
  189. ]
  190. },
  191. {
  192. "cell_type": "code",
  193. "execution_count": 3,
  194. "metadata": {},
  195. "outputs": [
  196. {
  197. "name": "stdout",
  198. "output_type": "stream",
  199. "text": [
  200. "Epoch 1/10\n"
  201. ]
  202. },
  203. {
  204. "name": "stderr",
  205. "output_type": "stream",
  206. "text": [
  207. "c:\\Users\\vsavelev\\AppData\\Local\\anaconda3\\Lib\\site-packages\\keras\\src\\trainers\\data_adapters\\py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
  208. " self._warn_if_super_not_called()\n"
  209. ]
  210. },
  211. {
  212. "name": "stdout",
  213. "output_type": "stream",
  214. "text": [
  215. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 2s/step - accuracy: 0.8540 - loss: 6.8043 - val_accuracy: 0.8860 - val_loss: 0.6636\n",
  216. "Epoch 2/10\n",
  217. "\u001b[1m 1/68\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m2:09\u001b[0m 2s/step - accuracy: 0.8438 - loss: 0.8550"
  218. ]
  219. },
  220. {
  221. "name": "stderr",
  222. "output_type": "stream",
  223. "text": [
  224. "c:\\Users\\vsavelev\\AppData\\Local\\anaconda3\\Lib\\contextlib.py:158: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.\n",
  225. " self.gen.throw(typ, value, traceback)\n"
  226. ]
  227. },
  228. {
  229. "name": "stdout",
  230. "output_type": "stream",
  231. "text": [
  232. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 11ms/step - accuracy: 0.8438 - loss: 0.8550 - val_accuracy: 1.0000 - val_loss: 0.4526\n",
  233. "Epoch 3/10\n",
  234. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m146s\u001b[0m 2s/step - accuracy: 0.8873 - loss: 0.6507 - val_accuracy: 0.8879 - val_loss: 0.4938\n",
  235. "Epoch 4/10\n",
  236. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step - accuracy: 0.8750 - loss: 0.5112 - val_accuracy: 0.8000 - val_loss: 0.6924\n",
  237. "Epoch 5/10\n",
  238. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m149s\u001b[0m 2s/step - accuracy: 0.8857 - loss: 0.5040 - val_accuracy: 0.8879 - val_loss: 0.4471\n",
  239. "Epoch 6/10\n",
  240. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step - accuracy: 0.8750 - loss: 0.4322 - val_accuracy: 0.8000 - val_loss: 0.6413\n",
  241. "Epoch 7/10\n",
  242. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m166s\u001b[0m 2s/step - accuracy: 0.8926 - loss: 0.4378 - val_accuracy: 0.8860 - val_loss: 0.4396\n",
  243. "Epoch 8/10\n",
  244. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 7ms/step - accuracy: 0.9375 - loss: 0.3477 - val_accuracy: 1.0000 - val_loss: 0.2697\n",
  245. "Epoch 9/10\n",
  246. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m173s\u001b[0m 3s/step - accuracy: 0.8890 - loss: 0.4132 - val_accuracy: 0.8879 - val_loss: 0.4444\n",
  247. "Epoch 10/10\n",
  248. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 7ms/step - accuracy: 0.9688 - loss: 0.3492 - val_accuracy: 0.8000 - val_loss: 0.5367\n"
  249. ]
  250. }
  251. ],
  252. "source": [
  253. "history = model.fit(\n",
  254. " train_generator,\n",
  255. " steps_per_epoch=train_generator.samples // train_generator.batch_size,\n",
  256. " validation_data=validation_generator,\n",
  257. " validation_steps=validation_generator.samples // validation_generator.batch_size,\n",
  258. " epochs=10\n",
  259. ")\n",
  260. "\n",
  261. "#This specifies the number of complete passes through the training dataset. Here, the model will train for 10 epochs."
  262. ]
  263. },
  264. {
  265. "cell_type": "code",
  266. "execution_count": 12,
  267. "metadata": {},
  268. "outputs": [
  269. {
  270. "name": "stdout",
  271. "output_type": "stream",
  272. "text": [
  273. "Model compiled successfully.\n",
  274. "Callbacks created successfully.\n",
  275. "Epoch 1/20\n",
  276. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m267s\u001b[0m 4s/step - accuracy: 0.8686 - loss: 0.6860 - val_accuracy: 0.8897 - val_loss: 0.5719\n",
  277. "Epoch 2/20\n",
  278. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - accuracy: 0.8750 - loss: 0.6368 - val_accuracy: 0.6000 - val_loss: 0.7365\n",
  279. "Epoch 3/20\n",
  280. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m286s\u001b[0m 4s/step - accuracy: 0.8856 - loss: 0.6046 - val_accuracy: 0.8897 - val_loss: 0.6145\n",
  281. "Epoch 4/20\n",
  282. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - accuracy: 0.8438 - loss: 0.5484 - val_accuracy: 0.6000 - val_loss: 0.7114\n",
  283. "Epoch 5/20\n",
  284. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m323s\u001b[0m 5s/step - accuracy: 0.8840 - loss: 0.5082 - val_accuracy: 0.8860 - val_loss: 0.5360\n",
  285. "Epoch 6/20\n",
  286. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 7ms/step - accuracy: 0.8750 - loss: 0.4655 - val_accuracy: 1.0000 - val_loss: 0.4314\n",
  287. "Epoch 7/20\n",
  288. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m299s\u001b[0m 4s/step - accuracy: 0.8770 - loss: 0.4487 - val_accuracy: 0.8879 - val_loss: 0.4546\n",
  289. "Epoch 8/20\n",
  290. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 6ms/step - accuracy: 0.9375 - loss: 0.3386 - val_accuracy: 0.8000 - val_loss: 0.5507\n",
  291. "Epoch 9/20\n",
  292. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m276s\u001b[0m 4s/step - accuracy: 0.8946 - loss: 0.3906 - val_accuracy: 0.8879 - val_loss: 0.4199\n",
  293. "Epoch 10/20\n",
  294. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 7ms/step - accuracy: 0.8438 - loss: 0.4413 - val_accuracy: 0.8000 - val_loss: 0.5716\n",
  295. "Epoch 11/20\n",
  296. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m293s\u001b[0m 4s/step - accuracy: 0.8840 - loss: 0.3867 - val_accuracy: 0.8860 - val_loss: 0.4153\n",
  297. "Epoch 12/20\n",
  298. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9565 - loss: 0.2604 - val_accuracy: 1.0000 - val_loss: 0.2164\n",
  299. "Epoch 13/20\n",
  300. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m290s\u001b[0m 4s/step - accuracy: 0.8951 - loss: 0.3485 - val_accuracy: 0.8915 - val_loss: 0.3989\n",
  301. "Epoch 14/20\n",
  302. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 7ms/step - accuracy: 0.9375 - loss: 0.2921 - val_accuracy: 0.4000 - val_loss: 1.3309\n",
  303. "Epoch 15/20\n",
  304. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m273s\u001b[0m 4s/step - accuracy: 0.8950 - loss: 0.3413 - val_accuracy: 0.8860 - val_loss: 0.4095\n",
  305. "Epoch 16/20\n",
  306. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 7ms/step - accuracy: 0.9062 - loss: 0.3136 - val_accuracy: 1.0000 - val_loss: 0.1990\n",
  307. "Epoch 17/20\n",
  308. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m270s\u001b[0m 4s/step - accuracy: 0.8919 - loss: 0.3295 - val_accuracy: 0.8860 - val_loss: 0.4054\n",
  309. "Epoch 18/20\n",
  310. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - accuracy: 0.8750 - loss: 0.3683 - val_accuracy: 1.0000 - val_loss: 0.1749\n",
  311. "Epoch 19/20\n",
  312. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m279s\u001b[0m 4s/step - accuracy: 0.8906 - loss: 0.3170 - val_accuracy: 0.8860 - val_loss: 0.4067\n",
  313. "Epoch 20/20\n",
  314. "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - accuracy: 0.9375 - loss: 0.2421 - val_accuracy: 1.0000 - val_loss: 0.2140\n"
  315. ]
  316. }
  317. ],
  318. "source": [
  319. "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n",
  320. "\n",
  321. "\n",
  322. "# Unfreeze some layers\n",
  323. "base_model.trainable = True\n",
  324. "fine_tune_at = 100 # fine-tune from this layer onwards\n",
  325. "\n",
  326. "for layer in base_model.layers[:fine_tune_at]:\n",
  327. " layer.trainable = False\n",
  328. "\n",
  329. "# # Unfreeze more layers gradually\n",
  330. "# for layer in base_model.layers[:-10]: # Unfreeze all layers except the last 10 layers\n",
  331. "# layer.trainable = False\n",
  332. "\n",
  333. "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), #learning_rate=1e-5\n",
  334. " loss='categorical_crossentropy', #The categorical cross-entropy loss function is used because this is a multi-class classification problem\n",
  335. " metrics=['accuracy'])\n",
  336. "\n",
  337. "print(\"Model compiled successfully.\")\n",
  338. "\n",
  339. "\"\"\"\n",
  340. "The Adam optimizer is used with a very small learning rate (1e-5). Fine-tuning typically \n",
  341. "uses a smaller learning rate to prevent large updates to the weights, which could potentially destroy the learned features in the pre-trained model.\n",
  342. "\"\"\"\n",
  343. "\n",
  344. "early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
  345. "model_checkpoint = ModelCheckpoint('best_model.keras', save_best_only=True, monitor='val_loss')\n",
  346. "\n",
  347. "print(\"Callbacks created successfully.\")\n",
  348. "\n",
  349. "\n",
  350. "history_fine = model.fit(\n",
  351. " train_generator,\n",
  352. " steps_per_epoch=train_generator.samples // train_generator.batch_size,\n",
  353. " validation_data=validation_generator,\n",
  354. " validation_steps=validation_generator.samples // validation_generator.batch_size,\n",
  355. " epochs=20,\n",
  356. " #callbacks=[early_stopping, model_checkpoint]\n",
  357. ")\n"
  358. ]
  359. },
  360. {
  361. "cell_type": "code",
  362. "execution_count": 13,
  363. "metadata": {},
  364. "outputs": [
  365. {
  366. "name": "stdout",
  367. "output_type": "stream",
  368. "text": [
  369. "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 1s/step - accuracy: 0.8704 - loss: 0.4312\n",
  370. "Validation loss: 0.4050602614879608\n",
  371. "Validation accuracy: 0.8870673775672913\n"
  372. ]
  373. }
  374. ],
  375. "source": [
  376. "loss, accuracy = model.evaluate(validation_generator)\n",
  377. "print(f'Validation loss: {loss}')\n",
  378. "print(f'Validation accuracy: {accuracy}')\n"
  379. ]
  380. }
  381. ],
  382. "metadata": {
  383. "kernelspec": {
  384. "display_name": "base",
  385. "language": "python",
  386. "name": "python3"
  387. },
  388. "language_info": {
  389. "codemirror_mode": {
  390. "name": "ipython",
  391. "version": 3
  392. },
  393. "file_extension": ".py",
  394. "mimetype": "text/x-python",
  395. "name": "python",
  396. "nbconvert_exporter": "python",
  397. "pygments_lexer": "ipython3",
  398. "version": "3.11.7"
  399. }
  400. },
  401. "nbformat": 4,
  402. "nbformat_minor": 2
  403. }