-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgmres_sdr.m
447 lines (368 loc) · 11.7 KB
/
gmres_sdr.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
function [x,out] = gmres_sdr(A,b,param)
% [x,out] = GMRES_SDR(A,b,param)
%
% A function which solves a linear system Ax = b using GMRES
% with sketching and deflated resarting.
% The recycling subspace is updated using (harmonic) Ritz extraction.
%
% INPUT: A matrix or function handle
% b right-hand side vector
% param input struct with the following fields
%
% param.x0 initial guess
% param.max_it maximum number of inner Arnoldi iterations
% param.max_restarts maximum number of outer restarts
% param.tol absolute target residual norm
% param.U basis of the recycling subspace
% param.SU, SAU sketched basis of recycling subspace
% param.k target dimension of recycling subspace
% param.t Arnoldi truncation parameter
% param.hS handle to subspace embedding operator
% param.s subspace embedding dimension
% param.verbose control verbosity, can be 0,1,2
%
% OUTPUT: out An output struct with the following fields
% out.residuals: Vector containing residual norms for each cycle of
% of final system
% out.U updated recycling subspace
% out.SU, SAU sketches of updated recycling subspace
% out.mv number of matrix-vector products
% out.ip number of inner-products
% out.s number of sketched vectors
% Initialize counter variables
mv = 0; % Number of matrix vector products
sv = 0; % Number of vector sketches
if isnumeric(A)
A = @(v) A*v;
end
n = size(b,1);
if ~isfield(param,'verbose')
param.verbose = 1;
end
if ~isfield(param,'U')
param.U = []; param.SU = []; param.SAU = [];
end
if ~isfield(param,'max_it')
param.max_it = 50;
end
if ~isfield(param,'max_restarts')
param.max_restarts = min(ceil(n/param.max_it), 10);
end
if ~isfield(param,'tol')
param.tol = 1e-6;
end
if ~isfield(param,'x0')
x = zeros(n,1); % Assuming initial zero guess
r = b;
else
x = param.x0;
r = b - A(x);
end
if ~isfield(param,'ssa')
param.ssa = 0; % Sketch-and-select?
end
if ~isfield(param,'t')
param.t = 2; % truncation param
end
if ~isfield(param,'k')
param.k = min(10, param.max_it);
end
if ~isfield(param,'s')
param.s = min(n, 8*(param.max_it + param.k));
end
if param.verbose
% please stop commenting these out! use verbose=0
disp([' using sketching dimension of s = ' num2str(param.s)])
end
if ~isfield(param,'reorth')
param.reorth = 0;
end
if ~isfield(param,'hS') || isempty(param.hS)
param.hS = srft(n,param.s);
end
if ~isfield(param, 'sketch_distortion')
param.sketch_distortion = 1.4; % assumed upper bound on distortion factor
end
if ~isfield(param, 'ls_solve')
param.ls_solve = 'mgs'; % how to solve the least squares problems
end
if ~isfield(param, 'svd_tol')
param.svd_tol = 1e-15; % SVD stabilization for recycling subspace
end
if ~isfield(param, 'harmonic')
param.harmonic = 1; % harmonic/standard Ritz for recycling subspace
end
if ~isfield(param, 'd')
param.d = 1; % TODO: modify inner loop to deal properly with d>1 case
end
% Vector to store the true residual norm at end of each cycle
residuals = [norm(b)];
ip = 1;
% store number of inner iterations per restart cycle
iters = 0;
% store sketched residual at every inner iteration
sres = NaN; % initial sketch not available here
for restart = 1:param.max_restarts
% Call cycle of srgmres
[ e, r, cycle_out ] = srgmres_cycle(A, r, param);
% update solution approximation
x = x + e;
% Compute and store norm of residual
residnorm = norm(r);
residuals(restart+1) = residnorm;
iters(restart+1) = cycle_out.m;
sres = [ sres , cycle_out.sres ];
%fprintf(' ||r|| = %5.2e\n',residnorm)
% Increment counter variables appropriately
mv = mv + cycle_out.mv;
ip = ip + cycle_out.ip;
sv = sv + cycle_out.sv;
% Output recycling matrices for next cycle
param.U = cycle_out.U; param.SU = cycle_out.SU;
param.SAU = cycle_out.SAU;
% Potentially increase bound on sketch distortion
param.sketch_distortion = cycle_out.sketch_distortion;
if norm(r) < param.tol
break
end
end
% Output relevant parameters
out = cycle_out;
out.mv = mv;
out.ip = ip;
out.sv = sv;
out.residuals = residuals;
out.iters = iters;
out.sres = sres;
end
function [e,r,out] = srgmres_cycle(A,r0,param)
max_it = param.max_it;
tol = param.tol;
hS = param.hS;
t = param.t;
U = param.U;
k = param.k;
d = param.d;
sketch_distortion = param.sketch_distortion;
% Reset count parameters for each new cycle
mv = 0;
ip = 0;
sv = 0;
if isempty(U)
SW = [];
SAW = [];
else
% In the special case when the matrix does not change,
% we can re-use SU from previous problem,
if param.pert == 0
SW = param.SU;
SAW = param.SAU;
mv = mv + 0;
else
SW = param.SU;
if isempty(U)
SAW = [];
else
SAW = hS(A(U));
mv = mv + size(U,2);
sv = sv + size(U,2);
end
end
end
% Arnoldi for (A,b)
Sr = hS(r0);
sv = sv + 1;
if param.ssa
nrm = norm(Sr);
else
nrm = norm(r0);
ip = ip + 1;
end
SV(:,1) = Sr/nrm;
V(:,1) = r0/nrm;
% NOTE: Interestingly, the vectors for which the distortion
% norm(Sv)/norm(v) is largest away from 1, happen to be the
% residual vectors after each restart (at least with dct).
% What happens is that within a circle, norm(Sr)/norm(r)
% typically starts to deviate more and more from 1 and
% as the next cycle is restarted with a residual vector,
% it has large distortion.
%
%nrmV = norm(V(:,1));
%nrmSV = norm(SV(:,1));
% Initialize QR factorization of SAW (recycling subspace)
if strcmp(param.ls_solve,'mgs') % modified GS
[Q,R] = qr(SAW,0);
end
if strcmp(param.ls_solve,'hh') % Householder
[W,R,QtSr] = qrupdate_hh(SAW,[],[],Sr);
end
d_it = 0; sres = [];
for j = 1:max_it
w = A(V(:,j));
mv = mv + 1;
if param.ssa == 0 % standard t-truncated Arnoldi
for i = max(j-t+1,1):j
H(i,j) = V(:,i)'*w;
ip = ip + 1;
w = w - V(:,i)*H(i,j);
end
H(j+1,j) = norm(w);
ip = ip + 1;
V(:,j+1) = w/H(j+1,j);
SV(:,j+1) = hS(V(:,j+1));
sv = sv + 1;
% No need to sketch A*V since S*A*V = (S*V)*H
%SAV = SV(:,1:j+1)*H(1:j+1,1:j);
SAV(:,j) = SV(:,1:j+1)*H(1:j+1,j);
end
if param.ssa == 1 % sketched t-truncated Arnoldi
sw = hS(w); sv = sv + 1;
SAV(:,j) = sw;
% quasi-orthogonalise against U
if size(param.U,2)>0
coeffs = pinv(param.SU)*sw;
w = w - param.U*coeffs;
sw = sw - param.SU*coeffs;
end
% get coeffs with respect to previous t vectors
ind = max(j-t+1,1):j;
coeffs = SV(:,ind)'*sw;
w = w - V(:,ind)*coeffs;
%w = w - submatxmat(V,coeffs,min(ind),max(ind));
sw = sw - SV(:,ind)*coeffs;
nsw = norm(sw);
SV(:,j+1) = sw/nsw; V(:,j+1) = w/nsw;
H(ind,j) = coeffs; H(j+1,j) = nsw;
end
if param.ssa == 2 % sketch-and-select
sw = hS(w); sv = sv + 1;
SAV(:,j) = sw;
% the following two lines perform the select operation
coeffs = pinv(SV(:,1:j))*sw;
[coeffs,ind] = maxk(abs(coeffs),t);
w = w - V(:,ind)*coeffs;
sw = sw - SV(:,ind)*coeffs;
nsw = norm(sw);
SV(:,j+1) = sw/nsw; V(:,j+1) = w/nsw;
H(ind,j) = coeffs; H(j+1,j) = nsw;
end
% Every d iterations, compute the sketched residual
% If sres is small enough, compute full residual
% If this is small enough, break the inner loop
if rem(j,d) == 0 || j == max_it
d_it = d_it + 1;
% TODO: Both could be updated column-wise
SW = [ param.SU, SV(:,1:j) ];
SAW = [ param.SAU, SAV(:,1:j) ];
if ~isempty(U)
%keyboard
end
% Incrementally extend QR factorization and get LS coeffs
if strcmp(param.ls_solve,'mgs')
[Q,R] = qrupdate_gs(SAW,Q,R);
y = R\(Q'*Sr);
end
if strcmp(param.ls_solve,'hh')
[W,R,QtSr] = qrupdate_hh(SAW,W,R,QtSr);
y = triu(R)\(QtSr);
end
if strcmp(param.ls_solve,'pinv')
y = pinv(SAW)*Sr;
end
if strcmp(param.ls_solve,'\')
y = SAW\Sr;
end
% Compute residual estimate (without forming full approximation)
sres(d_it) = norm(Sr - SAW*y);
% If the residual estimate is small enough (or we reached the max
% number of iterations), then we form the full approximation
% correction (without explicitly forming [U V(:,1:j)])
if sres(d_it) < tol/sketch_distortion || j == max_it
if size(U,2) > 0
e = U*y(1:size(U,2),1) + V(:,1:j)*y(size(U,2)+1:end,1);
else
e = V(:,1:j)*y(size(U,2)+1:end,1);
end
% Compute true residual
r = r0 - A(e);
mv = mv + 1;
nrmr = norm(r);
ip = ip + 1;
% potentially increase sketch_distortion
if nrmr/sres(d_it) > sketch_distortion
sketch_distortion = nrmr/sres(d_it);
if param.verbose >= 1
% please stop commenting these out! use verbose=0
disp([' sketch distortion increased to ' num2str(sketch_distortion)])
end
end
if nrmr < tol || j == max_it
break
end
end
end
end
% Compute economic SVD of SW or SAW
if param.harmonic
[Lfull,Sigfull,Jfull] = svd(SAW,'econ'); % harmonic
else
[Lfull,Sigfull,Jfull] = svd(SW,'econ'); % non-harmonic
end
if param.verbose >= 2
fprintf(' cond(SAU) = %4.1e\n', cond(param.SAU))
fprintf(' cond(SV) = %4.1e\n', cond(SV(:,1:j)))
fprintf(' full subspace condition number = %4.1e\n', Sigfull(1,1)/Sigfull(end,end))
end
% Truncate SVD
ell = find(diag(Sigfull) > param.svd_tol*Sigfull(1,1), 1, 'last');
k = min(ell,k);
L = Lfull(:,1:ell);
Sig = Sigfull(1:ell,1:ell);
J = Jfull(:,1:ell);
if param.harmonic
HH = L'*SW*J; % harmonic
else
HH = L'*SAW*J; % non-harmonic
end
% update augmentation space using QZ
if isreal(HH) && isreal(Sig)
[AA, BB, Q, Z] = qz(HH,Sig,'real'); % Q*A*Z = AA, Q*B*Z = BB
else
[AA, BB, Q, Z] = qz(HH,Sig);
end
ritz = ordeig(AA,BB);
if param.harmonic
[~,ind] = sort(abs(ritz),'descend'); % harmonic
else
[~,ind] = sort(abs(ritz),'ascend'); % non-harmonic
end
select = false(length(ritz),1);
select(ind(1:k)) = 1;
[AA,BB,~,Z] = ordqz(AA,BB,Q,Z,select);
if k>0 && k<size(AA,1) && (AA(k+1,k)~=0 || BB(k+1,k)~=0) % don't tear apart 2x2 diagonal blocks
keep = k+1;
else
keep = k;
end
if param.verbose >= 2
disp([' recycling subspace dimension k = ' num2str(keep)])
end
% cheap update of recycling subspace without explicitly constructing [U V(:,1:j)]
JZ = J*Z(:,1:keep);
if size(U,2) > 0
out.U = U*JZ(1:size(U,2),:) + V(:,1:j)*JZ(size(U,2)+1:end,:);
else
out.U = V(:,1:j)*JZ(size(U,2)+1:end,:);
end
out.SU = SW*JZ;
out.SAU = SAW*JZ;
out.hS = hS;
out.k = keep;
out.m = j;
out.mv = mv;
out.ip = ip;
out.sv = sv;
out.sres = sres;
out.sketch_distortion = sketch_distortion;
end