-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpl_transporter.py
314 lines (240 loc) · 9.54 KB
/
pl_transporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import pandas as pd
import torchvision.io
import av
#from scipy.fftpack import fftshift, ifftshift
from phasepack.tools import rayleighmode as _rayleighmode
from phasepack.tools import lowpassfilter as _lowpassfilter
from phasepack.filtergrid import filtergrid
import time
from skimage.transform import radon, iradon, rescale
# Try and use the faster Fourier transform functions from the pyfftw module if
# available
from phasepack.tools import fft2, ifft2
import torch
import torch.utils.data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os
import torch
from PIL import Image
print("importing ssim")
# from pyssim.ssim.__init__ import *
print("imported ssim")
print("importing torchradon")
from torchradon import *
print("imported torchradon")
import json
from datetime import datetime
import socket
import torchvision
#from data import Dataset, Sampler
import transporter_orig
import transporter
import utils
#from generate_lus_data_wrist_radon import *
import pytorch_lightning as pl
import sys
import matplotlib.pyplot as plt
def normalise(img):
return (img - img.min())/(img.max() - img.min() + 0.001)
def integrated_backscatter_energy(img): #img is numpy image with 1 channel
ibs= np.cumsum(img ** 2,0)
return ibs
def indices(i, rows):
ret = np.zeros((rows-i+1,))
for i in range(ret.shape[0]):
ret[i] = ret[i] + i
return ret
#print(indices(1,3))
def shadow(img):
rows = img.shape[0]
cols = img.shape[1]
stdImg = round(rows/4)
sh = np.zeros_like(img)
for j in range(cols):
for i in range(rows):
gaussWin= np.exp(-((indices(i+1,rows))**2)/(2*(stdImg**2)))
#print(gaussWin)
sh[i,j] = np.sum(np.multiply(img[i:rows,j], np.transpose(gaussWin)) / np.sum(gaussWin))
#print(sh[i,j])
return sh
def analyticEstimator(img, nscale=5, minWaveLength=10, mult=2.1, sigmaOnf=0.55, k=2.,\
polarity=0, noiseMethod=-1):
if img.dtype not in ['float32', 'float64']:
img = np.float64(img)
imgdtype = 'float64'
else:
imgdtype = img.dtype
if img.ndim == 3:
img = img.mean(2)
rows, cols = img.shape
epsilon = 1E-4 # used to prevent /0.
IM = fft2(img) # Fourier transformed image
zeromat = np.zeros((rows, cols), dtype=imgdtype)
# Matrix for accumulating weighted phase congruency values (energy).
totalEnergy = zeromat.copy()
# Matrix for accumulating filter response amplitude values.
sumAn = zeromat.copy()
radius, u1, u2 = filtergrid(rows, cols)
# Get rid of the 0 radius value at the 0 frequency point (at top-left
# corner after fftshift) so that taking the log of the radius will not
# cause trouble.
radius[0, 0] = 1.
H = (1j * u1 - u2) / radius
lp = _lowpassfilter([rows, cols], .4, 10)
# Radius .4, 'sharpness' 10
logGaborDenom = 2. * np.log(sigmaOnf) ** 2.
for ss in range(nscale):
wavelength = minWaveLength * mult ** ss
fo = 1. / wavelength # Centre frequency of filter
logRadOverFo = np.log(radius / fo)
logGabor = np.exp(-(logRadOverFo * logRadOverFo) / logGaborDenom)
logGabor *= lp # Apply the low-pass filter
logGabor[0, 0] = 0. # Undo the radius fudge
IMF = IM * logGabor # Frequency bandpassed image
f = np.real(ifft2(IMF)) # Spatially bandpassed image
# Bandpassed monogenic filtering, real part of h contains convolution
# result with h1, imaginary part contains convolution result with h2.
h = ifft2(IMF * H)
# Squared amplitude of the h1 and h2 filters
hAmp2 = h.real * h.real + h.imag * h.imag
# Magnitude of energy
sumAn += np.sqrt(f * f + hAmp2)
# At the smallest scale estimate noise characteristics from the
# distribution of the filter amplitude responses stored in sumAn. tau
# is the Rayleigh parameter that is used to describe the distribution.
if ss == 0:
# Use median to estimate noise statistics
if noiseMethod == -1:
tau = np.median(sumAn.flatten()) / np.sqrt(np.log(4))
# Use the mode to estimate noise statistics
elif noiseMethod == -2:
tau = _rayleighmode(sumAn.flatten())
# Calculate the phase symmetry measure
# look for 'white' and 'black' spots
if polarity == 0:
totalEnergy += np.abs(f) - np.sqrt(hAmp2)
# just look for 'white' spots
elif polarity == 1:
totalEnergy += f - np.sqrt(hAmp2)
# just look for 'black' spots
elif polarity == -1:
totalEnergy += -f - np.sqrt(hAmp2)
if noiseMethod >= 0:
T = noiseMethod
else:
totalTau = tau * (1. - (1. / mult) ** nscale) / (1. - (1. / mult))
# Calculate mean and std dev from tau using fixed relationship
# between these parameters and tau. See
# <http://mathworld.wolfram.com/RayleighDistribution.html>
EstNoiseEnergyMean = totalTau * np.sqrt(np.pi / 2.)
EstNoiseEnergySigma = totalTau * np.sqrt((4 - np.pi) / 2.)
# Noise threshold, must be >= epsilon
T = np.maximum(EstNoiseEnergyMean + k * EstNoiseEnergySigma,
epsilon)
#print(totalEnergy,'!!!!!!!!!\n')
phaseSym = np.maximum(totalEnergy - T, 0)
#print(phaseSym,'||||||||||||\n')
phaseSym /= sumAn + epsilon
#print(type(f), f.shape, f)
#print(type(hAmp2), hAmp2.shape, hAmp2)
LP = (1 - np.arctan2(np.sqrt(hAmp2),f))
FS = phaseSym #????????????
LE = (hAmp2 + f*f)
return LP, FS, LE #, totalEnergy, T
def bone_prob_map(img, minwl = 10):
ibs = normalise(integrated_backscatter_energy(img))
LP,FS,LE = analyticEstimator(normalise(img) ** 4, minWaveLength = minwl)
final = normalise( normalise(LP) * normalise(FS) * (1-ibs))
meanFinal = (final*(final > 0)).mean()
final = final * (final > 1.5*meanFinal)
return final
"""
Assume that this class generates pairs of adjacents frames (not necessarily consecutive,
depending on 'sample_rate' variable) of US video sequences
(with similar visual qualities, due to them being from the same video as well as same jittering applied to both........)
"""
print("Check 1")
args = utils.ConfigDict({})
args.metric = 'mse'
def get_config():
config = utils.ConfigDict({})
#set by default as 10
config.image_channels = 10
#set by default as 10
config.k = 10
config.htmaplam = 0.1
return config
def _get_model_orig(config):
feature_encoder = transporter_orig.FeatureEncoder(config.image_channels)
pose_regressor = transporter_orig.PoseRegressor(config.image_channels, config.k)
refine_net = transporter_orig.RefineNet(config.image_channels)
return transporter_orig.Transporter(feature_encoder, pose_regressor, refine_net, std = config.htmaplam)
def _get_model(config):
feature_encoder = transporter.FeatureEncoder(config.image_channels)
pose_regressor = transporter.PoseRegressor(config.image_channels, config.k)
refine_net = transporter.RefineNet(config.image_channels)
return transporter.Transporter(feature_encoder, pose_regressor, refine_net, std = config.htmaplam)
def _get_data_loader(config):
transform = transforms.ToTensor()
dataset = Dataset(config.dataset_root, transform=transform)
sampler = Sampler(dataset)
loader = torch.utils.data.DataLoader(
dataset, batch_size=config.batch_size, sampler=sampler, pin_memory=True, num_workers=4)
return loader
class VQVAEPerceptualLoss(torch.nn.Module):
def __init__(self, vqvae_path = 'VQVAE_unnorm_trained.pth'):
super(VQVAEPerceptualLoss, self).__init__()
encoder = torch.load(vqvae_path)._encoder
encoder.eval()
blocks = []
encoder._residual_stack._layers[0]._block[0].inplace = False
encoder._residual_stack._layers[0]._block[2].inplace = False
encoder._residual_stack._layers[1]._block[0].inplace = False
encoder._residual_stack._layers[1]._block[2].inplace = False
for module_name, module in encoder.named_modules():
if module_name == '':
continue
if 'residual_stack' in module_name and 'block.' not in module_name:
continue
blocks.append(module)
#for bl in blocks:
# bl.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
def forward(self, input, target): #(batch_size, 10, 256, 256)
x = input
x_ = target
loss = 0.0
x = x.view(-1,256,256).unsqueeze(1)
x_ = x_.view(-1,256,256).unsqueeze(1)
for block in self.blocks:
x = block(x)
x_ = block(x_)
loss = loss + F.mse_loss(x, x_)
return loss
class plTransporter(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = _get_model(config)
self.model.train()
if args.metric == 'mse':
self.metric = torch.nn.MSELoss()
elif args.metric == 'perc':
self.metric = VQVAEPerceptualLoss(args.vq_path)
print("Initial Hlam weights are:",self.model.hlam_weights)
def forward(self, x1, x2):
return self.model(x1, x2)
class plTransporter_orig(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = _get_model_orig(config)
self.model.train()
if args.metric == 'mse':
self.metric = torch.nn.MSELoss()
elif args.metric == 'perc':
self.metric = VQVAEPerceptualLoss(args.vq_path)
def forward(self, x1, x2):
return self.model(x1, x2)