Skip to content

Commit

Permalink
chore: Update dataset paths in code files
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashguptas committed Jun 23, 2024
1 parent 2be4fe9 commit 9bb95d7
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 11 deletions.
2 changes: 0 additions & 2 deletions CAPTCHA Images/.gitignore

This file was deleted.

20 changes: 17 additions & 3 deletions combined_code.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,25 @@ def split_image(image):
return output


# class CaptchaImagesDataset(Dataset):
# def __init__(self, root, augment=False):
# super(CaptchaImagesDataset, self).__init__()
# self.root = root
# self.augment = augment

# self.image_list = []
# for ext in ('*.png', '*.jpg'):
# self.image_list.extend(glob.glob(os.path.join(root, ext)))

class CaptchaImagesDataset(Dataset):
def __init__(self, root, augment=False):
super(CaptchaImagesDataset, self).__init__()
self.root = root
self.augment = augment

self.image_list = []
for ext in ('*.png', '*.jpg'):
self.image_list.extend(glob.glob(os.path.join(root, ext)))
print(f"Found {len(self.image_list)} images in {root}")

def __len__(self):
return len(self.image_list)
Expand All @@ -54,8 +64,13 @@ class CaptchaImagesDataset(Dataset):
def get_loader(root, batch_size):
train_dataset = CaptchaImagesDataset(root + '/train', augment=True)
val_dataset = CaptchaImagesDataset(root + '/val', augment=False)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

return train_loader, val_loader


Expand Down Expand Up @@ -384,5 +399,4 @@ if __name__ == '__main__':
if ground_truth == decoded_text:
correct += 1

print('accuracy =', correct/len(image_list))

print('accuracy =', correct/len(image_list))
2 changes: 1 addition & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_loader(root, batch_size):


if __name__ == '__main__':
train, val = get_loader('data/CAPTCHA Images/', batch_size=2)
train, val = get_loader('/home/dev/dev_work_shrey/playing_around/data/CAPTCHA Images/', batch_size=2)
for image, labels in train:
print()
print()
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
model.eval()

correct = 0.0
image_list = glob.glob('/app/data/CAPTCHA Images/test/*')
image_list = glob.glob('/home/dev/dev_work_shrey/playing_around/data/CAPTCHA Images/test/*')
for image in tqdm(image_list):
ground_truth = image.split('/')[-1].split('.')[0]
image = Image.open(image).convert('RGB')
Expand Down
2 changes: 1 addition & 1 deletion split_train_val_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def split_train_val_test(root):


if __name__ == '__main__':
split_train_val_test('data/CAPTCHA Images')
split_train_val_test('/home/dev/dev_work_shrey/playing_around/data/CAPTCHA Images')

print()
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def train():
model = CRNN(vocab_size=vocab_size).to(device)
# model.load_state_dict(torch.load('output/weight.pth', map_location=device))

train_loader, val_loader = get_loader('/app/data/CAPTCHA Images', batch_size=batch_size)
train_loader, val_loader = get_loader('/home/dev/dev_work_shrey/playing_around/data/CAPTCHA Images', batch_size=batch_size)

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
Expand All @@ -77,7 +77,7 @@ def train():
print('-----------------------------------------')

if epoch == 0 or val_epoch_loss <= np.min(val_losses):
torch.save(model.state_dict(), '/app/output/weight.pth')
torch.save(model.state_dict(), '/home/dev/dev_work_shrey/playing_around/output/weight.pth')

train_losses.append(train_epoch_loss)
val_losses.append(val_epoch_loss)
Expand All @@ -90,4 +90,4 @@ def train():


if __name__ == "__main__":
train()
train()

0 comments on commit 9bb95d7

Please sign in to comment.