diff --git a/SpatialQARules/graph_spartun_rel.py b/SpatialQARules/graph_spartun_rel.py index dd86cb69..ebd50cf8 100644 --- a/SpatialQARules/graph_spartun_rel.py +++ b/SpatialQARules/graph_spartun_rel.py @@ -28,8 +28,8 @@ inside = question(name="inside") cover = question(name="cover") contain = question(name="contain") - # answer_class = question(name="answer_class", ConceptClass=EnumConcept, - # values=["yes", "no"]) + output_for_loss = question(name="output_for_loss") + # Only one label of opposite concepts exactL(left, right) exactL(above, below) @@ -52,9 +52,9 @@ ifL(andL(ans2('x'), existsL(inverse('s', path=('x', inverse)))), ans1(path=('s', inv_question2))) - # Only inverse one way - inverse_list2 = [(near, near), (far, far), (touch, touch), (disconnected, disconnected), (overlap, overlap), - (coveredby, inside), (cover, contain)] + # 2 PMD : = entropy + beta * constraint_loss ( Train with no-constraint first then working on) + # symmetric + inverse_list2 = [(near, near), (far, far), (touch, touch), (disconnected, disconnected), (overlap, overlap)] for ans1, ans2 in inverse_list2: ifL(andL(ans1('x'), existsL(inverse('s', path=('x', inverse)))), ans2(path=('s', inv_question2))) @@ -64,6 +64,7 @@ tran_quest1, tran_quest2, tran_quest3 = transitive.has_a(arg11=question, arg22=question, arg33=question) transitive_1 = [left, right, above, below, behind, front, inside, contain] + for rel in transitive_1: ifL(andL(rel('x'), existsL(transitive("t", path=('x', transitive))), @@ -103,4 +104,14 @@ rel1(path=('to', tran_topo_quest2)), rel2(path=('to', tran_topo_quest3)) ), - rel2(path=('to', tran_topo_quest4))) \ No newline at end of file + rel2(path=('to', tran_topo_quest4))) + + tran_topo_3_1 = [left, right, above, below, behind, front, near, far, disconnected] + tran_topo_3_2 = [contain, cover] + for rel1 in tran_topo_3_1: + for rel2 in tran_topo_3_2: + ifL(andL(rel1('x'), + existsL(tran_topo('to', path=('x', tran_topo))), + rel1(path=('to', tran_topo_quest2)), + rel2(path=('to', tran_topo_quest3))), + rel1(path=('to', tran_topo_quest4))) \ No newline at end of file diff --git a/SpatialQARules/main.py b/SpatialQARules/main.py index 29f60d21..59b62c3e 100644 --- a/SpatialQARules/main.py +++ b/SpatialQARules/main.py @@ -171,7 +171,7 @@ def main(args): augmented=args.train_file.upper() == "SPARTUN", batch_size=args.batch_size, boolQL=boolQ, - rule=args.text_rules) + rule_text=args.text_rules) test_file = "human_test.json" if args.test_file.upper() == "HUMAN" \ else "test.json" @@ -180,7 +180,7 @@ def main(args): size=args.test_size, augmented=False, batch_size=args.batch_size, - rule=args.text_rules) + rule_text=args.text_rules) eval_file = "DataSet/human_dev.json" if args.test_file.upper() == "HUMAN" \ else "DataSet/boolQ/train.json" if args.train_file.upper() == "BOOLQ" else "DataSet/dev_Spartun.json" @@ -190,7 +190,7 @@ def main(args): augmented=False, batch_size=args.batch_size, boolQL=boolQ, - rule=args.text_rules) + rule_text=args.text_rules) program_name = "PMD" if args.pmd else "Sampling" if args.sampling else "Base" if args.loaded: program.load("Models/" + args.loaded_file, map_location={'cuda:0': cur_device, 'cuda:1': cur_device}) diff --git a/SpatialQARules/main_rel.py b/SpatialQARules/main_rel.py index f2bcf581..4fa0e62d 100644 --- a/SpatialQARules/main_rel.py +++ b/SpatialQARules/main_rel.py @@ -10,22 +10,31 @@ import torch import argparse import numpy as np +import transformers from domiknows.graph import Graph, Concept, Relation -from program_declaration import program_declaration_spartun_fr, program_declaration_StepGame +from program_declaration import program_declaration_spartun_fr, program_declaration_StepGame, \ + program_declaration_spartun_fr_T5, program_declaration_spartun_fr_T5_v2, \ + program_declaration_spartun_fr_T5_v3, program_declaration_spartun_fr_T5_v4 from reader import DomiKnowS_reader import tqdm from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix -def eval(program, testing_set, cur_device, args, print_result=False, StepGame_number=None): +def eval(program, testing_set, cur_device, args, print_result=False, StepGame_number=None, multilabel=False): if args.test_file.upper() != "STEPGAME": from graph_spartun_rel import left, right, above, below, behind, front, near, far, disconnected, touch, \ - overlap, coveredby, inside, cover, contain + overlap, coveredby, inside, cover, contain, output_for_loss all_labels = [left, right, above, below, behind, front, near, far, disconnected, touch, overlap, coveredby, inside, cover, contain] + + all_labels_text = ["left", "right", "above", "below", "behind", "front", + "near", "far", "disconnect", "touch", "overlap", "covered by", + "inside", "cover", "contain"] else: from graph_stepgame import left, right, above, below, lower_left, lower_right, upper_left, upper_right, overlap all_labels = [left, right, above, below, lower_left, lower_right, upper_left, upper_right, overlap] + all_labels_text = ["left", "right", "above", "below", "lower-left", + "lower-right", "upper-left", "upper-right", "overlap"] def remove_opposite(ind1, ind2, result_set, result_list): if ind1 in pred_set and ind2 in pred_set: @@ -38,21 +47,20 @@ def remove_opposite(ind1, ind2, result_set, result_list): correct = 0 total = 0 pred_set = set() - for datanode in tqdm.tqdm(program.populate(testing_set, device=cur_device), "Manually Testing"): + for datanode in tqdm.tqdm(program.populate(testing_set, device=cur_device), "Checking accuracy"): for question in datanode.getChildDataNodes(): pred_set.clear() pred_list.clear() total += 1 + # Getting predict label for ind, label in enumerate(all_labels): pred = question.getAttribute(label, 'local/softmax') if pred.argmax().item() == 1: pred_set.add(ind) pred_list.append(pred[1].item()) - if args.train_file.upper() == "STEPGAME": pred = np.array(pred_list).argmax() - pred_set.clear() - pred_set.add(pred) + pred_set = {pred} else: remove_opposite(0, 1, pred_set, pred_list) remove_opposite(2, 3, pred_set, pred_list) @@ -60,11 +68,24 @@ def remove_opposite(ind1, ind2, result_set, result_list): remove_opposite(6, 7, pred_set, pred_list) remove_opposite(8, 9, pred_set, pred_list) accuracy_check = True + # Getting acutal label + # if args.model == "t5-adapter": + # expected_text = question.getAttribute("text_labels") + # pred_text = "" + # for i, label in enumerate(all_labels_text): + # if multilabel: + # pred_text += label + ":" + ("yes" if i in pred_set else "no") + " " + # else: + # if i in pred_set: + # pred_text += label if not pred_text else (", " + label) + # correct += int(expected_text.strip() == pred_text.strip()) + # else: for ind, label_ind in enumerate(all_labels): label = question.getAttribute(label_ind, 'label').item() pred = 1 if ind in pred_set else 0 accuracy_check = accuracy_check and label == pred - if accuracy_check: correct += 1 + if accuracy_check: + correct += 1 accuracy = correct / total if print_result: @@ -85,37 +106,58 @@ def remove_opposite(ind1, ind2, result_set, result_list): print("Testing on StepGame {:} steps".format(StepGame_number), file=result_file) print("Accuracy:", accuracy, file=result_file) - # print("Constrains Satisfied rate:", satisfy_constrain_rate, "%", file=result_file) - # print("Precious:", precision_score(actual, pred, average=None), file=result_file) - # print("Recall:", recall_score(actual, pred, average=None), file=result_file) - # print("F1:", f1_score(actual, pred, average=None), file=result_file) - # print("F1 Macro:", f1_score(actual, pred, average='macro'), file=result_file) - # print("Confusion Matrix:\n", confusion_matrix(actual, pred), file=result_file) - # result_file.close() return accuracy -def train(program, train_set, eval_set, cur_device, limit, lr, check_epoch=4, program_name="DomiKnow", args=None): +def train(program, train_set, eval_set, cur_device, limit, lr, check_epoch=1, program_name="DomiKnow", args=None): + def get_avg_loss(): + from domiknows.program.model.base import Mode + if cur_device is not None: + program.model.to(cur_device) + program.model.mode(Mode.TEST) + program.model.reset() + train_loss = 0 + total_loss = 0 + with torch.no_grad(): + for data_item in tqdm.tqdm(train_set, "Calculating Loss of training"): + loss, _, *output = program.model(data_item) + total_loss += 1 + train_loss += loss + return train_loss / total_loss + best_accuracy = 0 best_epoch = 0 old_file = None + check_epoch = args.check_epoch training_file = open("training.txt", 'a') print("-" * 10, file=training_file) print("Training by {:s} of ({:s} {:s})".format(program_name, args.train_file, "FR"), file=training_file) print("Learning Rate:", args.lr, file=training_file) training_file.close() cur_epoch = 0 + if args.model == "t5-adapter": + optimizer = lambda param: transformers.optimization.Adafactor(param, lr=lr, scale_parameter=False, + relative_step=False) + else: + optimizer = lambda param: torch.optim.AdamW(param, lr=lr) for epoch in range(check_epoch, limit, check_epoch): - training_file = open("training.txt", 'a') print("Training") - program.train(train_set, train_epoch_num=check_epoch, - Optim=lambda param: torch.optim.Adam(param, lr=lr, amsgrad=True), - device=cur_device) - cur_epoch += limit + if args.pmd: + program.train(train_set, c_warmup_iters=0, train_epoch_num=check_epoch, + Optim=optimizer, + device=cur_device) + else: + program.train(train_set, train_epoch_num=check_epoch, + Optim=optimizer, + device=cur_device) + cur_epoch += check_epoch + # loss = get_avg_loss() + training_file = open("training.txt", 'a') accuracy = eval(program, eval_set, cur_device, args) print("Epoch:", epoch, file=training_file) + # print("Loss:", loss, file=training_file) print("Dev Accuracy:", accuracy * 100, "%", file=training_file) - if accuracy > best_accuracy: + if accuracy >= best_accuracy: best_epoch = epoch best_accuracy = accuracy # if old_file: @@ -125,23 +167,34 @@ def train(program, train_set, eval_set, cur_device, limit, lr, check_epoch=4, pr program_addition = "_beta_" + str(args.beta) else: program_addition = "_size_" + str(args.sampling_size) - new_file = program_name + "_" + str(epoch) + "epoch" + "_lr_" + str(args.lr) + program_addition + new_file = program_name + "_" + str(epoch) + "epoch" + "_lr_" + str( + args.lr) + program_addition + "_model_" + args.model program.save("Models/" + new_file) training_file.close() training_file = open("training.txt", 'a') if cur_epoch < limit: - program.train(train_set, train_epoch_num=limit - cur_epoch, - Optim=lambda param: torch.optim.Adam(param, lr=lr, amsgrad=True), - device=cur_device) + if args.pmd: + program.train(train_set, c_warmup_iters=0, train_epoch_num=check_epoch, + Optim=optimizer, + device=cur_device) + else: + program.train(train_set, train_epoch_num=check_epoch, + Optim=optimizer, + device=cur_device) accuracy = eval(program, eval_set, cur_device, args) print("Epoch:", limit, file=training_file) print("Dev Accuracy:", accuracy * 100, "%", file=training_file) - if accuracy > best_accuracy: + if accuracy >= best_accuracy: best_epoch = limit # if old_file: # os.remove(old_file) - new_file = program_name + "_" + str(limit) + "epoch" + "_lr_" + str(args.lr) + if program_name == "PMD": + program_addition = "_beta_" + str(args.beta) + else: + program_addition = "_size_" + str(args.sampling_size) + new_file = program_name + "_" + str(limit) + "epoch" + "_lr_" + str( + args.lr) + program_addition + "_model_" + args.model old_file = new_file program.save("Models/" + new_file) print("Best epoch ", best_epoch, file=training_file) @@ -153,7 +206,6 @@ def main(args): SEED = 382 np.random.seed(SEED) random.seed(SEED) - # pl.seed_everything(SEED) torch.manual_seed(SEED) cuda_number = args.cuda @@ -162,43 +214,62 @@ def main(args): else: cur_device = "cuda:" + str(cuda_number) if torch.cuda.is_available() else 'cpu' - program = None if args.train_file.upper() == "STEPGAME": program = program_declaration_StepGame(cur_device, - pmd=args.pmd, beta=args.beta, - sampling=args.sampling, sampleSize=args.sampling_size, - dropout=args.dropout, constrains=args.constrains) + pmd=args.pmd, beta=args.beta, + sampling=args.sampling, sampleSize=args.sampling_size, + dropout=args.dropout, constraints=args.constrains) else: - program = program_declaration_spartun_fr(cur_device, - pmd=args.pmd, beta=args.beta, - sampling=args.sampling, sampleSize=args.sampling_size, - dropout=args.dropout, constrains=args.constrains) + if args.model == "t5-adapter": + print("call T5") + program_declaration_function = None + if args.version == 2: + program_declaration_function = program_declaration_spartun_fr_T5_v2 + elif args.version == 3: + program_declaration_function = program_declaration_spartun_fr_T5_v3 + elif args.version == 4: + program_declaration_function = program_declaration_spartun_fr_T5_v4 + else: + program_declaration_function = program_declaration_spartun_fr_T5 + + program = program_declaration_function(cur_device, + pmd=args.pmd, beta=args.beta, + sampling=args.sampling, sampleSize=args.sampling_size, + dropout=args.dropout, constraints=args.constrains) + else: + program = program_declaration_spartun_fr(cur_device, + pmd=args.pmd, beta=args.beta, + sampling=args.sampling, sampleSize=args.sampling_size, + dropout=args.dropout, constraints=args.constrains, + model=args.model) boolQ = args.train_file.upper() == "BOOLQ" train_file = "train.json" if args.train_file.upper() == "ORIGIN" \ - else "train_with_rules.json" if args.train_file.upper() == "SPARTUN" \ + else "train_FR_v3.json" if args.train_file.upper() == "SPARTUN" \ else "boolQ/train.json" if args.train_file.upper() == "BOOLQ" \ else "StepGame" if args.train_file.upper() == "STEPGAME" \ else "human_train.json" training_set = DomiKnowS_reader("DataSet/" + train_file, "FR", - size=args.train_size, upward_level=8, + type_dataset=args.train_file.upper(), + size=args.train_size, + upward_level=12, augmented=args.train_file.upper() == "SPARTUN", batch_size=args.batch_size, - boolQL=boolQ, - rule=args.text_rules, - StepGame_status="train" if args.train_file.upper() == "STEPGAME" else None) + rule_text=args.text_rules, + STEPGAME_status="train" if args.train_file.upper() == "STEPGAME" else None) test_file = "human_test.json" if args.test_file.upper() == "HUMAN" \ else "StepGame" if args.train_file.upper() == "STEPGAME" \ else "test.json" testing_set = DomiKnowS_reader("DataSet/" + test_file, "FR", + type_dataset=args.train_file.upper(), size=args.test_size, augmented=False, batch_size=args.batch_size, - rule=args.text_rules, - StepGame_status="test" if args.train_file.upper() == "STEPGAME" else None, + rule_text=args.text_rules, + STEPGAME_status="test" if args.train_file.upper() == "STEPGAME" else None, ) eval_file = "human_dev.json" if args.test_file.upper() == "HUMAN" \ @@ -206,12 +277,12 @@ def main(args): else "boolQ/train.json" if args.train_file.upper() == "BOOLQ" else "dev_Spartun.json" eval_set = DomiKnowS_reader("DataSet/" + eval_file, "FR", + type_dataset=args.train_file.upper(), size=args.test_size, augmented=False, batch_size=args.batch_size, - boolQL=boolQ, - rule=args.text_rules, - StepGame_status="dev" if args.train_file.upper() == "STEPGAME" else None) + rule_text=args.text_rules, + STEPGAME_status="dev" if args.train_file.upper() == "STEPGAME" else None) program_name = "PMD" if args.pmd else "Sampling" if args.sampling else "Base" @@ -219,37 +290,48 @@ def main(args): if args.loaded: if args.model_change: pretrain_model = torch.load("Models/" + args.loaded_file, - map_location={'cuda:0': cur_device, 'cuda:1': cur_device}) + map_location={'cuda:0': cur_device, 'cuda:1': cur_device, 'cuda:2': cur_device, + 'cuda:3': cur_device, 'cuda:4': cur_device, 'cuda:5': cur_device}) pretrain_dict = pretrain_model.state_dict() current_dict = program.model.state_dict() # Filter out unnecessary keys pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in current_dict} program.model.load_state_dict(pretrain_dict) else: - program.load("Models/" + args.loaded_file, map_location={'cuda:0': cur_device, 'cuda:1': cur_device}) + program.load("Models/" + args.loaded_file, + map_location={'cuda:0': cur_device, 'cuda:1': cur_device, 'cuda:2': cur_device, + 'cuda:3': cur_device, 'cuda:4': cur_device, 'cuda:5': cur_device}) if args.test_each: for i in range(10): + print("Testing {:} steps".format(i)) testing_set = DomiKnowS_reader("DataSet/" + test_file, "FR", + type_dataset=args.train_file.upper(), size=args.test_size, augmented=False, batch_size=args.batch_size, - rule=args.text_rules, - StepGame_status="test" if args.train_file.upper() == "STEPGAME" else None, - StepGame_number=i+1) + rule_text=args.text_rules, + STEPGAME_status="test" if args.train_file.upper() == "STEPGAME" else None, + reasoning_steps=i) eval(program, testing_set, cur_device, args, print_result=True) else: eval(program, testing_set, cur_device, args, print_result=True) elif args.loaded_train: if args.model_change: pretrain_model = torch.load("Models/" + args.loaded_file, - map_location={'cuda:0': cur_device, 'cuda:1': cur_device}) - pretrain_dict = pretrain_model.state_dict() + map_location={'cuda:0': cur_device, 'cuda:1': cur_device, 'cuda:2': cur_device, + 'cuda:3': cur_device, 'cuda:4': cur_device, 'cuda:5': cur_device}) + pretrain_dict = pretrain_model current_dict = program.model.state_dict() # Filter out unnecessary keys - pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in current_dict} - program.model.load_state_dict(pretrain_dict) + # pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in current_dict} + # Loaded same parameters + new_state_dict = {k: v if k not in pretrain_dict else pretrain_dict[k] + for k, v in current_dict.items()} + program.model.load_state_dict(new_state_dict) else: - program.load("Models/" + args.loaded_file, map_location={'cuda:0': cur_device, 'cuda:1': cur_device}) + program.load("Models/" + args.loaded_file, + map_location={'cuda:0': cur_device, 'cuda:1': cur_device, 'cuda:2': cur_device, + 'cuda:3': cur_device, 'cuda:4': cur_device, 'cuda:5': cur_device}) train(program, training_set, eval_set, cur_device, args.epoch, args.lr, program_name=program_name, args=args) else: train(program, training_set, eval_set, cur_device, args.epoch, args.lr, program_name=program_name, args=args) @@ -260,10 +342,10 @@ def main(args): parser.add_argument("--epoch", dest="epoch", type=int, default=1) parser.add_argument("--lr", dest="lr", type=float, default=1e-5) parser.add_argument("--cuda", dest="cuda", type=int, default=0) - parser.add_argument("--test_size", dest="test_size", type=int, default=2) - parser.add_argument("--train_size", dest="train_size", type=int, default=2) - parser.add_argument("--batch_size", dest="batch_size", type=int, default=1) - parser.add_argument("--train_file", type=str, default="origin", help="Option: SpaRTUN or Human") + parser.add_argument("--test_size", dest="test_size", type=int, default=12) + parser.add_argument("--train_size", dest="train_size", type=int, default=16) + parser.add_argument("--batch_size", dest="batch_size", type=int, default=4) + parser.add_argument("--train_file", type=str, default="SPARTUN", help="Option: SpaRTUN or Human") parser.add_argument("--test_file", type=str, default="SPARTUN", help="Option: SpaRTUN or Human") parser.add_argument("--text_rules", type=bool, default=False, help="Including rules as text or not") parser.add_argument("--dropout", dest="dropout", type=bool, default=False) @@ -275,9 +357,13 @@ def main(args): parser.add_argument("--loaded", dest="loaded", type=bool, default=False) parser.add_argument("--loaded_file", dest="loaded_file", type=str, default="train_model") parser.add_argument("--loaded_train", type=bool, default=False, help="Option to load and then further train") + parser.add_argument("--model_change", type=bool, default=False, help="Option to load and then further train") parser.add_argument("--save", dest="save", type=bool, default=False) parser.add_argument("--save_file", dest="save_file", type=str, default="train_model") parser.add_argument("--step_game_test_each", dest="test_each", type=bool, default=False) + parser.add_argument("--model", dest="model", type=str, default="bert") + parser.add_argument("--check_epoch", dest="check_epoch", type=int, default=1) + parser.add_argument("--version", dest="version", type=int, default=0) args = parser.parse_args() main(args) diff --git a/SpatialQARules/models.py b/SpatialQARules/models.py index 3b32ba56..9c897f2f 100644 --- a/SpatialQARules/models.py +++ b/SpatialQARules/models.py @@ -1,7 +1,10 @@ -from transformers import BertModel, BertPreTrainedModel, BertTokenizer +from torch.nn.modules.module import T +from transformers import BertModel, BertPreTrainedModel, BertTokenizer, RobertaTokenizer, RobertaModel, \ + RobertaPreTrainedModel, AutoTokenizer, T5ForConditionalGeneration, T5PreTrainedModel, AutoModelForSeq2SeqLM from torch import nn import torch from torch.autograd import Variable +from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training class BERTTokenizer: @@ -14,6 +17,16 @@ def __call__(self, _, question, story): return torch.LongTensor(input_ids) +class RoBERTaTokenizer: + def __init__(self): + self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + def __call__(self, _, question, story): + encoded_input = self.tokenizer(question, story, padding="max_length", truncation=True) + input_ids = encoded_input["input_ids"] + return torch.LongTensor(input_ids) + + class MultipleClassYN(BertPreTrainedModel): def __init__(self, config, device="cpu", drp=False): super().__init__(config) @@ -38,6 +51,32 @@ def forward(self, input_ids): return output + +class MultipleClassYNRoberta(RobertaPreTrainedModel): + def __init__(self, config, device="cpu", drp=False): + super().__init__(config) + + if drp: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.cur_device = device + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.num_classes = 2 + self.classifier = nn.Linear(config.hidden_size, self.num_classes) + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax() + + def forward(self, input_ids): + outputs = self.roberta(input_ids) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + output = self.classifier(pooled_output) + + return output + + class MultipleClassYN_Hidden(BertPreTrainedModel): def __init__(self, config, device="cpu", drp=False): super().__init__(config) @@ -58,6 +97,28 @@ def forward(self, input_ids): return pooled_output + +class MultipleClassYN_Hidden_Roberta(RobertaPreTrainedModel): + def __init__(self, config, device="cpu", drp=False): + super().__init__(config) + + if drp: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.cur_device = device + self.bert = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.hidden_size = config.hidden_size + + def forward(self, input_ids): + outputs = self.bert(input_ids) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + + return pooled_output + + class ClassifyLayer(nn.Module): def __init__(self, hidden_size, device="cpu", drp=False): super().__init__() @@ -70,4 +131,464 @@ def __init__(self, hidden_size, device="cpu", drp=False): def forward(self, pooled_output): output = self.classifier(pooled_output) - return output \ No newline at end of file + return output + + +class ClassifyLayer2(nn.Module): + def __init__(self, hidden_size, hidden_layer=1, device="cpu", drp=False): + super().__init__() + + self.num_classes = 2 + layer_parameters = [hidden_size] + [256 for i in range(hidden_layer - 1)] + [self.num_classes] + + all_layer = [] + for i in range(len(layer_parameters) - 1): + all_layer.append(nn.Linear(layer_parameters[i], layer_parameters[i + 1])) + self.classifier = nn.Sequential(*all_layer) + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax() + + def forward(self, pooled_output): + logits = self.classifier(pooled_output) + # logits = self.sigmoid(logits) + + return logits + + +class MultipleClassYNT5(nn.Module): + def __init__(self, config, device="cpu", adapter=False): + super().__init__() + + self.cur_device = device + if adapter: + print("Using Lora") + self.model = AutoModelForSeq2SeqLM.from_pretrained(config) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q", "v"], + lora_dropout=0.05, + bias="none", + task_type=TaskType.SEQ_2_SEQ_LM + ) + # prepare int-8 model for training + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, lora_config) + self.model.config.use_cache = False + else: + self.model = AutoModelForSeq2SeqLM.from_pretrained(config) + self.tokenizer = AutoTokenizer.from_pretrained(config) + self.num_classes = 2 + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=1) + self.output_size = 1 + + def forward(self, input_ids): + decoder_id = torch.tensor([[self.tokenizer.pad_token_id] * self.output_size] * input_ids.size(0)).to( + self.cur_device) + logits = self.model(input_ids, decoder_input_ids=decoder_id)[0] + tokens = torch.argmax(logits, dim=2) + # Yes token is 2163, No token is 465 + # Output ["Yes", "No"] + logits = logits.squeeze(1) + selected_logits = logits[:, [2163, 465]] + output = self.softmax(selected_logits) + + return output + + +class T5Tokenizer: + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def __call__(self, _, questions, stories): + prompts = [] + for ind, question in enumerate(questions): + prompts.append("You will answer the question based on the following context: " + stories[ + ind] + "\n Question: " + question) + encoded_input = self.tokenizer(prompts, padding="max_length", truncation=True) + input_ids = encoded_input["input_ids"] + return torch.LongTensor(input_ids) + + +class MultipleClassFRT5(nn.Module): + def __init__(self, model_name, expected_label, device="cpu", adapter=False): + super().__init__() + + self.cur_device = device + if adapter: + print("Using Lora") + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + lora_config = LoraConfig( + r=64, + lora_alpha=64, + target_modules=["q", "v"], + lora_dropout=0.05, + bias="none", + task_type=TaskType.SEQ_2_SEQ_LM + ) + # prepare int-8 model for training + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, lora_config) + self.model.config.use_cache = False + else: + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.label_tokens = self.tokenizer(expected_label)["input_ids"] + self.map_token = {} + self.unique_token = {} + # FIXTHIS + for token_label in self.label_tokens: + for token in token_label: + self.unique_token[token] = 1 + self.unique_token = [token for token in self.unique_token.keys()] + for i, token in enumerate(self.unique_token): + self.map_token[token] = i + self.num_classes = 2 + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=1) + self.map_label = {label: i for i, label in enumerate(expected_label)} + self.second_model = nn.Sequential(nn.Linear(len(self.unique_token) * 2, len(expected_label)), + nn.Sigmoid()) + + def forward(self, input_ids): + # Force decoder to output 2 token + decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]] * input_ids.size()[0]).to( + self.cur_device) + logits = self.model(input_ids, decoder_input_ids=decoder_input_ids)[0] + first_word = logits.argmax(dim=2) + + decoder_input_ids = torch.concat((decoder_input_ids, first_word), dim=-1) + + logits = self.model(input_ids, decoder_input_ids=decoder_input_ids)[0] + # Only output the selecting value of token in the unique_tokens + logits = logits[:, :, self.unique_token] + logits = torch.concat((logits[:, 0, :], logits[:, 1, :]), dim=1) + output = self.second_model(logits) + # tokens = torch.argmax(logits, dim=2) + # Yes token is 2163, No token is 465 + # Output ["Yes", "No"] + # logits = logits.squeeze(1) + + return output + + +class ClassifyLabelT5(nn.Module): + def __init__(self, label_word, map_index, device="cpu", drp=False): + super().__init__() + self.map_index = map_index[label_word] + + def forward(self, logits): + output = logits[:, self.map_index] + output = output.reshape(-1, 1) + output = torch.cat((torch.sub(torch.ones_like(output), output), output), dim=-1) + return output + + +class T5WithLora(nn.Module): + def __init__(self, model_name, device="cpu", adapter=False): + super().__init__() + + self.cur_device = device + if adapter: + print("Using Lora") + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + lora_config = LoraConfig( + r=32, + lora_alpha=32, + target_modules=["q", "v"], + lora_dropout=0.01, + bias="lora_only", + task_type=TaskType.SEQ_2_SEQ_LM + ) + # prepare int-8 model for training + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, lora_config) + self.model.config.use_cache = False + else: + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.num_classes = 2 + + def forward(self, _, cat_input_ids): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + logits = self.model.generate(**{'input_ids': input_ids, 'attention_mask': attention_mask}, max_new_tokens=20) + return logits + + def loss(self, cat_input_ids, cat_encoded_label): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + label_input_ids = cat_encoded_label[0, :, :] + label_attention_mask = cat_encoded_label[1, :, :] + loss_t5 = self.model(input_ids, attention_mask=attention_mask, labels=label_input_ids, + decoder_attention_mask=label_attention_mask).loss + return loss_t5 + + +class T5TokenizerInput: + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def __call__(self, _, questions, stories, return_long=False): + prompts = [] + for ind, question in enumerate(questions): + prompts.append("Answer based on the context:\n\n" + stories[ind] + "\n\n" + question) + encoded_input = self.tokenizer(prompts, padding="max_length", truncation=True) + input_ids = encoded_input["input_ids"] + attention_mask = encoded_input["attention_mask"] + input_ids = torch.Tensor(input_ids) if not return_long else torch.LongTensor(input_ids) + attention_mask = torch.Tensor(attention_mask) if not return_long else torch.LongTensor(attention_mask) + return torch.stack((input_ids, attention_mask)) + + +class T5TokenizerOutput: + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def __call__(self, _, labels, return_long=False): + encoded_input = self.tokenizer(labels, padding="max_length", truncation=True) + input_ids = encoded_input["input_ids"] + attention_mask = encoded_input["attention_mask"] + input_ids = torch.Tensor(input_ids) if not return_long else torch.LongTensor(input_ids) + attention_mask = torch.Tensor(attention_mask) if not return_long else torch.LongTensor(attention_mask) + return torch.stack((input_ids, attention_mask)) + + +class T5TokenizerDecoder: + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def __call__(self, _, encoded): + decoded = self.tokenizer.batch_decode(encoded, skip_special_tokens=True, clean_up_toenization_spaces=True) + return decoded + + +class T5LossFunction(torch.nn.Module): + def __init__(self, T5_model): + super().__init__() + self.T5_model = T5_model + + def forward(self, input, target): + input = input.long() + target = target.long() + loss = self.T5_model.loss(input, target) + return loss + + +class T5WithLoraGenerativeCLF(nn.Module): + def __init__(self, model_name, label, tokenizer, output_length=32, device="cpu", adapter=False): + super().__init__() + + self.cur_device = device + if adapter: + print("Using Lora") + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q", "v"], + lora_dropout=0.01, + bias="lora_only", + task_type=TaskType.SEQ_2_SEQ_LM + ) + # prepare int-8 model for training + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, lora_config) + self.model.config.use_cache = False + else: + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self.train_t5_mode = True + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.num_classes = 2 + + label_tokens = tokenizer(label + [","])["input_ids"] + self.interested_tokens = [] + for tokens in label_tokens: + self.interested_tokens.extend(tokens) + self.interested_tokens = list(set(self.interested_tokens)) + self.output_length = output_length + self.hidden_size = len(self.interested_tokens) * self.output_length + + def forward(self, _, cat_input_ids, cat_encoded_label): + if self.train_t5_mode: + return self.train_forward(cat_input_ids, cat_encoded_label) + return self.inference_forward(cat_input_ids) + + def train_forward(self, cat_input_ids, cat_encoded_label): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + label_input_ids = cat_encoded_label[0, :, :] + label_attention_mask = cat_encoded_label[1, :, :] + logits = self.model(input_ids, attention_mask=attention_mask, + labels=label_input_ids, decoder_attention_mask=label_attention_mask).logits + return self.transform_logits(logits) + + def inference_forward(self, cat_input_ids): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + + seq = self.model.generate( + **{'input_ids': input_ids, 'attention_mask': attention_mask, 'min_new_tokens': self.output_length, + 'max_new_tokens': self.output_length + 1}) + logits = self.model(input_ids, attention_mask=attention_mask, + decoder_input_ids=seq).logits + return self.transform_logits(logits) + + def transform_logits(self, logit): + logit = logit[:, :self.output_length, self.interested_tokens].flatten(1, 2) # Combine last two dimensions + return logit + + def train(self: T, mode: bool = True) -> T: + return_val = super().train(mode) + print("Setting Training on T5 to {:}".format(mode)) + self.train_t5_mode = mode + return return_val + + +class T5WithLoraGenerativeCLF2(nn.Module): + def __init__(self, model_name, label, max_group, group_label, tokenizer, device="cpu", adapter=False): + super().__init__() + + self.cur_device = device + if adapter: + print("Using Lora") + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q", "v"], + lora_dropout=0.01, + bias="lora_only", + task_type=TaskType.SEQ_2_SEQ_LM + ) + # prepare int-8 model for training + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, lora_config) + self.model.config.use_cache = False + else: + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self.train_t5_mode = True + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.group_label = group_label + self.max_group = max_group + + self.token_each_label = 0 + + self._space_token = tokenizer(" ")["input_ids"][0] + self._comma_token = tokenizer(",")["input_ids"][1] + self._eos_token = tokenizer(" ")["input_ids"][-1] + self.empty_pred_end = None + self.empty_pred = None + self.label_token_map, self.label_token_map_normalize, self.interested_tokens = self.tokenize_label(label, + tokenizer) + self.output_length = self.max_group * self.token_each_label + 1 # (+ End of sentence) + + def tokenize_label(self, labels, tokenizer): + labels = labels + [" "] + label_tokens = tokenizer(labels)["input_ids"] + self.token_each_label = max([len(tokens) for tokens in label_tokens]) + label_tokens_map = {} + interested_tokens = [] + for label, label_token in zip(labels, label_tokens): + # Format the token + label_token[-1] = self._space_token + label_token += [self._space_token] * (self.token_each_label - len(label_token)) + label_token[-1] = self._comma_token if self.group_label.get(label, + 0) != self.max_group - 1 else self._eos_token + interested_tokens.extend(label_token) + label_tokens_map[label] = label_token + + interested_tokens = sorted(list(set(interested_tokens))) + map_token_loc = {tokens: i for i, tokens in enumerate(interested_tokens)} + label_tokens_map_normalize = {} + for label, tokens in label_tokens_map.items(): + new_tokens = [map_token_loc[token] for token in tokens] + label_tokens_map_normalize[label] = new_tokens + + self.empty_pred_end = ([map_token_loc[self._space_token]] * (self.token_each_label - 1) + + [map_token_loc[self._comma_token]]) + + self.empty_pred = [map_token_loc[self._space_token]] * (self.token_each_label - 1) + [map_token_loc[self._eos_token]] + + return label_tokens_map, label_tokens_map_normalize, interested_tokens + + def forward(self, _, cat_input_ids, cat_encoded_label): + if self.train_t5_mode: + return self._train_forward(cat_input_ids, cat_encoded_label) + return self._inference_forward(cat_input_ids) + + def _train_forward(self, cat_input_ids, cat_encoded_label): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + label_input_ids = cat_encoded_label[0, :, :] + label_attention_mask = cat_encoded_label[1, :, :] + # Need label to generate + logits = self.model(input_ids, attention_mask=attention_mask, + labels=label_input_ids, decoder_attention_mask=label_attention_mask).logits + return self.transform_logits(logits) + + def _inference_forward(self, cat_input_ids): + input_ids = cat_input_ids[0, :, :] + attention_mask = cat_input_ids[1, :, :] + + seq = self.model.generate( + **{'input_ids': input_ids, 'attention_mask': attention_mask, 'min_new_tokens': self.output_length, + 'max_new_tokens': self.output_length + 1}) + + logits = self.model(input_ids, attention_mask=attention_mask, + decoder_input_ids=seq).logits + + return self.transform_logits(logits) + + def transform_logits(self, logit): + logits = logit[:, :self.output_length, self.interested_tokens] + return logits + + def train(self: T, mode: bool = True) -> T: + return_val = super().train(mode) + print("Setting Training on T5 to {:}".format(mode)) + self.train_t5_mode = mode + return return_val + + +class T5LocationClassification(nn.Module): + def __init__(self, token_loc, candidate_output_token, device="cpu"): + super().__init__() + self.st_token, self.ed_token = token_loc + self.candidate_output_token = candidate_output_token + print(self.candidate_output_token) + self.softmax = nn.Softmax(dim=-1) + self.device = device + + def forward(self, _, logits): + logits = logits[:, self.st_token:self.ed_token, :] + all_prob = torch.Tensor().requires_grad_().to(self.device) + for token_label in self.candidate_output_token: + label_prob = logits[:, 0, token_label[0]] + for i, label_token in enumerate(token_label): + if i == 0: + continue + label_prob = torch.mul(label_prob, logits[:, i, token_label[i]]) + + label_prob = label_prob.reshape(-1, 1) + all_prob = torch.concat((all_prob, label_prob), dim=-1) + + all_prob = self.softmax(all_prob) + return all_prob + + +class LabelClassification(nn.Module): + def __init__(self, index_label): + super().__init__() + self.index_label = index_label + + def forward(self, _, all_prob): + label_prob = all_prob[:, self.index_label].reshape(-1, 1) + prob = torch.concat((torch.ones_like(label_prob) - label_prob, label_prob), dim=-1) + return prob diff --git a/SpatialQARules/program_declaration.py b/SpatialQARules/program_declaration.py index 8c7eda16..5bc9775d 100644 --- a/SpatialQARules/program_declaration.py +++ b/SpatialQARules/program_declaration.py @@ -115,12 +115,21 @@ def read_label(_, label): return program +from domiknows.sensor.pytorch.sensors import ReaderSensor, ConcatSensor, FunctionalSensor, JointSensor +from domiknows.sensor.pytorch.learners import ModuleLearner, LSTMLearner +from models import * +from utils import * +from domiknows.sensor.pytorch.relation_sensors import CompositionCandidateSensor + + def program_declaration_spartun_fr(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, - constrains=False, spartun=True): + constraints=False, spartun=True, model="bert"): program = None from graph_spartun_rel import graph, story, story_contain, question, \ left, right, above, below, behind, front, near, far, disconnected, touch, \ - overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2 + overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2, \ + transitive, tran_quest1, tran_quest2, tran_quest3, tran_topo, tran_topo_quest1, \ + tran_topo_quest2, tran_topo_quest3, tran_topo_quest4 story["questions"] = ReaderSensor(keyword="questions") story["stories"] = ReaderSensor(keyword="stories") @@ -150,113 +159,1130 @@ def make_question(questions, stories, relations, q_ids, labels): all_labels = make_labels(labels) ids = to_int_list(q_ids.split("@@")) left_list, right_list, above_list, below_list, behind_list, \ - front_list, near_list, far_list, dc_list, ec_list, po_list, \ - tpp_list, ntpp_list, tppi_list, ntppi_list = all_labels + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list = all_labels return torch.ones(len(questions.split("@@")), 1), questions.split("@@"), stories.split("@@"), \ - relations.split("@@"), ids, left_list, right_list, above_list, below_list, behind_list, \ - front_list, near_list, far_list, dc_list, ec_list, po_list, \ - tpp_list, ntpp_list, tppi_list, ntppi_list + relations.split("@@"), ids, left_list, right_list, above_list, below_list, behind_list, \ + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list question[story_contain, "question", "story", "relation", "id", "left_label", "right_label", - "above_label", "below_label", "behind_label", "front_label", "near_label", "far_label", "dc_label", "ec_label", "po_label", - "tpp_label", "ntpp_label", "tppi_label", "ntppi_label"] = \ + "above_label", "below_label", "behind_label", "front_label", "near_label", "far_label", "dc_label", "ec_label", "po_label", + "tpp_label", "ntpp_label", "tppi_label", "ntppi_label"] = \ JointSensor(story["questions"], story["stories"], story["relations"], story["question_ids"], story["labels"], forward=make_question, device=device) def read_label(_, label): return label - # question[answer_class] = - # FunctionalSensor(story_contain, "label", forward=read_label, label=True, device=cur_device) - # Replace with all classes + # Model + if model == "t5-adapter": + t5_model_id = "google/flan-t5-base" + print("Using", t5_model_id) + question["input_ids"] = JointSensor(story_contain, 'question', "story", + forward=T5Tokenizer(t5_model_id), device=device) + + all_answers = [left, right, above, below, behind, front, + near, far, disconnected, touch, overlap, coveredby, + inside, cover, contain] + expected_label = ["left", "right", "above", "below", "behind", "front", + "near", "far", "disconnected", "touch", "overlap", "covered by", + "inside", "cover", "contain"] + + clf1 = MultipleClassFRT5(t5_model_id, expected_label, device=device, adapter=True) + question["hidden_layer"] = ModuleLearner("input_ids", module=clf1, device=device) + question[left] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[0], map_index=clf1.map_label, + device=device), + device=device) + question[right] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[1], map_index=clf1.map_label, + device=device), + device=device) + question[above] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[2], map_index=clf1.map_label, + device=device), + device=device) + question[below] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[3], map_index=clf1.map_label, + device=device), + device=device) + question[behind] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[4], map_index=clf1.map_label, + device=device), + device=device) + question[front] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[5], map_index=clf1.map_label, + device=device), + device=device) + question[near] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[6], map_index=clf1.map_label, + device=device), + device=device) + question[far] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[7], map_index=clf1.map_label, + device=device), + device=device) + question[disconnected] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[8], map_index=clf1.map_label, + device=device), + device=device) + question[touch] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[9], map_index=clf1.map_label, + device=device), + device=device) + question[overlap] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[10], map_index=clf1.map_label, + device=device), + device=device) + question[coveredby] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[11], map_index=clf1.map_label, + device=device), + device=device) + question[inside] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[12], map_index=clf1.map_label, + device=device), + device=device) + question[cover] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[13], map_index=clf1.map_label, + device=device), + device=device) + question[contain] = ModuleLearner("hidden_layer", + module=ClassifyLabelT5(expected_label[14], map_index=clf1.map_label, + device=device), + device=device) + else: + print("Using BERT") + question["input_ids"] = JointSensor(story_contain, 'question', "story", + forward=BERTTokenizer(), device=device) + clf1 = MultipleClassYN_Hidden.from_pretrained('bert-base-uncased', device=device, drp=dropout) + question["hidden_layer"] = ModuleLearner("input_ids", module=clf1, device=device) + all_answers = [left, right, above, below, behind, front, + near, far, disconnected, touch, overlap, coveredby, + inside, cover, contain] + question[left] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[right] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[above] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[below] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[behind] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[front] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[near] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[far] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[disconnected] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[touch] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[overlap] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[coveredby] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[inside] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[cover] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) + question[contain] = ModuleLearner("hidden_layer", + module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + device=device) - question["input_ids"] = JointSensor(story_contain, 'question', "story", - forward=BERTTokenizer(), device=device) + # Reading label + question[left] = FunctionalSensor(story_contain, "left_label", forward=read_label, label=True, device=device) + question[right] = FunctionalSensor(story_contain, "right_label", forward=read_label, label=True, device=device) + question[above] = FunctionalSensor(story_contain, "above_label", forward=read_label, label=True, device=device) + question[below] = FunctionalSensor(story_contain, "below_label", forward=read_label, label=True, device=device) + question[behind] = FunctionalSensor(story_contain, "behind_label", forward=read_label, label=True, device=device) + question[front] = FunctionalSensor(story_contain, "front_label", forward=read_label, label=True, device=device) + question[near] = FunctionalSensor(story_contain, "near_label", forward=read_label, label=True, device=device) + question[far] = FunctionalSensor(story_contain, "far_label", forward=read_label, label=True, device=device) + question[disconnected] = FunctionalSensor(story_contain, "dc_label", forward=read_label, label=True, device=device) + question[touch] = FunctionalSensor(story_contain, "ec_label", forward=read_label, label=True, device=device) + question[overlap] = FunctionalSensor(story_contain, "po_label", forward=read_label, label=True, device=device) + question[coveredby] = FunctionalSensor(story_contain, "tpp_label", forward=read_label, label=True, device=device) + question[inside] = FunctionalSensor(story_contain, "ntpp_label", forward=read_label, label=True, device=device) + question[cover] = FunctionalSensor(story_contain, "tppi_label", forward=read_label, label=True, device=device) + question[contain] = FunctionalSensor(story_contain, "ntppi_label", forward=read_label, label=True, device=device) - clf1 = MultipleClassYN_Hidden.from_pretrained('bert-base-uncased', device=device, drp=dropout) + # question[left] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[right] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[above] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[below] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[behind] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[front] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[near] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[far] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[disconnected] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[touch] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), device=device) + # question[overlap] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), device=device) + # question[coveredby] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # question[inside] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # + # question[cover] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) + # + # question[contain] = ModuleLearner("hidden_layer", + # module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # device=device) - question["hidden_layer"] = ModuleLearner("input_ids", module=clf1, device=device) + poi_list = [question, left, right, above, below, behind, front, near, far, + disconnected, touch, overlap, coveredby, inside, cover, contain] - question[left] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + if constraints: + print("Included constraints") + inverse[inv_question1.reversed, inv_question2.reversed] = \ + CompositionCandidateSensor( + relations=(inv_question1.reversed, inv_question2.reversed), + forward=check_symmetric, device=device) + + transitive[tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed] = \ + CompositionCandidateSensor( + relations=(tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed), + forward=check_transitive, device=device) + + tran_topo[tran_topo_quest1.reversed, tran_topo_quest2.reversed, + tran_topo_quest3.reversed, tran_topo_quest4.reversed] = \ + CompositionCandidateSensor( + relations=(tran_topo_quest1.reversed, tran_topo_quest2.reversed + , tran_topo_quest3.reversed, tran_topo_quest4.reversed), + forward=check_transitive_topo, device=device) + poi_list.extend([inverse, transitive, tran_topo]) + + from domiknows.program.metric import PRF1Tracker, PRF1Tracker, DatanodeCMMetric, MacroAverageTracker, ValueTracker + from domiknows.program.loss import NBCrossEntropyLoss, BCEWithLogitsIMLoss, BCEFocalLoss + from domiknows.program import LearningBasedProgram, SolverPOIProgram + from domiknows.program.lossprogram import SampleLossProgram, PrimalDualProgram + from domiknows.program.model.pytorch import model_helper, PoiModel, SolverModel + + infer_list = ['ILP', 'local/argmax'] # ['ILP', 'local/argmax'] + if pmd: + print("Using PMD program") + program = PrimalDualProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + beta=beta, + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, + device=device) + elif sampling: + program = SampleLossProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, + sample=True, + sampleSize=sampleSize, + sampleGlobalLoss=False, + beta=1, + device=device) + else: + print("Using Base program") + program = SolverPOIProgram(graph, + poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, device=device) - question[left] = FunctionalSensor(story_contain, "left_label", forward=read_label, label=True, device=device) - question[right] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + return program + + +def program_declaration_spartun_fr_T5(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, + constraints=False, spartun=True): + from graph_spartun_rel import graph, story, story_contain, question, \ + left, right, above, below, behind, front, near, far, disconnected, touch, \ + overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2, \ + transitive, tran_quest1, tran_quest2, tran_quest3, tran_topo, tran_topo_quest1, \ + tran_topo_quest2, tran_topo_quest3, tran_topo_quest4, output_for_loss + + story["questions"] = ReaderSensor(keyword="questions") + story["stories"] = ReaderSensor(keyword="stories") + story["relations"] = ReaderSensor(keyword="relation") + story["question_ids"] = ReaderSensor(keyword="question_ids") + story["labels"] = ReaderSensor(keyword="labels") + all_labels = ["left", "right", "above", "below", "behind", "front", + "near", "far", "disconnected", "touch", "overlap", "covered by", + "inside", "cover", "contain"] + map_label_index = {text: i for i, text in enumerate(all_labels)} + + def to_int_list(x): + return torch.LongTensor([int(i) for i in x]) + + def to_float_list(x): + return torch.Tensor([float(i) for i in x]) + + def make_labels(label_list): + labels = label_list.split("@@") + text_label = ["" for _ in range(len(labels))] + + for ind, bits_label in enumerate(labels): + bits_label = int(bits_label) + cur_bit = 1 + for label in all_labels: + if bits_label & cur_bit: + text_label[ind] += label if text_label[ind] == "" else (", " + label) + cur_bit *= 2 + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + return text_label + + def make_question(questions, stories, relations, q_ids, labels): + text_label = make_labels(labels) + ids = to_int_list(q_ids.split("@@")) + + return torch.ones(len(questions.split("@@")), 1), questions.split("@@"), stories.split("@@"), relations.split( + "@@"), ids, text_label + + question[story_contain, "question", "story", "relation", "id", "text_labels"] = \ + JointSensor(story["questions"], story["stories"], story["relations"], + story["question_ids"], story["labels"], forward=make_question, device=device) + + T5_model = T5WithLora("google/flan-t5-base", device=device, adapter=True) + # defined loss based on the model + LossT5 = T5LossFunction(T5_model=T5_model) + t5_outTokenizer = T5TokenizerOutput('google/flan-t5-base') + t5_inTokenizer = T5TokenizerInput('google/flan-t5-base') + question[output_for_loss] = JointSensor(story_contain, 'question', "story", + forward=t5_inTokenizer, device=device) + + question["input_ids"] = JointSensor(story_contain, 'question', "story", True, + forward=t5_inTokenizer, device=device) + + question[output_for_loss] = FunctionalSensor(story_contain, + 'text_labels', + forward=t5_outTokenizer, + label=True, + device=device) + + all_answers = [left, right, above, below, behind, front, + near, far, disconnected, touch, overlap, coveredby, + inside, cover, contain] + + question["output_encoder"] = ModuleLearner(story_contain, "input_ids", module=T5_model, device=device) + question["output_decoder"] = FunctionalSensor(story_contain, "output_encoder", + forward=T5TokenizerDecoder('google/flan-t5-base'), device=device) + + def read_decoder(_, decoder_list): + text_label = [[0] * 15 for _ in range(len(decoder_list))] + for ind, text_decode in enumerate(decoder_list): + text_decode = text_decode.replace("and", "") + all_relations = text_decode.strip().split(", ") + for relation in all_relations: + relation = relation.strip() + if relation not in map_label_index: + continue + text_label[ind][map_label_index[relation]] = 1 + list_tensor = [to_float_list(labels_list) for labels_list in text_label] + return torch.stack(list_tensor) + + def read_label(_, relation_list, index): + label = relation_list[:, index].reshape((-1, 1)) + label = torch.concat((torch.ones_like(label) - label, label), dim=-1) + return label + + question["output_relations"] = FunctionalSensor(story_contain, "output_decoder", forward=read_decoder, + device=device) + + question[left] = FunctionalSensor(story_contain, "output_relations", 0, forward=read_label, device=device) + question[right] = FunctionalSensor(story_contain, "output_relations", 1, forward=read_label, device=device) + question[above] = FunctionalSensor(story_contain, "output_relations", 2, forward=read_label, device=device) + question[below] = FunctionalSensor(story_contain, "output_relations", 3, forward=read_label, device=device) + question[behind] = FunctionalSensor(story_contain, "output_relations", 4, forward=read_label, device=device) + question[front] = FunctionalSensor(story_contain, "output_relations", 5, forward=read_label, device=device) + question[near] = FunctionalSensor(story_contain, "output_relations", 6, forward=read_label, device=device) + question[far] = FunctionalSensor(story_contain, "output_relations", 7, forward=read_label, device=device) + question[disconnected] = FunctionalSensor(story_contain, "output_relations", 8, forward=read_label, device=device) + question[touch] = FunctionalSensor(story_contain, "output_relations", 9, forward=read_label, device=device) + question[overlap] = FunctionalSensor(story_contain, "output_relations", 10, forward=read_label, device=device) + question[coveredby] = FunctionalSensor(story_contain, "output_relations", 11, forward=read_label, device=device) + question[inside] = FunctionalSensor(story_contain, "output_relations", 12, forward=read_label, device=device) + question[cover] = FunctionalSensor(story_contain, "output_relations", 13, forward=read_label, device=device) + question[contain] = FunctionalSensor(story_contain, "output_relations", 14, forward=read_label, device=device) + + poi_list = [question, left, right, above, below, behind, front, near, far, + disconnected, touch, overlap, coveredby, inside, cover, contain, output_for_loss] + + if constraints: + inverse[inv_question1.reversed, inv_question2.reversed] = \ + CompositionCandidateSensor( + relations=(inv_question1.reversed, inv_question2.reversed), + forward=check_symmetric, device=device) + + transitive[tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed] = \ + CompositionCandidateSensor( + relations=(tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed), + forward=check_transitive, device=device) + + tran_topo[tran_topo_quest1.reversed, tran_topo_quest2.reversed, + tran_topo_quest3.reversed, tran_topo_quest4.reversed] = \ + CompositionCandidateSensor( + relations=(tran_topo_quest1.reversed, tran_topo_quest2.reversed + , tran_topo_quest3.reversed, tran_topo_quest4.reversed), + forward=check_transitive_topo, device=device) + poi_list.extend([inverse, transitive, tran_topo]) + + from domiknows.program.metric import PRF1Tracker, PRF1Tracker, DatanodeCMMetric, MacroAverageTracker, ValueTracker + from domiknows.program.loss import NBCrossEntropyLoss, BCEWithLogitsIMLoss, BCEFocalLoss + from domiknows.program import LearningBasedProgram, SolverPOIProgram + from domiknows.program.lossprogram import SampleLossProgram, PrimalDualProgram + from domiknows.program.model.pytorch import model_helper, PoiModel, SolverModel + + infer_list = ['local/argmax'] # ['ILP', 'local/argmax'] + if pmd: + print("Using PMD program") + program = PrimalDualProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + beta=beta, device=device) - question[right] = FunctionalSensor(story_contain, "right_label", forward=read_label, label=True, device=device) + elif sampling: + program = SampleLossProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + sample=True, + sampleSize=sampleSize, + sampleGlobalLoss=False, + beta=1, + device=device) + else: + print("Using Base program") + program = SolverPOIProgram(graph, + poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + device=device) - question[above] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + return program + + +def program_declaration_spartun_fr_T5_v2(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, + constraints=False, spartun=True): + from graph_spartun_rel import graph, story, story_contain, question, \ + left, right, above, below, behind, front, near, far, disconnected, touch, \ + overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2, \ + transitive, tran_quest1, tran_quest2, tran_quest3, tran_topo, tran_topo_quest1, \ + tran_topo_quest2, tran_topo_quest3, tran_topo_quest4, output_for_loss + + story["questions"] = ReaderSensor(keyword="questions") + story["stories"] = ReaderSensor(keyword="stories") + story["relations"] = ReaderSensor(keyword="relation") + story["question_ids"] = ReaderSensor(keyword="question_ids") + story["labels"] = ReaderSensor(keyword="labels") + all_labels = ["left", "right", "above", "below", "behind", "front", + "near", "far", "disconnected", "touch", "overlap", "covered by", + "inside", "cover", "contain"] + map_label_index = {text: i for i, text in enumerate(all_labels)} + + def to_int_list(x): + return torch.LongTensor([int(i) for i in x]) + + def to_float_list(x): + return torch.Tensor([float(i) for i in x]) + + def make_labels(label_list): + labels = label_list.split("@@") + text_label = ["" for _ in range(len(labels))] + + for ind, bits_label in enumerate(labels): + bits_label = int(bits_label) + cur_bit = 1 + for label in all_labels: + text_label[ind] += label + ":" + ("yes" if bits_label & cur_bit else "no") + " " + cur_bit *= 2 + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + # print(text_label) + return text_label + + def make_question(questions, stories, relations, q_ids, labels): + text_label = make_labels(labels) + ids = to_int_list(q_ids.split("@@")) + + return torch.ones(len(questions.split("@@")), 1), questions.split("@@"), stories.split("@@"), relations.split( + "@@"), ids, text_label + + question[story_contain, "question", "story", "relation", "id", "text_labels"] = \ + JointSensor(story["questions"], story["stories"], story["relations"], + story["question_ids"], story["labels"], forward=make_question, device=device) + + T5_model = T5WithLora("google/flan-t5-base", device=device, adapter=True) + # defined loss based on the model + LossT5 = T5LossFunction(T5_model=T5_model) + t5_outTokenizer = T5TokenizerOutput('google/flan-t5-base') + t5_inTokenizer = T5TokenizerInput('google/flan-t5-base') + question[output_for_loss] = JointSensor(story_contain, 'question', "story", + forward=t5_inTokenizer, device=device) + + question["input_ids"] = JointSensor(story_contain, 'question', "story", True, + forward=t5_inTokenizer, device=device) + + question[output_for_loss] = FunctionalSensor(story_contain, + 'text_labels', + forward=t5_outTokenizer, + label=True, + device=device) + + all_answers = [left, right, above, below, behind, front, + near, far, disconnected, touch, overlap, coveredby, + inside, cover, contain] + + question["output_encoder"] = ModuleLearner(story_contain, "input_ids", module=T5_model, device=device) + question["output_decoder"] = FunctionalSensor(story_contain, "output_encoder", + forward=T5TokenizerDecoder('google/flan-t5-base'), device=device) + + def read_decoder(_, decoder_list): + text_label = [[0] * 15 for _ in range(len(decoder_list))] + for ind, text_decode in enumerate(decoder_list): + all_relations = text_decode.strip() + for label in all_labels: + if all_relations.find(label + ":" + "yes"): # This is may be wrong + text_label[ind][map_label_index[label]] = 1 + list_tensor = [to_float_list(labels_list) for labels_list in text_label] + return torch.stack(list_tensor) + + def read_label(_, relation_list, index): + label = relation_list[:, index].reshape((-1, 1)) + label = torch.concat((torch.ones_like(label) - label, label), dim=-1) + return label + + question["output_relations"] = FunctionalSensor(story_contain, "output_decoder", forward=read_decoder, + device=device) + + question[left] = FunctionalSensor(story_contain, "output_relations", 0, forward=read_label, device=device) + question[right] = FunctionalSensor(story_contain, "output_relations", 1, forward=read_label, device=device) + question[above] = FunctionalSensor(story_contain, "output_relations", 2, forward=read_label, device=device) + question[below] = FunctionalSensor(story_contain, "output_relations", 3, forward=read_label, device=device) + question[behind] = FunctionalSensor(story_contain, "output_relations", 4, forward=read_label, device=device) + question[front] = FunctionalSensor(story_contain, "output_relations", 5, forward=read_label, device=device) + question[near] = FunctionalSensor(story_contain, "output_relations", 6, forward=read_label, device=device) + question[far] = FunctionalSensor(story_contain, "output_relations", 7, forward=read_label, device=device) + question[disconnected] = FunctionalSensor(story_contain, "output_relations", 8, forward=read_label, device=device) + question[touch] = FunctionalSensor(story_contain, "output_relations", 9, forward=read_label, device=device) + question[overlap] = FunctionalSensor(story_contain, "output_relations", 10, forward=read_label, device=device) + question[coveredby] = FunctionalSensor(story_contain, "output_relations", 11, forward=read_label, device=device) + question[inside] = FunctionalSensor(story_contain, "output_relations", 12, forward=read_label, device=device) + question[cover] = FunctionalSensor(story_contain, "output_relations", 13, forward=read_label, device=device) + question[contain] = FunctionalSensor(story_contain, "output_relations", 14, forward=read_label, device=device) + + poi_list = [question, left, right, above, below, behind, front, near, far, + disconnected, touch, overlap, coveredby, inside, cover, contain, output_for_loss] + + if constraints: + inverse[inv_question1.reversed, inv_question2.reversed] = \ + CompositionCandidateSensor( + relations=(inv_question1.reversed, inv_question2.reversed), + forward=check_symmetric, device=device) + + transitive[tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed] = \ + CompositionCandidateSensor( + relations=(tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed), + forward=check_transitive, device=device) + + tran_topo[tran_topo_quest1.reversed, tran_topo_quest2.reversed, + tran_topo_quest3.reversed, tran_topo_quest4.reversed] = \ + CompositionCandidateSensor( + relations=(tran_topo_quest1.reversed, tran_topo_quest2.reversed + , tran_topo_quest3.reversed, tran_topo_quest4.reversed), + forward=check_transitive_topo, device=device) + poi_list.extend([inverse, transitive, tran_topo]) + + from domiknows.program.metric import PRF1Tracker, PRF1Tracker, DatanodeCMMetric, MacroAverageTracker, ValueTracker + from domiknows.program.loss import NBCrossEntropyLoss, BCEWithLogitsIMLoss, BCEFocalLoss + from domiknows.program import LearningBasedProgram, SolverPOIProgram + from domiknows.program.lossprogram import SampleLossProgram, PrimalDualProgram + from domiknows.program.model.pytorch import model_helper, PoiModel, SolverModel + + infer_list = ['local/argmax'] # ['ILP', 'local/argmax'] + if pmd: + print("Using PMD program") + program = PrimalDualProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + beta=beta, device=device) - question[above] = FunctionalSensor(story_contain, "above_label", forward=read_label, label=True, device=device) + elif sampling: + program = SampleLossProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + sample=True, + sampleSize=sampleSize, + sampleGlobalLoss=False, + beta=1, + device=device) + else: + print("Using Base program") + program = SolverPOIProgram(graph, + poi=poi_list, + inferTypes=infer_list, + loss=ValueTracker(LossT5), + device=device) + return program + + +def program_declaration_spartun_fr_T5_v3(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, + constraints=False, spartun=True): + program = None + from graph_spartun_rel import graph, story, story_contain, question, \ + left, right, above, below, behind, front, near, far, disconnected, touch, \ + overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2, \ + transitive, tran_quest1, tran_quest2, tran_quest3, tran_topo, tran_topo_quest1, \ + tran_topo_quest2, tran_topo_quest3, tran_topo_quest4 + + story["questions"] = ReaderSensor(keyword="questions") + story["stories"] = ReaderSensor(keyword="stories") + story["relations"] = ReaderSensor(keyword="relation") + story["question_ids"] = ReaderSensor(keyword="question_ids") + story["labels"] = ReaderSensor(keyword="labels") + all_labels = ["left", "right", "above", "below", "behind", "front", "near", "far", "dc", "ec", "po", "tpp", "ntpp", + "tppi", "ntppi"] + + def to_int_list(x): + return torch.LongTensor([int(i) for i in x]) + + def make_labels(label_list): + labels = label_list.split("@@") + all_labels_list = [[] for _ in range(15)] + for bits_label in labels: + bits_label = int(bits_label) + cur_label = 1 + for ind, label in enumerate(all_labels): + all_labels_list[ind].append(1 if bits_label & cur_label else 0) + cur_label *= 2 + + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + return [to_int_list(labels_list) for labels_list in all_labels_list] + + def make_text_labels(label_list): + labels = label_list.split("@@") + text_label = ["" for _ in range(len(labels))] + + for ind, bits_label in enumerate(labels): + bits_label = int(bits_label) + cur_bit = 1 + for label in all_labels: + if bits_label & cur_bit: + text_label[ind] += label if text_label[ind] == "" else (", " + label) + cur_bit *= 2 + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + return text_label + + def make_question(questions, stories, relations, q_ids, labels): + all_labels = make_labels(labels) + text_labels = make_text_labels(labels) + ids = to_int_list(q_ids.split("@@")) + left_list, right_list, above_list, below_list, behind_list, \ + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list = all_labels + return torch.ones(len(questions.split("@@")), 1), questions.split("@@"), stories.split("@@"), \ + relations.split("@@"), ids, left_list, right_list, above_list, below_list, behind_list, \ + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list, text_labels + + question[story_contain, "question", "story", "relation", "id", "left_label", "right_label", + "above_label", "below_label", "behind_label", "front_label", "near_label", "far_label", "dc_label", "ec_label", "po_label", + "tpp_label", "ntpp_label", "tppi_label", "ntppi_label", "text_label"] = \ + JointSensor(story["questions"], story["stories"], story["relations"], + story["question_ids"], story["labels"], forward=make_question, device=device) + + def read_label(_, label): + return label + + print("Using T5") + + t5_outtokenizer = T5TokenizerOutput('google/flan-t5-base') + t5_inTokenizer = T5TokenizerInput('google/flan-t5-base') + + t5_normal_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base') + + t5_model = T5WithLoraGenerativeCLF("google/flan-t5-base", + label=all_labels, + tokenizer=t5_normal_tokenizer, + device=device, + adapter=True) + + question["input_ids"] = JointSensor(story_contain, 'question', "story", True, + forward=t5_inTokenizer, device=device) + + question["label_input_ids"] = JointSensor(story_contain, "text_label", True, + forward=t5_outtokenizer, device=device) + + question["hidden_layer"] = ModuleLearner(story_contain, "input_ids", "label_input_ids", module=t5_model, + device=device) + + all_answers = [left, right, above, below, behind, front, + near, far, disconnected, touch, overlap, coveredby, + inside, cover, contain] + hidden_layers = 2 + question[left] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) + question[right] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) + question[above] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) question[below] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), device=device) - question[below] = FunctionalSensor(story_contain, "below_label", forward=read_label, label=True, device=device) - question[behind] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), device=device) - question[behind] = FunctionalSensor(story_contain, "behind_label", forward=read_label, label=True, device=device) - question[front] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), device=device) - question[front] = FunctionalSensor(story_contain, "front_label", forward=read_label, label=True, device=device) - question[near] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), device=device) - question[near] = FunctionalSensor(story_contain, "near_label", forward=read_label, label=True, device=device) - question[far] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), device=device) + question[disconnected] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, + device=device, drp=dropout), + device=device) + question[touch] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) + question[overlap] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) + question[coveredby] = ModuleLearner("hidden_layer", + module= + ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, device=device, + drp=dropout), + device=device) + question[inside] = ModuleLearner("hidden_layer", + module=ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, + device=device, drp=dropout), + device=device) + question[cover] = ModuleLearner("hidden_layer", + module=ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, + device=device, drp=dropout), + device=device) + question[contain] = ModuleLearner("hidden_layer", + module=ClassifyLayer2(t5_model.hidden_size, hidden_layer=hidden_layers, + device=device, drp=dropout), + device=device) + + # Reading label + question[left] = FunctionalSensor(story_contain, "left_label", forward=read_label, label=True, device=device) + question[right] = FunctionalSensor(story_contain, "right_label", forward=read_label, label=True, device=device) + question[above] = FunctionalSensor(story_contain, "above_label", forward=read_label, label=True, device=device) + question[below] = FunctionalSensor(story_contain, "below_label", forward=read_label, label=True, device=device) + question[behind] = FunctionalSensor(story_contain, "behind_label", forward=read_label, label=True, device=device) + question[front] = FunctionalSensor(story_contain, "front_label", forward=read_label, label=True, device=device) + question[near] = FunctionalSensor(story_contain, "near_label", forward=read_label, label=True, device=device) question[far] = FunctionalSensor(story_contain, "far_label", forward=read_label, label=True, device=device) + question[disconnected] = FunctionalSensor(story_contain, "dc_label", forward=read_label, label=True, device=device) + question[touch] = FunctionalSensor(story_contain, "ec_label", forward=read_label, label=True, device=device) + question[overlap] = FunctionalSensor(story_contain, "po_label", forward=read_label, label=True, device=device) + question[coveredby] = FunctionalSensor(story_contain, "tpp_label", forward=read_label, label=True, device=device) + question[inside] = FunctionalSensor(story_contain, "ntpp_label", forward=read_label, label=True, device=device) + question[cover] = FunctionalSensor(story_contain, "tppi_label", forward=read_label, label=True, device=device) + question[contain] = FunctionalSensor(story_contain, "ntppi_label", forward=read_label, label=True, device=device) - question[disconnected] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + poi_list = [question, left, right, above, below, behind, front, near, far, + disconnected, touch, overlap, coveredby, inside, cover, contain] + + if constraints: + print("Included constraints") + inverse[inv_question1.reversed, inv_question2.reversed] = \ + CompositionCandidateSensor( + relations=(inv_question1.reversed, inv_question2.reversed), + forward=check_symmetric, device=device) + + transitive[tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed] = \ + CompositionCandidateSensor( + relations=(tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed), + forward=check_transitive, device=device) + + tran_topo[tran_topo_quest1.reversed, tran_topo_quest2.reversed, + tran_topo_quest3.reversed, tran_topo_quest4.reversed] = \ + CompositionCandidateSensor( + relations=(tran_topo_quest1.reversed, tran_topo_quest2.reversed + , tran_topo_quest3.reversed, tran_topo_quest4.reversed), + forward=check_transitive_topo, device=device) + poi_list.extend([inverse, transitive, tran_topo]) + + from domiknows.program.metric import PRF1Tracker, PRF1Tracker, DatanodeCMMetric, MacroAverageTracker, ValueTracker + from domiknows.program.loss import NBCrossEntropyLoss, BCEWithLogitsIMLoss, BCEFocalLoss + from domiknows.program import LearningBasedProgram, SolverPOIProgram + from domiknows.program.lossprogram import SampleLossProgram, PrimalDualProgram + from domiknows.program.model.pytorch import model_helper, PoiModel, SolverModel + + infer_list = ['ILP', 'local/argmax'] # ['ILP', 'local/argmax'] + if pmd: + print("Using PMD program") + program = PrimalDualProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + beta=beta, + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, + device=device) + elif sampling: + program = SampleLossProgram(graph, SolverModel, poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, + sample=True, + sampleSize=sampleSize, + sampleGlobalLoss=False, + beta=1, + device=device) + else: + print("Using Base program") + program = SolverPOIProgram(graph, + poi=poi_list, + inferTypes=infer_list, + loss=MacroAverageTracker(NBCrossEntropyLoss()), + metric={ + 'ILP': PRF1Tracker(DatanodeCMMetric()), + 'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))}, + device=device) + + return program + + +def program_declaration_spartun_fr_T5_v4(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, + constraints=False, spartun=True): + program = None + from graph_spartun_rel import graph, story, story_contain, question, \ + left, right, above, below, behind, front, near, far, disconnected, touch, \ + overlap, coveredby, inside, cover, contain, inverse, inv_question1, inv_question2, \ + transitive, tran_quest1, tran_quest2, tran_quest3, tran_topo, tran_topo_quest1, \ + tran_topo_quest2, tran_topo_quest3, tran_topo_quest4 + + story["questions"] = ReaderSensor(keyword="questions") + story["stories"] = ReaderSensor(keyword="stories") + story["relations"] = ReaderSensor(keyword="relation") + story["question_ids"] = ReaderSensor(keyword="question_ids") + story["labels"] = ReaderSensor(keyword="labels") + all_labels = ["left", "right", "above", "below", "behind", "front", "near", "far", "dc", "ec", "po", "tpp", "ntpp", + "tppi", "ntppi"] + + label_bit = {label: 2 ** i for i, label in enumerate(all_labels)} + + # 6 Group of answer, detail below + all_labels_text = ["left", "right", "above", "below", "behind", "front", + "near", "far", "disconnected", "touch", "overlap", "covered by", + "inside", "cover", "contain"] + + group_label = {"left": 0, "right": 0, + "above": 1, "below": 1, + "behind": 2, "front": 2, + "disconnected": 3, "touch": 3, "overlap": 3, + "near": 4, "far": 4, + "covered by": 5, "inside": 5, "cover": 5, "contain": 5} + + group_number = 6 + + print("Using T5") + + t5_outtokenizer = T5TokenizerOutput('google/flan-t5-base') + t5_inTokenizer = T5TokenizerInput('google/flan-t5-base') + + t5_normal_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base') + + t5_model = T5WithLoraGenerativeCLF2("google/flan-t5-base", + group_label=group_label, + max_group=group_number, + label=all_labels_text, + tokenizer=t5_normal_tokenizer, + device=device, + adapter=True) + + token_each_label = t5_model.token_each_label + token_map_normalize = t5_model.label_token_map_normalize + token_map = t5_model.label_token_map + + def to_int_list(x): + return torch.LongTensor([int(i) for i in x]) + + def make_labels(label_list): + labels = label_list.split("@@") + all_labels_list = [[] for _ in range(15)] + for bits_label in labels: + bits_label = int(bits_label) + cur_label = 1 + for ind, label in enumerate(all_labels): + all_labels_list[ind].append(1 if bits_label & cur_label else 0) + cur_label *= 2 + + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + return [to_int_list(labels_list) for labels_list in all_labels_list] + + def make_text_labels(label_list): + labels = label_list.split("@@") + text_label = ["" for _ in range(len(labels))] + + for ind, bits_label in enumerate(labels): + bits_label = int(bits_label) + labels = [all_labels_text[i] for i, label in enumerate(all_labels) if bits_label & label_bit[label]] + all_labels_group = [" "] * group_number + for label in labels: + label_group = group_label[label] + if all_labels_group[label_group] != " ": + print("ERROR") + all_labels_group[label_group] = label + + all_labels_group = [t5_normal_tokenizer.decode(token_map[label], skip_special_tokens=True) for label in all_labels_group] + + if all_labels_group[-1] == ",": + all_labels_group[-1] = "" + text_label[ind] = "".join(all_labels_group) + # label_nums = [0 if label == "Yes" else 1 if label == "No" else 2 for label in labels] + return text_label + + def make_question(questions, stories, relations, q_ids, labels): + all_labels = make_labels(labels) + text_labels = make_text_labels(labels) + ids = to_int_list(q_ids.split("@@")) + left_list, right_list, above_list, below_list, behind_list, \ + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list = all_labels + return torch.ones(len(questions.split("@@")), 1), questions.split("@@"), stories.split("@@"), \ + relations.split("@@"), ids, left_list, right_list, above_list, below_list, behind_list, \ + front_list, near_list, far_list, dc_list, ec_list, po_list, \ + tpp_list, ntpp_list, tppi_list, ntppi_list, text_labels + + question[story_contain, "question", "story", "relation", "id", "left_label", "right_label", + "above_label", "below_label", "behind_label", "front_label", "near_label", "far_label", "dc_label", "ec_label", "po_label", + "tpp_label", "ntpp_label", "tppi_label", "ntppi_label", "text_label"] = \ + JointSensor(story["questions"], story["stories"], story["relations"], + story["question_ids"], story["labels"], forward=make_question, device=device) + + def read_label(_, label): + return label + + def transform_label_token(labels, token_map): + return [token_map[label] for label in labels] + + question["input_ids"] = JointSensor(story_contain, 'question', "story", True, + forward=t5_inTokenizer, device=device) + + question["label_input_ids"] = JointSensor(story_contain, "text_label", True, + forward=t5_outtokenizer, device=device) + + question["hidden_layer"] = ModuleLearner(story_contain, "input_ids", "label_input_ids", module=t5_model, + device=device) + + # Empty prediction (both classes are 0) + # 1. Left - Right (Further extend with lower-left lower-right in STEPGAME) + candidate_output = ["left", "right"] + index_labels = {"left": 0, "right": 1} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred] + + first_token_clf = T5LocationClassification(token_loc=[0, token_each_label], + candidate_output_token=candidate_output_token, device=device) + question["first_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=first_token_clf) + + question[left] = ModuleLearner(story_contain, "first_token_prob", + module=LabelClassification(index_label=index_labels["left"]), device=device) + question[left] = FunctionalSensor(story_contain, "left_label", forward=read_label, label=True, device=device) + + question[right] = ModuleLearner(story_contain, "first_token_prob", + module=LabelClassification(index_label=index_labels["right"]), device=device) + question[right] = FunctionalSensor(story_contain, "right_label", forward=read_label, label=True, device=device) + # 2. Above - Below + candidate_output = ["above", "below"] + index_labels = {"above": 0, "below": 1} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred] + + second_token_clf = T5LocationClassification(token_loc=[token_each_label, token_each_label * 2], + candidate_output_token=candidate_output_token, device=device) + question["second_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=second_token_clf) + + question[above] = ModuleLearner(story_contain, "second_token_prob", + module=LabelClassification(index_label=index_labels["above"]), device=device) + question[above] = FunctionalSensor(story_contain, "above_label", forward=read_label, label=True, device=device) + + question[below] = ModuleLearner(story_contain, "second_token_prob", + module=LabelClassification(index_label=index_labels["below"]), device=device) + question[below] = FunctionalSensor(story_contain, "below_label", forward=read_label, label=True, device=device) + + # 3.Behind - Front + candidate_output = ["behind", "front"] + index_labels = {"behind": 0, "front": 1} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred] + + third_token_clf = T5LocationClassification(token_loc=[token_each_label * 2, token_each_label * 3], + candidate_output_token=candidate_output_token, device=device) + question["third_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=third_token_clf) + question[behind] = ModuleLearner(story_contain, "third_token_prob", + module=LabelClassification(index_label=index_labels["behind"]), device=device) + question[behind] = FunctionalSensor(story_contain, "behind_label", forward=read_label, label=True, device=device) + + question[front] = ModuleLearner(story_contain, "third_token_prob", + module=LabelClassification(index_label=index_labels["front"]), device=device) + question[front] = FunctionalSensor(story_contain, "front_label", forward=read_label, label=True, device=device) + + # 4. Disconnect, touch, overlap + candidate_output = ["disconnected", "touch", "overlap"] + index_labels = {"disconnected": 0, "touch": 1, "overlap": 2} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred] + + forth_token_clf = T5LocationClassification(token_loc=[token_each_label * 3, token_each_label * 4], + candidate_output_token=candidate_output_token, device=device) + question["forth_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=forth_token_clf) + question[disconnected] = ModuleLearner(story_contain, "forth_token_prob", + module=LabelClassification(index_label=index_labels["disconnected"]), device=device) question[disconnected] = FunctionalSensor(story_contain, "dc_label", forward=read_label, label=True, device=device) - question[touch] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), device=device) + question[touch] = ModuleLearner(story_contain, "forth_token_prob", + module=LabelClassification(index_label=index_labels["touch"]), + device=device) question[touch] = FunctionalSensor(story_contain, "ec_label", forward=read_label, label=True, device=device) - question[overlap] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), device=device) + question[overlap] = ModuleLearner(story_contain, "forth_token_prob", + module=LabelClassification(index_label=index_labels["overlap"]), + device=device) question[overlap] = FunctionalSensor(story_contain, "po_label", forward=read_label, label=True, device=device) - question[coveredby] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + # 5. Near - Far + candidate_output = ["near", "far"] + index_labels = {"near": 0, "far": 1} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred] + + fifth_token_clf = T5LocationClassification(token_loc=[token_each_label * 4, token_each_label * 5], + candidate_output_token=candidate_output_token, device=device) + question["fifth_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=fifth_token_clf) + question[near] = ModuleLearner(story_contain, "fifth_token_prob", + module=LabelClassification(index_label=index_labels["near"]), + device=device) + question[near] = FunctionalSensor(story_contain, "near_label", forward=read_label, label=True, device=device) + question[far] = ModuleLearner(story_contain, "fifth_token_prob", + module=LabelClassification(index_label=index_labels["far"]), + device=device) + question[far] = FunctionalSensor(story_contain, "far_label", forward=read_label, label=True, device=device) + + # 6. Covered by, Inside, Cover, Contain + candidate_output = ["covered by", "inside", "cover", "contain"] + index_labels = {"coveredby": 0, "inside": 1, "cover": 2, "contain": 3} + candidate_output_token = transform_label_token(candidate_output, token_map_normalize) + [t5_model.empty_pred_end] + + sixth_token_clf = T5LocationClassification(token_loc=[token_each_label * 5, token_each_label * 6], + candidate_output_token=candidate_output_token, device=device) + question["sixth_token_prob"] = ModuleLearner(story_contain, "hidden_layer", + module=sixth_token_clf) + question[coveredby] = ModuleLearner(story_contain, "sixth_token_prob", + module=LabelClassification(index_label=index_labels["coveredby"]), device=device) question[coveredby] = FunctionalSensor(story_contain, "tpp_label", forward=read_label, label=True, device=device) - question[inside] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + question[inside] = ModuleLearner(story_contain, "sixth_token_prob", + module=LabelClassification(index_label=index_labels["inside"]), device=device) question[inside] = FunctionalSensor(story_contain, "ntpp_label", forward=read_label, label=True, device=device) - - question[cover] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + question[cover] = ModuleLearner(story_contain, "sixth_token_prob", + module=LabelClassification(index_label=index_labels["cover"]), device=device) question[cover] = FunctionalSensor(story_contain, "tppi_label", forward=read_label, label=True, device=device) - - question[contain] = ModuleLearner("hidden_layer", - module=ClassifyLayer(clf1.hidden_size, device=device, drp=dropout), + question[contain] = ModuleLearner(story_contain, "sixth_token_prob", + module=LabelClassification(index_label=index_labels["contain"]), device=device) question[contain] = FunctionalSensor(story_contain, "ntppi_label", forward=read_label, label=True, device=device) - inverse[inv_question1.reversed, inv_question2.reversed] = \ - CompositionCandidateSensor( - relations=(inv_question1.reversed, inv_question2.reversed), - forward=check_symmetric, device=device) - poi_list = [question, left, right, above, below, behind, front, near, far, - disconnected, touch, overlap, coveredby, inside, cover, contain, inverse] + disconnected, touch, overlap, coveredby, inside, cover, contain] + + if constraints: + print("Included constraints") + inverse[inv_question1.reversed, inv_question2.reversed] = \ + CompositionCandidateSensor( + relations=(inv_question1.reversed, inv_question2.reversed), + forward=check_symmetric, device=device) + + transitive[tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed] = \ + CompositionCandidateSensor( + relations=(tran_quest1.reversed, tran_quest2.reversed, tran_quest3.reversed), + forward=check_transitive, device=device) + + tran_topo[tran_topo_quest1.reversed, tran_topo_quest2.reversed, + tran_topo_quest3.reversed, tran_topo_quest4.reversed] = \ + CompositionCandidateSensor( + relations=(tran_topo_quest1.reversed, tran_topo_quest2.reversed + , tran_topo_quest3.reversed, tran_topo_quest4.reversed), + forward=check_transitive_topo, device=device) + poi_list.extend([inverse, transitive, tran_topo]) from domiknows.program.metric import PRF1Tracker, PRF1Tracker, DatanodeCMMetric, MacroAverageTracker, ValueTracker from domiknows.program.loss import NBCrossEntropyLoss, BCEWithLogitsIMLoss, BCEFocalLoss @@ -266,6 +1292,7 @@ def read_label(_, label): infer_list = ['ILP', 'local/argmax'] # ['ILP', 'local/argmax'] if pmd: + print("Using PMD program") program = PrimalDualProgram(graph, SolverModel, poi=poi_list, inferTypes=infer_list, loss=MacroAverageTracker(NBCrossEntropyLoss()), @@ -287,6 +1314,7 @@ def read_label(_, label): beta=1, device=device) else: + print("Using Base program") program = SolverPOIProgram(graph, poi=poi_list, inferTypes=infer_list, @@ -299,6 +1327,7 @@ def read_label(_, label): return program + def program_declaration_StepGame(device, *, pmd=False, beta=0.5, sampling=False, sampleSize=1, dropout=False, constrains=False, spartun=True): program = None diff --git a/SpatialQARules/reader.py b/SpatialQARules/reader.py index 87908511..fd000115 100644 --- a/SpatialQARules/reader.py +++ b/SpatialQARules/reader.py @@ -1,5 +1,6 @@ import json import tqdm +import random VOCABULARY = { "LEFT": ["to the left of"], @@ -38,32 +39,39 @@ } -def create_key(obj1, obj2, relation): - key = str(obj1) + ":" + str(obj2) + ":" + str(relation) - return key +def create_key(obj1, obj2, relation, question_type): + if question_type == "YN": + return str(obj1) + ":" + str(obj2) + ":" + relation + return str(obj1) + ":" + str(obj2) -def create_simple_question(obj1, obj2, relation, obj_info): - return "Is " + obj_info[obj1]["full_name"] + " " + \ - (VOCABULARY[relation][0] if isinstance(VOCABULARY[relation], list) else VOCABULARY[relation]) \ - + " " + obj_info[obj2]["full_name"] + "?" +def create_simple_question(obj1, obj2, relation, obj_info, question_type): + if question_type == "YN": + return "Is " + obj_info[obj1]["full_name"] + " " + \ + (VOCABULARY[relation][0] if isinstance(VOCABULARY[relation], list) else VOCABULARY[relation]) \ + + " " + obj_info[obj2]["full_name"] + "?" + + question_fr1 = "Where is {:} relative to the {:}?".format(obj_info[obj1]["full_name"], + obj_info[obj2]["full_name"]) + question_fr2 = "What is the position of the {:} regarding {:}".format(obj_info[obj1]["full_name"], + obj_info[obj2]["full_name"]) + return question_fr1 if random.random() < 0.5 else question_fr2 def label_fr_to_int(labels: list): result = 0 for label in labels: - result += LABELS_INT[label] + result += LABELS_INT[label.upper()] return result -def train_reader(file, question_type, size=None, upward_level=0): +def train_reader(file, question_type, *, limit_questions=300000, upward_level=0): with open(file) as json_file: data = json.load(json_file) - size = 300000 if not size else size print("level:", upward_level) - + print("USING THIS") dataset = [] - count = 0 + count_questions = 0 count_original = 0 all_batch_dynamic_info = {} for story in data["data"]: @@ -72,11 +80,12 @@ def train_reader(file, question_type, size=None, upward_level=0): obj_info = story["objects_info"] relation_info = {} question_id = {} - run_id = 0 + run_id_within_q = 0 for question in story["questions"]: - if count >= size: + if count_questions >= limit_questions: break + question_txt = question["question"] q_type = question["q_type"] @@ -85,11 +94,13 @@ def train_reader(file, question_type, size=None, upward_level=0): candidates = question['candidate_answers'] + # Finding the target relation (Can be more than 1?) target_relation = question['question_info']['target_relation'][0] \ if isinstance(question['question_info']['target_relation'], list) \ else question['question_info']['target_relation'] target_relation = target_relation.upper() + # Finding the asked relation (Can be more than 1?) asked_relation = question['question_info']['asked_relation'][0] \ if isinstance(question['question_info']['asked_relation'], list) \ else question['question_info']['asked_relation'] @@ -99,47 +110,67 @@ def train_reader(file, question_type, size=None, upward_level=0): obj1, obj2 = question['query'] target_question = (obj1, obj2, target_relation) asked_question = (obj1, obj2, asked_relation) - current_key = create_key(*asked_question) + current_key = create_key(*asked_question, question_type) - added_questions = [] - deep_level = upward_level + added_questions = [] # questions to be added to the model + reasoning_steps_from_target = upward_level + # Create question id of current answer if current_key not in question_id: - question_id[current_key] = run_id - run_id += 1 - added_questions.append((question_txt, question['answer'][0], current_key)) + question_id[current_key] = run_id_within_q + run_id_within_q += 1 + + if question_type == "YN": + label = question["answer"][0] + else: + label = label_fr_to_int(question["answer"]) - if question['answer'][0] == "No": - target_key = create_key(*target_question) - added_questions.append((create_simple_question(*target_question, obj_info), "Yes", target_key)) + added_questions.append((question_txt, label, current_key)) - if target_key not in question_id: - question_id[target_key] = run_id - run_id += 1 - relation_info[current_key] = "reverse," + str(question_id[target_key]) + if question_type == "YN": + # If the answer of question is no, adding another question asking the same thing but "Yes" input + if question['answer'][0].lower() == "no": + target_key = create_key(*target_question, question_type) + added_questions.append((create_simple_question(*target_question, obj_info, question_type), + "Yes", + target_key)) - deep_level -= 1 + if target_key not in question_id: + question_id[target_key] = run_id_within_q + run_id_within_q += 1 + relation_info[current_key] = "reverse," + str(question_id[target_key]) + + reasoning_steps_from_target -= 1 current_level = [target_question] - for _ in range(deep_level): + for _ in range(reasoning_steps_from_target): new_level = [] for current_fact in current_level: - current_key = create_key(*current_fact) + current_key = create_key(*current_fact, question_type) + fact_info_key = create_key(*current_fact, "") previous_ids = [] if current_key not in question_id: - question_id[current_key] = run_id - run_id += 1 - previous_facts = facts_info[current_key]["previous"] + question_id[current_key] = run_id_within_q + run_id_within_q += 1 + previous_facts = facts_info[fact_info_key][current_fact[2]]["previous"] for previous in previous_facts: - previous_key = create_key(*previous) + previous_key = create_key(*previous, question_type) + fact_info_prev_key = create_key(*previous, "") if previous_key not in question_id: - question_id[previous_key] = run_id - run_id += 1 - previous_ids.append(question_id[previous_key]) + question_id[previous_key] = run_id_within_q + run_id_within_q += 1 + previous_ids.append(str(question_id[previous_key])) new_level.append(previous) - added_questions.append((create_simple_question(*previous, obj_info), "Yes", previous_key)) + if question_type == "YN": + added_questions.append((create_simple_question(*previous, obj_info, question_type), + "Yes", + previous_key)) + else: + added_questions.append((create_simple_question(*previous, obj_info, question_type), + label_fr_to_int(list(facts_info[fact_info_prev_key].keys())), + previous_key)) current_level = new_level size_relation = len(previous_ids) @@ -149,8 +180,7 @@ def train_reader(file, question_type, size=None, upward_level=0): relation_type = "symmetric" if size_relation == 1 \ else "transitive" if size_relation == 2 \ else "transitive_topo" - for previous_id in previous_ids: - relation_type += "," + str(previous_id) + relation_type = relation_type + ',' + ','.join(previous_ids) relation_info[current_key] = relation_type if len(added_questions) not in all_batch_dynamic_info: @@ -164,25 +194,25 @@ def train_reader(file, question_type, size=None, upward_level=0): candidates, relation_info[question_key] if question_key in relation_info else "", label, question_id[question_key])) - count += 1 + count_questions += 1 dataset.append(batch_question) - # print("Original questions", count_original) - # print("Total questions", count) - # print(all_batch_dynamic_info) + print("Original questions", count_original) + print("Total questions", count_questions) + print(all_batch_dynamic_info) # Return Type need to be list of dict with name of variable as key return dataset -def test_reader(file, question_type, size=None): +def general_reader(file, question_type, size=None): with open(file) as json_file: data = json.load(json_file) - size = 300000 if not size else size + size = 10 ** 6 if not size else size dataset = [] count = 0 for story in data["data"]: - story_txt = story['story'][0] + story_txt = " ".join(story['story']) question_id = {} run_id = 0 @@ -228,6 +258,34 @@ def test_reader(file, question_type, size=None): return dataset +def RESQ_reader(file, question_type, size=None, reasoning=None): + with open(file) as json_file: + data = json.load(json_file) + size = 300000 if not size else size + + dataset = [] + count = 0 + for story in data["data"]: + story_txt = " ".join(story['story']) + run_id = 0 + for question in story["questions"]: + if count >= size: + break + if reasoning is not None: + if reasoning == 0 and isinstance(question["step_of_reasoning"], int): + continue + if reasoning != 0 and question["step_of_reasoning"] != reasoning: + continue + question_txt = question["question"] + candidates = question['candidate_answers'] + label = question["answer"][0] if question["answer"][0] != "DK" else "NO" + dataset.append([[question_txt, story_txt, "YN", candidates, "", label, run_id]]) + run_id += 1 + count += 1 + + return dataset + + def boolQ_reader(file, size=None): with open(file) as json_file: data = json.load(json_file) @@ -260,10 +318,10 @@ def StepGame_reader(prefix, train_dev_test="train", size=None, file_number=None) else: files = ["qa" + str(file_number + 1) + "_test.json"] - dataset = [] + print(prefix, files) for file in files: - with open(prefix+ "/" + file) as json_file: + with open(prefix + "/" + file) as json_file: data = json.load(json_file) size = 300000 if not size else size run_id = 0 @@ -282,20 +340,40 @@ def StepGame_reader(prefix, train_dev_test="train", size=None, file_number=None) return dataset -def DomiKnowS_reader(file, question_type, size=None, upward_level=0, augmented=True, boolQL=False, batch_size=8, - rule=False, StepGame_status=None, StepGame_number=None): - dataset = StepGame_reader(file, StepGame_status, size, file_number=StepGame_number) if StepGame_status \ - else train_reader(file, question_type, size, upward_level) if augmented \ - else boolQ_reader(file, size) if boolQL else test_reader(file, question_type, size) +def DomiKnowS_reader(file, question_type, size=300000, *, + type_dataset=None, + upward_level=0, + augmented=True, + batch_size=8, + rule_text=False, + reasoning_steps=None, + STEPGAME_status="train"): + print(type_dataset, reasoning_steps) + if type_dataset == "STEPGAME": + dataset = StepGame_reader(file, STEPGAME_status, size, file_number=reasoning_steps) + elif type_dataset == "BOOLQ": + dataset = boolQ_reader(file, size) + elif type_dataset == "RESQ": + dataset = RESQ_reader(file, size, reasoning=reasoning_steps) + elif augmented: # Refer to SPARTUN with chain of reasoning when training + dataset = train_reader(file, question_type, limit_questions=size, upward_level=upward_level) + else: + dataset = general_reader(file, question_type, size) + additional_text = "" - if rule: + if rule_text: with open("DataSet/rules.txt", 'r') as rules: additional_text = rules.readline() return_dataset = [] current_batch_size = 0 + count_question = 0 batch_data = {'questions': [], 'stories': [], 'relation': [], 'labels': [], "question_ids": []} - for batch in tqdm.tqdm(dataset, desc="Reading " + file + " " + (str(StepGame_status) if StepGame_status is not None else "")): - if current_batch_size + len(batch) > batch_size and current_batch_size != 0: + for batch in tqdm.tqdm(dataset, desc="Reading " + file + " " + ( + str(STEPGAME_status) if STEPGAME_status is not None else "")): + count_question += len(batch) + # Checking each batch have same story, prevent mixing IDs + check_same_story = current_batch_size != 0 and batch[0][1] == batch_data["stories"][0] + if (current_batch_size + len(batch) > batch_size) and current_batch_size != 0: current_batch_size = 0 return_dataset.append({"questions": "@@".join(batch_data['questions']), "stories": "@@".join(batch_data['stories']), @@ -317,5 +395,14 @@ def DomiKnowS_reader(file, question_type, size=None, upward_level=0, augmented=T "relation": "@@".join(batch_data['relation']), "question_ids": "@@".join(batch_data['question_ids']), "labels": "@@".join(batch_data['labels'])}) - + print("Total question:", count_question) return return_dataset + + +if __name__ == "__main__": + # print(os.path.abspath(".")) + + # dataset = train_reader("DataSet/train_v3.json", "FR", limit_questions=300000, upward_level=100) + datasets = DomiKnowS_reader("DataSet/train_v3.json", "FR", upward_level=14, augmented=True, batch_size=8) + print(len(datasets)) + print(datasets[0]) diff --git a/SpatialQARules/utils.py b/SpatialQARules/utils.py index 7a36c10a..c3ec5cf5 100644 --- a/SpatialQARules/utils.py +++ b/SpatialQARules/utils.py @@ -6,12 +6,9 @@ def check_symmetric(arg1, arg2): return False relation_describe = relation_arg2.split(',') if relation_describe[0] == "symmetric": - story1 = arg1.getAttribute("story") - story2 = arg2.getAttribute("story") - if story1 == story2: - qid1 = arg1.getAttribute("id").item() - if qid1 == int(relation_describe[1]): - return True + qid1 = arg1.getAttribute("id").item() + if qid1 == int(relation_describe[1]): + return True return False @@ -23,12 +20,9 @@ def check_reverse(arg10, arg20): return False relation_describe = relation_arg2.split(',') if relation_describe[0] == "reverse": - story1 = arg10.getAttribute("story") - story2 = arg20.getAttribute("story") - if story1 == story2: - qid1 = arg10.getAttribute("id").item() - if qid1 == int(relation_describe[1]): - return True + qid1 = arg10.getAttribute("id").item() + if qid1 == int(relation_describe[1]): + return True return False @@ -40,14 +34,10 @@ def check_transitive(arg11, arg22, arg33): return False relation_describe = relation_arg3.split(',') if relation_describe[0] == "transitive": - story1 = arg11.getAttribute("story") - story2 = arg22.getAttribute("story") - story3 = arg33.getAttribute("story") - if story1 == story2 and story2 == story3: - qid1 = arg11.getAttribute("id").item() - qid2 = arg22.getAttribute("id").item() - if qid1 == int(relation_describe[1]) and qid2 == int(relation_describe[2]): - return True + qid1 = arg11.getAttribute("id").item() + qid2 = arg22.getAttribute("id").item() + if qid1 == int(relation_describe[1]) and qid2 == int(relation_describe[2]): + return True return False @@ -59,14 +49,9 @@ def check_transitive_topo(arg111, arg222, arg333, arg444): return False relation_describe = relation_arg4.split(',') if relation_describe[0] == "transitive_topo": - story1 = arg111.getAttribute("story") - story2 = arg222.getAttribute("story") - story3 = arg333.getAttribute("story") - story4 = arg444.getAttribute("story") - if story1 == story2 and story2 == story3 and story3 == story4: - qid1 = arg111.getAttribute("id").item() - qid2 = arg222.getAttribute("id").item() - qid3 = arg333.getAttribute("id").item() - if qid1 == int(relation_describe[1]) and qid2 == int(relation_describe[2]) and qid3 == int(relation_describe[3]): - return True + qid1 = arg111.getAttribute("id").item() + qid2 = arg222.getAttribute("id").item() + qid3 = arg333.getAttribute("id").item() + if qid1 == int(relation_describe[1]) and qid2 == int(relation_describe[2]) and qid3 == int(relation_describe[3]): + return True return False