Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.

README.md 3.7 KiB

3 giorni fa
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Mushroom Classification Model - JarviSpore
  2. This repository contains **JarviSpore**, a mushroom image classification model trained on a multi-class dataset with 23 different types of mushrooms. Developed from scratch with TensorFlow and Keras, this model aims to provide accurate mushroom identification using advanced deep learning techniques, including *Grad-CAM* for interpreting predictions. This project explores the performance of from-scratch models compared to transfer learning.
  3. ## Model Details
  4. - **Architecture**: Custom CNN (Convolutional Neural Network)
  5. - **Number of Classes**: 23 mushroom classes
  6. - **Input Format**: RGB images resized to 224x224 pixels
  7. - **Framework**: TensorFlow & Keras
  8. - **Training**: Conducted on a machine with an i9 14900k processor, 192GB RAM, and an RTX 3090 GPU
  9. ## Key Features
  10. 1. **Multi-Class Classification**: The model can predict among 23 mushroom species.
  11. 2. **Regularization**: Includes L2 regularization and Dropout to prevent overfitting.
  12. 3. **Class Weighting**: Manages dataset imbalances by applying specific weights for each class.
  13. 4. **Grad-CAM Visualization**: Utilizes Grad-CAM to generate heatmaps, allowing visualization of the regions influencing the model's predictions.
  14. ## Model Training
  15. The model was trained using a structured dataset directory with data split as follows:
  16. - `train`: Balanced training dataset
  17. - `validation`: Validation set to monitor performance
  18. - `test`: Test set to evaluate final accuracy
  19. Main training hyperparameters include:
  20. - **Batch Size**: 32
  21. - **Epochs**: 20 with Early Stopping
  22. - **Learning Rate**: 0.0001
  23. Training was tracked and logged via MLflow, including accuracy and loss curves, as well as the best model weights saved automatically.
  24. ## Model Usage
  25. ### Prerequisites
  26. Ensure the following libraries are installed:
  27. ```bash
  28. pip install tensorflow pillow matplotlib numpy
  29. ```
  30. ### Loading the Model
  31. To load and use the model for predictions:
  32. ```python
  33. import tensorflow as tf
  34. from PIL import Image
  35. import numpy as np
  36. # Load the model
  37. model = tf.keras.models.load_model("path_to_model.h5")
  38. # Prepare an image for prediction
  39. def prepare_image(image_path):
  40. img = Image.open(image_path).convert("RGB")
  41. img = img.resize((224, 224))
  42. img_array = tf.keras.preprocessing.image.img_to_array(img)
  43. img_array = np.expand_dims(img_array, axis=0)
  44. return img_array
  45. # Prediction
  46. image_path = "path_to_image.jpg"
  47. img_array = prepare_image(image_path)
  48. predictions = model.predict(img_array)
  49. predicted_class = np.argmax(predictions[0])
  50. print(f"Predicted Class: {predicted_class}")
  51. ```
  52. ### Grad-CAM Visualization
  53. The integrated *Grad-CAM* functionality allows interpretation of the model's predictions. To use it, select an image and apply the Grad-CAM function to display the heatmap overlaid on the original image, highlighting areas influencing the model.
  54. Grad-CAM example usage:
  55. ```python
  56. # Example usage of the make_gradcam_heatmap function
  57. heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name="last_conv_layer_name")
  58. # Superimpose the heatmap on the original image
  59. superimposed_img = superimpose_heatmap(Image.open(image_path), heatmap)
  60. superimposed_img.show()
  61. ```
  62. ## Evaluation
  63. The model was evaluated on the test set with an average accuracy above random chance, showing promising results for a first from-scratch version.
  64. ## Contributing
  65. Contributions to improve accuracy or add new features (e.g., other visualization techniques or advanced optimization) are welcome. Please submit a pull request with relevant modifications.
  66. ## License
  67. This model is licensed under a controlled license: please refer to the `LICENSE` file for details. You may use this model for personal projects, but any modifications or redistribution must be approved.