-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclustering.py
28 lines (20 loc) · 915 Bytes
/
clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import BertModel, AutoTokenizer, AutoModelForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
#model = BertModel.from_pretrained("data/japanese_med/ja_med_bert/checkpoint-10")
model = AutoModelForMaskedLM.from_pretrained("data/japanese_med/ja_med_bert/checkpoint-10")
def get_embedding(model, tokenizer, text):
inputs = tokenizer(text, return_tensors='pt')
# Get embeddings
with torch.no_grad():
outputs = model(**inputs)
# Get the last hidden states
last_hidden_states = outputs.logits
# Compute sentence embedding by averaging over token embeddings
return last_hidden_states.mean(dim=1).squeeze()
concepts = []
embeddings = []
for line in open('data/japanese_med/concepts.tsv', 'r'):
concept = line.strip()
concepts.append(concept)
embeddings.append(get_embedding(model, tokenizer, concept))