Skip to content

Commit

Permalink
Added output file name selection
Browse files Browse the repository at this point in the history
  • Loading branch information
diegonti committed Nov 26, 2024
1 parent 14c41ee commit d828d6e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 33 deletions.
55 changes: 30 additions & 25 deletions mxgap/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def model_prediction(model_str,data_array):
return model_pred


def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"):
def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap",output=None):
"""
Main function for predicting bandgap with ML model, from CONTCAR and DOSCAR paths.
Expand All @@ -40,9 +40,9 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"):
"""

#! TODO: Implement _edge models and results

base_path = os.path.dirname(contcar_path)
base_file = os.path.join(base_path,"mxgap.info")
if output is None:
base_path = os.path.dirname(contcar_path)
output = os.path.join(base_path,"mxgap.info")

norm_x_contcar,norm_x_doscar, norm_y = load_normalization(norm_path)

Expand All @@ -68,20 +68,20 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"):

ML_isgap, ML_gap = clf_pred, round(reg_pred_rescaled,3)

print_clf(base_file,ML_isgap)
print_reg(base_file,ML_gap)
print_clf(output,ML_isgap)
print_reg(output,ML_gap)

return ML_isgap, ML_gap
return [ML_isgap, ML_gap]

elif len(model_list) == 1: #### C or R only case

pred = model_prediction(model_list[0], data_array_dict[model_needsDOS(model_list[0])])

if m_type[0] == "R":
pred = round(rescale(pred,norm_y,0), 3) #! Adapt for _edges
print_reg(base_file,pred)
print_reg(output,pred)
elif m_type[0] == "C":
print_clf(base_file,pred)
print_clf(output,pred)
pass

return [pred]
Expand All @@ -90,43 +90,48 @@ def ML_prediction(contcar_path:str,doscar_path:str,model:str="GBC+RFR_onlygap"):
raise ValueError(f"Model {model} not available. Use {PACKAGE_NAME} -l tu get the full list of models.")


def run_prediction(path:str=None, model:str=None, files:list=None):
def run_prediction(path:str=None, model:str=None, files:list=None, output:str=None):
"""Main function for predicting bandgap with ML model. Does the validation of inputs.
Parameters
----------
`path` : Optional. Path of the folder of a calculation, where the CONTCAR and DOSCAR are found. By default cwd.
`model` : Optional. ML model to use. By default GBC+RFR_onlygap (best).
`files` : Optional. Specify the paths for the CONTCAR and DOSCAR files, in a list. By default None.
Use either `paths` or `files`, if both are specified, `path` will take preference.
`path` : Optional. Path of the folder of a calculation, where the CONTCAR and DOSCAR are found. By default cwd.
`model` : Optional. ML model to use. By default GBC+RFR_onlygap (best).
`files` : Optional. Specify the paths for the CONTCAR and DOSCAR files, in a list. By default None.
Use either `paths` or `files`, if both are specified, `path` will take preference.
`output` : Optional. Specify the output file. By default it will generate a mxgap.info in the CONTCAR folder.
Returns
---------
`pred` : Result of the prediction in a list. The length will vary depending on the used model.
Can be either 1 (single Classifier or Regressor), 2 for combination of C+R, or +2 more for each when using the R_edges approach.
"""
print()
initial_time = time()

contcar_path, doscar_path, model = validate_user_input(path, model, files, default_path, default_model)

input_path_exists(contcar_path,doscar_path)
contcar_path, doscar_path, model, output = validate_user_input(path, model, files, output, default_path, default_model, default_output)

input_path_exists(contcar_path, doscar_path)
#! validate_files() (validate they are actual CONTCAR/DOSCAR files, maybe not necessary)

#! open file and write intro + results
# Open output file and write report (#! verbosity?)
base_path = os.path.dirname(contcar_path)
base_file = os.path.join(base_path,f"{PACKAGE_NAME}.info")
print_header(base_file,path,model,contcar_path,doscar_path)
output = os.path.join(base_path,output).replace("\\","/") if output == default_output else output
print_header(output,path,model,contcar_path,doscar_path,output)


pred = ML_prediction(contcar_path,doscar_path,model)
pred = ML_prediction(contcar_path,doscar_path,model,output)

#! better output print/file

final_time = time()
print2(base_file,f"\nFinished successfully in {final_time-initial_time:.2f}s")
print2(output,f"\nFinished successfully in {final_time-initial_time:.2f}s")

return pred


# Initialization of some paths
default_path = "./"
default_model = "GBC+RFR_onlygap"
default_output = "mxgap.info"
models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'models/')
norm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'NORM_INFO.txt')
model_list_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'MODELS_LIST.txt')
Expand Down
4 changes: 2 additions & 2 deletions mxgap/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
def cli():
"""Command Line Interface. Get user inputs from terminal (ArgParse) and feed them to the main ML prediction."""

path, model, files = parse_user_input()
path, model, files, output = parse_user_input()

run_prediction(path, model, files)
run_prediction(path, model, files, output)


##########################################################################
Expand Down
17 changes: 12 additions & 5 deletions mxgap/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,23 @@ def parse_user_input():
parser.add_argument("path",type=str,nargs="?",default=None,help="Specify the path to the directory containing the calculation output files, if empty, will select the current directory. Must contain at least the optimized CONTCAR, and the PBE DOSCAR for the PBE trained models.")
parser.add_argument("-f","--files",type=str,nargs="+",required=False,help="Specify in order the direct CONTCAR and DOSCAR (if needed) paths manually. The path positional argument has preference over this.")
parser.add_argument("-m","--model",type=str,default=None,help="Choose the trained MXene-Learning model to use. By default, the most accurate version is selected (RFR).")
parser.add_argument("-o","--output",type=str,default=None,help="Path of the output file. By default it will generate a mxgap.info in the CONTCAR folder.")
parser.add_argument("-l","--list", action="store_true",help="List of all trained ML models available to choose.")
args = parser.parse_args()

if args.list:
print(models_list_string)
sys.exit(0)

return args.path, args.model, args.files
return args.path, args.model, args.files, args.output


########################################################################
########################### Input Validation ###########################
########################################################################

def input_path_exists(*paths):
"""Asserts that the paths given bythe suer exist."""
"""Asserts that the paths given by the user exist."""
for path in paths:
if path is None: continue
assert os.path.exists(path), f"The provided path {path} does not exist."
Expand All @@ -60,9 +61,15 @@ def model_exists(model:str,models_list):
assert m in models_list, f"The provided model {model} does not exist. Use {PACKAGE_NAME} -l to get the full list."


def validate_user_input(path,model,files,default_path="./", default_model="GBC+RFR_onlygap"):
def validate_user_input(path,model,files,output,default_path="./", default_model="GBC+RFR_onlygap",default_output="mxgap.info"):
"""Validates the input given by the user. Checks input incompatibility, errors, etc.
If valid, returns the CONTCAR and DOSCAR paths."""
If valid, returns the CONTCAR, DOSCAR, and output paths."""

if output is None:
output = default_output
else:
if os.path.dirname(output) == "": output = "./" + output
input_path_exists(os.path.dirname(output))

if model is None:
print(f"No ML model detected. The {default_model} model (most accurate) will be used.")
Expand Down Expand Up @@ -99,7 +106,7 @@ def validate_user_input(path,model,files,default_path="./", default_model="GBC+R
else:
raise ValueError("File paths not detected properly. Indicate the CONTCAR and DOSCAR (if needed) paths.")

return contcar_path, doscar_path, model
return contcar_path, doscar_path, model, output


def validate_user_files():
Expand Down
4 changes: 3 additions & 1 deletion mxgap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def print_reg(file,pred):
print2(file,text)


def print_header(file,path,model,contcar_path,doscar_path):
def print_header(file,path,model,contcar_path,doscar_path,output):
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
report = f"""
====================================================================
Expand All @@ -53,6 +53,7 @@ def print_header(file,path,model,contcar_path,doscar_path):
Folder Path: {path}
CONTCAR file: {contcar_path}
DOSCAR file: {doscar_path}
Output Path: {output}
====================================================================
"""
Expand All @@ -78,6 +79,7 @@ def add_path_ending(path):
elif "/" in path: path = path + "/"
elif "\\" in path: path = path + "\\"
elif path == ".": path = "./"
else: path = path + "/"

return path

Expand Down

0 comments on commit d828d6e

Please sign in to comment.