-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
105 lines (87 loc) · 3.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Conv2D, LeakyReLU
from tensorflow.keras.layers import Add, ZeroPadding2D
class InstanceNorm(Layer):
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def build(self, input_shape):
self.beta = tf.Variable(tf.zeros([input_shape[3]]))
self.gamma = tf.Variable(tf.ones([input_shape[3]]))
def call(self, inputs):
mean, var = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
x = tf.divide(tf.subtract(inputs, mean), tf.sqrt(tf.add(var, self.epsilon)))
return self.gamma * x + self.beta
class ConvBlock(Layer):
def __init__(self, num_filters):
super().__init__()
self.num_filters = num_filters
self.initializer = tf.random_normal_initializer(0.0, 0.02)
self.conv_2d = Conv2D(
filters=self.num_filters,
kernel_size=3,
strides=1,
padding="valid",
use_bias=True,
kernel_initializer=self.initializer
)
self.instance_norm = InstanceNorm()
self.leaky_relu = LeakyReLU(alpha=0.2)
def call(self, x):
x = self.conv_2d(x)
x = self.instance_norm(x)
x = self.leaky_relu(x)
return x
class Generator(Model):
def __init__(self, num_filters, name="Generator"):
super().__init__()
self.initializer = tf.random_normal_initializer(0.0, 0.02)
self.padding = ZeroPadding2D(5)
self.head = ConvBlock(num_filters)
self.conv_block1 = ConvBlock(num_filters)
self.conv_block2 = ConvBlock(num_filters)
self.conv_block3 = ConvBlock(num_filters)
self.tail = Conv2D(
filters=3,
kernel_size=3,
strides=1,
padding="valid",
activation="tanh",
kernel_initializer=self.initializer
)
def call(self, prev, noise):
prev_pad = self.padding(prev)
noise_pad = self.padding(noise)
x = Add()([prev_pad, noise_pad])
x = self.head(x)
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.tail(x)
x = Add()([x, prev])
return x
class Discriminator(Model):
def __init__(self, num_filters, name="Discriminator"):
super().__init__()
self.initializer = tf.random_normal_initializer(0.0, 0.02)
self.head = ConvBlock(num_filters)
self.conv_block1 = ConvBlock(num_filters)
self.conv_block2 = ConvBlock(num_filters)
self.conv_block3 = ConvBlock(num_filters)
self.tail = Conv2D(
filters=1,
kernel_size=3,
strides=1,
padding="valid",
kernel_initializer=self.initializer
)
def call(self, x):
x = self.head(x)
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.tail(x)
return x