Skip to content

Commit

Permalink
test exec utors compile but tdf validation fails
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalgoyal committed Oct 16, 2024
1 parent ab9c266 commit 164aa5c
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,14 @@ public Labeller<S, D, R, C, T> getLabeller(){
return this.labeller;
}

public TrainingDataFinder<S, D, R, C, T> getFinder() {
return finder;
}

public void setFinder(TrainingDataFinder<S, D, R, C, T> finder) {
this.finder = finder;
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,23 @@ public void execute() throws ZinggClientException {
trainer.execute();
matcher.execute();
}

public Trainer<S, D, R, C, T> getTrainer() {
return trainer;
}

public void setTrainer(Trainer<S, D, R, C, T> trainer) {
this.trainer = trainer;
}

public Matcher<S, D, R, C, T> getMatcher() {
return matcher;
}

public void setMatcher(Matcher<S, D, R, C, T> matcher) {
this.matcher = matcher;
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ public abstract class ExecutorValidator<S, D, R, C, T> {

public static final Log LOG = LogFactory.getLog(ExecutorValidator.class);

public ZinggBase<S,D, R, C, T> executorObj;
public ZinggBase<S,D, R, C, T> executor;

public ExecutorValidator(ZinggBase<S, D, R, C, T> executorObj) {
this.executorObj = executorObj;
public ExecutorValidator(ZinggBase<S, D, R, C, T> executor) {
this.executor = executor;
}

public abstract void validateResults() throws ZinggClientException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ public class FindAndLabelValidator<S, D, R, C, T> extends ExecutorValidator<S, D
public TrainingDataFinderValidator<S, D, R, C, T> tdfv;
public LabellerValidator <S, D, R, C, T> lv;

public FindAndLabelValidator(FindAndLabeller<S, D, R, C, T> validator) {
super(validator);
public FindAndLabelValidator(FindAndLabeller<S, D, R, C, T> executor) {
super(executor);
this.tdfv = new TrainingDataFinderValidator<S, D, R, C, T>(executor.getFinder());
this.lv = new LabellerValidator <S, D, R, C, T>(executor.getLabeller());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ public class LabellerValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R,

public static final Log LOG = LogFactory.getLog(LabellerValidator.class);

public LabellerValidator(Labeller<S, D, R, C, T> validator) {
super(validator);
public LabellerValidator(Labeller<S, D, R, C, T> executor) {
super(executor);
}

@Override
public void validateResults() throws ZinggClientException {
// check that marked data has at least 1 match row and 1 unmatch row
ZFrame<D, R, C> dfMarked = validator.getContext().getPipeUtil().
read(false, false, validator.getContext().getPipeUtil().getTrainingDataMarkedPipe(validator.getArgs()));
ZFrame<D, R, C> dfMarked = executor.getContext().getPipeUtil().
read(false, false, executor.getContext().getPipeUtil().getTrainingDataMarkedPipe(executor.getArgs()));

C matchCond = dfMarked.equalTo(ColName.MATCH_FLAG_COL, 1);
C notMatchCond = dfMarked.equalTo(ColName.MATCH_FLAG_COL, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class LinkerValidator<S, D, R, C, T> extends MatcherValidator<S, D, R, C,

public static final Log LOG = LogFactory.getLog(LinkerValidator.class);

public LinkerValidator(Matcher<S, D, R, C, T> validator) {
super(validator);
public LinkerValidator(Matcher<S, D, R, C, T> executor) {
super(executor);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class MatcherValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R,

public static final Log LOG = LogFactory.getLog(MatcherValidator.class);

public MatcherValidator(Matcher<S, D, R, C, T> validator) {
super(validator);
public MatcherValidator(Matcher<S, D, R, C, T> executor) {
super(executor);
}

@Override
Expand Down Expand Up @@ -70,7 +70,7 @@ protected void testAccuracy(ZFrame<D, R, C> gold, ZFrame<D, R, C> result) throws


public ZFrame<D, R, C> getOutputData() throws ZinggClientException {
ZFrame<D, R, C> output = validator.getContext().getPipeUtil().read(false, false, validator.getArgs().getOutput()[0]);
ZFrame<D, R, C> output = executor.getContext().getPipeUtil().read(false, false, executor.getArgs().getOutput()[0]);
return output;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ public TestExecutorsCompound() {

@Override
public List<ExecutorTester<S, D, R, C, T>> getExecutors() throws ZinggClientException{
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getFindAndLabeller(), new FindAndLabelValidator<S, D, R, C, T>(getFindAndLabeller())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getFindAndLabeller(), new FindAndLabelValidator<S, D, R, C, T>(getFindAndLabeller())));
FindAndLabeller<S, D, R, C, T> findAndLabel = getFindAndLabeller();
FindAndLabelValidator<S, D, R, C, T> falValidator = new FindAndLabelValidator<S, D, R, C, T>(findAndLabel);
ExecutorTester<S, D, R, C, T> et = new ExecutorTester<S, D, R, C, T>(findAndLabel, falValidator);
executorTesterList.add(et);
executorTesterList.add(et);
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getTrainMatcher(),getTrainMatchValidator(getTrainMatcher())));
return executorTesterList;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ public List<ExecutorTester<S, D, R, C, T>> getExecutors() throws ZinggClientExce
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getLabeller(), new LabellerValidator<S, D, R, C, T>(getLabeller())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getTrainingDataFinder(), new TrainingDataFinderValidator<S, D, R, C, T>(getTrainingDataFinder())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getLabeller(), new LabellerValidator<S, D, R, C, T>(getLabeller())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getTrainingDataFinder(), new TrainingDataFinderValidator<S, D, R, C, T>(getTrainingDataFinder())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getLabeller(), new LabellerValidator<S, D, R, C, T>(getLabeller())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getTrainer(),getTrainerValidator(getTrainer())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getMatcher(),new MatcherValidator(getMatcher())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getLinker(),new LinkerValidator(getLinker())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getMatcher(),new MatcherValidator<S, D, R, C, T>(getMatcher())));
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(getLinker(),new LinkerValidator<S, D, R, C, T>(getLinker())));
return executorTesterList;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ public class TrainMatchValidator<S, D, R, C, T> extends ExecutorValidator<S, D,
TrainerValidator<S, D, R, C, T> tv;
MatcherValidator<S, D, R, C, T> mv;

public TrainMatchValidator(TrainMatcher<S, D, R, C, T> validator, IArguments args) {
super(validator);
public TrainMatchValidator(TrainMatcher<S, D, R, C, T> executor, IArguments args) {
super(executor);
this.args = args;
tv = new TrainerValidator<S, D, R, C, T>(executor.getTrainer(), args);
mv = new MatcherValidator<S, D, R, C, T>(executor.getMatcher());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;

public abstract class TrainerValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R, C, T> {
public class TrainerValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R, C, T> {

public static final Log LOG = LogFactory.getLog(TrainerValidator.class);

protected IArguments args;

public TrainerValidator(Trainer<S, D, R, C, T> validator,IArguments args) {
super(validator);
public TrainerValidator(Trainer<S, D, R, C, T> executor,IArguments args) {
super(executor);
this.args = args;
}

@Override
public void validateResults() throws ZinggClientException {
//doesnt do anything
//TODO - add modele xistence checks
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ public class TrainingDataFinderValidator<S, D, R, C, T> extends ExecutorValidato

public static final Log LOG = LogFactory.getLog(TrainingDataFinderValidator.class);

public TrainingDataFinderValidator(TrainingDataFinder<S, D, R, C, T> validator) {
super(validator);
public TrainingDataFinderValidator(TrainingDataFinder<S, D, R, C, T> executor) {
super(executor);
}

@Override
public void validateResults() throws ZinggClientException {
// check that unmarked data has at least 10 rows
ZFrame<D, R, C> df = validator.getContext().getPipeUtil().read(false, false, validator.getContext().getPipeUtil().getTrainingDataUnmarkedPipe(validator.getArgs()));
ZFrame<D, R, C> df = executor.getContext().getPipeUtil().read(false, false, executor.getContext().getPipeUtil().getTrainingDataUnmarkedPipe(executor.getArgs()));

long trainingDataCount = df.count();
assertTrue(trainingDataCount > 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,46 +35,12 @@ public class ZinggSparkContext extends Context<SparkSession, Dataset<Row>, Row,C
public static final Log LOG = LogFactory.getLog(ZinggSparkContext.class);


<<<<<<< HEAD
public void initSessionAndContext(SparkSession session)
throws ZinggClientException {
try{
// if (session==null) {
// session = SparkSession
// .builder()
// .appName("Zingg")
// .getOrCreate();
//
// //session = new SparkSession(spark, license);
// }
this.session = session;
if (ctx==null) {
ctx = JavaSparkContext.fromSparkContext(session.sparkContext());
JavaSparkContext.jarOfClass(IZingg.class);
LOG.debug("Context " + ctx.toString());
//initHashFns();
ctx.setCheckpointDir("/tmp/checkpoint");
}
}
catch(Throwable e) {
if (LOG.isDebugEnabled()) e.printStackTrace();
throw new ZinggClientException(e.getMessage());
}
=======
@Override
public void init(SparkSession session)
throws ZinggClientException {
this.session = session;
setUtils();

>>>>>>> 622a907d (init changes to spark)
}

@Override
public void init(SparkSession session)
throws ZinggClientException {
initSessionAndContext(session);
setUtils();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.junit.jupiter.api.AfterEach;

import zingg.common.client.IZingg;
import zingg.common.client.ZinggClientException;
import zingg.common.core.executor.Labeller;
import zingg.common.core.executor.TestExecutorsCompound;
Expand All @@ -21,11 +23,11 @@
import zingg.common.core.executor.TrainerValidator;
import zingg.spark.core.context.ZinggSparkContext;

public class TestSparkExecutors extends TestExecutorsCompound<SparkSession,Dataset<Row>,Row,Column,DataType> {
public class TestSparkExecutors extends TestExecutorsGeneric<SparkSession,Dataset<Row>,Row,Column,DataType> {
protected static final String CONFIG_FILE = "zingg/spark/core/executor/configSparkIntTest.json";
protected static final String TEST_DATA_FILE = "zingg/spark/core/executor/test.csv";

protected static final String CONFIGLINK_FILE = "ingg/spark/core/executor/configSparkLinkTest.json";
protected static final String CONFIGLINK_FILE = "zingg/spark/core/executor/configSparkLinkTest.json";
protected static final String TEST1_DATA_FILE = "zingg/spark/core/executor/test1.csv";
protected static final String TEST2_DATA_FILE = "zingg/spark/core/executor/test2.csv";

Expand All @@ -39,6 +41,11 @@ public TestSparkExecutors() throws IOException, ZinggClientException {
.master("local[*]")
.appName("Zingg" + "Junit")
.getOrCreate();

JavaSparkContext ctx1 = new JavaSparkContext(spark.sparkContext());
JavaSparkContext.jarOfClass(IZingg.class);
ctx1.setCheckpointDir("/tmp/checkpoint");

this.ctx = new ZinggSparkContext();
this.ctx.setSession(spark);
this.ctx.setUtils();
Expand Down Expand Up @@ -84,6 +91,7 @@ protected SparkLinker getLinker() throws ZinggClientException {
protected SparkTrainerTester getTrainerValidator(Trainer<SparkSession,Dataset<Row>,Row,Column,DataType> trainer) {
return new SparkTrainerTester(trainer,args);
}
/*
@Override
protected SparkFindAndLabeller getFindAndLabeller() throws ZinggClientException {
Expand All @@ -102,6 +110,7 @@ protected SparkTrainMatcher getTrainMatcher() throws ZinggClientException {
protected SparkTrainMatchTester getTrainMatchValidator(TrainMatcher<SparkSession,Dataset<Row>,Row,Column,DataType> trainMatch) {
return new SparkTrainMatchTester(trainMatch,args);
}
*/

@Override
public String setupArgs() throws ZinggClientException, IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public static void setup() {
.getOrCreate();
ctx = new JavaSparkContext(spark.sparkContext());
JavaSparkContext.jarOfClass(IZingg.class);
ctx.setCheckpointDir("/tmp/checkpoint");
args = new Arguments();
zsCTX = new ZinggSparkContext();
zsCTX.init(spark);
Expand Down

0 comments on commit 164aa5c

Please sign in to comment.