Skip to content

Commit

Permalink
Merge pull request #8 from arplaboratory/baseline
Browse files Browse the repository at this point in the history
fix name
  • Loading branch information
xjh19971 authored May 4, 2024
2 parents 10c3853 + 4e65810 commit 58c5576
Showing 1 changed file with 49 additions and 49 deletions.
98 changes: 49 additions & 49 deletions local_pipeline/datasets_4cor_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,49 +80,49 @@ def __init__(self, args, augment=False):
self.database_transform = base_transform
self.database_transform_ori = base_transform_ori

def rotate_transform(self, rotation, four_point_org, four_point_1, four_point_org_permute, four_point_1_permute):
def rotate_transform(self, rotation, four_point_org, four_point_1, four_point_org_augment, four_point_1_augment):
center_x_org = torch.tensor((self.args.resize_width - 1)/2)
center_x_1 = (four_point_1[0, 0, :] + four_point_1[0, 3, :])/2
four_point_org_permute[0, 0, 0] = (four_point_org[0, 0, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 0, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_permute[0, 0, 1] = (four_point_org[0, 0, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 0, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_permute[0, 1, 0] = (four_point_org[0, 1, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 1, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_permute[0, 1, 1] = (four_point_org[0, 1, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 1, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_permute[0, 2, 0] = (four_point_org[0, 2, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 2, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_permute[0, 2, 1] = (four_point_org[0, 2, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 2, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_permute[0, 3, 0] = (four_point_org[0, 3, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 3, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_permute[0, 3, 1] = (four_point_org[0, 3, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 3, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_1_permute[0, 0, 0] = (four_point_1[0, 0, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 0, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_permute[0, 0, 1] = (four_point_1[0, 0, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 0, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_permute[0, 1, 0] = (four_point_1[0, 1, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 1, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_permute[0, 1, 1] = (four_point_1[0, 1, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 1, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_permute[0, 2, 0] = (four_point_1[0, 2, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 2, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_permute[0, 2, 1] = (four_point_1[0, 2, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 2, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_permute[0, 3, 0] = (four_point_1[0, 3, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 3, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_permute[0, 3, 1] = (four_point_1[0, 3, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 3, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_org_augment[0, 0, 0] = (four_point_org[0, 0, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 0, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_augment[0, 0, 1] = (four_point_org[0, 0, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 0, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_augment[0, 1, 0] = (four_point_org[0, 1, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 1, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_augment[0, 1, 1] = (four_point_org[0, 1, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 1, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_augment[0, 2, 0] = (four_point_org[0, 2, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 2, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_augment[0, 2, 1] = (four_point_org[0, 2, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 2, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_org_augment[0, 3, 0] = (four_point_org[0, 3, 0] - center_x_org) * torch.cos(rotation) - (four_point_org[0, 3, 1] - center_x_org) * torch.sin(rotation) + center_x_org
four_point_org_augment[0, 3, 1] = (four_point_org[0, 3, 0] - center_x_org) * torch.sin(rotation) + (four_point_org[0, 3, 1] - center_x_org) * torch.cos(rotation) + center_x_org
four_point_1_augment[0, 0, 0] = (four_point_1[0, 0, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 0, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_augment[0, 0, 1] = (four_point_1[0, 0, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 0, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_augment[0, 1, 0] = (four_point_1[0, 1, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 1, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_augment[0, 1, 1] = (four_point_1[0, 1, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 1, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_augment[0, 2, 0] = (four_point_1[0, 2, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 2, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_augment[0, 2, 1] = (four_point_1[0, 2, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 2, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
four_point_1_augment[0, 3, 0] = (four_point_1[0, 3, 0] - center_x_1[0]) * torch.cos(rotation) - (four_point_1[0, 3, 1] - center_x_1[1]) * torch.sin(rotation) + center_x_1[0]
four_point_1_augment[0, 3, 1] = (four_point_1[0, 3, 0] - center_x_1[0]) * torch.sin(rotation) + (four_point_1[0, 3, 1] - center_x_1[1]) * torch.cos(rotation) + center_x_1[1]
# print("ori:", four_point_org[0, 0, 0], four_point_org[0, 0, 1], four_point_1[0, 0, 0], four_point_1[0, 0, 1])
# print("now:", four_point_org_permute[0, 0, 0], four_point_org_permute[0, 0, 1], four_point_1_permute[0, 0, 0], four_point_1_permute[0, 0, 1])
# print("now:", four_point_org_augment[0, 0, 0], four_point_org_augment[0, 0, 1], four_point_1_augment[0, 0, 0], four_point_1_augment[0, 0, 1])
# print("center:", center_x_1, four_point_1[0, 0, :], four_point_1[0, 3, :])
return four_point_org_permute, four_point_1_permute
return four_point_org_augment, four_point_1_augment

def resize_transform(self, scale_factor, beta, alpha, four_point_org_permute, four_point_1_permute):
def resize_transform(self, scale_factor, beta, alpha, four_point_org_augment, four_point_1_augment):
offset = self.args.resize_width * (1 - scale_factor) / 2
four_point_org_permute[0, 0, 0] += offset
four_point_org_permute[0, 0, 1] += offset
four_point_org_permute[0, 1, 0] -= offset
four_point_org_permute[0, 1, 1] += offset
four_point_org_permute[0, 2, 0] += offset
four_point_org_permute[0, 2, 1] -= offset
four_point_org_permute[0, 3, 0] -= offset
four_point_org_permute[0, 3, 1] -= offset
four_point_1_permute[0, 0, 0] += offset * beta / alpha
four_point_1_permute[0, 0, 1] += offset * beta / alpha
four_point_1_permute[0, 1, 0] -= offset * beta / alpha
four_point_1_permute[0, 1, 1] += offset * beta / alpha
four_point_1_permute[0, 2, 0] += offset * beta / alpha
four_point_1_permute[0, 2, 1] -= offset * beta / alpha
four_point_1_permute[0, 3, 0] -= offset * beta / alpha
four_point_1_permute[0, 3, 1] -= offset * beta / alpha
return four_point_org_permute, four_point_1_permute
four_point_org_augment[0, 0, 0] += offset
four_point_org_augment[0, 0, 1] += offset
four_point_org_augment[0, 1, 0] -= offset
four_point_org_augment[0, 1, 1] += offset
four_point_org_augment[0, 2, 0] += offset
four_point_org_augment[0, 2, 1] -= offset
four_point_org_augment[0, 3, 0] -= offset
four_point_org_augment[0, 3, 1] -= offset
four_point_1_augment[0, 0, 0] += offset * beta / alpha
four_point_1_augment[0, 0, 1] += offset * beta / alpha
four_point_1_augment[0, 1, 0] -= offset * beta / alpha
four_point_1_augment[0, 1, 1] += offset * beta / alpha
four_point_1_augment[0, 2, 0] += offset * beta / alpha
four_point_1_augment[0, 2, 1] -= offset * beta / alpha
four_point_1_augment[0, 3, 0] -= offset * beta / alpha
four_point_1_augment[0, 3, 1] -= offset * beta / alpha
return four_point_org_augment, four_point_1_augment

def __getitem__(self, query_PIL_image, database_PIL_image, query_utm, database_utm):
if hasattr(self, "rng") and self.rng is None:
Expand Down Expand Up @@ -183,48 +183,48 @@ def __getitem__(self, query_PIL_image, database_PIL_image, query_utm, database_u

if self.augnent:
#augnent
four_point_org_permute = four_point_org.clone()
four_point_1_permute = four_point_1.clone()
four_point_org_augment = four_point_org.clone()
four_point_1_augment = four_point_1.clone()
beta = 512/self.args.resize_width
if self.args.eval_model is None: # EVAL
permute_type_single = random.choice(self.permute_type)
if permute_type_single == "rotate":
rotation = torch.tensor(random.random() - 0.5) * 2 * self.args.rotate_max # on 256x256
four_point_org_permute, four_point_1_permute = self.rotate_transform(rotation, four_point_org, four_point_1, four_point_org_permute, four_point_1_permute)
four_point_org_augment, four_point_1_augment = self.rotate_transform(rotation, four_point_org, four_point_1, four_point_org_augment, four_point_1_augment)
elif permute_type_single == "resize":
scale_factor = 1 + (random.random() - 0.5) * 2 * self.args.resize_max # on 256x256
assert scale_factor > 0
four_point_org_permute, four_point_1_permute = self.resize_transform(scale_factor, beta, alpha, four_point_org_permute, four_point_1_permute)
four_point_org_augment, four_point_1_augment = self.resize_transform(scale_factor, beta, alpha, four_point_org_augment, four_point_1_augment)
elif permute_type_single == "perspective":
for p in range(4):
for xy in range(2):
t1 = random.randint(-self.args.perspective_max, self.args.perspective_max)
four_point_org_permute[0, p, xy] += t1 # original for 256
four_point_1_permute[0, p, xy] += t1 * beta / alpha # original for 256 then to 512 in 1536 scale then to 256 in 1536 scale
four_point_org_augment[0, p, xy] += t1 # original for 256
four_point_1_augment[0, p, xy] += t1 * beta / alpha # original for 256 then to 512 in 1536 scale then to 256 in 1536 scale
elif permute_type_single == "no":
pass
else:
raise NotImplementedError()
else:
if self.args.rotate_max!=0:
rotation = torch.tensor(self.rng.random() - 0.5) * 2 * self.args.rotate_max # on 256x256
four_point_org_permute, four_point_1_permute = self.rotate_transform(rotation, four_point_org, four_point_1, four_point_org_permute, four_point_1_permute)
four_point_org_augment, four_point_1_augment = self.rotate_transform(rotation, four_point_org, four_point_1, four_point_org_augment, four_point_1_augment)
elif self.args.resize_max!=0:
scale_factor = 1 + (self.rng.random() - 0.5) * 2 * self.args.resize_max # on 256x256
assert scale_factor > 0
four_point_org_permute, four_point_1_permute = self.resize_transform(scale_factor, beta, alpha, four_point_org_permute, four_point_1_permute)
four_point_org_augment, four_point_1_augment = self.resize_transform(scale_factor, beta, alpha, four_point_org_augment, four_point_1_augment)
elif self.args.perspective_max!=0:
for p in range(4):
for xy in range(2):
t1 = self.rng.integers(-self.args.perspective_max, self.args.perspective_max) # on 256x256
four_point_org_permute[0, p, xy] += t1 # original for 256
four_point_1_permute[0, p, xy] += t1 * beta / alpha # original for 256 then to 512 in 1536 scale then to 256 in 1536 scale
four_point_org_augment[0, p, xy] += t1 # original for 256
four_point_1_augment[0, p, xy] += t1 * beta / alpha # original for 256 then to 512 in 1536 scale then to 256 in 1536 scale
else:
raise NotImplementedError()
H = tgm.get_perspective_transform(four_point_org, four_point_org_permute)
H = tgm.get_perspective_transform(four_point_org, four_point_org_augment)
H_inverse = torch.inverse(H)
img1 = tgm.warp_perspective(img1.unsqueeze(0), H_inverse, (self.args.resize_width, self.args.resize_width)).squeeze(0)
four_point_1 = four_point_1_permute
four_point_1 = four_point_1_augment

H = tgm.get_perspective_transform(four_point_org, four_point_1)
H = H.squeeze()
Expand Down

0 comments on commit 58c5576

Please sign in to comment.