Skip to content

Commit

Permalink
Merge pull request #106 from sarvghotra/bug_fix_dist_sampler
Browse files Browse the repository at this point in the history
Bug fix in dist sampler that caused same data order in each epoch
  • Loading branch information
cszn authored Apr 21, 2022
2 parents 7e51c16 + fc2f79e commit f4573f3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
5 changes: 4 additions & 1 deletion main_train_drunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main(json_path='options/train_drunet.json'):
if opt['dist']:
init_dist('pytorch')
opt['rank'], opt['world_size'] = get_dist_info()

if opt['rank'] == 0:
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))

Expand Down Expand Up @@ -160,6 +160,9 @@ def main(json_path='options/train_drunet.json'):
'''

for epoch in range(1000000): # keep running
if opt['dist']:
train_sampler.set_epoch(epoch)

for i, train_data in enumerate(train_loader):

current_step += 1
Expand Down
3 changes: 3 additions & 0 deletions main_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def main(json_path='options/train_msrresnet_gan.json'):
'''

for epoch in range(1000000): # keep running
if opt['dist']:
train_sampler.set_epoch(epoch)

for i, train_data in enumerate(train_loader):

current_step += 1
Expand Down
5 changes: 4 additions & 1 deletion main_train_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(json_path='options/train_msrresnet_psnr.json'):
if opt['dist']:
init_dist('pytorch')
opt['rank'], opt['world_size'] = get_dist_info()

if opt['rank'] == 0:
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))

Expand Down Expand Up @@ -165,6 +165,9 @@ def main(json_path='options/train_msrresnet_psnr.json'):
'''

for epoch in range(1000000): # keep running
if opt['dist']:
train_sampler.set_epoch(epoch)

for i, train_data in enumerate(train_loader):

current_step += 1
Expand Down

0 comments on commit f4573f3

Please sign in to comment.