-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrod_cutting.py
151 lines (120 loc) · 3.7 KB
/
rod_cutting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/local/env python
"""
=================================
Dynamic programming - Rod cutting
=================================
Refs:
[1]https://en.wikipedia.org/wiki/Dynamic_programming
"""
import sys
import time
import argparse
import functools
from collections import defaultdict
def benchmark(func):
"""
Calcuate the running time for func
"""
start = time.time()
@functools.wraps(func)
def wrapper(*args, **kwargs):
rc = func(*args, **kwargs)
print('Running time: {}'.format(time.time() - start))
return rc
return wrapper
def memo(func):
"""
Cache the result for given function
"""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
key = str(args) + str(kwargs)
try:
return cache[key]
except KeyError:
rc = func(*args, **kwargs)
cache[key] = rc
return rc
return wrapper
def memo_rod_cutting(price_table, cache_size=20):
"""
Memorization version of rod cutting
:param dict price_table: The rod pricing table
:param int n: The length of the rod
:param int cache_size: The cache size
:returns: Solutions for the cutting and the optimal revenue
"""
@functools.lru_cache(maxsize=cache_size)
# @memo
def wrapper(n):
if n == 0:
return 0
revenue = max(
[price_table[n - 1]] + [price_table[i - 1] + wrapper(n - i) for i in range(1, n)]
)
return revenue
return wrapper
def memo_rod_cutting_with_solution(price_table, cache, cache_size=20):
"""
Memorization version of rod cutting with solution
:param dict price_table: The rod pricing table
:param int cache: The dict to cache the solution
:param int cache_size: The cache size
:returns: A function to find optimal revenue and solution for rod cutting
"""
@functools.lru_cache(maxsize=cache_size)
# @memo
def wrapper(n):
if n == 0:
return (0, 0)
solution, revenue = max(
[(n, price_table[n - 1])] + [(i, price_table[i - 1] + wrapper(n - i)) for i in range(1, n)],
key=lambda x: x[1]
)
cache[n] = solution
return revenue
return wrapper
def solutions_builder(cache):
"""
Build the solution with solution cache
:param dict cache: A cache that keeps the rod cutting solutions
"""
res = []
def wrapper(n):
if n not in cache:
return res
cut = cache[n]
res.append(cut)
wrapper(n - cut)
return res
return wrapper
@benchmark
def find_optimal_revenue(price_table, n):
cutting = memo_rod_cutting(price_table, n)
print('Rod with length {} optimal cutting revenue is {}'.format(
args.length, cutting(args.length))
)
@benchmark
def find_optimal_revenue_and_solution(price_table, n):
cache = {}
cutting = memo_rod_cutting_with_solution(price_table, cache, n)
revenue = cutting(n)
solutions = solutions_builder(cache)(n)
print('Rod with length {} optimal cutting solution: {}, revenue: {}'.format(
n,
' -> '.join(map(str, solutions)),
revenue
))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('length', type=int, help='The length of the rod')
args = parser.parse_args(sys.argv[1:])
price_table = defaultdict(lambda: -float('inf'))
price_table.update(
{k: r for k, r in enumerate([1, 5, 8, 9, 10, 17, 17, 20, 24, 30])}
)
# Find the optimal value without solution
find_optimal_revenue(price_table, args.length)
# Find the optimal value and solution path
find_optimal_revenue_and_solution(price_table, args.length)