-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchexbert.py
243 lines (204 loc) · 10.7 KB
/
chexbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""
script adapted from https://github.com/jbdel/vilmedic/blob/main/vilmedic/blocks/scorers/CheXbert/chexbert.py
"""
import torch
import os
import logging
import torch.nn as nn
import pandas as pd
from collections import OrderedDict
from transformers import BertTokenizer
from transformers import BertModel, AutoModel, AutoConfig
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics._classification import _check_targets
from huggingface_hub import hf_hub_download
import numpy as np
from sklearn.utils.sparsefuncs import count_nonzero
def generate_attention_masks(batch, source_lengths):
"""Generate masks for padded batches to avoid self-attention over pad tokens
@param batch (Tensor): tensor of token indices of shape (batch_size, max_len)
where max_len is length of longest sequence in the batch
@param source_lengths (List[Int]): List of actual lengths for each of the
sequences in the batch
@param device (torch.device): device on which data should be
@returns masks (Tensor): Tensor of masks of shape (batch_size, max_len)
"""
masks = torch.ones(batch.size(0), batch.size(1), dtype=torch.float)
for idx, src_len in enumerate(source_lengths):
masks[idx, src_len:] = 0
return masks.cuda()
class bert_labeler(nn.Module):
def __init__(self, p=0.1, clinical=False, freeze_embeddings=False, pretrain_path=None, inference=False, **kwargs):
""" Init the labeler module
@param p (float): p to use for dropout in the linear heads, 0.1 by default is consistant with
transformers.BertForSequenceClassification
@param clinical (boolean): True if Bio_Clinical BERT desired, False otherwise. Ignored if
pretrain_path is not None
@param freeze_embeddings (boolean): true to freeze bert embeddings during training
@param pretrain_path (string): path to load checkpoint from
"""
super(bert_labeler, self).__init__()
if pretrain_path is not None:
self.bert = BertModel.from_pretrained(pretrain_path)
elif clinical:
self.bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
elif inference:
config = AutoConfig.from_pretrained('bert-base-uncased')
self.bert = AutoModel.from_config(config)
else:
self.bert = BertModel.from_pretrained('bert-base-uncased')
if freeze_embeddings:
for param in self.bert.embeddings.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(p)
# size of the output of transformer's last layer
hidden_size = self.bert.pooler.dense.in_features
# classes: present, absent, unknown, blank for 12 conditions + support devices
self.linear_heads = nn.ModuleList([nn.Linear(hidden_size, 4, bias=True) for _ in range(13)])
# classes: yes, no for the 'no finding' observation
self.linear_heads.append(nn.Linear(hidden_size, 2, bias=True))
def forward(self, source_padded, attention_mask):
""" Forward pass of the labeler
@param source_padded (torch.LongTensor): Tensor of word indices with padding, shape (batch_size, max_len)
@param attention_mask (torch.Tensor): Mask to avoid attention on padding tokens, shape (batch_size, max_len)
@returns out (List[torch.Tensor])): A list of size 14 containing tensors. The first 13 have shape
(batch_size, 4) and the last has shape (batch_size, 2)
"""
# shape (batch_size, max_len, hidden_size)
final_hidden = self.bert(source_padded, attention_mask=attention_mask)[0]
# shape (batch_size, hidden_size)
cls_hidden = final_hidden[:, 0, :].squeeze(dim=1)
cls_hidden = self.dropout(cls_hidden)
out = []
for i in range(14):
out.append(self.linear_heads[i](cls_hidden))
return out
def tokenize(impressions, tokenizer):
imp = impressions.str.strip()
imp = imp.replace('\n', ' ', regex=True)
imp = imp.replace('\s+', ' ', regex=True)
impressions = imp.str.strip()
new_impressions = []
for i in (range(impressions.shape[0])):
tokenized_imp = tokenizer.tokenize(impressions.iloc[i])
if tokenized_imp: # not an empty report
res = tokenizer.encode_plus(tokenized_imp)['input_ids']
if len(res) > 512: # length exceeds maximum size
# print("report length bigger than 512")
res = res[:511] + [tokenizer.sep_token_id]
new_impressions.append(res)
else: # an empty report
new_impressions.append([tokenizer.cls_token_id, tokenizer.sep_token_id])
return new_impressions
class CheXbert(nn.Module):
def __init__(self, refs_filename=None, hyps_filename=None, **kwargs):
super(CheXbert, self).__init__()
self.refs_filename = refs_filename
self.hyps_filename = hyps_filename
# Model and tok
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.model = bert_labeler(inference=True)
# Downloading pretrain model from huggingface
# Load model
state_dict = torch.load(hf_hub_download(repo_id='StanfordAIMI/RRG_scorers', filename="chexbert.pth"))['model_state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '') # remove `module.`
new_state_dict[name] = v
# Load params
self.model.load_state_dict(new_state_dict, strict=False)
self.model = self.model.cuda().eval()
for name, param in self.model.named_parameters():
param.requires_grad = False
# Defining classes
self.target_names = [
"Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema",
"Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other",
"Fracture", "Support Devices", "No Finding"]
self.target_names_5 = ["Cardiomegaly", "Edema", "Consolidation", "Atelectasis", "Pleural Effusion"]
self.target_names_5_index = np.where(np.isin(self.target_names, self.target_names_5))[0]
def get_label(self, report, mode="rrg"):
impressions = pd.Series([report])
out = tokenize(impressions, self.tokenizer)
batch = torch.LongTensor([o for o in out])
src_len = [b.shape[0] for b in batch]
attn_mask = generate_attention_masks(batch, src_len)
out = self.model(batch.cuda(), attn_mask)
out = [out[j].argmax(dim=1).item() for j in range(len(out))]
v = []
if mode == "rrg":
for c in out:
if c == 0:
v.append('')
if c == 3:
v.append(1)
if c == 2:
v.append(0)
if c == 1:
v.append(1)
v = [1 if (isinstance(l, int) and l > 0) else 0 for l in v]
elif mode == "classification":
# https://github.com/stanfordmlgroup/CheXbert/blob/master/src/label.py#L124
for c in out:
if c == 0:
v.append('')
if c == 3:
v.append(-1)
if c == 2:
v.append(0)
if c == 1:
v.append(1)
else:
raise NotImplementedError(mode)
return v
def forward(self, hyps, refs):
if self.refs_filename is None:
refs_chexbert = [self.get_label(l.strip()) for l in refs]
else:
if os.path.exists(self.refs_filename):
refs_chexbert = [eval(l.strip()) for l in open(self.refs_filename).readlines()]
else:
refs_chexbert = [self.get_label(l.strip()) for l in refs]
open(self.refs_filename, 'w').write('\n'.join(map(str, refs_chexbert)))
hyps_chexbert = [self.get_label(l.strip()) for l in hyps]
if self.hyps_filename is not None:
open(self.hyps_filename, 'w').write('\n'.join(map(str, hyps_chexbert)))
refs_chexbert_5 = [np.array(r)[self.target_names_5_index] for r in refs_chexbert]
hyps_chexbert_5 = [np.array(h)[self.target_names_5_index] for h in hyps_chexbert]
# Accuracy
accuracy = accuracy_score(y_true=refs_chexbert_5, y_pred=hyps_chexbert_5)
# Per element accuracy
y_type, y_true, y_pred = _check_targets(refs_chexbert_5, hyps_chexbert_5)
differing_labels = count_nonzero(y_true - y_pred, axis=1)
pe_accuracy = (differing_labels == 0).astype(np.float32)
cr = classification_report(refs_chexbert, hyps_chexbert, target_names=self.target_names, output_dict=True)
cr_5 = classification_report(refs_chexbert_5, hyps_chexbert_5, target_names=self.target_names_5,
output_dict=True)
return cr['micro avg']['f1-score'], accuracy, pe_accuracy, cr, cr_5
def train(self, mode: bool = True):
mode = False # force False
self.training = mode
for module in self.children():
module.train(mode)
return self
if __name__ == '__main__':
import json
import time
m = CheXbert()
t = time.time()
accuracy, accuracy_not_averaged, class_report, class_report_5 = m(hyps=['No pleural effusion. Normal heart size.',
'Normal heart size.',
'Increased mild pulmonary edema and left basal atelectasis.',
'Bilateral lower lobe bronchiectasis with improved right lower medial lung peribronchial consolidation.',
'Elevated left hemidiaphragm and blunting of the left costophrenic angle although no definite evidence of pleural effusion seen on the lateral view.',
],
refs=['No pleural effusions.',
'Enlarged heart.',
'No evidence of pneumonia. Stable cardiomegaly.',
'Bilateral lower lobe bronchiectasis with improved right lower medial lung peribronchial consolidation.',
'No acute cardiopulmonary process. No significant interval change. Please note that peribronchovascular ground-glass opacities at the left greater than right lung bases seen on the prior chest CT of ___ were not appreciated on prior chest radiography on the same date and may still be present. Additionally, several pulmonary nodules measuring up to 3 mm are not not well appreciated on the current study-CT is more sensitive.'
])
print(json.dumps(class_report, indent=4))
print(f"Micro_avg_f1_score_14classes: {class_report['micro avg']['f1-score']}")
# print(one)
# print(two)