Skip to content

Commit

Permalink
Added some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasrebra committed Jul 12, 2024
1 parent a00438e commit b4a73b2
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 0 deletions.
4 changes: 4 additions & 0 deletions scripts/yolov1_package/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
torchvision
pytest
flake8
Empty file.
13 changes: 13 additions & 0 deletions scripts/yolov1_package/tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest
import torch
from yolov1.loss import yolo_loss

class TestLoss(unittest.TestCase):
def test_yolo_loss(self):
y_true = torch.randn(1, 7 * 7 * 30)
y_pred = torch.randn(1, 7 * 7 * 30)
loss = yolo_loss(y_true, y_pred)
self.assertGreaterEqual(loss.item(), 0)

if __name__ == '__main__':
unittest.main()
17 changes: 17 additions & 0 deletions scripts/yolov1_package/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import unittest
import torch
from yolov1.model import YOLOv1

class TestYOLOv1(unittest.TestCase):
def test_model_creation(self):
model = YOLOv1()
self.assertIsNotNone(model)

def test_model_forward_pass(self):
model = YOLOv1()
input_tensor = torch.randn(1, 3, 448, 448)
output = model(input_tensor)
self.assertEqual(output.shape, (1, 7 * 7 * 30))

if __name__ == '__main__':
unittest.main()
20 changes: 20 additions & 0 deletions scripts/yolov1_package/tests/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
import torch
from yolov1.postprocessing import non_max_suppression, parse_yolo_output

class TestPostprocessing(unittest.TestCase):
def test_non_max_suppression(self):
boxes = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
scores = torch.tensor([0.9, 0.8])
selected_indices = non_max_suppression(boxes, scores, iou_threshold=0.5)
self.assertGreaterEqual(len(selected_indices), 0)

def test_parse_yolo_output(self):
predictions = torch.randn(1, 7 * 7 * 30)
boxes, scores, classes = parse_yolo_output(predictions)
self.assertIsInstance(boxes, list)
self.assertIsInstance(scores, list)
self.assertIsInstance(classes, list)

if __name__ == '__main__':
unittest.main()
12 changes: 12 additions & 0 deletions scripts/yolov1_package/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import unittest
from yolov1.preprocessing import create_dataset

class TestPreprocessing(unittest.TestCase):
def test_create_dataset(self):
filenames = ['dummy_path1', 'dummy_path2']
labels = [0, 1]
dataset = create_dataset(filenames, labels)
self.assertEqual(len(dataset), 2)

if __name__ == '__main__':
unittest.main()

0 comments on commit b4a73b2

Please sign in to comment.