Skip to content

Commit

Permalink
Merge pull request #4 from andreped/gpu-support
Browse files Browse the repository at this point in the history
added GPU support
  • Loading branch information
carloalbertobarbano authored Sep 12, 2021
2 parents 0296439 + ef59f58 commit efa86d9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
12 changes: 4 additions & 8 deletions torchstain/normalizers/torch_macenko_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,15 @@ def __find_HE(self, ODhat, eigvecs, alpha):
That = torch.matmul(ODhat, eigvecs)
phi = torch.atan2(That[:, 1], That[:, 0])

minPhi = torch.tensor(percentile(phi, alpha))
maxPhi = torch.tensor(percentile(phi, 100 - alpha))
minPhi = percentile(phi, alpha)
maxPhi = percentile(phi, 100 - alpha)

vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))).T).unsqueeze(1)
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))).T).unsqueeze(1)

# a heuristic to make the vector corresponding to hematoxylin first and the
# one corresponding to eosin second
if vMin[0] > vMax[0]:
HE = torch.cat((vMin, vMax), dim=1)

else:
HE = torch.cat((vMax, vMin), dim=1)
HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1))

return HE

Expand All @@ -66,7 +62,7 @@ def __compute_matrices(self, I, Io, alpha, beta):
HE = self.__find_HE(ODhat, eigvecs, alpha)

C = self.__find_concentration(OD, HE)
maxC = torch.tensor([percentile(C[0, :], 99), percentile(C[1, :], 99)])
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])

return HE, C, maxC

Expand Down
2 changes: 1 addition & 1 deletion torchstain/utils/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def percentile(t: torch.tensor, q: float) -> Union[int, float]:
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
# so that ``round()`` returns an integer, even if q is a np.float32.
k = 1 + round(.01 * float(q) * (t.numel() - 1))
result = t.view(-1).kthvalue(k).values.item()
result = t.view(-1).kthvalue(k).values
return result

0 comments on commit efa86d9

Please sign in to comment.