-
Notifications
You must be signed in to change notification settings - Fork 0
/
WoodburyLS_timings.m
95 lines (82 loc) · 2.55 KB
/
WoodburyLS_timings.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
close all; clear all; clc
rng('default')
runs = 10;
m = 1e5;
NN = 100:100:1000;
RR = [10,20,30];
T_qr0 = [];
T_qr1 = [];
T_chol = [];
T_update = [];
ERR = []; ERR2 = [];
for i = 1:length(NN), n = NN(i);
disp(['cols n = ' num2str(n) ])
for j = 1:length(RR), r = RR(j);
disp(['rank n = ' num2str(r) ])
rng('default')
A = randn(m,n); b = randn(m,1);
disp('solve min || b - Ax || using QR')
tt = [];
for run = 1:runs
tic
[x0,AtAsolver] = WoodburyLS(A,b); % original LS problem
tt(end+1) = toc;
end
T_qr0(i,j,:) = tt;
U = randn(m,r); V = randn(n,r);
disp('solve min || b - (A+UV'')x || using QR')
tt = [];
for run = 1:runs
tic
Ahat = A + U*V';
[Qhat,Rhat] = qr(Ahat,0);
x1 = Rhat\(Qhat'*b); % updated problem (from scratch)
tt(end+1) = toc;
end
T_qr1(i,j,:) = tt;
disp('solve min || b - (A+UV'')x || using Chol-QR')
tt = [];
for run = 1:runs
tic
Ahat = A + U*V';
M = Ahat'*Ahat;
Rc = chol(M);
%Qc = Ahat/Rc; %xc = Rc\(Qc'*b);
xc = Rc\((Rc')\(Ahat'*b));
tt(end+1) = toc;
end
T_chol(i,j,:) = tt;
% compute update like this instead:
disp('solve min || b - (A+UV'')x || using update formula')
tt = [];
for run = 1:runs
tic
x2 = WoodburyLS(A,b,U,V,x0,AtAsolver);
tt(end+1) = toc;
end
T_update(i,j,:) = tt;
ERR(i,j) = norm(x2 - x1)/norm(x1); % error check
% iterative refinement
r2 = b - A*x2 - U*(V'*x2);
x0r = AtAsolver(A'*r2);
e2r = WoodburyLS(A,r2,U,V,x0r,AtAsolver);
x2r = x2 + e2r; % refined x2
ERR2(i,j) = norm(x2r - x1)/norm(x1); % error check
save timings NN RR T_qr0 T_qr1 T_chol T_update ERR ERR2
end
end
disp('maximal relative forward error WITHOUT iterative refinement')
max(max(ERR))
disp('maximal relative forward error WITH iterative refinement')
max(max(ERR2))
%%
mydefaults
load timings
plot(NN,mean(T_qr1,3)./mean(T_update,3),'-+')
title('number of rows m = 1e5')
xlabel('number of columns n')
legend('r = 10','r = 20','r = 30','Location','NorthWest','FontSize',18)
ylabel('speedup over QR')
grid on
shg
mypdf('WoodburyLS_timings',.6,0.8)