diff --git a/bin/upgrade-db b/bin/upgrade-db index 253ef6e..11dd72b 100644 --- a/bin/upgrade-db +++ b/bin/upgrade-db @@ -32,11 +32,8 @@ def main() -> None: suffix = args.suffix or "_upgrade" for dbfile in args.dbfiles: - new_db = upgrade_db( - dbfile, suffix=suffix, overwrite=args.overwrite - ) + upgrade_db(dbfile, suffix=suffix, overwrite=args.overwrite) - new_db.commit() if __name__ == "__main__": main() diff --git a/logpyle/runalyzer.py b/logpyle/runalyzer.py index 825c011..e6dc6a7 100644 --- a/logpyle/runalyzer.py +++ b/logpyle/runalyzer.py @@ -458,6 +458,20 @@ def my_sprintf(format: str, arg: str) -> str: # }}} +def is_gathered(conn: sqlite3.Connection): + gathered = False + # get a list of tables with the name of 'runs' + res = list(conn.execute(""" + SELECT name + FROM sqlite_master + WHERE type='table' AND name='runs' + """)) + if len(res) == 1: + gathered = True + + return gathered + + def auto_gather(filenames: List[str]) -> sqlite3.Connection: # allow for creating ungathered files. # Check if database has been gathered, if not, create one in memory @@ -467,16 +481,7 @@ def auto_gather(filenames: List[str]) -> sqlite3.Connection: # check if any of the provided files have been gathered for f in filenames: db = sqlite3.connect(f) - cur = db.cursor() - - # get a list of tables with the name of 'runs' - res = list(cur.execute(""" - SELECT name - FROM sqlite_master - WHERE type='table' AND name='runs' - """)) - # there exists a table with the name of 'runs' - if len(res) == 1: + if is_gathered(db): gathered = True if gathered: diff --git a/logpyle/upgrade_db.py b/logpyle/upgrade_db.py index 04396ef..43ee095 100644 --- a/logpyle/upgrade_db.py +++ b/logpyle/upgrade_db.py @@ -3,69 +3,61 @@ def upgrade_conn(conn: sqlite3.Connection) -> sqlite3.Connection: + from logpyle.runalyzer import is_gathered tmp = conn.execute("select * from warnings").description warning_columns = [col[0] for col in tmp] - # check if any of the provided files have been gathered - gathered = False - # get a list of tables with the name of 'runs' - res = list(conn.execute(""" - SELECT name - FROM sqlite_master - WHERE type='table' AND name='runs' - """)) - if len(res) == 1: - gathered = True + # check if the provided connection has been gathered + gathered = is_gathered(conn) # ensure that warnings table has unixtime column - if ("unixtime" not in warning_columns): + if "unixtime" not in warning_columns: print("Adding a unixtime column in the warnings table") conn.execute(""" - ALTER TABLE warnings - ADD unixtime integer DEFAULT NULL; + ALTER TABLE warnings + ADD unixtime integer DEFAULT NULL; """) # ensure that warnings table has rank column # nowhere to grab the rank of the process that generated # the warning - if ("rank" not in warning_columns): + if "rank" not in warning_columns: print("Adding a rank column in the warnings table") conn.execute(""" - ALTER TABLE warnings - ADD rank integer DEFAULT NULL; + ALTER TABLE warnings + ADD rank integer DEFAULT NULL; """) + tables = [col[0] for col in conn.execute(""" + SELECT name + FROM sqlite_master + WHERE type='table' + """)] + print("Ensuring a logging table exists") - if gathered: - conn.execute(""" - CREATE TABLE IF NOT EXISTS logging ( - run_id integer, - rank integer, - step integer, - unixtime integer, - level text, - message text, - filename text, - lineno integer - )""") - else: + if "logging" not in tables: conn.execute(""" - CREATE TABLE IF NOT EXISTS logging ( - rank integer, - step integer, - unixtime integer, - level text, - message text, - filename text, - lineno integer - )""") + CREATE TABLE logging ( + rank integer, + step integer, + unixtime integer, + level text, + message text, + filename text, + lineno integer + )""") + if gathered: + conn.execute(""" + ALTER TABLE logging + ADD run_id integer; + """) return conn def upgrade_db( dbfile: str, suffix: str, overwrite: bool - ) -> sqlite3.Connection: + ) -> None: # original db files old_conn = sqlite3.connect(dbfile) @@ -92,6 +84,8 @@ def upgrade_db( new_conn = upgrade_conn(new_conn) - old_conn.close() + if old_conn != new_conn: + old_conn.close() - return new_conn + new_conn.commit() + new_conn.close()