Skip to content

Commit

Permalink
Refactoring SQL of upgrade-db
Browse files Browse the repository at this point in the history
  • Loading branch information
EdgesFTW committed Nov 17, 2023
1 parent b7d4557 commit 5ca70c5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
5 changes: 1 addition & 4 deletions bin/upgrade-db
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ def main() -> None:
suffix = args.suffix or "_upgrade"

for dbfile in args.dbfiles:
new_db = upgrade_db(
dbfile, suffix=suffix, overwrite=args.overwrite
)
upgrade_db(dbfile, suffix=suffix, overwrite=args.overwrite)

new_db.commit()

if __name__ == "__main__":
main()
25 changes: 15 additions & 10 deletions logpyle/runalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,20 @@ def my_sprintf(format: str, arg: str) -> str:
# }}}


def is_gathered(conn: sqlite3.Connection):
gathered = False
# get a list of tables with the name of 'runs'
res = list(conn.execute("""
SELECT name
FROM sqlite_master
WHERE type='table' AND name='runs'
"""))
if len(res) == 1:
gathered = True

return gathered


def auto_gather(filenames: List[str]) -> sqlite3.Connection:
# allow for creating ungathered files.
# Check if database has been gathered, if not, create one in memory
Expand All @@ -467,16 +481,7 @@ def auto_gather(filenames: List[str]) -> sqlite3.Connection:
# check if any of the provided files have been gathered
for f in filenames:
db = sqlite3.connect(f)
cur = db.cursor()

# get a list of tables with the name of 'runs'
res = list(cur.execute("""
SELECT name
FROM sqlite_master
WHERE type='table' AND name='runs'
"""))
# there exists a table with the name of 'runs'
if len(res) == 1:
if is_gathered(db):
gathered = True

if gathered:
Expand Down
76 changes: 35 additions & 41 deletions logpyle/upgrade_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,61 @@


def upgrade_conn(conn: sqlite3.Connection) -> sqlite3.Connection:
from logpyle.runalyzer import is_gathered
tmp = conn.execute("select * from warnings").description
warning_columns = [col[0] for col in tmp]

# check if any of the provided files have been gathered
gathered = False
# get a list of tables with the name of 'runs'
res = list(conn.execute("""
SELECT name
FROM sqlite_master
WHERE type='table' AND name='runs'
"""))
if len(res) == 1:
gathered = True
# check if the provided connection has been gathered
gathered = is_gathered(conn)

# ensure that warnings table has unixtime column
if ("unixtime" not in warning_columns):
if "unixtime" not in warning_columns:
print("Adding a unixtime column in the warnings table")
conn.execute("""
ALTER TABLE warnings
ADD unixtime integer DEFAULT NULL;
ALTER TABLE warnings
ADD unixtime integer DEFAULT NULL;
""")

# ensure that warnings table has rank column
# nowhere to grab the rank of the process that generated
# the warning
if ("rank" not in warning_columns):
if "rank" not in warning_columns:
print("Adding a rank column in the warnings table")
conn.execute("""
ALTER TABLE warnings
ADD rank integer DEFAULT NULL;
ALTER TABLE warnings
ADD rank integer DEFAULT NULL;
""")

tables = [col[0] for col in conn.execute("""
SELECT name
FROM sqlite_master
WHERE type='table'
""")]

print("Ensuring a logging table exists")
if gathered:
conn.execute("""
CREATE TABLE IF NOT EXISTS logging (
run_id integer,
rank integer,
step integer,
unixtime integer,
level text,
message text,
filename text,
lineno integer
)""")
else:
if "logging" not in tables:
conn.execute("""
CREATE TABLE IF NOT EXISTS logging (
rank integer,
step integer,
unixtime integer,
level text,
message text,
filename text,
lineno integer
)""")
CREATE TABLE logging (
rank integer,
step integer,
unixtime integer,
level text,
message text,
filename text,
lineno integer
)""")
if gathered:
conn.execute("""
ALTER TABLE logging
ADD run_id integer;
""")

return conn


def upgrade_db(
dbfile: str, suffix: str, overwrite: bool
) -> sqlite3.Connection:
) -> None:

# original db files
old_conn = sqlite3.connect(dbfile)
Expand All @@ -92,6 +84,8 @@ def upgrade_db(

new_conn = upgrade_conn(new_conn)

old_conn.close()
if old_conn != new_conn:
old_conn.close()

return new_conn
new_conn.commit()
new_conn.close()

0 comments on commit 5ca70c5

Please sign in to comment.