Skip to content

Commit

Permalink
Merge branch 'princeton-nlp:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww authored Mar 20, 2024
2 parents 8653c7a + 8ff6e0b commit cbf19bd
Show file tree
Hide file tree
Showing 38 changed files with 562 additions and 186 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ Contact person: [Carlos E. Jimenez](http://www.carlosejimenez.com/) and [John Ya
## ✍️ Citation
If you find our work helpful, please use the following citations.
```
@inproceedings{jimenez2024swebench,
title={SWE-bench: Can Language Models Resolve Real-World GitHub Issues?},
author={Carlos E. Jimenez and John Yang and Alexander Wettig and Shunyu Yao and Kexin Pei and Ofir Press and Karthik Narasimhan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=VTF8yNQM66}
@inproceedings{
jimenez2024swebench,
title={{SWE}-bench: Can Language Models Resolve Real-world Github Issues?},
author={Carlos E Jimenez and John Yang and Alexander Wettig and Shunyu Yao and Kexin Pei and Ofir Press and Karthik R Narasimhan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=VTF8yNQM66}
}
```

Expand Down
12 changes: 7 additions & 5 deletions collect/build_dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3

import argparse
import json
import logging
import os
from typing import Dict, Optional
from typing import Optional


from utils import Repo, extract_patches, extract_problem_statement_and_hints
Expand All @@ -13,7 +15,7 @@
logger = logging.getLogger(__name__)


def create_instance(repo: Repo, pull: Dict) -> Dict:
def create_instance(repo: Repo, pull: dict) -> dict:
"""
Create a single task instance from a pull request, where task instance is:
Expand Down Expand Up @@ -43,7 +45,7 @@ def create_instance(repo: Repo, pull: Dict) -> Dict:
}


def is_valid_pull(pull: Dict) -> bool:
def is_valid_pull(pull: dict) -> bool:
"""
Check whether PR has an associated issue and is merged
Expand All @@ -59,7 +61,7 @@ def is_valid_pull(pull: Dict) -> bool:
return True


def is_valid_instance(instance: Dict) -> bool:
def is_valid_instance(instance: dict) -> bool:
"""
Check whether task instance has all required fields for task instance creation
Expand All @@ -75,7 +77,7 @@ def is_valid_instance(instance: Dict) -> bool:
return True


def has_test_patch(instance: Dict) -> bool:
def has_test_patch(instance: dict) -> bool:
"""
Check whether task instance has a test suite
Expand Down
6 changes: 4 additions & 2 deletions collect/build_dataset_ft.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import argparse
import glob
import json
Expand Down Expand Up @@ -29,7 +31,7 @@ def main(instances_path: str, output_path: str, eval_path: str, seed: int):
# Gather Evaluation Set Task Instances
eval_instances = []
for x in glob.glob(os.path.join(eval_path, "*-task-instances.jsonl")):
with open(x, "r") as f:
with open(x) as f:
eval_instances.extend(f.readlines())
eval_instances = set(eval_instances)

Expand All @@ -39,7 +41,7 @@ def main(instances_path: str, output_path: str, eval_path: str, seed: int):
glob.glob(os.path.join(instances_path, "*-task-instances.jsonl.all"))
):
total_repos += 1
with open(dataset_path, "r") as f:
with open(dataset_path) as f:
lines = f.readlines()

# Remove data from evaluation dataset
Expand Down
2 changes: 2 additions & 0 deletions collect/cleanup/delete_gh_workflows.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import argparse
import os
import subprocess
Expand Down
2 changes: 2 additions & 0 deletions collect/cleanup/remove_envs.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import os
import subprocess
Expand Down
8 changes: 3 additions & 5 deletions collect/get_tasks_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from build_dataset import main as build_dataset
from print_pulls import main as print_pulls
from multiprocessing import Pool
from typing import Dict, List


load_dotenv()


def split_instances(input_list: List, n: int) -> List:
def split_instances(input_list: list, n: int) -> list:
"""
Split a list into n approximately equal length sublists
Expand All @@ -33,7 +32,7 @@ def split_instances(input_list: List, n: int) -> List:
return result


def construct_data_files(data: Dict):
def construct_data_files(data: dict):
"""
Logic for combining multiple .all PR files into a single fine tuning dataset
Expand Down Expand Up @@ -77,10 +76,9 @@ def construct_data_files(data: Dict):
)
except Exception as e:
print(f"Something went wrong for {repo}, skipping: {e}")
pass


def main(repos: List, path_prs: str, path_tasks: str):
def main(repos: list, path_prs: str, path_tasks: str):
"""
Spawns multiple threads given multiple GitHub tokens for collecting fine tuning data
Expand Down
19 changes: 15 additions & 4 deletions collect/get_top_pypi.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
#!/usr/bin/env python3

import os, json
import argparse

from bs4 import BeautifulSoup
from ghapi.core import GhApi
from selenium import webdriver
from selenium.webdriver.common.by import By


api = GhApi(token="GITHUB TOKEN HERE")
gh_token = os.environ.get("GITHUB_TOKEN")
if not gh_token:
msg = "Please set the GITHUB_TOKEN environment variable."
raise ValueError(msg)
api = GhApi(token="gh_token")


def get_package_stats(data_tasks, f):
Expand All @@ -21,7 +28,7 @@ def get_package_stats(data_tasks, f):
content = None
access_type = "w"
if os.path.exists(f):
with open(f, "r") as fp_:
with open(f) as fp_:
content = fp_.read()
access_type = "a"
fp_.close()
Expand Down Expand Up @@ -84,16 +91,20 @@ def get_package_stats(data_tasks, f):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-repos", help="Maximum number of repos to get", type=int, default=5000)
args = parser.parse_args()

# Start selenium driver to get top 5000 pypi page
url_top_pypi = "https://hugovk.github.io/top-pypi-packages/"
driver = webdriver.Chrome()
driver.get(url_top_pypi)
button = driver.find_element(By.CSS_SELECTOR, 'button[ng-click="show(5000)"]')
button = driver.find_element(By.CSS_SELECTOR, 'button[ng-click="show(8000)"]')
button.click()

# Retrieve HTML for packages from page
soup = BeautifulSoup(driver.page_source, "html.parser")
package_list = soup.find("div", {"class": "list"})
packages = package_list.find_all("a", class_="ng-scope")

get_package_stats(packages, "pypi_rankings.jsonl")
get_package_stats(packages[:args.max_repos], "pypi_rankings.jsonl")
13 changes: 13 additions & 0 deletions collect/make_lite/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## SWE-bench *lite*
This directory contains the scripts used to make the *lite* version of SWE-bench. The *lite* version is a subset of the full SWE-bench, that filters out certain types of instances to make evaluation on SWE-bench a bit cheaper and more accessible.

SWE-bench lite consists of 300 test instances and 23 development instances; both subsets of the full SWE-bench splits. We filter the full SWE-bench according to the following criteria to get *lite*:
- We remove instances with images, external hyperlinks, references to specific commit shas and references to other pull requests or issues.
- We remove instances that have fewer than 40 words in the problem statement.
- We remove instances that edit more than 1 file.
- We remove instances where the gold patch has more than 3 edit hunks (see [patch](https://man7.org/linux/man-pages/man1/patch.1.html)).
- We remove instances that create or remove files.
- We remove instances that contain tests with error message checks.
- Finally, we sample 300 test instances and 23 development instances from the remaining instances.

See `make_lite.py` for the script that makes the *lite* version of SWE-bench, or download the *lite* version from the Hugging Face datasets [princeton-nlp/SWE-bench_lite](https://huggingface.co/datasets/princeton-nlp/SWE-bench_lite)
152 changes: 152 additions & 0 deletions collect/make_lite/criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import re
import requests

from unidiff import PatchSet


def contains_git_commit_hash(text: str) -> bool:
"""
Returns True if the text contains a git commit hash (40 character SHA-1 hash).
* Excludes commit hashes that are part of a URL.
"""
pattern_git_commit_hash = re.compile(r'(?<!/)\b[0-9a-f]{40}\b')
if re.search(pattern_git_commit_hash, text) is not None:
return True
pattern_django_commit_hash = re.compile(r'\[[0-9a-f]{23}\]')
if re.search(pattern_django_commit_hash, text) is not None:
return True
return False


def contains_hyperlinks(text: str, repo: str = None) -> bool:
"""
Returns True if the text contains a URL. Excludes URLs that are part of the repository.
"""
if repo:
repo_prefix = f"http://github.com/{repo}"
pattern_repo = re.escape(repo_prefix)
# Adding a negative lookahead assertion to ensure URLs starting with the repository prefix are excluded
pattern_urls = r'(?:https?://(?!{}).+)|(?:www\.(?!{}).+)'.format(pattern_repo, pattern_repo)
else:
pattern_urls = r'https?://(?:www\.)?\S+'

return bool(re.search(pattern_urls, text))


def contains_image(text: str) -> bool:
"""
Returns True if the text contains an image or video file extension.
"""
image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.svg', '.webp', '.ico', '.heif', '.bpg', '.avif']
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']

pattern_image = '|'.join(re.escape(ext) for ext in image_extensions)
pattern_video = '|'.join(re.escape(ext) for ext in video_extensions)

image_regex = re.compile(r'\b({})\b'.format(pattern_image), flags=re.IGNORECASE)
video_regex = re.compile(r'\b({})\b'.format(pattern_video), flags=re.IGNORECASE)

return image_regex.search(text) is not None or video_regex.search(text) is not None


def contains_issue_reference(text: str, repo: str) -> bool:
"""
Returns True if text (problem statement) contains a reference to another issue (e.g. #1234).
"""
# Look for GitHub style issue references
pattern_issue_ref = re.compile(r"(\w+)\s+\#(\d+)")
keywords = {
"close", "closes", "closed",
"fix", "fixes", "fixed",
"resolve", "resolves", "resolved",
}
references = dict(pattern_issue_ref.findall(text))
if references:
for word, _ in references.items():
if word.lower() in keywords:
return True

# Look for GitLab style issue references
pattern_gitlab = re.compile(r"https?:\/\/gitlab.com\/(.*)\/issues")
if re.search(pattern_gitlab, text):
return True

# Look for GitHub `#` style references + verify if the issue exists
pattern_issue_ref = re.compile(r'#\d+')
matches = pattern_issue_ref.findall(text)
for match in matches:
url = f"http://github.com/{repo}/issues/{match[1:]}"
if repo == "django/django":
url = f"https://code.djangoproject.com/ticket/{match[1:]}"
if requests.get(url).status_code == 200:
return True

return False


def contains_non_modified_files(patch_text: str) -> bool:
"""
Returns True if the patch contains files that are not modified.
"""
patch = PatchSet(patch_text)
return len(patch.removed_files) > 0 or len(patch.added_files) > 0


def contains_pytest_match_arg(patch_test_text: str) -> bool:
"""
Returns True if the test patch contains a pytest.raises() call with a match argument.
"""
if any([x in patch_test_text for x in [
'pytest.raises',
'pytest.warns',
'pytest.deprecated_call',
]]):
return 'match' in patch_test_text
# Django style assertions:
if any([x in patch_test_text for x in [
'assertOutput',
'assertRaises',
'checks.Error',
]]):
return True
return False


def leq_n_code_lines(patch_text: str, n: int = 25) -> bool:
"""
Returns True if the patch has at most n lines of code changed.
"""
lines = 0
patch = PatchSet(patch_text)
for file in patch:
for hunk in file:
lines += hunk.added
lines += hunk.removed
return lines <= n


def leq_n_files(patch_text: str, n: int = 1) -> bool:
"""
Returns True if the patch has at most n files.
"""
patch = PatchSet(patch_text)
return len(patch.modified_files) <= n


def leq_n_hunks(patch_text: str, n: int = 3) -> bool:
"""
Returns True if the patch has at most n hunks.
"""
patch = PatchSet(patch_text)
num_hunks = sum([
len([h for h in f])
for f in patch.modified_files
])
return num_hunks <= n and num_hunks > 0


def leq_n_words(text: str, n: int = 50) -> bool:
"""
Returns True if the text has at most n words.
"""
return len(text.split()) <= n
Loading

0 comments on commit cbf19bd

Please sign in to comment.