Summary
This project is an implementation of the Message Passing Neural Networks (MPNN) for 3D model classification. The implementation utilises the PyTorch Geometric library and PyTorch Lightning for efficient neural network management and training. The MPNN model implemented in this project is based on .Features
- Utilizes Message Passing Neural Networks (MPNNs).
- Offers multiple configurations:
- Basic: Without any edge attributes.
- Dist: Uses Euclidean distance as edge embeddings.
- Relpos: Uses relative position vectors as edge embeddings.
- Integrates PyTorch Lightning for training management.
- Loads custom datasets from the ModelNet10 dataset that has been preprocessed.
Getting Started
Prerequisites
- Python 3.x
- PyTorch and PyTorch Geometric
- PyTorch Lightning
- YAML for configuration management
Setup
- Clone this repository.
- Install the necessary libraries.
- Modify the
configs/paths.yml
andconfigs/hyperparameters.yml
to suit your dataset paths and desired hyperparameters respectively.
git clone https://github.com/admir-selimovic/mpnn.git
cd mpnn
pip install torch torch_geometric pytorch_lightning yaml
Usage
Execute the preprocessing script:
python data_preprocessing.py
Execute the training script:
python train_model.py
When prompted, input the model type you want to train (basic
, dist
, or relpos
).
Structure
The project contains the following key components:
: Contains the MPNN and MNISTClassifier implementations.models/
: Houses the custom dataset loader for the ModelNet10 dataset.datasets/
: Contains YAML configurations for hyperparameters and dataset paths.configs/
Contributions
Contributions are welcome! Please fork the repository and create a pull request with your changes.