forked from nomi-sethi/Wise-Translator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEncoder.py
27 lines (23 loc) · 1.08 KB
/
Encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Import useful libraries
import torch.nn as nn
from FeedForward import FeedForward
from MultiHeadAttention import MultiHeadAttention
from Normalization import Normalization
'''Define EncodingLayer of Transformer and inherit it from nn.module which
contain the implementation of ecoding layer'''
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super().__init__()
# Lets define encoder with two Normalize layers, 1 attention and one feedforward
self.normalization_1 = Normalization(d_model)
self.normalization_2 = Normalization(d_model)
self.attention = MultiHeadAttention(heads, d_model, dropout=dropout)
self.feedforward = FeedForward(d_model, dropout=dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
x2 = self.normalization_1(x)
x = x + self.dropout_1(self.attention(x2, x2, x2, mask))
x2 = self.normalization_2(x)
x = x + self.dropout_2(self.feedforward(x2))
return x