-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfft_convolution_effect.cpp
274 lines (226 loc) · 9.47 KB
/
fft_convolution_effect.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#include <epoxy/gl.h>
#include <string.h>
#include "complex_modulate_effect.h"
#include "effect_chain.h"
#include "fft_convolution_effect.h"
#include "fft_input.h"
#include "fft_pass_effect.h"
#include "multiply_effect.h"
#include "padding_effect.h"
#include "slice_effect.h"
#include "util.h"
using namespace std;
namespace movit {
FFTConvolutionEffect::FFTConvolutionEffect(int input_width, int input_height, int convolve_width, int convolve_height)
: input_width(input_width),
input_height(input_height),
convolve_width(convolve_width),
convolve_height(convolve_height),
fft_input(new FFTInput(convolve_width, convolve_height)),
crop_effect(new PaddingEffect()),
owns_effects(true) {
CHECK(crop_effect->set_int("width", input_width));
CHECK(crop_effect->set_int("height", input_height));
CHECK(crop_effect->set_float("top", 0));
CHECK(crop_effect->set_float("left", 0));
}
FFTConvolutionEffect::~FFTConvolutionEffect()
{
if (owns_effects) {
delete fft_input;
delete crop_effect;
}
}
namespace {
// Returns the last Effect in the new chain.
Effect *add_overlap_and_fft(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
{
// Overlap.
{
Effect *overlap_effect = chain->add_effect(new SliceEffect(), last_effect);
CHECK(overlap_effect->set_int("input_slice_size", fft_size - pad_size));
CHECK(overlap_effect->set_int("output_slice_size", fft_size));
CHECK(overlap_effect->set_int("offset", -pad_size));
if (direction == FFTPassEffect::HORIZONTAL) {
CHECK(overlap_effect->set_int("direction", SliceEffect::HORIZONTAL));
} else {
assert(direction == FFTPassEffect::VERTICAL);
CHECK(overlap_effect->set_int("direction", SliceEffect::VERTICAL));
}
last_effect = overlap_effect;
}
// FFT.
int num_passes = ffs(fft_size) - 1;
for (int i = 1; i <= num_passes; ++i) {
Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
CHECK(fft_effect->set_int("pass_number", i));
CHECK(fft_effect->set_int("fft_size", fft_size));
CHECK(fft_effect->set_int("direction", direction));
CHECK(fft_effect->set_int("inverse", 0));
last_effect = fft_effect;
}
return last_effect;
}
// Returns the last Effect in the new chain.
Effect *add_ifft_and_discard(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
{
// IFFT.
int num_passes = ffs(fft_size) - 1;
for (int i = 1; i <= num_passes; ++i) {
Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
CHECK(fft_effect->set_int("pass_number", i));
CHECK(fft_effect->set_int("fft_size", fft_size));
CHECK(fft_effect->set_int("direction", direction));
CHECK(fft_effect->set_int("inverse", 1));
last_effect = fft_effect;
}
// Discard.
{
Effect *discard_effect = chain->add_effect(new SliceEffect(), last_effect);
CHECK(discard_effect->set_int("input_slice_size", fft_size));
CHECK(discard_effect->set_int("output_slice_size", fft_size - pad_size));
if (direction == FFTPassEffect::HORIZONTAL) {
CHECK(discard_effect->set_int("direction", SliceEffect::HORIZONTAL));
} else {
assert(direction == FFTPassEffect::VERTICAL);
CHECK(discard_effect->set_int("direction", SliceEffect::VERTICAL));
}
CHECK(discard_effect->set_int("offset", pad_size));
last_effect = discard_effect;
}
return last_effect;
}
} // namespace
void FFTConvolutionEffect::rewrite_graph(EffectChain *chain, Node *self)
{
int pad_width = convolve_width - 1;
int pad_height = convolve_height - 1;
// Try all possible FFT widths and heights to see which one is the
// cheapest. As a proxy for real performance, we use number of texel
// fetches; this isn't perfect by any means, but it's easy to work with
// and should be approximately correct.
int min_x = next_power_of_two(1 + pad_width);
int min_y = next_power_of_two(1 + pad_height);
int max_y = next_power_of_two(input_height + pad_width);
int max_x = next_power_of_two(input_width + pad_height);
size_t best_cost = numeric_limits<size_t>::max();
int best_x = -1, best_y = -1, best_x_before_y_fft = -1, best_x_before_y_ifft = -1;
// Try both
//
// overlap(X), FFT(X), overlap(Y), FFT(Y), modulate, IFFT(Y), discard(Y), IFFT(X), discard(X) and
// overlap(Y), FFT(Y), overlap(X), FFT(X), modulate, IFFT(X), discard(X), IFFT(Y), discard(Y)
//
// For simplicity, call them the XY-YX and YX-XY orders. In theory, we
// could have XY-XY and YX-YX orders as well, and I haven't found a
// convincing argument that they will never be optimal (although it
// sounds odd and should be rare), so we test all four possible ones.
//
// We assume that the kernel FFT is for free, since it is typically done
// only once and per frame.
for (int x_before_y_fft = 0; x_before_y_fft <= 1; ++x_before_y_fft) {
for (int x_before_y_ifft = 0; x_before_y_ifft <= 1; ++x_before_y_ifft) {
for (int y = min_y; y <= max_y; y *= 2) {
int y_pixels_per_block = y - pad_height;
int num_vertical_blocks = div_round_up(input_height, y_pixels_per_block);
size_t output_height = y * num_vertical_blocks;
for (int x = min_x; x <= max_x; x *= 2) {
int x_pixels_per_block = x - pad_width;
int num_horizontal_blocks = div_round_up(input_width, x_pixels_per_block);
size_t output_width = x * num_horizontal_blocks;
size_t cost = 0;
if (x_before_y_fft) {
// First, the cost of the horizontal padding.
cost = output_width * input_height;
// log(X) FFT passes. Each pass reads two inputs per pixel,
// plus the support texture.
cost += (ffs(x) - 1) * 3 * output_width * input_height;
// Now, horizontal padding.
cost += output_width * output_height;
// log(Y) FFT passes, now at full resolution.
cost += (ffs(y) - 1) * 3 * output_width * output_height;
} else {
// First, the cost of the vertical padding.
cost = input_width * output_height;
// log(Y) FFT passes. Each pass reads two inputs per pixel,
// plus the support texture.
cost += (ffs(y) - 1) * 3 * input_width * output_height;
// Now, horizontal padding.
cost += output_width * output_height;
// log(X) FFT passes, now at full resolution.
cost += (ffs(x) - 1) * 3 * output_width * output_height;
}
// The actual modulation. Reads one pixel each from two textures.
cost += 2 * output_width * output_height;
if (x_before_y_ifft) {
// log(X) IFFT passes.
cost += (ffs(x) - 1) * 3 * output_width * output_height;
// Discard horizontally.
cost += input_width * output_height;
// log(Y) IFFT passes.
cost += (ffs(y) - 1) * 3 * input_width * output_height;
// Discard horizontally.
cost += input_width * input_height;
} else {
// log(Y) IFFT passes.
cost += (ffs(y) - 1) * 3 * output_width * output_height;
// Discard vertically.
cost += output_width * input_height;
// log(X) IFFT passes.
cost += (ffs(x) - 1) * 3 * output_width * input_height;
// Discard horizontally.
cost += input_width * input_height;
}
if (cost < best_cost) {
best_x = x;
best_y = y;
best_x_before_y_fft = x_before_y_fft;
best_x_before_y_ifft = x_before_y_ifft;
best_cost = cost;
}
}
}
}
}
const int fft_width = best_x, fft_height = best_y;
assert(self->incoming_links.size() == 1);
Node *last_node = self->incoming_links[0];
self->incoming_links.clear();
last_node->outgoing_links.clear();
// Do FFT.
Effect *last_effect = last_node->effect;
if (best_x_before_y_fft) {
last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
} else {
last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
}
// Normalizer.
Effect *multiply_effect;
float fft_size = fft_width * fft_height;
float factor[4] = { 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size };
last_effect = multiply_effect = chain->add_effect(new MultiplyEffect(), last_effect);
CHECK(multiply_effect->set_vec4("factor", factor));
// Multiply by the FFT of the convolution kernel.
CHECK(fft_input->set_int("fft_width", fft_width));
CHECK(fft_input->set_int("fft_height", fft_height));
chain->add_input(fft_input);
owns_effects = false;
Effect *modulate_effect = chain->add_effect(new ComplexModulateEffect(), multiply_effect, fft_input);
CHECK(modulate_effect->set_int("num_repeats_x", div_round_up(input_width, fft_width - pad_width)));
CHECK(modulate_effect->set_int("num_repeats_y", div_round_up(input_height, fft_height - pad_height)));
last_effect = modulate_effect;
// Finally, do IFFT.
if (best_x_before_y_ifft) {
last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
} else {
last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
}
// ...and crop away any extra padding we have have added.
last_effect = chain->add_effect(crop_effect);
chain->replace_sender(self, chain->find_node_for_effect(last_effect));
self->disabled = true;
}
} // namespace movit