snnTorch Documentation
Introduction
The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern deep learning is that the brain encodes information in spikes rather than continuous activations. snnTorch is a Python package for performing gradientbased learning with spiking neural networks. It extends the capabilities of PyTorch, taking advantage of its GPU accelerated tensor computation and applying it to networks of spiking neurons. Predesigned spiking neuron models are seamlessly integrated within the PyTorch framework and can be treated as recurrent activation units.
snnTorch Structure
snnTorch contains the following components:
Component 
Description 

a spiking neuron library like torch.nn, deeply integrated with autograd 

variations of backpropagation commonly used with SNNs 

common arithmetic operations on spikes, e.g., loss, regularization etc. 

a library for spike generation and data conversion 

visualization tools for spikebased data using matplotlib and celluloid 

contains popular neuromorphic datasets 

optional surrogate gradient functions 

dataset utility functions 
snnTorch is designed to be intuitively used with PyTorch, as though each spiking neuron were simply another activation in a sequence of layers. It is therefore agnostic to fullyconnected layers, convolutional layers, residual connections, etc.
At present, the neuron models are represented by recursive functions which removes the need to store membrane potential traces for all neurons in a system in order to calculate the gradient. The lean requirements of snnTorch enable small and large networks to be viably trained on CPU, where needed. Provided that the network models and tensors are loaded onto CUDA, snnTorch takes advantage of GPU acceleration in the same way as PyTorch.
Citation
If you find snnTorch useful in your work, please cite the following source:
@article{eshraghian2021training,
title={Training spiking neural networks using lessons from deep learning},
author={Eshraghian, Jason K and Ward, Max and Neftci, Emre and Wang, Xinxin
and Lenz, Gregor and Dwivedi, Girish and Bennamoun, Mohammed and Jeong, Doo Seok
and Lu, Wei D},
journal={arXiv preprint arXiv:2109.12894},
year={2021}
}
Let us know if you are using snnTorch in any interesting work, research or blogs, as we would love to hear more about it! Reach out at snntorch@gmail.com.
Requirements
The following packages need to be installed to use snnTorch:
torch >= 1.1.0
numpy >= 1.17
pandas
matplotlib
math
They are automatically installed if snnTorch is installed using the pip command. Ensure the correct version of torch is installed for your system to enable CUDA compatibility.
Installation
Run the following to install:
$ python
$ pip install snntorch
To install snnTorch from source instead:
$ git clone https://github.com/jeshraghian/snnTorch
$ cd snnTorch
$ python setup.py install
To install snntorch with conda:
$ conda install c condaforge snntorch
API & Examples
A complete API is available here. Examples, tutorials and Colab notebooks are provided.
Quickstart
Here are a few ways you can get started with snnTorch:
For a quick example to run snnTorch, see the following snippet, or test the quickstart notebook:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate
num_steps = 25 # number of time steps
batch_size = 1
beta = 0.5 # neuron decay rate
spike_grad = surrogate.fast_sigmoid()
net = nn.Sequential(
nn.Conv2d(1, 8, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
nn.Conv2d(8, 16, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
nn.Flatten(),
nn.Linear(16 * 4 * 4, 10),
snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True)
)
# random input data
data_in = torch.rand(num_steps, batch_size, 1, 28, 28)
spike_recording = []
for step in range(num_steps):
spike, state = net(data_in[step])
spike_recording.append(spike)
If you’re feeling lazy and want the training process to be taken care of:
import snntorch.functional as SF
from snntorch import backprop
# correct class should fire 80% of the time
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
optimizer = torch.optim.Adam(net.parameters(), lr=1e3, betas=(0.9, 0.999))
# train for one epoch using the backprop through time algorithm
# assume train_loader is a DataLoader with timevarying input
avg_loss = backprop.BPTT(net, train_loader, optimizer=optimizer,
num_steps=num_steps, criterion=loss_fn)
A Deep Dive into SNNs
If you wish to learn all the fundamentals of training spiking neural networks, from neuron models, to the neural code, up to backpropagation, the snnTorch tutorial series is a great place to begin. It consists of interactive notebooks with complete explanations that can get you up to speed.
Contributing
If you’re ready to contribute to snnTorch, instructions to do so can be found here.
Acknowledgments
snnTorch was initially developed by Jason K. Eshraghian in the Lu Group (University of Michigan).
Additional contributions were made by Xinxin Wang, Vincent Sun, and Emre Neftci.
Several features in snnTorch were inspired by the work of Friedemann Zenke, Emre Neftci, Doo Seok Jeong, Sumit Bam Shrestha and Garrick Orchard.
License & Copyright
snnTorch is licensed under the GNU General Public License v3.0: https://www.gnu.org/licenses/gpl3.0.en.html.