diff --git a/mxgap/ML.py b/mxgap/ML.py index d94e2f7..dfa61b4 100644 --- a/mxgap/ML.py +++ b/mxgap/ML.py @@ -30,7 +30,7 @@ def model_prediction(model_str,data_array): return model_pred -def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"): +def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap",output=None): """ Main function for predicting bandgap with ML model, from CONTCAR and DOSCAR paths. @@ -40,9 +40,9 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"): """ #! TODO: Implement _edge models and results - - base_path = os.path.dirname(contcar_path) - base_file = os.path.join(base_path,"mxgap.info") + if output is None: + base_path = os.path.dirname(contcar_path) + output = os.path.join(base_path,"mxgap.info") norm_x_contcar,norm_x_doscar, norm_y = load_normalization(norm_path) @@ -68,10 +68,10 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"): ML_isgap, ML_gap = clf_pred, round(reg_pred_rescaled,3) - print_clf(base_file,ML_isgap) - print_reg(base_file,ML_gap) + print_clf(output,ML_isgap) + print_reg(output,ML_gap) - return ML_isgap, ML_gap + return [ML_isgap, ML_gap] elif len(model_list) == 1: #### C or R only case @@ -79,9 +79,9 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"): if m_type[0] == "R": pred = round(rescale(pred,norm_y,0), 3) #! Adapt for _edges - print_reg(base_file,pred) + print_reg(output,pred) elif m_type[0] == "C": - print_clf(base_file,pred) + print_clf(output,pred) pass return [pred] @@ -90,36 +90,40 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"): raise ValueError(f"Model {model} not available. Use {PACKAGE_NAME} -l tu get the full list of models.") -def run_prediction(path:str=None, model:str=None, files:list=None): +def run_prediction(path:str=None, model:str=None, files:list=None, output:str=None): """Main function for predicting bandgap with ML model. Does the validation of inputs. Parameters ---------- - `path` : Optional. Path of the folder of a calculation, where the CONTCAR and DOSCAR are found. By default cwd. - `model` : Optional. ML model to use. By default GBC+RFR_onlygap (best). - `files` : Optional. Specify the paths for the CONTCAR and DOSCAR files, in a list. By default None. - Use either `paths` or `files`, if both are specified, `path` will take preference. + `path` : Optional. Path of the folder of a calculation, where the CONTCAR and DOSCAR are found. By default cwd. + `model` : Optional. ML model to use. By default GBC+RFR_onlygap (best). + `files` : Optional. Specify the paths for the CONTCAR and DOSCAR files, in a list. By default None. + Use either `paths` or `files`, if both are specified, `path` will take preference. + `output` : Optional. Specify the output file. By default it will generate a mxgap.info in the CONTCAR folder. + + Returns + --------- + `pred` : Result of the prediction in a list. The length will vary depending on the used model. + Can be either 1 (single Classifier or Regressor), 2 for combination of C+R, or +2 more for each when using the R_edges approach. + """ print() initial_time = time() - - contcar_path, doscar_path, model = validate_user_input(path, model, files, default_path, default_model) - input_path_exists(contcar_path,doscar_path) + contcar_path, doscar_path, model, output = validate_user_input(path, model, files, output, default_path, default_model, default_output) + + input_path_exists(contcar_path, doscar_path) #! validate_files() (validate they are actual CONTCAR/DOSCAR files, maybe not necessary) - #! open file and write intro + results + # Open output file and write report (#! verbosity?) base_path = os.path.dirname(contcar_path) - base_file = os.path.join(base_path,f"{PACKAGE_NAME}.info") - print_header(base_file,path,model,contcar_path,doscar_path) + output = os.path.join(base_path,output).replace("\\","/") if output == default_output else output + print_header(output,path,model,contcar_path,doscar_path,output) - - pred = ML_prediction(contcar_path,doscar_path,model) + pred = ML_prediction(contcar_path,doscar_path,model,output) - #! better output print/file - final_time = time() - print2(base_file,f"\nFinished successfully in {final_time-initial_time:.2f}s") + print2(output,f"\nFinished successfully in {final_time-initial_time:.2f}s") return pred @@ -127,6 +131,7 @@ def run_prediction(path:str=None, model:str=None, files:list=None): # Initialization of some paths default_path = "./" default_model = "GBC+RFR_onlygap" +default_output = "mxgap.info" models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'models/') norm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'NORM_INFO.txt') model_list_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'MODELS_LIST.txt') diff --git a/mxgap/cli.py b/mxgap/cli.py index 5aa5a69..fce41f3 100644 --- a/mxgap/cli.py +++ b/mxgap/cli.py @@ -14,9 +14,9 @@ def cli(): """Command Line Interface. Get user inputs from terminal (ArgParse) and feed them to the main ML prediction.""" - path, model, files = parse_user_input() + path, model, files, output = parse_user_input() - run_prediction(path, model, files) + run_prediction(path, model, files, output) ########################################################################## diff --git a/mxgap/input.py b/mxgap/input.py index d5c27ae..7b34ec8 100644 --- a/mxgap/input.py +++ b/mxgap/input.py @@ -32,6 +32,7 @@ def parse_user_input(): parser.add_argument("path",type=str,nargs="?",default=None,help="Specify the path to the directory containing the calculation output files, if empty, will select the current directory. Must contain at least the optimized CONTCAR, and the PBE DOSCAR for the PBE trained models.") parser.add_argument("-f","--files",type=str,nargs="+",required=False,help="Specify in order the direct CONTCAR and DOSCAR (if needed) paths manually. The path positional argument has preference over this.") parser.add_argument("-m","--model",type=str,default=None,help="Choose the trained MXene-Learning model to use. By default, the most accurate version is selected (RFR).") + parser.add_argument("-o","--output",type=str,default=None,help="Path of the output file. By default it will generate a mxgap.info in the CONTCAR folder.") parser.add_argument("-l","--list", action="store_true",help="List of all trained ML models available to choose.") args = parser.parse_args() @@ -39,7 +40,7 @@ def parse_user_input(): print(models_list_string) sys.exit(0) - return args.path, args.model, args.files + return args.path, args.model, args.files, args.output ######################################################################## @@ -47,7 +48,7 @@ def parse_user_input(): ######################################################################## def input_path_exists(*paths): - """Asserts that the paths given bythe suer exist.""" + """Asserts that the paths given by the user exist.""" for path in paths: if path is None: continue assert os.path.exists(path), f"The provided path {path} does not exist." @@ -60,9 +61,15 @@ def model_exists(model:str,models_list): assert m in models_list, f"The provided model {model} does not exist. Use {PACKAGE_NAME} -l to get the full list." -def validate_user_input(path,model,files,default_path="./", default_model="GBC+RFR_onlygap"): +def validate_user_input(path,model,files,output,default_path="./", default_model="GBC+RFR_onlygap",default_output="mxgap.info"): """Validates the input given by the user. Checks input incompatibility, errors, etc. - If valid, returns the CONTCAR and DOSCAR paths.""" + If valid, returns the CONTCAR, DOSCAR, and output paths.""" + + if output is None: + output = default_output + else: + if os.path.dirname(output) == "": output = "./" + output + input_path_exists(os.path.dirname(output)) if model is None: print(f"No ML model detected. The {default_model} model (most accurate) will be used.") @@ -99,7 +106,7 @@ def validate_user_input(path,model,files,default_path="./", default_model="GBC+R else: raise ValueError("File paths not detected properly. Indicate the CONTCAR and DOSCAR (if needed) paths.") - return contcar_path, doscar_path, model + return contcar_path, doscar_path, model, output def validate_user_files(): diff --git a/mxgap/utils.py b/mxgap/utils.py index 0b02efe..ed1cee4 100644 --- a/mxgap/utils.py +++ b/mxgap/utils.py @@ -41,7 +41,7 @@ def print_reg(file,pred): print2(file,text) -def print_header(file,path,model,contcar_path,doscar_path): +def print_header(file,path,model,contcar_path,doscar_path,output): current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") report = f""" ==================================================================== @@ -53,6 +53,7 @@ def print_header(file,path,model,contcar_path,doscar_path): Folder Path: {path} CONTCAR file: {contcar_path} DOSCAR file: {doscar_path} +Output Path: {output} ==================================================================== """ @@ -78,6 +79,7 @@ def add_path_ending(path): elif "/" in path: path = path + "/" elif "\\" in path: path = path + "\\" elif path == ".": path = "./" + else: path = path + "/" return path