forked from sjgershm/reward-complexity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyze_collins.m
65 lines (51 loc) · 1.97 KB
/
analyze_collins.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function results = analyze_collins(data)
% Analyze Collins (2018) data.
if nargin < 1
data = load_data('collins18');
end
beta = linspace(0.1,15,50);
for s = 1:length(data)
B = unique(data(s).learningblock);
cond = zeros(length(B),1);
R_data =zeros(length(B),1);
V_data =zeros(length(B),1);
for b = 1:length(B)
ix = data(s).learningblock==B(b) & data(s).phase==0;
state = data(s).state(ix);
c = data(s).corchoice(ix);
action = data(s).action(ix);
R_data(b) = mutual_information(state,action,0.7);
V_data(b) = mean(data(s).reward(ix));
S = unique(state);
Q = zeros(length(S),3);
Ps = zeros(1,length(S));
for i = 1:length(S)
ii = state==S(i);
Ps(i) = mean(ii);
a = c(ii); a = a(1);
Q(i,a) = 1;
end
[R(b,:),V(b,:)] = blahut_arimoto(Ps,Q,beta);
if length(S)==3
cond(b) = 1;
else
cond(b) = 2;
end
end
for c = 1:2
results.R(s,:,c) = nanmean(R(cond==c,:));
results.V(s,:,c) = nanmean(V(cond==c,:));
results.R_data(s,c) = nanmean(R_data(cond==c));
results.V_data(s,c) = nanmean(V_data(cond==c));
end
clear R V
end
p = signrank(results.R_data(:,1),results.R_data(:,2))
R = squeeze(nanmean(results.R));
V = squeeze(nanmean(results.V));
for c = 1:2
Vd2(:,c) = interp1(R(:,c),V(:,c),results.R_data(:,c));
results.bias(:,c) = results.V_data(:,c) - Vd2(:,c);
end
[r,p] = corr([results.V_data(:,1); results.V_data(:,2)],[Vd2(:,1); Vd2(:,2)])
[r,p] = corr([results.R_data(:,1); results.R_data(:,2)],abs([results.bias(:,1); results.bias(:,2)]))