Skip to content

Commit

Permalink
fix model loading with flash-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
chenkenbio committed Jan 5, 2024
1 parent 9bb9120 commit a461ef0
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ SpliceBERT is implemented with [Huggingface](https://huggingface.co/docs/transfo
- Install PyTorch: https://pytorch.org/get-started/locally/
- Install Huggingface transformers: https://huggingface.co/docs/transformers/installation
- Install FlashAttention (optional): https://github.com/Dao-AILab/flash-attention
- FlashAttention v2 does not support Turing GPUs, please use FlashAttention v1 instead.

SpliceBERT can be easily used for a series of downstream tasks through the official API.
See [official guide](https://huggingface.co/docs/transformers/model_doc/bert) for more details.
Expand All @@ -39,7 +38,7 @@ We recommend running SpliceBERT on a Linux system with a NVIDIA GPU of at least
**Examples**
We provide a demo script to show how to use SpliceBERT though the official API of Huggingface transformers in the first part of the following code block.
Users can also use SpliceBERT with FlashAttention by replacing the official API with the custom API, as shown in the second part of the following code block.
**Note that flash-attention requires automatic mixed precision (amp) mode to be enabled and currently it does not support `attention_mask`**
**Note that flash-attention requires automatic mixed precision (amp) mode to be enabled and currently it does not support [`attention_mask`](https://huggingface.co/docs/transformers/glossary#attention-mask)**

Use SpliceBERT though the official API of Huggingface transformers:
```python
Expand All @@ -53,7 +52,7 @@ tokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)
# prepare input sequence
seq = "ACGUACGuacguaCGu" ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in length
seq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespace
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. warning: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = torch.as_tensor(input_ids) # convert python list to Tensor
input_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)

Expand Down Expand Up @@ -88,29 +87,29 @@ tokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)
# prepare input sequence
seq = "ACGUACGuacguaCGu" ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in length
seq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespace
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. warning: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = torch.as_tensor(input_ids) # convert python list to Tensor
input_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)

# Or use custom BertModel with FlashAttention
# get nucleotide embeddings (hidden states)
model = AutoModel.from_pretrained(SPLICEBERT_PATH) # load model
model = BertModel.from_pretrained(SPLICEBERT_PATH) # load model
with autocast():
last_hidden_state = model(input_ids).last_hidden_state # get hidden states from last layer
hiddens_states = model(input_ids, output_hidden_states=True).hidden_states # hidden states from the embedding layer (nn.Embedding) and the 6 transformer encoder layers

# get nucleotide type logits in masked language modeling
model = AutoModelForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load model
model = BertForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load model
with autocast():
logits = model(input_ids).logits # shape: (batch_size, sequence_length, vocab_size)

# finetuning SpliceBERT for token classification tasks
with autocast():
model = AutoModelForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)
model = BertForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)

# finetuning SpliceBERT for sequence classification tasks
with autocast():
model = AutoModelForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)
model = BertForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)
```


Expand Down

0 comments on commit a461ef0

Please sign in to comment.