Skip to content

Commit

Permalink
Merge pull request #11 from QVPR/reconfig
Browse files Browse the repository at this point in the history
Reconfiguration of codebase
  • Loading branch information
AdamDHines authored Dec 12, 2023
2 parents 29597f7 + f405246 commit 519ff7b
Show file tree
Hide file tree
Showing 24 changed files with 54 additions and 246 deletions.
26 changes: 9 additions & 17 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
__pycache__/
.ipynb_checkpoints/
src/__pycache__/
dataset/conv/
dataset/Dusk/
dataset/fall/
dataset/model/
dataset/output_database/
dataset/output_query/
dataset/Rain/
dataset/spring/
dataset/summer/
dataset/Sun/
dataset/winter/
dataset/event.csv/
models/VPRTempo78415685001.pth
models/VPRTempoQuant78415685001.pth
VPRTempo.egg-info/
.pyest_cache/
vprtempo/__pycache__/
vprtempo/dataset/fall/
vprtempo/dataset/spring/
vprtempo/dataset/summer/
vprtempo/dataset/winter/
vprtempo/dataset/event.csv
vprtempo/output/
vprtempo/src/__pycache__/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Once downloaded, please install the required dependencies to run the network thr
Dependencies for VPRTempo can downloaded from our [PyPi package](https://pypi.org/project/VPRTempo/).

```python
pip3 install VPRTempo
pip install vprtempo
```
If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get Started](https://pytorch.org/get-started/locally/) page to install the required software versions for your hardware and operating system.

Expand Down
160 changes: 0 additions & 160 deletions dataset/event.csv

This file was deleted.

Binary file removed dataset/test/images-00202.png
Binary file not shown.
Binary file removed dataset/test/images-07028.png
Binary file not shown.
17 changes: 8 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,16 @@
'''
Imports
'''
import argparse
import sys
sys.path.append('./src')
sys.path.append('./vprtempo')
import argparse

import torch.quantization as quantization

from VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model
from VPRTempo import VPRTempo, run_inference
from VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
from VPRTempoQuant import VPRTempoQuant, run_inference_quant
from loggers import model_logger, model_logger_quant
from vprtempo.VPRTempo import VPRTempo, run_inference
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model

def initialize_and_run_model(args,dims):
# If user wants to train a new network
Expand Down Expand Up @@ -109,7 +108,7 @@ def parse_network(use_quantize=False, train_new_model=False):
# Define the dataset arguments
parser.add_argument('--dataset', type=str, default='nordland',
help="Dataset to use for training and/or inferencing")
parser.add_argument('--data_dir', type=str, default='./dataset/',
parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
help="Directory where dataset files are stored")
parser.add_argument('--num_places', type=int, default=500,
help="Number of places to use for training and/or inferencing")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# define the setup
setup(
name="VPRTempo",
version="1.1.2",
version="1.1.3",
description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
Binary file removed src/.DS_Store
Binary file not shown.
20 changes: 7 additions & 13 deletions vprtempo/VPRTempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,16 @@

import os
import torch
import gc
import sys
sys.path.append('./src')
sys.path.append('./models')
sys.path.append('./output')
sys.path.append('./dataset')

import blitnet as bn

import numpy as np
import torch.nn as nn
import vprtempo.src.blitnet as bn

from dataset import CustomImageDataset, ProcessImage
from torch.utils.data import DataLoader
from tqdm import tqdm
from prettytable import PrettyTable
from metrics import recallAtK
from torch.utils.data import DataLoader
from vprtempo.src.metrics import recallAtK
from vprtempo.src.dataset import CustomImageDataset, ProcessImage

class VPRTempo(nn.Module):
def __init__(self, dims, args=None, logger=None):
Expand All @@ -60,7 +54,7 @@ def __init__(self, dims, args=None, logger=None):

self.logger = logger
# Set the dataset file
self.dataset_file = os.path.join('./dataset', self.dataset + '.csv')
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')

# Layer dict to keep track of layer names and their order
self.layer_dict = {}
Expand Down Expand Up @@ -220,7 +214,7 @@ def run_inference(models, model_name):
persistent_workers=True)

# Load the model
models[0].load_model(models, os.path.join('./models', model_name))
models[0].load_model(models, os.path.join('./vprtempo/models', model_name))

# Retrieve layer names for inference
layer_names = list(models[0].layer_dict.keys())
Expand Down
25 changes: 8 additions & 17 deletions vprtempo/VPRTempoQuant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,17 @@

import os
import torch
import subprocess
import sys
sys.path.append('./src')
sys.path.append('./models')
sys.path.append('./output')
sys.path.append('./dataset')

import blitnet as bn

import numpy as np
import torch.nn as nn
import torch.quantization as quantization
import vprtempo.src.blitnet as bn

from loggers import model_logger_quant
from VPRTempoQuantTrain import generate_model_name_quant
from dataset import CustomImageDataset, ProcessImage
from torch.utils.data import DataLoader
from torch.ao.quantization import QuantStub, DeQuantStub
from tqdm import tqdm
from prettytable import PrettyTable
from metrics import recallAtK
from torch.utils.data import DataLoader
from vprtempo.src.metrics import recallAtK
from torch.ao.quantization import QuantStub, DeQuantStub
from vprtempo.src.dataset import CustomImageDataset, ProcessImage

#from main import parse_network

Expand All @@ -59,7 +50,7 @@ def __init__(self, dims, args=None, logger=None):
setattr(self, arg, getattr(args, arg))
setattr(self, 'dims', dims)
# Set the dataset file
self.dataset_file = os.path.join('./dataset', self.dataset + '.csv')
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')

# Set the model logger and return the device
self.logger = logger
Expand Down Expand Up @@ -240,7 +231,7 @@ def run_inference_quant(models, model_name, qconfig):
persistent_workers=True)

# Load the model
models[0].load_model(models, os.path.join('./models', model_name))
models[0].load_model(models, os.path.join('./vprtempo/models', model_name))

# Use evaluate method for inference accuracy
with torch.no_grad():
Expand Down
16 changes: 5 additions & 11 deletions vprtempo/VPRTempoQuantTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,17 @@

import os
import torch
import gc
import sys
sys.path.append('./src')
sys.path.append('./models')
sys.path.append('./output')
sys.path.append('./dataset')

import blitnet as bn
import numpy as np
import torch.nn as nn
import vprtempo.src.blitnet as bn
import torch.quantization as quantization
import torchvision.transforms as transforms

from dataset import CustomImageDataset, ProcessImage
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.ao.quantization import QuantStub, DeQuantStub
from tqdm import tqdm
from vprtempo.src.dataset import CustomImageDataset, ProcessImage

class VPRTempoQuantTrain(nn.Module):
def __init__(self, args, dims, logger):
Expand All @@ -61,7 +55,7 @@ def __init__(self, args, dims, logger):
self.logger = logger

# Set the dataset file
self.dataset_file = os.path.join('./dataset', self.dataset + '.csv')
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')

# Add quantization stubs for Quantization Aware Training (QAT)
self.quant = QuantStub()
Expand Down Expand Up @@ -288,4 +282,4 @@ def train_new_model_quant(models, model_name, qconfig):
# After training the current layer, add it to the list of trained layer

# Save the model
model.save_model(trained_models,os.path.join('./models', model_name))
model.save_model(trained_models,os.path.join('./vprtempo/models', model_name))
Loading

0 comments on commit 519ff7b

Please sign in to comment.