-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUtils.py
157 lines (137 loc) · 5.24 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.distributed.rpc as rpc
import torchvision
from torchvision import datasets, transforms
import json
import numpy as np
from Models.CNN import Net as CNN
from Models.AlexNet import Net as AlexNet
from Models.ResNet18 import Net as ResNet
# import http.client
config = json.load(open('config.json'))
# For convenience, read the config.json from another server
# def read_config_from_internet(ip, port, save_dir):
# global config
# conn = http.client.HTTPConnection(ip, port)
# conn.request("GET", "/")
# r1 = conn.getresponse()
# data1 = r1.read()
# print(data1)
# f = open(save_dir, 'wb')
# f.write(data1)
# f.close()
# conn.close()
# config = json.load(open('config.json')) # re-read the file content
# return config
def get_ps_lr(cur_iteration, all_iteration):
return config['initial_lr']
def get_worker_gamma(cur_iteration, all_iteration):
return config['gamma']
bit_width = 16
quant_type = np.uint16
dequant_type = np.float32
quant_min = -config['quant_range'] # -0.05
quant_max = config['quant_range'] # 0.05
quant_range = quant_max - quant_min
bins = np.arange(2 ** bit_width) / (2 ** bit_width)
bins -= 0.5
bins *= quant_range
# print(bins)
def quant(arr):
norm = np.linalg.norm(arr)
# norm = 1
arr /= norm
return np.digitize(arr, bins).astype(quant_type), norm
def dequant(arr, norm):
dequanted = []
for ele in arr:
dequanted.append(bins[ele])
dequanted = np.array(dequanted)
dequanted *= norm
return dequanted
def quant_grad(grad):
# return grad
new_grad = {}
for k in grad:
quanted, norm = quant(grad[k])
new_grad[k] = [quanted, norm]
return new_grad
def dequant_grad(grad):
# return grad
new_grad = {}
for k in grad:
quanted_grad = grad[k][0]
norm = grad[k][1]
dequanted = dequant(quanted_grad, norm)
new_grad[k] = dequanted.astype(dequant_type)
return new_grad
# --------- Helper Methods --------------------
# On the local node, call a method with first arg as the value held by the
# RRef. Other args are passed in as arguments to the function called.
# Useful for calling instance methods.
def call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
# Given an RRef, return the result of calling the passed in method on the value
# held by the RRef. This call is done on the remote node that owns
# the RRef. args and kwargs are passed into the method.
# Example: If the value held by the RRef is of type Foo, then
# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
# back.
def remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)
def get_test_loader(type):
if type == 'MNIST':
return torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.1307,),(0.3081,))])),
batch_size=config['batch_size'], shuffle=True,)
elif type == "Cifar10":
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=config['batch_size'],
shuffle=False, num_workers=2)
print('Cifar test set size: ', len(testloader.dataset))
return testloader
else:
raise 'Unsupported Dataset Type'
def get_train_loader(type):
if type == 'MNIST':
return torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=config['batch_size'], shuffle=True,)
elif type == 'Cifar10':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'],
shuffle=True, num_workers=2)
print('Cifar train set size: ', len(trainloader.dataset))
return trainloader
else:
raise 'Unsupported Dataset Type'
def get_model(type):
if type == 'CNN':
return CNN
elif type == 'AlexNet':
return AlexNet
elif type == 'ResNet':
return ResNet
else:
raise 'Unsupported Model Type'