Admir Selimovic Research+Dev Index

Message Passing Neural Networks: 3D Shape Class Prediction

View on GitHub

Summary

This project is an implementation of the Message Passing Neural Networks (MPNN) for 3D model classification. The implementation utilises the PyTorch Geometric library and PyTorch Lightning for efficient neural network management and training. The MPNN model implemented in this project is based on .

Features

  • Utilizes Message Passing Neural Networks (MPNNs).
  • Offers multiple configurations:
    • Basic: Without any edge attributes.
    • Dist: Uses Euclidean distance as edge embeddings.
    • Relpos: Uses relative position vectors as edge embeddings.
  • Integrates PyTorch Lightning for training management.
  • Loads custom datasets from the ModelNet10 dataset that has been preprocessed.

Getting Started

Prerequisites

  • Python 3.x
  • PyTorch and PyTorch Geometric
  • PyTorch Lightning
  • YAML for configuration management

Setup

  1. Clone this repository.
  2. git clone https://github.com/admir-selimovic/mpnn.git
    cd mpnn
  3. Install the necessary libraries.
  4. pip install torch torch_geometric pytorch_lightning yaml
  5. Modify the configs/paths.yml and configs/hyperparameters.yml to suit your dataset paths and desired hyperparameters respectively.

Usage

Execute the preprocessing script:

python data_preprocessing.py

Execute the training script:

python train_model.py

When prompted, input the model type you want to train (basic, dist, or relpos).

Structure

The project contains the following key components:

  • models/
    : Contains the MPNN and MNISTClassifier implementations.
  • datasets/
    : Houses the custom dataset loader for the ModelNet10 dataset.
  • configs/
    : Contains YAML configurations for hyperparameters and dataset paths.

Contributions

Contributions are welcome! Please fork the repository and create a pull request with your changes.

Acknowledgments

References

Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl. "Neural Message Passing for Quantum Chemistry". Proceedings of the 34th International Conference on Machine Learning, 2017. pp. 1263-1272.