diff --git a/openhands_aci/__init__.py b/openhands_aci/__init__.py index 9aba37c..a4c62f3 100644 --- a/openhands_aci/__init__.py +++ b/openhands_aci/__init__.py @@ -1,3 +1,4 @@ from .editor import file_editor +from .navigator import symbol_navigator -__all__ = ['file_editor'] +__all__ = ['file_editor', 'symbol_navigator'] diff --git a/openhands_aci/editor/exceptions.py b/openhands_aci/core/exceptions.py similarity index 88% rename from openhands_aci/editor/exceptions.py rename to openhands_aci/core/exceptions.py index c53141a..1e3eb9f 100644 --- a/openhands_aci/editor/exceptions.py +++ b/openhands_aci/core/exceptions.py @@ -5,7 +5,7 @@ def __init__(self, message): self.message = message -class EditorToolParameterMissingError(ToolError): +class MultiCommandToolParameterMissingError(ToolError): """Raised when a required parameter is missing for a tool command.""" def __init__(self, command, parameter): @@ -14,7 +14,7 @@ def __init__(self, command, parameter): self.message = f'Parameter `{parameter}` is required for command: {command}.' -class EditorToolParameterInvalidError(ToolError): +class ToolParameterInvalidError(ToolError): """Raised when a parameter is invalid for a tool command.""" def __init__(self, parameter, value, hint=None): diff --git a/openhands_aci/editor/results.py b/openhands_aci/core/results.py similarity index 74% rename from openhands_aci/editor/results.py rename to openhands_aci/core/results.py index 83dca91..e406570 100644 --- a/openhands_aci/editor/results.py +++ b/openhands_aci/core/results.py @@ -1,7 +1,7 @@ from dataclasses import asdict, dataclass, fields -from .config import MAX_RESPONSE_LEN_CHAR -from .prompts import CONTENT_TRUNCATED_NOTICE +from ..editor.config import MAX_RESPONSE_LEN_CHAR +from ..editor.prompts import CONTENT_TRUNCATED_NOTICE @dataclass @@ -45,3 +45,12 @@ def maybe_truncate( if not truncate_after or len(content) <= truncate_after else content[:truncate_after] + CONTENT_TRUNCATED_NOTICE ) + + +def make_api_tool_result(tool_result: ToolResult) -> str: + """Convert an agent ToolResult to an API ToolResultBlockParam.""" + if tool_result.error: + return f'ERROR:\n{tool_result.error}' + + assert tool_result.output, 'Expected output in file_editor.' + return tool_result.output diff --git a/openhands_aci/editor/__init__.py b/openhands_aci/editor/__init__.py index f3636d8..4098c63 100644 --- a/openhands_aci/editor/__init__.py +++ b/openhands_aci/editor/__init__.py @@ -1,20 +1,36 @@ import json import uuid +from openhands_aci.core.exceptions import ToolError +from openhands_aci.core.results import ToolResult, make_api_tool_result + from .editor import Command, OHEditor -from .exceptions import ToolError -from .results import ToolResult _GLOBAL_EDITOR = OHEditor() -def _make_api_tool_result(tool_result: ToolResult) -> str: - """Convert an agent ToolResult to an API ToolResultBlockParam.""" - if tool_result.error: - return f'ERROR:\n{tool_result.error}' +TOOL_DESCRIPTION = """Custom editing tool for viewing, creating and editing files +* State is persistent across command calls and discussions with the user +* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep +* The `create` command cannot be used if the specified `path` already exists as a file +* If a `command` generates a long output, it will be truncated and marked with `` +* The `undo_edit` command will revert the last edit made to the file at `path` + +Notes for using the `str_replace` command: +* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces! +* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique +* The `new_str` parameter should contain the edited lines that should replace the `old_str` +""" - assert tool_result.output, 'Expected output in file_editor.' - return tool_result.output +PARAMS_DESCRIPTION = { + 'command': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.', + 'path': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.', + 'file_text': 'Required parameter of `create` command, with the content of the file to be created.', + 'old_str': 'Required parameter of `str_replace` command containing the string in `path` to replace.', + 'new_str': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.', + 'insert_line': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.', + 'view_range': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.', +} def file_editor( @@ -42,7 +58,7 @@ def file_editor( except ToolError as e: result = ToolResult(error=e.message) - formatted_output_and_error = _make_api_tool_result(result) + formatted_output_and_error = make_api_tool_result(result) marker_id = uuid.uuid4().hex return f""" {json.dumps(result.to_dict(extra_field={'formatted_output_and_error': formatted_output_and_error}), indent=2)} diff --git a/openhands_aci/editor/cli.py b/openhands_aci/editor/cli.py new file mode 100644 index 0000000..1e66f0a --- /dev/null +++ b/openhands_aci/editor/cli.py @@ -0,0 +1,113 @@ +import argparse +import json +import re +import sys +from pathlib import Path +from typing import Any, NoReturn + +from .editor import Command, get_args + + +def parse_view_range(value: str) -> list[int]: + try: + start, end = map(int, value.split(',')) + return [start, end] + except ValueError: + raise argparse.ArgumentTypeError( + 'view-range must be two comma-separated integers, e.g. "1,10"' + ) + + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description='OpenHands Editor CLI - A tool for viewing and editing files' + ) + parser.add_argument( + 'command', + type=str, + choices=list(get_args(Command)), + help='The command to execute', + ) + parser.add_argument( + 'path', + type=str, + help='Path to the file or directory to operate on', + ) + parser.add_argument( + '--file-text', + type=str, + help='Content for the file when using create command', + ) + parser.add_argument( + '--view-range', + type=parse_view_range, + help='Line range to view in format "start,end", e.g. "1,10"', + ) + parser.add_argument( + '--old-str', + type=str, + help='String to replace when using str_replace command', + ) + parser.add_argument( + '--new-str', + type=str, + help='New string to insert when using str_replace or insert commands', + ) + parser.add_argument( + '--insert-line', + type=int, + help='Line number after which to insert when using insert command', + ) + parser.add_argument( + '--enable-linting', + action='store_true', + help='Enable linting for file modifications', + ) + parser.add_argument( + '--raw', + action='store_true', + help='Output raw JSON response instead of formatted text', + ) + return parser + + +def extract_result(output: str) -> dict[str, Any]: + match = re.search( + r'(.*?)', + output, + re.DOTALL, + ) + assert match, f'Output does not contain the expected tags in the correct format: {output}' + result_dict = json.loads(match.group(1)) + return result_dict + + +def main() -> NoReturn: + parser = create_parser() + args = parser.parse_args() + + # Import here to avoid circular imports + from . import file_editor + + try: + output = file_editor( + command=args.command, + path=str(Path(args.path).absolute()), + file_text=args.file_text, + view_range=args.view_range, + old_str=args.old_str, + new_str=args.new_str, + insert_line=args.insert_line, + enable_linting=args.enable_linting, + ) + + result = extract_result(output) + print(result['formatted_output_and_error']) + sys.exit(0) + except Exception as e: + print(f'ERROR: {str(e)}', file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/openhands_aci/editor/editor.py b/openhands_aci/editor/editor.py index 07d38a5..e47f47c 100644 --- a/openhands_aci/editor/editor.py +++ b/openhands_aci/editor/editor.py @@ -3,16 +3,16 @@ from pathlib import Path from typing import Literal, get_args +from openhands_aci.core.exceptions import ( + MultiCommandToolParameterMissingError, + ToolError, + ToolParameterInvalidError, +) +from openhands_aci.core.results import CLIResult, maybe_truncate from openhands_aci.linter import DefaultLinter from openhands_aci.utils.shell import run_shell_cmd from .config import SNIPPET_CONTEXT_WINDOW -from .exceptions import ( - EditorToolParameterInvalidError, - EditorToolParameterMissingError, - ToolError, -) -from .results import CLIResult, maybe_truncate Command = Literal[ 'view', @@ -20,8 +20,6 @@ 'str_replace', 'insert', 'undo_edit', - # 'jump_to_definition', TODO: - # 'find_references' TODO: ] @@ -62,7 +60,7 @@ def __call__( return self.view(_path, view_range) elif command == 'create': if file_text is None: - raise EditorToolParameterMissingError(command, 'file_text') + raise MultiCommandToolParameterMissingError(command, 'file_text') self.write_file(_path, file_text) self._file_history[_path].append(file_text) return CLIResult( @@ -73,9 +71,9 @@ def __call__( ) elif command == 'str_replace': if old_str is None: - raise EditorToolParameterMissingError(command, 'old_str') + raise MultiCommandToolParameterMissingError(command, 'old_str') if new_str == old_str: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'new_str', new_str, 'No replacement was performed. `new_str` and `old_str` must be different.', @@ -83,9 +81,9 @@ def __call__( return self.str_replace(_path, old_str, new_str, enable_linting) elif command == 'insert': if insert_line is None: - raise EditorToolParameterMissingError(command, 'insert_line') + raise MultiCommandToolParameterMissingError(command, 'insert_line') if new_str is None: - raise EditorToolParameterMissingError(command, 'new_str') + raise MultiCommandToolParameterMissingError(command, 'new_str') return self.insert(_path, insert_line, new_str, enable_linting) elif command == 'undo_edit': return self.undo_edit(_path) @@ -167,7 +165,7 @@ def view(self, path: Path, view_range: list[int] | None = None) -> CLIResult: """ if path.is_dir(): if view_range: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'view_range', view_range, 'The `view_range` parameter is not allowed when `path` points to a directory.', @@ -195,7 +193,7 @@ def view(self, path: Path, view_range: list[int] | None = None) -> CLIResult: ) if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range): - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'view_range', view_range, 'It should be a list of two integers.', @@ -205,21 +203,21 @@ def view(self, path: Path, view_range: list[int] | None = None) -> CLIResult: num_lines = len(file_content_lines) start_line, end_line = view_range if start_line < 1 or start_line > num_lines: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'view_range', view_range, f'Its first element `{start_line}` should be within the range of lines of the file: {[1, num_lines]}.', ) if end_line > num_lines: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'view_range', view_range, f'Its second element `{end_line}` should be smaller than the number of lines in the file: `{num_lines}`.', ) if end_line != -1 and end_line < start_line: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'view_range', view_range, f'Its second element `{end_line}` should be greater than or equal to the first element `{start_line}`.', @@ -262,7 +260,7 @@ def insert( num_lines = len(file_text_lines) if insert_line < 0 or insert_line > num_lines: - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'insert_line', insert_line, f'It should be within the range of lines of the file: {[0, num_lines]}', @@ -315,26 +313,26 @@ def validate_path(self, command: Command, path: Path) -> None: # Check if its an absolute path if not path.is_absolute(): suggested_path = Path('') / path - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'path', path, f'The path should be an absolute path, starting with `/`. Maybe you meant {suggested_path}?', ) # Check if path and command are compatible if command == 'create' and path.exists(): - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'path', path, f'File already exists at: {path}. Cannot overwrite files using command `create`.', ) if command != 'create' and not path.exists(): - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'path', path, f'The path {path} does not exist. Please provide a valid path.', ) if command != 'view' and path.is_dir(): - raise EditorToolParameterInvalidError( + raise ToolParameterInvalidError( 'path', path, f'The path {path} is a directory and only the `view` command can be used on directories.', diff --git a/openhands_aci/editor/prompts.py b/openhands_aci/editor/prompts.py index 57d2a02..58b9f13 100644 --- a/openhands_aci/editor/prompts.py +++ b/openhands_aci/editor/prompts.py @@ -1 +1,3 @@ CONTENT_TRUNCATED_NOTICE: str = 'To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.' + +NAVIGATION_TIPS: str = 'Use the navigation tool to investigate more about a particular class, function/method and how it is used in the codebase.' diff --git a/openhands_aci/navigator/__init__.py b/openhands_aci/navigator/__init__.py new file mode 100644 index 0000000..5f51419 --- /dev/null +++ b/openhands_aci/navigator/__init__.py @@ -0,0 +1,35 @@ +from openhands_aci.core.exceptions import ToolError +from openhands_aci.core.results import ToolResult, make_api_tool_result + +from .navigator import Command, SymbolNavigator + +_GLOBAL_NAVIGATOR = SymbolNavigator() + + +TOOL_DESCRIPTION = """Custom navigation tool for navigating to symbols in a codebase +* If there are multiple symbols with the same name, the tool will print all of them +* The `jump_to_definition` command will print the FULL definition of the symbol, along with the absolute path to the file +* The `find_references` command will print only the file content at the line where the symbol is referenced, along with the absolute path to the file +* It is more preferable to use this tool for user-defined symbols. For built-in symbols, consider using other tools like `grep` +""" + +PARAMS_DESCRIPTION = { + 'command': 'The command to run. Allowed options are: `jump_to_definition`, `find_references`.', + 'symbol_name': 'The symbol name to navigate to.', +} + + +def symbol_navigator( + command: Command, + symbol_name: str, +) -> str: + result: ToolResult | None = None + try: + result = _GLOBAL_NAVIGATOR( + command=command, + symbol_name=symbol_name, + ) + except ToolError as e: + result = ToolResult(error=e.message) + + return make_api_tool_result(result) # Return as default IPython output diff --git a/openhands_aci/navigator/cli.py b/openhands_aci/navigator/cli.py new file mode 100644 index 0000000..7d24380 --- /dev/null +++ b/openhands_aci/navigator/cli.py @@ -0,0 +1,47 @@ +import argparse +import sys +from typing import NoReturn + +from .navigator import Command, get_args + + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description='OpenHands Navigator CLI - A tool for navigating in a codebase' + ) + parser.add_argument( + 'command', + type=str, + choices=list(get_args(Command)), + help='The command to execute', + ) + parser.add_argument( + 'symbol_name', + type=str, + help='The symbol name to navigate to', + ) + return parser + + +def main() -> NoReturn: + parser = create_parser() + args = parser.parse_args() + + # Import here to avoid circular imports + from . import symbol_navigator + + try: + result = symbol_navigator( + command=args.command, + symbol_name=args.symbol_name, + ) + + print(result) + sys.exit(0) + except Exception as e: + print(f'ERROR: {str(e)}', file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/openhands_aci/navigator/navigator.py b/openhands_aci/navigator/navigator.py new file mode 100644 index 0000000..0399393 --- /dev/null +++ b/openhands_aci/navigator/navigator.py @@ -0,0 +1,297 @@ +import os +from collections import defaultdict +from typing import Literal, get_args + +from grep_ast import TreeContext +from rapidfuzz import process +from tqdm import tqdm + +from openhands_aci.core.exceptions import ToolError, ToolParameterInvalidError +from openhands_aci.core.results import CLIResult +from openhands_aci.tree_sitter.parser import ParsedTag, TagKind, TreeSitterParser +from openhands_aci.utils.file import GitRepoUtils, get_modified_time, read_text +from openhands_aci.utils.path import PathUtils + +Command = Literal[ + 'jump_to_definition', + 'find_references', +] + + +class SymbolNavigator: + """ + A symbol navigator that allows the agent to: + - jump to the definition of a symbol + - find references to a symbol + """ + + TOOL_NAME = 'oh_navigator' + + def __init__(self, show_progress=False) -> None: + self.show_progress = show_progress + + # Lazy-initialized attributes + self._git_utils: GitRepoUtils | None = None + self._path_utils: PathUtils | None = None + self._ts_parser: TreeSitterParser | None = None + self._git_repo_found: bool | None = None + + # Caching + self._file_context_cache: dict = {} # (rel_file) -> {'context': TreeContext_obj, 'mtime': mtime}) + self._rendered_tree_cache: dict = {} # (rel_file, lines_of_interest, mtime) -> rendered_tree + + @property + def git_utils(self): + if self._git_repo_found is None: + pwd = os.getcwd() + try: + self._git_utils = GitRepoUtils( + pwd + ) # pwd is set to the workspace automatically + self._git_repo_found = True + except Exception: + self._git_repo_found = False + return None + + return self._git_utils + + if not self._git_repo_found: + return None + return self._git_utils + + @property + def path_utils(self): + if self._path_utils is None: + pwd = os.getcwd() + self._path_utils = PathUtils(pwd) + return self._path_utils + + @property + def ts_parser(self): + if self._ts_parser is None: + pwd = os.getcwd() + self._ts_parser = TreeSitterParser(pwd) + return self._ts_parser + + @property + def is_enabled(self): + if self._git_repo_found is None: + self.git_utils # Initialize the git_utils + return bool(self._git_repo_found) + + def __call__(self, *, command: Command, symbol_name: str, **kwargs) -> CLIResult: + if not symbol_name: + raise ToolParameterInvalidError( + 'symbol_name', symbol_name, 'Symbol name cannot be empty.' + ) + + if command == 'jump_to_definition': + return CLIResult(output=self.get_definitions_tree(symbol_name)) + elif command == 'find_references': + return CLIResult(output=self.get_references_tree(symbol_name)) + + raise ToolError( + f'Unrecognized command {command}. The allowed commands for the {self.TOOL_NAME} tool are: {", ".join(get_args(Command))}' + ) + + def get_definitions_tree( + self, symbol: str, rel_file_path: str | None = None, use_end_line=True + ): + if not self.git_utils: + return 'No git repository found. Navigation commands are disabled. Please use bash commands instead.' + + ident2defrels, _, identwrel2deftags, _ = self._get_parsed_tags() + + # Extract definitions for the symbol + def_tags = set() + if symbol: + def_rels = ident2defrels.get(symbol, set()) + for def_rel in def_rels: + if rel_file_path is not None and rel_file_path not in def_rel: + continue + def_tags.update(identwrel2deftags.get((def_rel, symbol), set())) + + if not def_tags: + # Perform a fuzzy search for the symbol + choices = list(ident2defrels.keys()) + suggested_matches = process.extract(symbol, choices, limit=5) + return f"No definitions found for `{symbol}`. Maybe you meant one of these: {', '.join(match[0] for match in suggested_matches)}?" + + # Concatenate the definitions to another tree representation + defs_repr = '' + defs_repr += f'Definition(s) of `{symbol}`:\n' + # Sort the tags by file path and line number + def_tags_list = list(def_tags) + def_tags_list.sort(key=lambda tag: (tag.rel_path, tag.start_line)) + defs_repr += self._tag_list_to_tree(def_tags_list, use_end_line=use_end_line) + defs_repr += '\n' + + return defs_repr + + def get_references_tree(self, symbol: str): + if not self.git_utils: + return 'No git repository found. Navigation commands are disabled. Please use bash commands instead.' + + _, ident2refrels, _, identwrel2reftags = self._get_parsed_tags() + + # Extract references for the symbol + ref_tags = set() + ref_rels = ident2refrels.get(symbol, set()) + for ref_rel in ref_rels: + ref_tags.update(identwrel2reftags.get((ref_rel, symbol), set())) + + if not ref_tags: + # Perform a fuzzy search for the symbol + choices = list(ident2refrels.keys()) + suggested_matches = process.extract(symbol, choices, limit=5) + return f"No references found for `{symbol}`. Maybe you meant one of these: {', '.join(match[0] for match in suggested_matches)}?" + + # Concatenate the direct references to another tree representation + direct_refs_repr = '' + direct_refs_repr += f'References to `{symbol}`:\n' + # Sort the tags by file path and line number + ref_tags_list = list(ref_tags) + ref_tags_list.sort(key=lambda tag: (tag.rel_path, tag.start_line)) + direct_refs_repr += self._tag_list_to_tree(ref_tags_list, use_end_line=False) + direct_refs_repr += '\n' + + return direct_refs_repr + + def _get_parsed_tags( + self, + depth: int | None = None, + rel_dir_path: str | None = None, + ) -> tuple[dict, dict, dict, dict]: + """ + Parse all tags in the tracked files and return the following dictionaries: + - ident2defrels: symbol identifier -> set of its definitions' relative file paths + - ident2refrels: symbol identifier -> list of its references' relative file paths + - identwrel2deftags: (symbol identifier, relative file) -> set of its DEF tags + - identwrel2reftags: (symbol identifier, relative file) -> set of its REF tags + """ + if rel_dir_path: + all_abs_files = self.git_utils.get_absolute_tracked_files_in_directory( + rel_dir_path=rel_dir_path, + depth=depth, + ) + else: + all_abs_files = self.git_utils.get_all_absolute_tracked_files(depth=depth) + + ident2defrels = defaultdict( + set + ) # symbol identifier -> set of its definitions' relative file paths + ident2refrels = defaultdict( + list + ) # symbol identifier -> list of its references' relative file paths + identwrel2deftags = defaultdict( + set + ) # (relative file, symbol identifier) -> set of its DEF tags + identwrel2reftags = defaultdict( + set + ) # (relative file, symbol identifier) -> set of its REF tags + + all_abs_files_iter = ( + tqdm(all_abs_files, desc='Parsing tags', unit='file') + if self.show_progress + else all_abs_files + ) + for abs_file in all_abs_files_iter: + rel_file = self.path_utils.get_relative_path_str(abs_file) + parsed_tags = self.ts_parser.get_tags_from_file(abs_file, rel_file) + + for parsed_tag in parsed_tags: + if parsed_tag.tag_kind == TagKind.DEF: + ident2defrels[parsed_tag.node_content].add(rel_file) + identwrel2deftags[(rel_file, parsed_tag.node_content)].add( + parsed_tag + ) + if parsed_tag.tag_kind == TagKind.REF: + ident2refrels[parsed_tag.node_content].append(rel_file) + identwrel2reftags[(rel_file, parsed_tag.node_content)].add( + parsed_tag + ) + + return ident2defrels, ident2refrels, identwrel2deftags, identwrel2reftags + + def _tag_list_to_tree(self, tags: list[ParsedTag], use_end_line=False) -> str: + """ + Convert a list of ParsedTag objects to a tree str representation. + """ + if not tags: + return '' + + cur_rel_file, cur_abs_file = '', '' + lines_of_interest: list[int] = [] + output = '' + + dummy_tag = ParsedTag( + abs_path='', + rel_path='', + node_content='', + tag_kind=TagKind.DEF, + start_line=0, + end_line=0, + ) + for tag in tags + [dummy_tag]: # Add dummy tag to trigger last file output + if tag.rel_path != cur_rel_file: + if lines_of_interest: + output += cur_rel_file + ':\n' + output += self._render_tree( + cur_abs_file, cur_rel_file, lines_of_interest + ) + lines_of_interest = [] + elif cur_rel_file: # No line of interest + output += '\n' + cur_rel_file + ':\n' + + cur_abs_file = tag.abs_path + cur_rel_file = tag.rel_path + + lines_of_interest += ( + list(range(tag.start_line, tag.end_line + 1)) + if use_end_line + else [tag.start_line] + ) + + # Truncate long lines in case we get minified js or something else crazy + output = '\n'.join(line[:150] for line in output.splitlines()) + return output + + def _render_tree( + self, abs_file: str, rel_file: str, lines_of_interest: list + ) -> str: + mtime = get_modified_time(abs_file) + tree_cache_key = (rel_file, tuple(sorted(lines_of_interest)), mtime) + if tree_cache_key in self._rendered_tree_cache: + return self._rendered_tree_cache[tree_cache_key] + + if ( + rel_file not in self._file_context_cache + or self._file_context_cache[rel_file]['mtime'] < mtime + ): + code = read_text(abs_file) or '' + if not code.endswith('\n'): + code += '\n' + + context = TreeContext( + filename=rel_file, + code=code, + color=False, + line_number=True, + child_context=False, + last_line=False, + margin=0, + mark_lois=False, + loi_pad=0, + # header_max=30, + show_top_of_file_parent_scope=False, + ) + self._file_context_cache[rel_file] = {'context': context, 'mtime': mtime} + else: + context = self._file_context_cache[rel_file]['context'] + + context.lines_of_interest = set() + context.add_lines_of_interest(lines_of_interest) + context.add_context() + res = context.format() + self._rendered_tree_cache[tree_cache_key] = res + return res diff --git a/openhands_aci/tree_sitter/parser.py b/openhands_aci/tree_sitter/parser.py new file mode 100644 index 0000000..58dce3d --- /dev/null +++ b/openhands_aci/tree_sitter/parser.py @@ -0,0 +1,138 @@ +import tempfile +import warnings +from collections import namedtuple +from enum import Enum +from pathlib import Path + +from diskcache import Cache +from grep_ast import filename_to_lang +from tree_sitter_languages import get_language, get_parser + +from openhands_aci.utils.file import get_modified_time, read_text +from openhands_aci.utils.logger import oh_aci_logger as logger + +warnings.filterwarnings('ignore', category=FutureWarning, module='tree_sitter') + +ParsedTag = namedtuple( + 'ParsedTag', + ('rel_path', 'abs_path', 'start_line', 'end_line', 'node_content', 'tag_kind'), +) + + +class TagKind(Enum): + DEF = 'def' + REF = 'ref' + DEF_WITH_BODY = 'def_with_body' + + +class TreeSitterParser: + TAGS_CACHE_DIR = '.oh_aci.cache.tags' + + def __init__(self, cache_root_dir: str) -> None: + self._load_tags_cache(cache_root_dir) + + def get_tags_from_file(self, abs_path: str, rel_path: str) -> list[ParsedTag]: + mtime = get_modified_time(abs_path) + cache_key = abs_path + cache_val = self.tags_cache.get(cache_key) + if cache_val and cache_val.get('mtime') == mtime: + return cache_val.get('data') + + data = self._get_tags_raw(abs_path, rel_path) + # Update cache + self.tags_cache[cache_key] = {'mtime': mtime, 'data': data} + return data + + def _get_tags_raw(self, abs_path: str, rel_path: str) -> list[ParsedTag]: + lang = filename_to_lang(abs_path) + if not lang: + return [] + + ts_language = get_language(lang) + ts_parser = get_parser(lang) + + tags_file_path = ( + Path(__file__).resolve().parent / 'queries' / f'tree-sitter-{lang}-tags.scm' + ) + if not tags_file_path.exists(): + return [] + tags_query = tags_file_path.read_text() + + if not Path(abs_path).exists(): + return [] + code = read_text(abs_path) + if not code: + return [] + + parsed_tree = ts_parser.parse(bytes(code, 'utf-8')) + + # Run the tags queries + query = ts_language.query(tags_query) + captures = query.captures(parsed_tree.root_node) + + parsed_tags = [] + for node, tag_str in captures: + if tag_str.startswith('name.definition.'): + tag_kind = TagKind.DEF + elif tag_str.startswith('name.reference.'): + tag_kind = TagKind.REF + elif tag_str.startswith('definition.'): + tag_kind = TagKind.DEF_WITH_BODY + else: + # Skip other tags + continue + + result_tag = ParsedTag( + rel_path=rel_path, + abs_path=abs_path, + start_line=node.start_point[0], + end_line=node.end_point[0], + node_content=node.text.decode('utf-8'), + tag_kind=tag_kind, + ) + parsed_tags.append(result_tag) + + parsed_tags = self._update_end_lines_for_def_using_def_with_body(parsed_tags) + return parsed_tags + + def _update_end_lines_for_def_using_def_with_body( + self, parsed_tags: list[ParsedTag] + ) -> list[ParsedTag]: + # Create a dictionary to quickly look up end_line for DEF_WITH_BODY tags + def_with_body_lookup = { + (tag.abs_path, tag.start_line): tag.end_line + for tag in parsed_tags + if tag.tag_kind == TagKind.DEF_WITH_BODY + } + + # Iterate over tags and update end_line if a matching DEF_WITH_BODY exists + result_tags = [] + for tag in parsed_tags: + if ( + tag.tag_kind == TagKind.DEF + and (tag.abs_path, tag.start_line) in def_with_body_lookup + ): + updated_tag = ParsedTag( + rel_path=tag.rel_path, + abs_path=tag.abs_path, + start_line=tag.start_line, + end_line=def_with_body_lookup[(tag.abs_path, tag.start_line)], + node_content=tag.node_content, + tag_kind=tag.tag_kind, + ) + result_tags.append(updated_tag) + else: + result_tags.append(tag) + + return result_tags + + def _load_tags_cache(self, abs_root_dir: str) -> None: + safe_path = str(Path(abs_root_dir).resolve()).replace('/', '_').lstrip('_') + cache_path = Path(tempfile.gettempdir()) / safe_path / self.TAGS_CACHE_DIR + try: + self.tags_cache = Cache(cache_path) + except Exception: + logger.warning( + f'Could not load tags cache from {cache_path}, try deleting cache directory.' + ) + self.tags_cache = dict() diff --git a/openhands_aci/tree_sitter/queries/README.md b/openhands_aci/tree_sitter/queries/README.md new file mode 100644 index 0000000..d7c8df0 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/README.md @@ -0,0 +1,22 @@ +# Credits + +`openhands-aci` uses [aider](https://github.com/Aider-AI/aider/tree/main/aider/queries)'s modified versions of the tags.scm files from these open source +tree-sitter language implementations: + +* [https://github.com/tree-sitter/tree-sitter-c](https://github.com/tree-sitter/tree-sitter-c) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-c-sharp](https://github.com/tree-sitter/tree-sitter-c-sharp) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-cpp](https://github.com/tree-sitter/tree-sitter-cpp) — licensed under the MIT License. +* [https://github.com/Wilfred/tree-sitter-elisp](https://github.com/Wilfred/tree-sitter-elisp) — licensed under the MIT License. +* [https://github.com/elixir-lang/tree-sitter-elixir](https://github.com/elixir-lang/tree-sitter-elixir) — licensed under the Apache License, Version 2.0. +* [https://github.com/elm-tooling/tree-sitter-elm](https://github.com/elm-tooling/tree-sitter-elm) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-go](https://github.com/tree-sitter/tree-sitter-go) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-java](https://github.com/tree-sitter/tree-sitter-java) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-javascript](https://github.com/tree-sitter/tree-sitter-javascript) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-ocaml](https://github.com/tree-sitter/tree-sitter-ocaml) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-php](https://github.com/tree-sitter/tree-sitter-php) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-python](https://github.com/tree-sitter/tree-sitter-python) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-ql](https://github.com/tree-sitter/tree-sitter-ql) — licensed under the MIT License. +* [https://github.com/r-lib/tree-sitter-r](https://github.com/r-lib/tree-sitter-r) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-ruby](https://github.com/tree-sitter/tree-sitter-ruby) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-rust](https://github.com/tree-sitter/tree-sitter-rust) — licensed under the MIT License. +* [https://github.com/tree-sitter/tree-sitter-typescript](https://github.com/tree-sitter/tree-sitter-typescript) — licensed under the MIT License. diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-c-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-c-tags.scm new file mode 100644 index 0000000..1035aa2 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-c-tags.scm @@ -0,0 +1,9 @@ +(struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class + +(declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class + +(function_declarator declarator: (identifier) @name.definition.function) @definition.function + +(type_definition declarator: (type_identifier) @name.definition.type) @definition.type + +(enum_specifier name: (type_identifier) @name.definition.type) @definition.type diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-c_sharp-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-c_sharp-tags.scm new file mode 100644 index 0000000..58e9199 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-c_sharp-tags.scm @@ -0,0 +1,46 @@ +(class_declaration + name: (identifier) @name.definition.class + ) @definition.class + +(class_declaration + bases: (base_list (_) @name.reference.class) + ) @reference.class + +(interface_declaration + name: (identifier) @name.definition.interface + ) @definition.interface + +(interface_declaration + bases: (base_list (_) @name.reference.interface) + ) @reference.interface + +(method_declaration + name: (identifier) @name.definition.method + ) @definition.method + +(object_creation_expression + type: (identifier) @name.reference.class + ) @reference.class + +(type_parameter_constraints_clause + target: (identifier) @name.reference.class + ) @reference.class + +(type_constraint + type: (identifier) @name.reference.class + ) @reference.class + +(variable_declaration + type: (identifier) @name.reference.class + ) @reference.class + +(invocation_expression + function: + (member_access_expression + name: (identifier) @name.reference.send + ) +) @reference.send + +(namespace_declaration + name: (identifier) @name.definition.module +) @definition.module diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-cpp-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-cpp-tags.scm new file mode 100644 index 0000000..7a7ad0b --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-cpp-tags.scm @@ -0,0 +1,15 @@ +(struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class + +(declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class + +(function_declarator declarator: (identifier) @name.definition.function) @definition.function + +(function_declarator declarator: (field_identifier) @name.definition.function) @definition.function + +(function_declarator declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method)) @definition.method + +(type_definition declarator: (type_identifier) @name.definition.type) @definition.type + +(enum_specifier name: (type_identifier) @name.definition.type) @definition.type + +(class_specifier name: (type_identifier) @name.definition.class) @definition.class diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-elisp-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-elisp-tags.scm new file mode 100644 index 0000000..743c8d8 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-elisp-tags.scm @@ -0,0 +1,8 @@ +;; defun/defsubst +(function_definition name: (symbol) @name.definition.function) @definition.function + +;; Treat macros as function definitions for the sake of TAGS. +(macro_definition name: (symbol) @name.definition.function) @definition.function + +;; Match function calls +(list (symbol) @name.reference.function) @reference.function diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-elixir-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-elixir-tags.scm new file mode 100644 index 0000000..9eb39d9 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-elixir-tags.scm @@ -0,0 +1,54 @@ +; Definitions + +; * modules and protocols +(call + target: (identifier) @ignore + (arguments (alias) @name.definition.module) + (#match? @ignore "^(defmodule|defprotocol)$")) @definition.module + +; * functions/macros +(call + target: (identifier) @ignore + (arguments + [ + ; zero-arity functions with no parentheses + (identifier) @name.definition.function + ; regular function clause + (call target: (identifier) @name.definition.function) + ; function clause with a guard clause + (binary_operator + left: (call target: (identifier) @name.definition.function) + operator: "when") + ]) + (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @definition.function + +; References + +; ignore calls to kernel/special-forms keywords +(call + target: (identifier) @ignore + (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp|defmodule|defprotocol|defimpl|defstruct|defexception|defoverridable|alias|case|cond|else|for|if|import|quote|raise|receive|require|reraise|super|throw|try|unless|unquote|unquote_splicing|use|with)$")) + +; ignore module attributes +(unary_operator + operator: "@" + operand: (call + target: (identifier) @ignore)) + +; * function call +(call + target: [ + ; local + (identifier) @name.reference.call + ; remote + (dot + right: (identifier) @name.reference.call) + ]) @reference.call + +; * pipe into function call +(binary_operator + operator: "|>" + right: (identifier) @name.reference.call) @reference.call + +; * modules +(alias) @name.reference.module @reference.module diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-elm-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-elm-tags.scm new file mode 100644 index 0000000..8b1589e --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-elm-tags.scm @@ -0,0 +1,19 @@ +(value_declaration (function_declaration_left (lower_case_identifier) @name.definition.function)) @definition.function + +(function_call_expr (value_expr (value_qid) @name.reference.function)) @reference.function +(exposed_value (lower_case_identifier) @name.reference.function) @reference.function +(type_annotation ((lower_case_identifier) @name.reference.function) (colon)) @reference.function + +(type_declaration ((upper_case_identifier) @name.definition.type)) @definition.type + +(type_ref (upper_case_qid (upper_case_identifier) @name.reference.type)) @reference.type +(exposed_type (upper_case_identifier) @name.reference.type) @reference.type + +(type_declaration (union_variant (upper_case_identifier) @name.definition.union)) @definition.union + +(value_expr (upper_case_qid (upper_case_identifier) @name.reference.union)) @reference.union + + +(module_declaration + (upper_case_qid (upper_case_identifier)) @name.definition.module +) @definition.module diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-go-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-go-tags.scm new file mode 100644 index 0000000..a32d03a --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-go-tags.scm @@ -0,0 +1,30 @@ +( + (comment)* @doc + . + (function_declaration + name: (identifier) @name.definition.function) @definition.function + (#strip! @doc "^//\\s*") + (#set-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (method_declaration + name: (field_identifier) @name.definition.method) @definition.method + (#strip! @doc "^//\\s*") + (#set-adjacent! @doc @definition.method) +) + +(call_expression + function: [ + (identifier) @name.reference.call + (parenthesized_expression (identifier) @name.reference.call) + (selector_expression field: (field_identifier) @name.reference.call) + (parenthesized_expression (selector_expression field: (field_identifier) @name.reference.call)) + ]) @reference.call + +(type_spec + name: (type_identifier) @name.definition.type) @definition.type + +(type_identifier) @name.reference.type @reference.type diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-java-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-java-tags.scm new file mode 100644 index 0000000..3b7290d --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-java-tags.scm @@ -0,0 +1,20 @@ +(class_declaration + name: (identifier) @name.definition.class) @definition.class + +(method_declaration + name: (identifier) @name.definition.method) @definition.method + +(method_invocation + name: (identifier) @name.reference.call + arguments: (argument_list) @reference.call) + +(interface_declaration + name: (identifier) @name.definition.interface) @definition.interface + +(type_list + (type_identifier) @name.reference.implementation) @reference.implementation + +(object_creation_expression + type: (type_identifier) @name.reference.class) @reference.class + +(superclass (type_identifier) @name.reference.class) @reference.class diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-javascript-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-javascript-tags.scm new file mode 100644 index 0000000..3bc55c5 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-javascript-tags.scm @@ -0,0 +1,88 @@ +( + (comment)* @doc + . + (method_definition + name: (property_identifier) @name.definition.method) @definition.method + (#not-eq? @name.definition.method "constructor") + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.method) +) + +( + (comment)* @doc + . + [ + (class + name: (_) @name.definition.class) + (class_declaration + name: (_) @name.definition.class) + ] @definition.class + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.class) +) + +( + (comment)* @doc + . + [ + (function + name: (identifier) @name.definition.function) + (function_declaration + name: (identifier) @name.definition.function) + (generator_function + name: (identifier) @name.definition.function) + (generator_function_declaration + name: (identifier) @name.definition.function) + ] @definition.function + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (lexical_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function) + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (variable_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function) + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +(assignment_expression + left: [ + (identifier) @name.definition.function + (member_expression + property: (property_identifier) @name.definition.function) + ] + right: [(arrow_function) (function)] +) @definition.function + +(pair + key: (property_identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function + +( + (call_expression + function: (identifier) @name.reference.call) @reference.call + (#not-match? @name.reference.call "^(require)$") +) + +(call_expression + function: (member_expression + property: (property_identifier) @name.reference.call) + arguments: (_) @reference.call) + +(new_expression + constructor: (_) @name.reference.class) @reference.class diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-ocaml-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-ocaml-tags.scm new file mode 100644 index 0000000..52d5a85 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-ocaml-tags.scm @@ -0,0 +1,115 @@ +; Modules +;-------- + +( + (comment)? @doc . + (module_definition (module_binding (module_name) @name.definition.module) @definition.module) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(module_path (module_name) @name.reference.module) @reference.module + +; Module types +;-------------- + +( + (comment)? @doc . + (module_type_definition (module_type_name) @name.definition.interface) @definition.interface + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(module_type_path (module_type_name) @name.reference.implementation) @reference.implementation + +; Functions +;---------- + +( + (comment)? @doc . + (value_definition + [ + (let_binding + pattern: (value_name) @name.definition.function + (parameter)) + (let_binding + pattern: (value_name) @name.definition.function + body: [(fun_expression) (function_expression)]) + ] @definition.function + ) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +( + (comment)? @doc . + (external (value_name) @name.definition.function) @definition.function + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(application_expression + function: (value_path (value_name) @name.reference.call)) @reference.call + +(infix_expression + left: (value_path (value_name) @name.reference.call) + operator: (concat_operator) @reference.call + (#eq? @reference.call "@@")) + +(infix_expression + operator: (rel_operator) @reference.call + right: (value_path (value_name) @name.reference.call) + (#eq? @reference.call "|>")) + +; Operator +;--------- + +( + (comment)? @doc . + (value_definition + (let_binding + pattern: (parenthesized_operator (_) @name.definition.function)) @definition.function) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +[ + (prefix_operator) + (sign_operator) + (pow_operator) + (mult_operator) + (add_operator) + (concat_operator) + (rel_operator) + (and_operator) + (or_operator) + (assign_operator) + (hash_operator) + (indexing_operator) + (let_operator) + (let_and_operator) + (match_operator) +] @name.reference.call @reference.call + +; Classes +;-------- + +( + (comment)? @doc . + [ + (class_definition (class_binding (class_name) @name.definition.class) @definition.class) + (class_type_definition (class_type_binding (class_type_name) @name.definition.class) @definition.class) + ] + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +[ + (class_path (class_name) @name.reference.class) + (class_type_path (class_type_name) @name.reference.class) +] @reference.class + +; Methods +;-------- + +( + (comment)? @doc . + (method_definition (method_name) @name.definition.method) @definition.method + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(method_invocation (method_name) @name.reference.call) @reference.call diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-php-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-php-tags.scm new file mode 100644 index 0000000..61c86fc --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-php-tags.scm @@ -0,0 +1,26 @@ +(class_declaration + name: (name) @name.definition.class) @definition.class + +(function_definition + name: (name) @name.definition.function) @definition.function + +(method_declaration + name: (name) @name.definition.function) @definition.function + +(object_creation_expression + [ + (qualified_name (name) @name.reference.class) + (variable_name (name) @name.reference.class) + ]) @reference.class + +(function_call_expression + function: [ + (qualified_name (name) @name.reference.call) + (variable_name (name)) @name.reference.call + ]) @reference.call + +(scoped_call_expression + name: (name) @name.reference.call) @reference.call + +(member_call_expression + name: (name) @name.reference.call) @reference.call diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-python-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-python-tags.scm new file mode 100644 index 0000000..3be5bed --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-python-tags.scm @@ -0,0 +1,12 @@ +(class_definition + name: (identifier) @name.definition.class) @definition.class + +(function_definition + name: (identifier) @name.definition.function) @definition.function + +(call + function: [ + (identifier) @name.reference.call + (attribute + attribute: (identifier) @name.reference.call) + ]) @reference.call diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-ql-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-ql-tags.scm new file mode 100644 index 0000000..3164aa2 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-ql-tags.scm @@ -0,0 +1,26 @@ +(classlessPredicate + name: (predicateName) @name.definition.function) @definition.function + +(memberPredicate + name: (predicateName) @name.definition.method) @definition.method + +(aritylessPredicateExpr + name: (literalId) @name.reference.call) @reference.call + +(module + name: (moduleName) @name.definition.module) @definition.module + +(dataclass + name: (className) @name.definition.class) @definition.class + +(datatype + name: (className) @name.definition.class) @definition.class + +(datatypeBranch + name: (className) @name.definition.class) @definition.class + +(qualifiedRhs + name: (predicateName) @name.reference.call) @reference.call + +(typeExpr + name: (className) @name.reference.type) @reference.type diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-ruby-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-ruby-tags.scm new file mode 100644 index 0000000..79e71d2 --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-ruby-tags.scm @@ -0,0 +1,64 @@ +; Method definitions + +( + (comment)* @doc + . + [ + (method + name: (_) @name.definition.method) @definition.method + (singleton_method + name: (_) @name.definition.method) @definition.method + ] + (#strip! @doc "^#\\s*") + (#select-adjacent! @doc @definition.method) +) + +(alias + name: (_) @name.definition.method) @definition.method + +(setter + (identifier) @ignore) + +; Class definitions + +( + (comment)* @doc + . + [ + (class + name: [ + (constant) @name.definition.class + (scope_resolution + name: (_) @name.definition.class) + ]) @definition.class + (singleton_class + value: [ + (constant) @name.definition.class + (scope_resolution + name: (_) @name.definition.class) + ]) @definition.class + ] + (#strip! @doc "^#\\s*") + (#select-adjacent! @doc @definition.class) +) + +; Module definitions + +( + (module + name: [ + (constant) @name.definition.module + (scope_resolution + name: (_) @name.definition.module) + ]) @definition.module +) + +; Calls + +(call method: (identifier) @name.reference.call) @reference.call + +( + [(identifier) (constant)] @name.reference.call @reference.call + (#is-not? local) + (#not-match? @name.reference.call "^(lambda|load|require|require_relative|__FILE__|__LINE__)$") +) diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-rust-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-rust-tags.scm new file mode 100644 index 0000000..dadfa7a --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-rust-tags.scm @@ -0,0 +1,60 @@ +; ADT definitions + +(struct_item + name: (type_identifier) @name.definition.class) @definition.class + +(enum_item + name: (type_identifier) @name.definition.class) @definition.class + +(union_item + name: (type_identifier) @name.definition.class) @definition.class + +; type aliases + +(type_item + name: (type_identifier) @name.definition.class) @definition.class + +; method definitions + +(declaration_list + (function_item + name: (identifier) @name.definition.method)) @definition.method + +; function definitions + +(function_item + name: (identifier) @name.definition.function) @definition.function + +; trait definitions +(trait_item + name: (type_identifier) @name.definition.interface) @definition.interface + +; module definitions +(mod_item + name: (identifier) @name.definition.module) @definition.module + +; macro definitions + +(macro_definition + name: (identifier) @name.definition.macro) @definition.macro + +; references + +(call_expression + function: (identifier) @name.reference.call) @reference.call + +(call_expression + function: (field_expression + field: (field_identifier) @name.reference.call)) @reference.call + +(macro_invocation + macro: (identifier) @name.reference.call) @reference.call + +; implementations + +(impl_item + trait: (type_identifier) @name.reference.implementation) @reference.implementation + +(impl_item + type: (type_identifier) @name.reference.implementation + !trait) @reference.implementation diff --git a/openhands_aci/tree_sitter/queries/tree-sitter-typescript-tags.scm b/openhands_aci/tree_sitter/queries/tree-sitter-typescript-tags.scm new file mode 100644 index 0000000..8a73dcc --- /dev/null +++ b/openhands_aci/tree_sitter/queries/tree-sitter-typescript-tags.scm @@ -0,0 +1,41 @@ +(function_signature + name: (identifier) @name.definition.function) @definition.function + +(method_signature + name: (property_identifier) @name.definition.method) @definition.method + +(abstract_method_signature + name: (property_identifier) @name.definition.method) @definition.method + +(abstract_class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(module + name: (identifier) @name.definition.module) @definition.module + +(interface_declaration + name: (type_identifier) @name.definition.interface) @definition.interface + +(type_annotation + (type_identifier) @name.reference.type) @reference.type + +(new_expression + constructor: (identifier) @name.reference.class) @reference.class + +(function_declaration + name: (identifier) @name.definition.function) @definition.function + +(method_definition + name: (property_identifier) @name.definition.method) @definition.method + +(class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(interface_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(type_alias_declaration + name: (type_identifier) @name.definition.type) @definition.type + +(enum_declaration + name: (identifier) @name.definition.enum) @definition.enum diff --git a/openhands_aci/utils/file.py b/openhands_aci/utils/file.py new file mode 100644 index 0000000..3491b43 --- /dev/null +++ b/openhands_aci/utils/file.py @@ -0,0 +1,69 @@ +from pathlib import Path + +from .logger import oh_aci_logger as logger +from .path import get_depth_of_rel_path, has_image_extension + + +class GitRepoUtils: + def __init__(self, abs_repo_path: str) -> None: + from git import Repo + + if not Path(abs_repo_path).is_absolute(): + raise ValueError('The path must be absolute') + + self.repo_path = Path(abs_repo_path) + try: + self.repo = Repo(self.repo_path) + except Exception: + logger.warning(f'Could not find git repository at {abs_repo_path}.') + raise Exception( + 'Could not find any git repository in the working directory.' + ) + + def get_all_absolute_tracked_files(self, depth: int | None = None) -> list[str]: + return [ + str(self.repo_path / item.path) + for item in self.repo.tree().traverse() + if item.type == 'blob' + and (not depth or get_depth_of_rel_path(item.path) <= depth) + ] + + def get_all_relative_tracked_files(self, depth: int | None = None) -> list[str]: + return [ + item.path + for item in self.repo.tree().traverse() + if item.type == 'blob' + and (not depth or get_depth_of_rel_path(item.path) <= depth) + ] + + def get_all_absolute_staged_files(self) -> list[str]: + return [ + str(self.repo_path / item.a_path) for item in self.repo.index.diff('HEAD') + ] + + def get_absolute_tracked_files_in_directory( + self, rel_dir_path: str, depth: int | None = None + ) -> list[str]: + rel_dir_path = rel_dir_path.rstrip('/') + return [ + str(self.repo_path / item.path) + for item in self.repo.tree().traverse() + if item.path.startswith(rel_dir_path + '/') + and item.type == 'blob' + and (not depth or get_depth_of_rel_path(item.path) <= depth) + ] + + +def get_modified_time(abs_path: str) -> int: + if not Path(abs_path).exists(): + return -1 + + return int(Path(abs_path).stat().st_mtime) + + +def read_text(abs_path: str) -> str: + if has_image_extension(abs_path): + return '' # Not support image files yet! + + with open(abs_path, 'r') as f: + return f.read() diff --git a/openhands_aci/utils/path.py b/openhands_aci/utils/path.py new file mode 100644 index 0000000..789a212 --- /dev/null +++ b/openhands_aci/utils/path.py @@ -0,0 +1,24 @@ +from pathlib import Path + + +class PathUtils: + def __init__(self, root: str) -> None: + self.root = root + + def get_absolute_path_str(self, rel_path: str) -> str: + return str(Path(self.root).joinpath(rel_path).resolve()) + + def get_relative_path_str(self, abs_path: str) -> str: + return str(Path(abs_path).relative_to(self.root)) + + def get_depth_from_root(self, abs_path: str) -> int: + return len(Path(abs_path).relative_to(self.root).parts) + + +def has_image_extension(path: str) -> bool: + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'} + return Path(path).suffix in IMAGE_EXTENSIONS + + +def get_depth_of_rel_path(rel_path: str) -> int: + return len(Path(rel_path).parts) diff --git a/openhands_aci/utils/shell.py b/openhands_aci/utils/shell.py index 51671d4..aa431e3 100644 --- a/openhands_aci/utils/shell.py +++ b/openhands_aci/utils/shell.py @@ -2,8 +2,8 @@ import subprocess import time +from openhands_aci.core.results import maybe_truncate from openhands_aci.editor.config import MAX_RESPONSE_LEN_CHAR -from openhands_aci.editor.results import maybe_truncate def run_shell_cmd( diff --git a/poetry.lock b/poetry.lock index 2dc5f28..6eb98f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1725,6 +1725,106 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "rapidfuzz" +version = "3.11.0" +description = "rapid fuzzy string matching" +optional = false +python-versions = ">=3.9" +files = [ + {file = "rapidfuzz-3.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb8a54543d16ab1b69e2c5ed96cabbff16db044a50eddfc028000138ca9ddf33"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:231c8b2efbd7f8d2ecd1ae900363ba168b8870644bb8f2b5aa96e4a7573bde19"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54e7f442fb9cca81e9df32333fb075ef729052bcabe05b0afc0441f462299114"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:906f1f2a1b91c06599b3dd1be207449c5d4fc7bd1e1fa2f6aef161ea6223f165"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ed59044aea9eb6c663112170f2399b040d5d7b162828b141f2673e822093fa8"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cb1965a28b0fa64abdee130c788a0bc0bb3cf9ef7e3a70bf055c086c14a3d7e"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b488b244931d0291412917e6e46ee9f6a14376625e150056fe7c4426ef28225"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f0ba13557fec9d5ffc0a22826754a7457cc77f1b25145be10b7bb1d143ce84c6"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3871fa7dfcef00bad3c7e8ae8d8fd58089bad6fb21f608d2bf42832267ca9663"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:b2669eafee38c5884a6e7cc9769d25c19428549dcdf57de8541cf9e82822e7db"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:ffa1bb0e26297b0f22881b219ffc82a33a3c84ce6174a9d69406239b14575bd5"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:45b15b8a118856ac9caac6877f70f38b8a0d310475d50bc814698659eabc1cdb"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win32.whl", hash = "sha256:22033677982b9c4c49676f215b794b0404073f8974f98739cb7234e4a9ade9ad"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:be15496e7244361ff0efcd86e52559bacda9cd975eccf19426a0025f9547c792"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win_arm64.whl", hash = "sha256:714a7ba31ba46b64d30fccfe95f8013ea41a2e6237ba11a805a27cdd3bce2573"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8724a978f8af7059c5323d523870bf272a097478e1471295511cf58b2642ff83"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b63cb1f2eb371ef20fb155e95efd96e060147bdd4ab9fc400c97325dfee9fe1"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82497f244aac10b20710448645f347d862364cc4f7d8b9ba14bd66b5ce4dec18"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:339607394941801e6e3f6c1ecd413a36e18454e7136ed1161388de674f47f9d9"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84819390a36d6166cec706b9d8f0941f115f700b7faecab5a7e22fc367408bc3"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eea8d9e20632d68f653455265b18c35f90965e26f30d4d92f831899d6682149b"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b659e1e2ea2784a9a397075a7fc395bfa4fe66424042161c4bcaf6e4f637b38"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1315cd2a351144572e31fe3df68340d4b83ddec0af8b2e207cd32930c6acd037"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a7743cca45b4684c54407e8638f6d07b910d8d811347b9d42ff21262c7c23245"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5bb636b0150daa6d3331b738f7c0f8b25eadc47f04a40e5c23c4bfb4c4e20ae3"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:42f4dd264ada7a9aa0805ea0da776dc063533917773cf2df5217f14eb4429eae"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:51f24cb39e64256221e6952f22545b8ce21cacd59c0d3e367225da8fc4b868d8"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win32.whl", hash = "sha256:aaf391fb6715866bc14681c76dc0308f46877f7c06f61d62cc993b79fc3c4a2a"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:ebadd5b8624d8ad503e505a99b8eb26fe3ea9f8e9c2234e805a27b269e585842"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win_arm64.whl", hash = "sha256:d895998fec712544c13cfe833890e0226585cf0391dd3948412441d5d68a2b8c"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f382fec4a7891d66fb7163c90754454030bb9200a13f82ee7860b6359f3f2fa8"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dfaefe08af2a928e72344c800dcbaf6508e86a4ed481e28355e8d4b6a6a5230e"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92ebb7c12f682b5906ed98429f48a3dd80dd0f9721de30c97a01473d1a346576"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a1b3ebc62d4bcdfdeba110944a25ab40916d5383c5e57e7c4a8dc0b6c17211a"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c6d7fea39cb33e71de86397d38bf7ff1a6273e40367f31d05761662ffda49e4"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99aebef8268f2bc0b445b5640fd3312e080bd17efd3fbae4486b20ac00466308"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4469307f464ae3089acf3210b8fc279110d26d10f79e576f385a98f4429f7d97"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:eb97c53112b593f89a90b4f6218635a9d1eea1d7f9521a3b7d24864228bbc0aa"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ef8937dae823b889c0273dfa0f0f6c46a3658ac0d851349c464d1b00e7ff4252"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d95f9e9f3777b96241d8a00d6377cc9c716981d828b5091082d0fe3a2924b43e"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:b1d67d67f89e4e013a5295e7523bc34a7a96f2dba5dd812c7c8cb65d113cbf28"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d994cf27e2f874069884d9bddf0864f9b90ad201fcc9cb2f5b82bacc17c8d5f2"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win32.whl", hash = "sha256:ba26d87fe7fcb56c4a53b549a9e0e9143f6b0df56d35fe6ad800c902447acd5b"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:b1f7efdd7b7adb32102c2fa481ad6f11923e2deb191f651274be559d56fc913b"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win_arm64.whl", hash = "sha256:ed78c8e94f57b44292c1a0350f580e18d3a3c5c0800e253f1583580c1b417ad2"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e60814edd0c9b511b5f377d48b9782b88cfe8be07a98f99973669299c8bb318a"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f28952da055dbfe75828891cd3c9abf0984edc8640573c18b48c14c68ca5e06"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e8f93bc736020351a6f8e71666e1f486bb8bd5ce8112c443a30c77bfde0eb68"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76a4a11ba8f678c9e5876a7d465ab86def047a4fcc043617578368755d63a1bc"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc0e0d41ad8a056a9886bac91ff9d9978e54a244deb61c2972cc76b66752de9c"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e8ea35f2419c7d56b3e75fbde2698766daedb374f20eea28ac9b1f668ef4f74"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd340bbd025302276b5aa221dccfe43040c7babfc32f107c36ad783f2ffd8775"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:494eef2c68305ab75139034ea25328a04a548d297712d9cf887bf27c158c388b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5a167344c1d6db06915fb0225592afdc24d8bafaaf02de07d4788ddd37f4bc2f"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:8c7af25bda96ac799378ac8aba54a8ece732835c7b74cfc201b688a87ed11152"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d2a0f7e17f33e7890257367a1662b05fecaf56625f7dbb6446227aaa2b86448b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4d0d26c7172bdb64f86ee0765c5b26ea1dc45c52389175888ec073b9b28f4305"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win32.whl", hash = "sha256:6ad02bab756751c90fa27f3069d7b12146613061341459abf55f8190d899649f"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:b1472986fd9c5d318399a01a0881f4a0bf4950264131bb8e2deba9df6d8c362b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:c408f09649cbff8da76f8d3ad878b64ba7f7abdad1471efb293d2c075e80c822"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bac4873f6186f5233b0084b266bfb459e997f4c21fc9f029918f44a9eccd304"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f9f12c2d0aa52b86206d2059916153876a9b1cf9dfb3cf2f344913167f1c3d4"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dd501de6f7a8f83557d20613b58734d1cb5f0be78d794cde64fe43cfc63f5f2"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4416ca69af933d4a8ad30910149d3db6d084781d5c5fdedb713205389f535385"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f0821b9bdf18c5b7d51722b906b233a39b17f602501a966cfbd9b285f8ab83cd"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0edecc3f90c2653298d380f6ea73b536944b767520c2179ec5d40b9145e47aa"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4513dd01cee11e354c31b75f652d4d466c9440b6859f84e600bdebfccb17735a"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d9727b85511b912571a76ce53c7640ba2c44c364e71cef6d7359b5412739c570"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ab9eab33ee3213f7751dc07a1a61b8d9a3d748ca4458fffddd9defa6f0493c16"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6b01c1ddbb054283797967ddc5433d5c108d680e8fa2684cf368be05407b07e4"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:3857e335f97058c4b46fa39ca831290b70de554a5c5af0323d2f163b19c5f2a6"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d98a46cf07c0c875d27e8a7ed50f304d83063e49b9ab63f21c19c154b4c0d08d"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win32.whl", hash = "sha256:c36539ed2c0173b053dafb221458812e178cfa3224ade0960599bec194637048"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:ec8d7d8567e14af34a7911c98f5ac74a3d4a743cd848643341fc92b12b3784ff"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win_arm64.whl", hash = "sha256:62171b270ecc4071be1c1f99960317db261d4c8c83c169e7f8ad119211fe7397"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:f06e3c4c0a8badfc4910b9fd15beb1ad8f3b8fafa8ea82c023e5e607b66a78e4"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fe7aaf5a54821d340d21412f7f6e6272a9b17a0cbafc1d68f77f2fc11009dcd5"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25398d9ac7294e99876a3027ffc52c6bebeb2d702b1895af6ae9c541ee676702"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a52eea839e4bdc72c5e60a444d26004da00bb5bc6301e99b3dde18212e41465"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c87319b0ab9d269ab84f6453601fd49b35d9e4a601bbaef43743f26fabf496c"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3048c6ed29d693fba7d2a7caf165f5e0bb2b9743a0989012a98a47b975355cca"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b04f29735bad9f06bb731c214f27253bd8bedb248ef9b8a1b4c5bde65b838454"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7864e80a0d4e23eb6194254a81ee1216abdc53f9dc85b7f4d56668eced022eb8"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3794df87313dfb56fafd679b962e0613c88a293fd9bd5dd5c2793d66bf06a101"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d71da0012face6f45432a11bc59af19e62fac5a41f8ce489e80c0add8153c3d1"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff38378346b7018f42cbc1f6d1d3778e36e16d8595f79a312b31e7c25c50bd08"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6668321f90aa02a5a789d4e16058f2e4f2692c5230252425c3532a8a62bc3424"}, + {file = "rapidfuzz-3.11.0.tar.gz", hash = "sha256:a53ca4d3f52f00b393fab9b5913c5bafb9afc27d030c8a1db1283da6917a860f"}, +] + +[package.extras] +all = ["numpy"] + [[package]] name = "referencing" version = "0.35.1" @@ -2490,4 +2590,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "0e7374c6ab359765af97ca2f296aa6bafcfc8cae8559dcbf420f635e55e7bc53" +content-hash = "6869ca59b7be433c3813460bc7e5d45e25332d6e6d9ac3d7ad4668dc4b9443ca" diff --git a/pyproject.toml b/pyproject.toml index 1f654e0..f0224c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,10 @@ packages = [ { include = "openhands_aci/**/*" } ] +[tool.poetry.scripts] +oh-editor = "openhands_aci.editor.cli:main" +oh-nav = "openhands_aci.navigator.cli:main" + [tool.poetry.dependencies] python = "^3.12" numpy = "*" @@ -22,6 +26,8 @@ grep-ast = "0.3.3" diskcache = "^5.6.3" flake8 = "*" whatthepatch = "^1.0.6" +rapidfuzz = "*" + [tool.poetry.group.dev.dependencies] diff --git a/tests/integration/test_oh_editor.py b/tests/integration/test_oh_editor.py index 44c6a73..4fac36e 100644 --- a/tests/integration/test_oh_editor.py +++ b/tests/integration/test_oh_editor.py @@ -1,12 +1,12 @@ import pytest -from openhands_aci.editor.editor import OHEditor -from openhands_aci.editor.exceptions import ( - EditorToolParameterInvalidError, - EditorToolParameterMissingError, +from openhands_aci.core.exceptions import ( + MultiCommandToolParameterMissingError, ToolError, + ToolParameterInvalidError, ) -from openhands_aci.editor.results import CLIResult, ToolResult +from openhands_aci.core.results import CLIResult, ToolResult +from openhands_aci.editor.editor import OHEditor @pytest.fixture @@ -67,7 +67,7 @@ def test_create_with_empty_string(editor): def test_create_with_none_file_text(editor): editor, test_file = editor new_file = test_file.parent / 'none_content.txt' - with pytest.raises(EditorToolParameterMissingError) as exc_info: + with pytest.raises(MultiCommandToolParameterMissingError) as exc_info: editor(command='create', path=str(new_file), file_text=None) assert 'file_text' in str(exc_info.value.message) @@ -221,7 +221,7 @@ def test_str_replace_with_empty_string(editor): def test_str_replace_with_none_old_str(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterMissingError) as exc_info: + with pytest.raises(MultiCommandToolParameterMissingError) as exc_info: editor( command='str_replace', path=str(test_file), @@ -275,7 +275,7 @@ def test_insert_with_linting(editor): def test_insert_invalid_line(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterInvalidError) as exc_info: + with pytest.raises(ToolParameterInvalidError) as exc_info: editor( command='insert', path=str(test_file), @@ -304,7 +304,7 @@ def test_insert_with_empty_string(editor): def test_insert_with_none_new_str(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterMissingError) as exc_info: + with pytest.raises(MultiCommandToolParameterMissingError) as exc_info: editor( command='insert', path=str(test_file), @@ -333,25 +333,25 @@ def test_undo_edit(editor): def test_validate_path_invalid(editor): editor, test_file = editor invalid_file = test_file.parent / 'nonexistent.txt' - with pytest.raises(EditorToolParameterInvalidError): + with pytest.raises(ToolParameterInvalidError): editor(command='view', path=str(invalid_file)) def test_create_existing_file_error(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterInvalidError): + with pytest.raises(ToolParameterInvalidError): editor(command='create', path=str(test_file), file_text='New content') def test_str_replace_missing_old_str(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterMissingError): + with pytest.raises(MultiCommandToolParameterMissingError): editor(command='str_replace', path=str(test_file), new_str='sample') def test_str_replace_new_str_and_old_str_same(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterInvalidError) as exc_info: + with pytest.raises(ToolParameterInvalidError) as exc_info: editor( command='str_replace', path=str(test_file), @@ -366,7 +366,7 @@ def test_str_replace_new_str_and_old_str_same(editor): def test_insert_missing_line_param(editor): editor, test_file = editor - with pytest.raises(EditorToolParameterMissingError): + with pytest.raises(MultiCommandToolParameterMissingError): editor(command='insert', path=str(test_file), new_str='Missing insert line') diff --git a/tests/integration/test_symbol_navigator.py b/tests/integration/test_symbol_navigator.py new file mode 100644 index 0000000..8b422be --- /dev/null +++ b/tests/integration/test_symbol_navigator.py @@ -0,0 +1,109 @@ +import os +import tempfile +from pathlib import Path + +import pytest + +from openhands_aci.core.exceptions import ToolParameterInvalidError +from openhands_aci.navigator.navigator import SymbolNavigator + + +@pytest.fixture +def temp_git_repo(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a git repo + os.chdir(temp_dir) + os.system('git init') + os.system('git config user.name "test"') + os.system('git config user.email "test@test.com"') + + # Create some Python files with known symbols + # main.py defines MyClass and uses utils.helper_func + main_content = """ +from utils import helper_func + +class MyClass: + def __init__(self): + self.value = 42 + + def process(self): + return helper_func(self.value) +""" + Path('main.py').write_text(main_content) + + # utils.py defines helper_func and uses MyClass + utils_content = """ +from main import MyClass + +def helper_func(x): + obj = MyClass() + return x + obj.value +""" + Path('utils.py').write_text(utils_content) + + # Add files to git + os.system('git add *.py') + os.system('git commit -m "Initial commit"') + + yield temp_dir + + +def test_jump_to_definition_finds_class(temp_git_repo): + navigator = SymbolNavigator() + result = navigator(command='jump_to_definition', symbol_name='MyClass') + + assert 'Definition(s) of `MyClass`:' in result.output + assert 'main.py:' in result.output + assert 'class MyClass:' in result.output + + +def test_jump_to_definition_finds_function(temp_git_repo): + navigator = SymbolNavigator() + result = navigator(command='jump_to_definition', symbol_name='helper_func') + + assert 'Definition(s) of `helper_func`:' in result.output + assert 'utils.py:' in result.output + assert 'def helper_func(x):' in result.output + + +def test_find_references_finds_class_usages(temp_git_repo): + navigator = SymbolNavigator() + result = navigator(command='find_references', symbol_name='MyClass') + + assert 'References to `MyClass`:' in result.output + assert 'utils.py:' in result.output + assert 'obj = MyClass()' in result.output + + +def test_find_references_finds_function_usages(temp_git_repo): + navigator = SymbolNavigator() + result = navigator(command='find_references', symbol_name='helper_func') + + assert 'References to `helper_func`:' in result.output + assert 'main.py:' in result.output + assert 'return helper_func(self.value)' in result.output + + +def test_fuzzy_matching_for_nonexistent_symbol(temp_git_repo): + navigator = SymbolNavigator() + result = navigator(command='jump_to_definition', symbol_name='MyClss') # Typo + + assert 'No definitions found for `MyClss`' in result.output + assert 'Maybe you meant one of these:' in result.output + assert 'MyClass' in result.output + + +def test_empty_symbol_raises_error(temp_git_repo): + navigator = SymbolNavigator() + with pytest.raises(ToolParameterInvalidError) as exc_info: + navigator(command='jump_to_definition', symbol_name='') + + assert 'Symbol name cannot be empty' in str(exc_info.value) + + +def test_invalid_command_raises_error(temp_git_repo): + navigator = SymbolNavigator() + with pytest.raises(Exception) as exc_info: + navigator(command='invalid_command', symbol_name='MyClass') + + assert 'Unrecognized command' in str(exc_info.value) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index a12324e..891980c 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -1,9 +1,9 @@ import pytest -from openhands_aci.editor.exceptions import ( - EditorToolParameterInvalidError, - EditorToolParameterMissingError, +from openhands_aci.core.exceptions import ( + MultiCommandToolParameterMissingError, ToolError, + ToolParameterInvalidError, ) @@ -18,8 +18,8 @@ def test_editor_tool_parameter_missing_error(): """Test EditorToolParameterMissingError for missing parameter error message.""" command = 'str_replace' parameter = 'old_str' - with pytest.raises(EditorToolParameterMissingError) as exc_info: - raise EditorToolParameterMissingError(command, parameter) + with pytest.raises(MultiCommandToolParameterMissingError) as exc_info: + raise MultiCommandToolParameterMissingError(command, parameter) assert exc_info.value.command == command assert exc_info.value.parameter == parameter assert ( @@ -33,8 +33,8 @@ def test_editor_tool_parameter_invalid_error_with_hint(): parameter = 'timeout' value = -10 hint = 'Must be a positive integer.' - with pytest.raises(EditorToolParameterInvalidError) as exc_info: - raise EditorToolParameterInvalidError(parameter, value, hint) + with pytest.raises(ToolParameterInvalidError) as exc_info: + raise ToolParameterInvalidError(parameter, value, hint) assert exc_info.value.parameter == parameter assert exc_info.value.value == value assert exc_info.value.message == f'Invalid `{parameter}` parameter: {value}. {hint}' @@ -44,8 +44,8 @@ def test_editor_tool_parameter_invalid_error_without_hint(): """Test EditorToolParameterInvalidError without hint.""" parameter = 'timeout' value = -10 - with pytest.raises(EditorToolParameterInvalidError) as exc_info: - raise EditorToolParameterInvalidError(parameter, value) + with pytest.raises(ToolParameterInvalidError) as exc_info: + raise ToolParameterInvalidError(parameter, value) assert exc_info.value.parameter == parameter assert exc_info.value.value == value assert exc_info.value.message == f'Invalid `{parameter}` parameter: {value}.' diff --git a/tests/unit/test_file_utils.py b/tests/unit/test_file_utils.py new file mode 100644 index 0000000..8eae388 --- /dev/null +++ b/tests/unit/test_file_utils.py @@ -0,0 +1,204 @@ +from pathlib import Path + +import pytest +from git import Repo + +from openhands_aci.utils.file import GitRepoUtils, read_text + + +@pytest.fixture +def temp_git_repo(tmp_path): + """Create a temporary git repository with some test files.""" + repo_dir = tmp_path / 'test_repo' + repo_dir.mkdir() + + # Initialize git repo + repo = Repo.init(repo_dir) + + # Create some test files and directories + (repo_dir / 'file1.txt').write_text('content1') + (repo_dir / 'file2.txt').write_text('content2') + + # Create a subdirectory with files + test_dir = repo_dir / 'test_dir' + test_dir.mkdir() + (test_dir / 'file3.txt').write_text('content3') + (test_dir / 'file4.txt').write_text('content4') + + # Stage and commit initial files + repo.index.add( + ['file1.txt', 'file2.txt', 'test_dir/file3.txt', 'test_dir/file4.txt'] + ) + repo.index.commit('Initial commit') + + # Create an unstaged file + (repo_dir / 'unstaged.txt').write_text('unstaged') + + # Create a staged but not committed file + (repo_dir / 'staged.txt').write_text('staged') + repo.index.add(['staged.txt']) + + return repo_dir + + +@pytest.fixture +def git_utils(temp_git_repo): + """Create a GitRepoUtils instance with the temporary repository.""" + return GitRepoUtils(str(temp_git_repo)) + + +def test_init_valid_repo(temp_git_repo): + """Test initialization with a valid repository.""" + utils = GitRepoUtils(str(temp_git_repo)) + assert utils.repo_path == Path(temp_git_repo) + assert utils.repo is not None + + +def test_get_all_abs_tracked_files(git_utils): + """Test getting all tracked files.""" + tracked_files = git_utils.get_all_absolute_tracked_files() + + # Check if all expected files are present + expected_files = { + 'file1.txt', + 'file2.txt', + 'test_dir/file3.txt', + 'test_dir/file4.txt', + } + # Convert to absolute paths + expected_files = {str(git_utils.repo_path / file) for file in expected_files} + assert set(tracked_files) == expected_files + + # Verify that unstaged and untracked files are not included + assert str(git_utils.repo_path / 'unstaged.txt') not in tracked_files + + +def test_get_all_rel_tracked_files(git_utils): + """Test getting all relative tracked files.""" + rel_tracked_files = git_utils.get_all_relative_tracked_files() + + # Check if all expected files are present + expected_files = { + 'file1.txt', + 'file2.txt', + 'test_dir/file3.txt', + 'test_dir/file4.txt', + } + assert set(rel_tracked_files) == expected_files + + # Verify that unstaged and untracked files are not included + assert str(git_utils.repo_path / 'unstaged.txt') not in rel_tracked_files + + +def test_get_all_abs_tracked_files_with_depth(git_utils): + """Test getting all tracked files with a depth limit of 1.""" + tracked_files = git_utils.get_all_absolute_tracked_files(depth=1) + + # Only files in the root directory should be returned + expected_files = { + 'file1.txt', + 'file2.txt', + } + # Convert to absolute paths + expected_files = {str(git_utils.repo_path / file) for file in expected_files} + assert set(tracked_files) == expected_files + + # Ensure that files in the 'test_dir' are not included + assert str(git_utils.repo_path / 'test_dir/file3.txt') not in tracked_files + assert str(git_utils.repo_path / 'test_dir/file4.txt') not in tracked_files + + +def test_get_all_abs_staged_files(git_utils): + """Test getting all staged files.""" + staged_files = git_utils.get_all_absolute_staged_files() + + # Only staged.txt should be in the staged files list + assert str(git_utils.repo_path / 'staged.txt') in staged_files + assert str(git_utils.repo_path / 'file1.txt') not in staged_files + assert str(git_utils.repo_path / 'unstaged.txt') not in staged_files + + +def test_get_abs_tracked_files_in_directory(git_utils): + """Test getting tracked files in a specific directory.""" + # Test files in test_dir + test_dir_files = git_utils.get_absolute_tracked_files_in_directory('test_dir') + expected_files = {'test_dir/file3.txt', 'test_dir/file4.txt'} + # Convert to absolute paths + expected_files = {str(git_utils.repo_path / file) for file in expected_files} + assert set(test_dir_files) == expected_files + + # Test files in root directory (should be empty when specifying a non-existent directory) + nonexistent_dir_files = git_utils.get_absolute_tracked_files_in_directory( + 'nonexistent' + ) + assert len(nonexistent_dir_files) == 0 + + +def test_get_abs_tracked_files_in_directory_with_trailing_slash(git_utils): + """Test getting tracked files in a directory with trailing slash.""" + test_dir_files = git_utils.get_absolute_tracked_files_in_directory('test_dir/') + expected_files = {'test_dir/file3.txt', 'test_dir/file4.txt'} + # Convert to absolute paths + expected_files = {str(git_utils.repo_path / file) for file in expected_files} + assert set(test_dir_files) == expected_files + + +def test_get_abs_tracked_files_in_subdirectory(git_utils): + """Test getting tracked files in a subdirectory.""" + # Test files in test_dir + test_dir_files = git_utils.get_absolute_tracked_files_in_directory('test_dir') + expected_files = {'test_dir/file3.txt', 'test_dir/file4.txt'} + # Convert to absolute paths + expected_files = {str(git_utils.repo_path / file) for file in expected_files} + assert set(test_dir_files) == expected_files + + +def test_empty_directory(git_utils, temp_git_repo): + """Test getting tracked files in an empty directory.""" + # Create an empty directory + empty_dir = temp_git_repo / 'empty_dir' + empty_dir.mkdir() + + files = git_utils.get_absolute_tracked_files_in_directory('empty_dir') + assert len(files) == 0 + + +def test_read_text_with_regular_file(tmp_path): + # Test reading a regular text file + test_file = tmp_path / 'test.txt' + content = 'Hello, World!' + test_file.write_text(content) + + result = read_text(str(test_file)) + assert result == content + + +def test_read_text_with_image_file(tmp_path): + # Test reading an image file (should return empty string) + image_file = tmp_path / 'test.jpg' + image_file.write_bytes(b'fake image content') + + result = read_text(str(image_file)) + assert result == '' + + +@pytest.fixture +def temp_git_repo_in_subdir(tmp_path): + """Create a temporary git repository inside a subdirectory.""" + # Create the parent directory and the subdirectory + parent_dir = tmp_path / 'parent_dir' + sub_dir = parent_dir / 'sub_repo' + sub_dir.mkdir(parents=True) + + # Initialize git repo in the subdirectory + repo = Repo.init(sub_dir) + + # Create some test files in the subdirectory + (sub_dir / 'file1.txt').write_text('content1') + (sub_dir / 'file2.txt').write_text('content2') + + # Stage and commit initial files + repo.index.add(['file1.txt', 'file2.txt']) + repo.index.commit('Initial commit in subdirectory repo') + + return parent_dir diff --git a/tests/unit/test_path_utils.py b/tests/unit/test_path_utils.py new file mode 100644 index 0000000..f47769f --- /dev/null +++ b/tests/unit/test_path_utils.py @@ -0,0 +1,62 @@ +from pathlib import Path + +import pytest + +from openhands_aci.utils.path import ( + PathUtils, + get_depth_of_rel_path, + has_image_extension, +) + + +@pytest.fixture +def path_utils(): + return PathUtils(root='/home/user/project') + + +def test_get_absolute_path_str(path_utils): + relative_path = 'src/module/file.py' + expected_absolute = str( + Path('/home/user/project').joinpath(relative_path).resolve() + ) + + result = path_utils.get_absolute_path_str(relative_path) + + assert result == expected_absolute, f'Expected {expected_absolute}, got {result}' + + +def test_get_relative_path_str(path_utils): + absolute_path = '/home/user/project/src/module/file.py' + expected_relative = 'src/module/file.py' + + result = path_utils.get_relative_path_str(absolute_path) + + assert result == expected_relative, f'Expected {expected_relative}, got {result}' + + +def test_get_absolute_path_str_with_non_existing_path(path_utils): + relative_path = 'non_existent_folder/file.txt' + expected_absolute = str( + Path('/home/user/project').joinpath(relative_path).resolve() + ) + + result = path_utils.get_absolute_path_str(relative_path) + + assert result == expected_absolute, 'Unexpected result for a non-existing path' + + +def test_get_relative_path_str_with_non_project_path(path_utils): + with pytest.raises(ValueError): + path_utils.get_relative_path_str('/home/user/other_project/file.py') + + +def test_has_image_extension_png(): + assert has_image_extension('some/dir/image.png') is True + + +def test_get_depth_of_rel_path(): + assert get_depth_of_rel_path('folder/subfolder/file.txt') == 3 + assert get_depth_of_rel_path('folder/file.txt') == 2 + assert get_depth_of_rel_path('file.txt') == 1 + assert get_depth_of_rel_path('') == 0 + assert get_depth_of_rel_path('folder/subfolder/subsubfolder/') == 3 diff --git a/tests/unit/test_results_utils.py b/tests/unit/test_results_utils.py index e269e43..4cf3075 100644 --- a/tests/unit/test_results_utils.py +++ b/tests/unit/test_results_utils.py @@ -1,6 +1,6 @@ +from openhands_aci.core.results import ToolResult, maybe_truncate from openhands_aci.editor.config import MAX_RESPONSE_LEN_CHAR from openhands_aci.editor.prompts import CONTENT_TRUNCATED_NOTICE -from openhands_aci.editor.results import ToolResult, maybe_truncate def test_tool_result_bool(): diff --git a/tests/unit/test_ts_parser.py b/tests/unit/test_ts_parser.py new file mode 100644 index 0000000..4e00f33 --- /dev/null +++ b/tests/unit/test_ts_parser.py @@ -0,0 +1,140 @@ +import os +import tempfile + +import pytest + +from openhands_aci.tree_sitter.parser import TagKind, TreeSitterParser + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmp_dir: + yield tmp_dir + + +@pytest.fixture +def parser(temp_dir): + return TreeSitterParser(temp_dir) + + +def test_parse_python_function_definition(temp_dir, parser): + # Create a temporary Python file with a function definition + code = """def hello_world(): + print("Hello, World!") + return 42 + +def another_func(): + result = hello_world() + return result +""" + file_path = os.path.join(temp_dir, 'test.py') + with open(file_path, 'w') as f: + f.write(code) + + tags = parser.get_tags_from_file(file_path, 'test.py') + + # Should find both function definitions and references + def_tags = [t for t in tags if t.tag_kind == TagKind.DEF] + ref_tags = [t for t in tags if t.tag_kind == TagKind.REF] + def_with_body_tags = [t for t in tags if t.tag_kind == TagKind.DEF_WITH_BODY] + + assert len(def_tags) == 2 # hello_world and another_func + assert len(ref_tags) >= 1 # At least hello_world reference + assert len(def_with_body_tags) == 2 + + # Verify specific references + hello_world_refs = [t for t in ref_tags if t.node_content == 'hello_world'] + assert len(hello_world_refs) == 1 + + # Check first function definition + hello_def = next(t for t in def_tags if t.node_content == 'hello_world') + assert hello_def.start_line == 0 + assert hello_def.end_line > 0 # Should be updated with body end line + + # Check function reference + hello_ref = next(t for t in ref_tags if t.node_content == 'hello_world') + assert hello_ref.start_line == 5 + + +def test_parse_unsupported_file(temp_dir, parser): + # Create a file with unsupported extension + file_path = os.path.join(temp_dir, 'test.xyz') + with open(file_path, 'w') as f: + f.write('some content') + + tags = parser.get_tags_from_file(file_path, 'test.xyz') + assert len(tags) == 0 + + +def test_parse_empty_file(temp_dir, parser): + # Create an empty Python file + file_path = os.path.join(temp_dir, 'empty.py') + with open(file_path, 'w') as f: + f.write('') + + tags = parser.get_tags_from_file(file_path, 'empty.py') + assert len(tags) == 0 + + +def test_cache_functionality(temp_dir, parser): + # Create a Python file + file_path = os.path.join(temp_dir, 'cached.py') + with open(file_path, 'w') as f: + f.write('def test_func():\n pass') + + # First call should parse the file + tags1 = parser.get_tags_from_file(file_path, 'cached.py') + assert len(tags1) > 0 + + # Second call should use cache + tags2 = parser.get_tags_from_file(file_path, 'cached.py') + assert tags1 == tags2 + + # Modify file should invalidate cache + with open(file_path, 'w') as f: + f.write('def another_func():\n pass') + + # Use os.utime to explicitly set a new modification time + import time + + new_time = time.time() + 1 # 1 second in the future + os.utime(file_path, (new_time, new_time)) + + tags3 = parser.get_tags_from_file(file_path, 'cached.py') + + # Compare the function names to verify cache invalidation + func_names1 = {t.node_content for t in tags1 if t.tag_kind == TagKind.DEF} + func_names3 = {t.node_content for t in tags3 if t.tag_kind == TagKind.DEF} + assert func_names1 != func_names3 + + +def test_parse_python_class_definition(temp_dir, parser): + # Create a temporary Python file with a class definition + code = """class TestClass: + def method1(self): + pass + + def method2(self): + self.method1() +""" + file_path = os.path.join(temp_dir, 'test_class.py') + with open(file_path, 'w') as f: + f.write(code) + + tags = parser.get_tags_from_file(file_path, 'test_class.py') + + def_tags = [t for t in tags if t.tag_kind == TagKind.DEF] + ref_tags = [t for t in tags if t.tag_kind == TagKind.REF] + + # Should find class and both method definitions + assert len(def_tags) >= 3 + # Should find the method1 reference + assert len(ref_tags) >= 1 + + # Verify class definition + class_def = next(t for t in def_tags if t.node_content == 'TestClass') + assert class_def.start_line == 0 + + # Verify method reference + method_ref = next(t for t in ref_tags if t.node_content == 'method1') + assert method_ref.start_line == 5