Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented support for MultiMacenkoNormalizer across all backends #66

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)
| Macenko | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✓ | ✓ |
| Modified Reinhard | ✓ | ✓ | ✓ |
| Multi-target Macenko | ✗ | ✓ | ✗ |
| Multi-target Macenko | ✓ | ✓ | ✓ |
| Macenko-Aug | ✓ | ✓ | ✓ |

## Backend comparison
Expand Down
38 changes: 38 additions & 0 deletions tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import tensorflow as tf
import numpy as np


def test_cov():
x = np.random.randn(10, 10)
cov_np = np.cov(x)
cov_t = torchstain.tf.utils.cov(x)

np.testing.assert_almost_equal(cov_np, cov_t.numpy())


def test_percentile():
x = np.random.randn(10, 10)
p = 20
Expand All @@ -20,6 +22,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)


def test_macenko_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -48,6 +51,7 @@ def test_macenko_tf():
# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)


def test_reinhard_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -75,3 +79,37 @@ def test_reinhard_tf():

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)


def test_multistain_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = lambda x: tf.convert_to_tensor(np.moveaxis(x, -1, 0).astype("float32")) # * 255
t_to_transform = T(to_transform)
target_transformed = T(target)

# move channel to first
target_numpy = np.moveaxis(target, -1, 0)
to_transform_numpy = np.moveaxis(to_transform, -1, 0)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend='numpy')
normalizer.fit([target_numpy, target_numpy, target_numpy])

tf_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend='tensorflow')
tf_normalizer.fit([target_transformed, target_transformed, target_transformed])

# transform
result_numpy, _, _ = normalizer.normalize(I=to_transform_numpy, stains=True)
result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32") / 255.
result_tf = result_tf.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)
41 changes: 41 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
def setup_function(fn):
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)


def test_cov():
x = np.random.randn(10, 10)
cov_np = np.cov(x)
cov_t = torchstain.torch.utils.cov(torch.tensor(x))

np.testing.assert_almost_equal(cov_np, cov_t.numpy())


def test_percentile():
x = np.random.randn(10, 10)
p = 20
Expand All @@ -26,6 +28,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)


def test_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -57,6 +60,7 @@ def test_macenko_torch():
# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)


def test_multitarget_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -122,3 +126,40 @@ def test_reinhard_torch():

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)


def test_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])
t_to_transform = T(to_transform)
target_transformed = T(target)

# move channel to first
target_numpy = np.moveaxis(target, -1, 0)
to_transform_numpy = np.moveaxis(to_transform, -1, 0)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend='numpy')
normalizer.fit([target_numpy, target_numpy, target_numpy])

torch_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend='torch')
torch_normalizer.fit([target_transformed, target_transformed, target_transformed])

# transform
result_numpy, _, _ = normalizer.normalize(I=to_transform_numpy, stains=True)
result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32") / 255.
result_torch = result_torch.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
6 changes: 4 additions & 2 deletions torchstain/base/normalizers/multitarget.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
def MultiMacenkoNormalizer(backend="torch", **kwargs):
if backend == "numpy":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend")
from torchstain.numpy.normalizers import NumpyMultiMacenkoNormalizer
return NumpyMultiMacenkoNormalizer(**kwargs)
elif backend == "torch":
from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer
return TorchMultiMacenkoNormalizer(**kwargs)
elif backend == "tensorflow":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend")
from torchstain.tf.normalizers import TensorFlowMultiMacenkoNormalizer
return TensorFlowMultiMacenkoNormalizer(**kwargs)
else:
raise Exception(f"Unsupported backend {backend}")
3 changes: 2 additions & 1 deletion torchstain/numpy/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .macenko import NumpyMacenkoNormalizer
from .reinhard import NumpyReinhardNormalizer
from .reinhard import NumpyReinhardNormalizer
from .multitarget import NumpyMultiMacenkoNormalizer
117 changes: 117 additions & 0 deletions torchstain/numpy/normalizers/multitarget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np

class NumpyMultiMacenkoNormalizer:
def __init__(self, norm_mode="avg-post"):
self.norm_mode = norm_mode
self.HERef = np.array([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = np.array([1.9705, 1.0308])

def __convert_rgb2od(self, I, Io, beta):
I = np.transpose(I, (1, 2, 0))
OD = -np.log((I.reshape(-1, I.shape[-1]).astype(float) + 1) / Io)
ODhat = OD[~np.any(OD < beta, axis=1)]
return OD, ODhat

def __find_phi_bounds(self, ODhat, eigvecs, alpha):
That = np.dot(ODhat, eigvecs)
phi = np.arctan2(That[:, 1], That[:, 0])

minPhi = np.percentile(phi, alpha)
maxPhi = np.percentile(phi, 100 - alpha)

return minPhi, maxPhi

def __find_HE_from_bounds(self, eigvecs, minPhi, maxPhi):
vMin = np.dot(eigvecs, [np.cos(minPhi), np.sin(minPhi)]).reshape(-1, 1)
vMax = np.dot(eigvecs, [np.cos(maxPhi), np.sin(maxPhi)]).reshape(-1, 1)

HE = np.concatenate([vMin, vMax], axis=1) if vMin[0] > vMax[0] else np.concatenate([vMax, vMin], axis=1)
return HE

def __find_HE(self, ODhat, eigvecs, alpha):
minPhi, maxPhi = self.__find_phi_bounds(ODhat, eigvecs, alpha)
return self.__find_HE_from_bounds(eigvecs, minPhi, maxPhi)

def __find_concentration(self, OD, HE):
Y = OD.T
C, _, _, _ = np.linalg.lstsq(HE, Y, rcond=None)
return C

def __compute_matrices_single(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io, beta)

cov_matrix = np.cov(ODhat.T)
eigvals, eigvecs = np.linalg.eigh(cov_matrix)
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
C = self.__find_concentration(OD, HE)
maxC = np.array([np.percentile(C[0, :], 99), np.percentile(C[1, :], 99)])

return HE, C, maxC

def fit(self, Is, Io=240, alpha=1, beta=0.15):
if self.norm_mode == "avg-post":
HEs, _, maxCs = zip(*[self.__compute_matrices_single(I, Io, alpha, beta) for I in Is])

self.HERef = np.mean(HEs, axis=0)
self.maxCRef = np.mean(maxCs, axis=0)
elif self.norm_mode == "concat":
ODs, ODhats = zip(*[self.__convert_rgb2od(I, Io, beta) for I in Is])
OD = np.vstack(ODs)
ODhat = np.vstack(ODhats)

cov_matrix = np.cov(ODhat.T)
eigvals, eigvecs = np.linalg.eigh(cov_matrix)
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
C = self.__find_concentration(OD, HE)
maxCs = np.array([np.percentile(C[0, :], 99), np.percentile(C[1, :], 99)])

self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode == "avg-pre":
ODs, ODhats = zip(*[self.__convert_rgb2od(I, Io, beta) for I in Is])

covs = [np.cov(ODhat.T) for ODhat in ODhats]
eigvecs = np.mean([np.linalg.eigh(cov)[1][:, [1, 2]] for cov in covs], axis=0)

OD = np.vstack(ODs)
ODhat = np.vstack(ODhats)

HE = self.__find_HE(ODhat, eigvecs, alpha)
C = self.__find_concentration(OD, HE)
maxCs = np.array([np.percentile(C[0, :], 99), np.percentile(C[1, :], 99)])

self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode in ["fixed-single", "stochastic-single"]:
self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta)
else:
raise ValueError("Unknown norm mode")

def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
c, h, w = I.shape

HE, C, maxC = self.__compute_matrices_single(I, Io, alpha, beta)
C = (self.maxCRef / maxC).reshape(-1, 1) * C

Inorm = Io * np.exp(-np.dot(self.HERef, C))
Inorm[Inorm > 255] = 255
Inorm = np.transpose(Inorm, (1, 0)).reshape(h, w, c).astype(np.int32)

H, E = None, None

if stains:
H = Io * np.exp(-np.dot(self.HERef[:, 0].reshape(-1, 1), C[0, :].reshape(1, -1)))
H[H > 255] = 255
H = np.transpose(H, (1, 0)).reshape(h, w, c).astype(np.int32)

E = Io * np.exp(-np.dot(self.HERef[:, 1].reshape(-1, 1), C[1, :].reshape(1, -1)))
E[E > 255] = 255
E = np.transpose(E, (1, 0)).reshape(h, w, c).astype(np.int32)

return Inorm, H, E
1 change: 1 addition & 0 deletions torchstain/tf/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
from torchstain.tf.normalizers.multitarget import TensorFlowMultiMacenkoNormalizer
Loading