Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] add simulation support for user-defined iterable objects #70620

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from

Conversation

GoldenStain
Copy link
Contributor

@GoldenStain GoldenStain commented Jan 3, 2025

PR Category

User Experience

PR Types

New features

Description

如果用户自定义的Iterable对象,其__iter()__方法返回的结果是listdicttupletensorrange,那么就能够模拟iter
PCard-66972

@GoldenStain
Copy link
Contributor Author

test case:

class IterableWithList():
    def __init__(self):
        self._list = [1, 2, 3]
    def __iter__(self):
        return self._list.__iter__()

def list_within_class(x: paddle.Tensor):
    my_iterable = IterableWithList()
    for i in my_iterable:
        x += i
    return x

Comment on lines +863 to 876
def load_sequence(self, obj):
self.stack.push(obj.get_iter())
# skip call
while self._instructions[self._lasti].opname != "RETURN_VALUE":
self._lasti += 1

def load_method(self, method_name):
obj = self.stack.pop()
if isinstance(obj, ContainerVariable) and method_name == "__iter__":
self.load_sequence(obj)
return
method_name_var = ConstantVariable.wrap_literal(
method_name, self._graph
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑对上下文有较大的修改,如何保证全场景的正确性?

每条字节码应该只做自己的事情,只看当前 instruction,不应该去访问 self._instructions

另外这里特判 ContainerVariable__iter__ 的原因是?

@@ -1049,6 +1049,35 @@ def main_info(self) -> dict[str, Any]:
def get_py_value(self, allow_tensor=False) -> Any:
return self.value

def get_iter(self):
"""
To simplify the problem, we only support the case where the __iter__ method returns a list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不止 list,应该是现在全部已经支持的 builtin types

另外这个实现在 VariableBase 上会有什么问题么

self._list = [1, 2, 3]

def __iter__(self):
return self._list.__iter__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用 iter(self._list),这是更加惯用的形式,不过当前写法也测一下吧

@@ -375,6 +391,9 @@ def test_list_extend_range(self):
def test_list_extend_dict(self):
self.assert_results(list_extend_dict)

def test_list_within_class(self):
self.assert_results(list_within_class, paddle.to_tensor(1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以新建一个文件用来测 iter,比如 test_iter.py

@SigureMo SigureMo changed the title [3.13] add simulation support for user-defined iterable objects [SOT] add simulation support for user-defined iterable objects Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants