-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreward.py
187 lines (148 loc) · 8.22 KB
/
reward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from os import environ
from re import L
from environment import Environment
import torch
from utils import gini_tensor, gini
from constants import device
def od_utility(tour_idx: torch.Tensor, environment: Environment):
"""Total sum of satisfied Origin Destination flows.
Args:
tour_idx (torch.Tensor): the generated line
environment (Environment): the environment where the line is generated.
Returns:
torch.Tensor: sum of satisfied Origin Destination Flows.
"""
sat_od_mask = environment.satisfied_od_mask(tour_idx)
reward = (environment.od_mx * sat_od_mask).sum().to(device)
return reward
def group_utility(tour_idx: torch.Tensor, environment: Environment, var_lambda=0, use_pct=True, mult_gini=False):
"""Sums total satisfied Origin Destination flows of all groups
(equal to od_utility in cases where every square with a group also has OD flows),
and subtracts a lambda of the variance (to achieve minimization of differences fairness.)
Args:
tour_idx (torch.Tensor): the generated line.
environment (Environment): the environment where the line is generated.
var_lambda (int, optional): variance weight parameter to subtract from the sum. Defaults to 0.
use_pct (boolean, optional): if True, reward will be calculated using percentage of satisfied OD per group. If false, it will use absolute values. Defaults to True.
mult_gini (boolean, optional): if True, it will multiply the group utility by 1-gini_index(group utility), as they do on the AI economist paper.
Returns:
torch.Tensor: total reward.
"""
assert environment.group_od_mx, 'Cannot use group_utility reward without group definitions. Provide --groups_file argument'
sat_od_mask = environment.satisfied_od_mask(tour_idx)
sat_group_ods = torch.zeros(len(environment.group_od_mx), device=device)
sat_group_ods_pct = torch.zeros(len(environment.group_od_mx), device=device)
for i, g_od in enumerate(environment.group_od_mx):
sat_group_ods[i] = (g_od * sat_od_mask).sum().item()
sat_group_ods_pct[i] = sat_group_ods[i] / g_od.sum()
if use_pct:
group_rw = sat_group_ods_pct
else:
group_rw = sat_group_ods
if mult_gini:
rw = group_rw.sum() * (1 - gini(group_rw.detach().cpu().numpy()))
if torch.isnan(rw):
return 0
return rw
else:
return group_rw.sum() - var_lambda * group_rw.var()
def lowest_quintile_utility(tour_idx: torch.Tensor, environment: Environment, use_pct=True, group_idx=0):
"""Based on Rawl's theory of justice - returns the satisfied OD (or % depending on use_pct) of the lowest quintile.
Args:
tour_idx (torch.Tensor): the generated line.
environment (Environment): the environment where the line is generated.
use_pct (boolean, optional): if True, reward will be calculated using percentage of satisfied OD per group. If false, it will use absolute values. Defaults to True.
group_idx (int, optional): Which group to optimize for -- defaults to the first one.
Returns:
torch.Tensor: total reward.
"""
assert environment.group_od_mx, 'Cannot use group_utility reward without group definitions. Provide --groups_file argument'
sat_od_mask = environment.satisfied_od_mask(tour_idx)
sat_od = (environment.group_od_mx[group_idx] * sat_od_mask).sum().item()
if use_pct:
return sat_od / environment.group_od_mx[group_idx].sum()
else:
return sat_od
def discounted_development_utility(tour_idx: torch.Tensor, environment: Environment, p=2.0):
"""Total sum of utility as defined by City Metro Network Expansion with Reinforcement Learning paper.
For each covered square in the generated line, calculate the distance to the other covered squares,
discount it and multiply it with the average house price of each square.
Args:
tour_idx (torch.Tensor): the generated line.
environment (Environment): the environment where the line is generated.
p (float, optional): p-norm distance to calculate: 1: manhattan, 2: euclidean, etc. Defaults to 2.0.
Returns:
torch.Tensor: sum of total discounted development utility
"""
tour_idx_g = environment.vector_to_grid(tour_idx).transpose(0, 1)
tour_ses = environment.price_mx_norm[tour_idx_g[:, 0], tour_idx_g[:, 1]]
# total_util = torch.zeros(tour_idx_g.shape[0], device=device)
total_util = torch.zeros(1, device=device)
for i in range(tour_idx_g.shape[0]):
# Calculate the distance from each origin square to every other square covered by the line.
distance = torch.cdist(tour_idx_g[i][None, :].float(), tour_idx_g.float(), p=p).squeeze()
# Discount squares based on their distance. (-0.5 is theoretically a tunable parameter)
discount = torch.exp(-0.5 * distance)
discount[i] = 0 # origin node should have no weight in the final calculation of the utility.
# total_util[i] = (discount * tour_ses).sum()
total_util += (discount * tour_ses).sum()
return total_util
def ggi(tour_idx: torch.Tensor, environment: Environment, weight, use_pct=True):
"""Generalized Gini Index reward (see paper for more information).
Exponentially smaller weights are assigned to the groups with the highest satisfied origin-destination flows.
Args:
tour_idx (torch.Tensor): the generated line.
environment (Environment): the environment where the line is generated: 1/weight^index in order
weight (int): weight base to use on the calculation of GGI:
use_pct (bool, optional): if True, reward will be calculated using percentage of satisfied OD per group. If false, it will use absolute values. Defaults to True.
Returns:
torch.Tensor: total ggi
"""
sat_od_mask = environment.satisfied_od_mask(tour_idx)
sat_group_ods = torch.zeros(len(environment.group_od_mx), device=device)
sat_group_ods_pct = torch.zeros(len(environment.group_od_mx), device=device)
for i, g_od in enumerate(environment.group_od_mx):
sat_group_ods[i] = (g_od * sat_od_mask).sum().item()
sat_group_ods_pct[i] = sat_group_ods[i] / g_od.sum()
if use_pct:
group_rw = sat_group_ods_pct
else:
group_rw = sat_group_ods
# Generate weights for each group.
weights = torch.tensor([1/(weight**i) for i in range(group_rw.shape[0])], device=device)
# "Normalize" weights to sum to 1
weights = weights/weights.sum()
group_rw, _ = torch.sort(group_rw)
reward = torch.sum(group_rw * weights)
if use_pct:
reward *= 1000
return reward
# def group_weighted_utility(tour_idx: torch.Tensor, environment: Environment, var_lambda=0, use_pct=True, mult_gini=False):
# """
# Args:
# tour_idx (torch.Tensor): the generated line.
# environment (Environment): the environment where the line is generated.
# var_lambda (int, optional): variance weight parameter to subtract from the sum. Defaults to 0.
# use_pct (boolean, optional): if True, reward will be calculated using percentage of satisfied OD per group. If false, it will use absolute values. Defaults to True.
# mult_gini (boolean, optional): if True, it will multiply the group utility by 1-gini_index(group utility), as they do on the AI economist paper.
# Returns:
# torch.Tensor: total reward.
# """
# assert environment.group_weights, 'Cannot use group_weighted_utility reward without group weights. Provide --group_weights_files argument'
# sat_od_mask = environment.satisfied_od_mask(tour_idx)
# sat_group_ods = torch.zeros(len(environment.group_od_mx), device=device)
# sat_group_ods_pct = torch.zeros(len(environment.group_od_mx), device=device)
# for i, g_od in enumerate(environment.group_od_mx):
# sat_group_ods[i] = (g_od * sat_od_mask).sum().item()
# sat_group_ods_pct[i] = sat_group_ods[i] / g_od.sum()
# if use_pct:
# group_rw = sat_group_ods_pct
# else:
# group_rw = sat_group_ods
# if mult_gini:
# rw = group_rw.sum() * (1 - gini(group_rw.detach().cpu().numpy()))
# if torch.isnan(rw):
# return 0
# return rw
# else:
# return group_rw.sum() - var_lambda * group_rw.var()