This project implements an image classification system using Convolutional Neural Networks (CNNs). The system is built with PyTorch and Flask, allowing for both training models on the CIFAR-10 dataset and serving a web interface for image classification.
models/: Contains the model definitions for both Simple and Advanced CNNs.simple_cnn_model.py: Defines theSimpleCNNmodel.advanced_cnn_model.py: Defines theAdvancedCNNmodel.
controllers/: Contains the Flask application.prediction_controller.py: Defines the Flask app and routes.
services/: Contains the prediction service.prediction_service.py: Provides methods for image preprocessing and prediction.
templates/: Contains the HTML templates for the web interface.index.html: Main page for image upload and URL input.
-
Clone the repository:
git clone https://github.com/svbuh/showcase_architecture_3-layer.git cd showcase_architecture_3-layer -
Create a virtual environment:
python3 -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the dependencies:
pip install -r requirements.txt
To train the model, run the train_model.py script. You can choose between the SimpleCNN and AdvancedCNN models by uncommenting the desired model in the script.
python train_model.pyThis script will:
- Load the CIFAR-10 dataset.
- Train the selected model.
- Save the trained model to a
.pthfile. - Evaluate the model on the test set and print the accuracy.
- Ensure the Flask app configuration points to the correct model path in
services/prediction_service.py. - Run the Flask app:
python controllers/prediction_controller.py
- Open your browser and navigate to
http://127.0.0.1:5000/to access the web interface.
The web interface allows users to classify images either by uploading a file or by providing an image URL.
- Provide an Image URL:
- Enter the image URL in the provided input field.
- Click on the "Classify" button to see the predicted class.
The train_model.py script can be run to train the model and evaluate it on the CIFAR-10 test set.
torch==2.3.1torchvision==0.18.1numpy==1.26.4Pillow==10.3.0matplotlib==3.9.0Werkzeug==3.0.3Flask==3.0.3requests==2.32.3
- The CIFAR-10 dataset is used for training and evaluation.
- The project uses pre-trained models from
torchvision.models.