forked from NVlabs/dlinputs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdli-lsmodel
executable file
·62 lines (54 loc) · 1.71 KB
/
dli-lsmodel
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/python3
import argparse
import glob
import itertools
import re
import numpy as np
import psutil
import pylab
import torch
import torch.optim as optim
from dlinputs import paths, utils
from pylab import *
from torch import nn
from torch.autograd import Variable
# matplotlib.use("GTK")
pylab.rc("image", cmap="hot")
parser = argparse.ArgumentParser("""Show model training graphs.""")
parser.add_argument("-r", "--reload", type=float, default=60.0)
parser.add_argument("-x", "--linear_x", action="store_true")
parser.add_argument("-y", "--linear_y", action="store_true")
parser.add_argument("-n", "--mintrain", type=int, default=1)
parser.add_argument("-t", "--tail", type=int, default=1000000)
parser.add_argument("files", nargs="*")
args = parser.parse_args()
def plotfiles(fnames):
graphs = {}
for fname in fnames:
match = re.search(r"^(.*)-([0-9]{6,})-([0-9]{6,}).pt", fname)
if not match: continue
key, ntrain, err = match.groups()
graphs[key] = graphs.get(key, []) + [(int(ntrain)*1e3, int(err)*1e-6)]
clf()
if not args.linear_x: xscale('log')
if not args.linear_y: yscale('log')
handles = []
for key, graph in graphs.items():
graph = sorted(graph)
graph = graph[-args.tail:]
graph = [(x, y) for x, y in graph if x >= args.mintrain]
handles += [plot(*zip(*graph), label=key)[0]]
legend(handles=handles)
if args.files == ["-"]:
args.files = [x.strip() for x in sys.stdin.readlines()]
plotfiles(args.files)
ginput(1, 1000000)
sys.exit(0)
if args.files == []:
args.files = ["*"]
while 1:
fnames = []
for arg in args.files:
fnames += glob.glob(arg)
plotfiles(fnames)
ginput(1, args.reload)