-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SOT] add simulation support for user-defined iterable objects #70620
base: develop
Are you sure you want to change the base?
Conversation
test case: class IterableWithList():
def __init__(self):
self._list = [1, 2, 3]
def __iter__(self):
return self._list.__iter__()
def list_within_class(x: paddle.Tensor):
my_iterable = IterableWithList()
for i in my_iterable:
x += i
return x |
def load_sequence(self, obj): | ||
self.stack.push(obj.get_iter()) | ||
# skip call | ||
while self._instructions[self._lasti].opname != "RETURN_VALUE": | ||
self._lasti += 1 | ||
|
||
def load_method(self, method_name): | ||
obj = self.stack.pop() | ||
if isinstance(obj, ContainerVariable) and method_name == "__iter__": | ||
self.load_sequence(obj) | ||
return | ||
method_name_var = ConstantVariable.wrap_literal( | ||
method_name, self._graph | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段逻辑对上下文有较大的修改,如何保证全场景的正确性?
每条字节码应该只做自己的事情,只看当前 instruction
,不应该去访问 self._instructions
另外这里特判 ContainerVariable
和 __iter__
的原因是?
@@ -1049,6 +1049,35 @@ def main_info(self) -> dict[str, Any]: | |||
def get_py_value(self, allow_tensor=False) -> Any: | |||
return self.value | |||
|
|||
def get_iter(self): | |||
""" | |||
To simplify the problem, we only support the case where the __iter__ method returns a list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不止 list,应该是现在全部已经支持的 builtin types
另外这个实现在 VariableBase 上会有什么问题么
self._list = [1, 2, 3] | ||
|
||
def __iter__(self): | ||
return self._list.__iter__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议使用 iter(self._list)
,这是更加惯用的形式,不过当前写法也测一下吧
@@ -375,6 +391,9 @@ def test_list_extend_range(self): | |||
def test_list_extend_dict(self): | |||
self.assert_results(list_extend_dict) | |||
|
|||
def test_list_within_class(self): | |||
self.assert_results(list_within_class, paddle.to_tensor(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以新建一个文件用来测 iter,比如 test_iter.py
PR Category
User Experience
PR Types
New features
Description
如果用户自定义的
Iterable
对象,其__iter()__
方法返回的结果是list
,dict
,tuple
,tensor
,range
,那么就能够模拟iterPCard-66972