Skip to content

Commit

Permalink
Bloods features added to training
Browse files Browse the repository at this point in the history
  • Loading branch information
ricshaw committed Jul 21, 2020
1 parent 91a79a5 commit e513d7f
Show file tree
Hide file tree
Showing 9 changed files with 3,403 additions and 6,009 deletions.
3,475 changes: 0 additions & 3,475 deletions KCH_CXR.csv

This file was deleted.

4,568 changes: 2,117 additions & 2,451 deletions KCH_CXR_JPG.csv

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions confusion_matrix_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, precision_recall_curve, auc
import seaborn as sn
import os
import matplotlib.pyplot as plt
import matplotlib
#matplotlib.use('TkAgg')

outputs_saved = True
plot_AUCs = False

num_classes = 2
if outputs_saved:
#experiment = 'preds-efficientnet-b3-bs32-512-tta-ranger.csv'
experiment = 'ensemble.csv'
experiments_output_file = os.path.join('/nfs/home/richard', experiment)
experiments_output = pd.read_csv(experiments_output_file)

# Separate network outputs from labels
filenames = experiments_output.Filename
df_labels = np.stack((experiments_output.Died, 1-experiments_output.Died), axis=1)
df_preds = np.stack((experiments_output.Pred, 1-experiments_output.Pred), axis=1)
print(df_labels.shape, df_preds.shape)
print(df_labels)
print(df_preds)

# Confusion matrix calcs
# Convert OHE Labels, outputs
standard_labels = []
standard_preds = []

for i in range(len(df_labels)):
standard_labels.append(np.argmax(df_labels[i]))
standard_preds.append(np.argmax(df_preds[i]))

# AUC calcs
# Convert OHE Labels, outputs
labels = []
preds = []

for i in range(len(df_labels)):
labels.append(df_labels[i])
preds.append(df_preds[i])

labels = np.array(labels)
preds = np.array(preds)

# Calculate confusion matrix
class_names = ['Died', 'Survived']
#class_names = ['48H', '1 week -', '1 week +', 'Survived', 'micro']
conf_mat = confusion_matrix(standard_labels, standard_preds)
l = len(conf_mat[0])
diags = [conf_mat[i][i] for i in range(l)]
conf_mat = conf_mat / np.array(diags)[None, ...].T
index = [i+'_lab' for i in class_names]
columns = [i for i in class_names]
print(index)
print(columns)
df_cm = pd.DataFrame(conf_mat, index=index, columns=columns)
plt.figure(figsize=(10, 7))
plt.title('Confusion matrix for experiment: ' + experiment)
ax = sn.heatmap(df_cm, annot=True)
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)
plt.savefig('confusion.png', dpi=300)
exit(0)

# Also plot AUCs
if plot_AUCs:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for classID in range(num_classes):
fpr[classID], tpr[classID], _ = roc_curve(labels[:, classID], preds[:, classID])
roc_auc[classID] = auc(fpr[classID], tpr[classID])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), preds.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute PR curve and PR area for each class
precision_tot = dict()
recall_tot = dict()
pr_auc = dict()
for classID in range(num_classes):
precision_tot[classID], recall_tot[classID], _ = precision_recall_curve(labels[:, classID],
preds[:, classID])
pr_auc[classID] = auc(recall_tot[classID], precision_tot[classID])

# Compute micro-average precision-recall curve and PR area
precision_tot["micro"], recall_tot["micro"], _ = precision_recall_curve(labels.ravel(), preds.ravel())
pr_auc["micro"] = auc(recall_tot["micro"], precision_tot["micro"])
no_skill = len(labels[labels == 1]) / len(labels)

colors = ['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'red']
# Plot ROC-AUC for different classes:
plt.figure()
plt.axis('square')
for classID, key in enumerate(fpr.keys()):
lw = 2
plt.plot(fpr[key], tpr[key], color=colors[classID], # 'darkorange',
lw=lw, label=f'{class_names[classID]} ROC curve (area = {roc_auc[key]: .2f})')
plt.title(f'Class ROC-AUC for ALL classes: {somedir}', fontsize=18)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.legend(loc="lower right")

plt.figure()
plt.axis('square')
for classID, key in enumerate(precision_tot.keys()):
lw = 2
plt.plot(recall_tot[key], precision_tot[key], color=colors[classID], # color='darkblue',
lw=lw, label=f'{class_names[classID]} PR curve (area = {pr_auc[key]: .2f})')
plt.title(f'Class PR-AUC for ALL classes: {somedir}', fontsize=18)
# plt.plot([0, 1], [0, 0], lw=lw, linestyle='--', label='No Skill')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=16)
plt.ylabel('Precision', fontsize=16)
plt.legend(loc="lower right")

else:
# Need to load model and get outputs
print('Not supported right now')
plt.show()
146 changes: 100 additions & 46 deletions dicom_to_jpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,100 @@

print(pd.__version__)

labels = pd.read_csv('KCH_CXR_labels.csv')
labels = labels.sort_values('AccessionID')
#labels = pd.read_csv('KCH_CXR_labels.csv')
labels = pd.read_csv('cxr_news2_pseudonymised.csv')
#labels = labels.sort_values('AccessionID')
print('Labels', labels.shape)

labels = labels.drop_duplicates(subset=['AccessionID'], ignore_index=True)
print('Unique patients', labels.shape)
#labels = labels.drop_duplicates(subset=['Accession'], ignore_index=True)
#print('Unique Accession', labels.shape)
#exit(0)

PATH = '/nfs/project/covid/CXR/KCH_CXR'
save_path = '/nfs/project/richard/COVID/KCH_CXR_JPG2'

#csv = [f for f in Path(PATH).rglob('*.dcm')]
#print('Dicoms', len(csv))
#exit(0)
save_path = '/nfs/project/covid/CXR/KCH_CXR_JPG'

csv = [f for f in Path(PATH).rglob('*.dcm')]
print('Dicoms', len(csv))
Accession = []
CXR_datetime = []
Pixel_data = []
acc_count = 0
pix_count = 0
combined_count = 0
for f in csv:
a = False
p = False
print (f)
ds = pydicom.dcmread(f)
acc = ds.AccessionNumber
if labels.Accession.str.contains(acc).any():
acc_count += 1
a = True
Accession.append(acc)
print(acc)
if "AcquisitionDateTime" in ds:
datetime = ds.AcquisitionDateTime.split('.')[0]
elif "ContributionDateTime" in ds:
datetime = ds.ContributionDateTime.split('.')[0]
name = acc
fname = str(name) + '_' + str(datetime)
datetime = pd.to_datetime(datetime, format='%Y%m%d%H%M%S')
CXR_datetime.append(datetime)
try:
img = ds.pixel_array.astype(np.float32)
img -= img.min()
img /= img.max()
img = np.uint8(255.0*img)
print(img.shape)
save_name = os.path.join(save_path, (fname + '.jpg'))
print(save_name)
cv2.imwrite(save_name, img)
Pixel_data.append('Y')
pix_count += 1
p = True
except:
Pixel_data.append('N')
if a and p:
combined_count += 1
print('Accession', acc_count)
print('Pixel data', pix_count)
print('Combined', combined_count)
#df = pd.DataFrame({'Accession':Accession, 'CXR_datetime':CXR_datetime, 'Pixel_data':Pixel_data})
#print(df)
#df.to_csv('all_data.csv', index=False)
exit(0)

files = 0
count =0
for i, name in enumerate(labels.AccessionID):
for i, name in enumerate(labels.Accession):
#print(i, name)
img_dir = os.path.join(PATH, name)
tmp = os.path.exists(img_dir)
if not tmp:
print('Cant find', name, tmp)
print('Cant find', img_dir)
else:
print('Found', img_dir)
count += 1
csv = [f for f in Path(img_dir).rglob('*.dcm')]
files += len(csv)
print('Matching files', count)
print('Matching dicoms', files)
#exit(0)

exit(0)

PatientID = []
Filename = []
AccessionID = []
Examination_Title = []
Age = []
Gender = []
SymptomOnset_DTM = []
Death_DTM = []
Died = []
Accession = []
#Examination_Title = []
#Age = []
#Gender = []
#SymptomOnset_DTM = []
#Death_DTM = []
#Died = []
count=0

for i, name in enumerate(labels.AccessionID):
print(i, name)
for i, name in enumerate(labels.Accession):
pid = labels.patient_pseudo_id[i]
print(i, name, pid)
img_dir = os.path.join(PATH, name)
files = [f for f in Path(img_dir).rglob('*.dcm')]
#print(files)
Expand All @@ -63,6 +114,7 @@
#print(os.fspath(f.absolute()))
ds = pydicom.dcmread(f)
#print(ds)
#exit(0)

if "AcquisitionDateTime" in ds:
datetime = ds.AcquisitionDateTime.split('.')[0]
Expand All @@ -75,13 +127,13 @@
#print(year, month, day, time)
fname = name + '_' + datetime

acc = labels.AccessionID[i]
ext = labels.Examination_Title[i]
age = labels.Age[i]
gen = labels.Gender[i]
sym = labels.SymptomOnset_DTM[i]
dtm = labels.Death_DTM[i]
ddd = labels.Died[i]
acc = labels.Accession[i]
#ext = labels.Examination_Title[i]
#age = labels.Age[i]
#gen = labels.Gender[i]
#sym = labels.SymptomOnset_DTM[i]
#dtm = labels.Death_DTM[i]
#ddd = labels.Died[i]

try:
img = ds.pixel_array.astype(np.float32)
Expand All @@ -95,32 +147,34 @@
print(save_name)
cv2.imwrite(save_name, img)

PatientID.append(pid)
Filename.append(fname)
AccessionID.append(acc)
Examination_Title.append(ext)
Age.append(age)
Gender.append(gen)
SymptomOnset_DTM.append(sym)
Death_DTM.append(dtm)
Died.append(ddd)
print(len(Filename), len(AccessionID), len(Died))
Accession.append(acc)
#Examination_Title.append(ext)
#Age.append(age)
#Gender.append(gen)
#SymptomOnset_DTM.append(sym)
#Death_DTM.append(dtm)
#Died.append(ddd)
#print(len(Filename), len(AccessionID), len(Died))
except:
print('Cannot load image')

print('Total', count)

print(len(Filename), len(AccessionID), len(Died))
#print(len(Filename), len(AccessionID), len(Died))

df = pd.DataFrame({'Filename':Filename,
'AccessionID':AccessionID,
'Examination_Title':Examination_Title,
'Age':Age,
'Gender':Gender,
'SymptomOnset_DTM':SymptomOnset_DTM,
'Death_DTM':Death_DTM,
'Died':Died
'PatientID':PatientID,
'Accession':Accession,
#'Examination_Title':Examination_Title,
#'Age':Age,
#'Gender':Gender,
#'SymptomOnset_DTM':SymptomOnset_DTM,
#'Death_DTM':Death_DTM,
#'Died':Died
})
df.to_csv('KCH_CXR_JPG2.csv')
df.to_csv('KCH_CXR_JPG.csv')

exit(0)

Expand Down
Loading

0 comments on commit e513d7f

Please sign in to comment.