diff --git a/.DS_Store b/.DS_Store index 0d364ad20..a52ee2cd7 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 9a7d5edcf..dafebe945 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -13,11 +13,16 @@ name: "CodeQL" on: push: - + branches: [ main, 0.4.0 ] pull_request: + # The branches below must be a subset of the branches above + branches: [ main, 0.4.0 ] paths-ignore: - '**/*.md' - '**/*.txt' +>>>>>>> 0.4.0 + schedule: + - cron: '22 3 * * 5' jobs: analyze: diff --git a/.gitignore b/.gitignore index 18b75246c..e7f285da9 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,30 @@ python/docs/_build/_doctrees **/python/build/* **/assembly/.classpath **/.DS_Store + +# Python stuff +.env +.venv + +# Sphinx _build +**/_build + +# Helix stuff +.helix + +# Emacs stuff +.dir-locals.el + +# JDTLS stuff +.package +.classpath +.project +.settings +.factorypath + +# Metals LSP +.metals +.bloop + +# Hadoop & Spark binaries +spark-* \ No newline at end of file diff --git a/assembly/dependency-reduced-pom.xml b/assembly/dependency-reduced-pom.xml index 2e7dabf83..041aa6573 100644 --- a/assembly/dependency-reduced-pom.xml +++ b/assembly/dependency-reduced-pom.xml @@ -45,7 +45,7 @@ maven-assembly-plugin - 2.4.1 + 3.6.0 make-assembly diff --git a/assembly/pom.xml b/assembly/pom.xml index b1282eda8..7027c77a8 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -112,7 +112,7 @@ maven-assembly-plugin - 2.4.1 + 3.6.0 ${project.basedir}/src/assembly/dist.xml diff --git a/buf.gen.yaml b/buf.gen.yaml new file mode 100644 index 000000000..3265e6a2c --- /dev/null +++ b/buf.gen.yaml @@ -0,0 +1,13 @@ +version: v1 +plugins: + # Building the Java classes + - plugin: buf.build/protocolbuffers/java:v25.3 + out: spark/client/src/main/java + # Building the Python build and building the mypy interfaces. + - plugin: buf.build/protocolbuffers/python:v25.3 + out: python/zingg_v2/proto + - plugin: buf.build/grpc/python:v1.62.0 + out: python/zingg_v2/proto + - plugin: buf.build/community/nipunn1313-mypy:v3.5.0 + out: python/zingg_v2/proto + diff --git a/buf.work.yaml b/buf.work.yaml new file mode 100644 index 000000000..540a3936d --- /dev/null +++ b/buf.work.yaml @@ -0,0 +1,3 @@ +version: v1 +directories: + - protobuf diff --git a/common/client/src/main/java/zingg/common/client/Arguments.java b/common/client/src/main/java/zingg/common/client/Arguments.java index a47323eef..4ec0bda44 100644 --- a/common/client/src/main/java/zingg/common/client/Arguments.java +++ b/common/client/src/main/java/zingg/common/client/Arguments.java @@ -4,6 +4,7 @@ import java.io.Serializable; import java.io.StringWriter; import java.util.List; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -16,6 +17,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import zingg.common.client.pipe.Pipe; +import zingg.common.client.util.JsonStringify; /** @@ -169,7 +171,7 @@ public void setLabelDataSampleSize(float labelDataSampleSize) throws ZinggClient public List getFieldDefinition() { return fieldDefinition; } - + /** * Set the field definitions consisting of match field indices, types and * classes @@ -308,18 +310,7 @@ public void checkNullBlankEmpty(Pipe[] field, String fieldName) throws ZinggClie @Override public String toString() { - ObjectMapper mapper = new ObjectMapper(); - mapper.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, - true); - //mapper.configure(JsonParser.Feature.FAIL_ON_EMPTY_BEANS, true) - try { - StringWriter writer = new StringWriter(); - return mapper.writeValueAsString(this); - } catch (IOException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - return null; - } + return JsonStringify.toString(this); } /** diff --git a/common/client/src/main/java/zingg/common/client/ArgumentsUtil.java b/common/client/src/main/java/zingg/common/client/ArgumentsUtil.java index df3a7c3c5..97a55c4f2 100644 --- a/common/client/src/main/java/zingg/common/client/ArgumentsUtil.java +++ b/common/client/src/main/java/zingg/common/client/ArgumentsUtil.java @@ -19,7 +19,7 @@ public class ArgumentsUtil { - protected Class argsClass; + protected Class argsClass; private static final String ENV_VAR_MARKER_START = "$"; private static final String ENV_VAR_MARKER_END = "$"; private static final String ESC = "\\"; @@ -31,7 +31,7 @@ public ArgumentsUtil() { this(Arguments.class); } - public ArgumentsUtil( Class argsClass) { + public ArgumentsUtil( Class argsClass) { this.argsClass = argsClass; } diff --git a/common/client/src/main/java/zingg/common/client/Client.java b/common/client/src/main/java/zingg/common/client/Client.java index 8610e2089..88a85ef78 100644 --- a/common/client/src/main/java/zingg/common/client/Client.java +++ b/common/client/src/main/java/zingg/common/client/Client.java @@ -5,9 +5,17 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.event.events.IEvent; +import zingg.common.client.event.events.ZinggStartEvent; +import zingg.common.client.event.events.ZinggStopEvent; +import zingg.common.client.event.listeners.EventsListener; +import zingg.common.client.event.listeners.IEventListener; +import zingg.common.client.event.listeners.ZinggStartListener; +import zingg.common.client.event.listeners.ZinggStopListener; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.Email; import zingg.common.client.util.EmailBody; +import zingg.common.client.util.PipeUtilBase; /** * This is the main point of interface with the Zingg matching product. @@ -22,9 +30,11 @@ public abstract class Client implements Serializable { protected IZingg zingg; protected ClientOptions options; protected S session; - + protected PipeUtilBase pipeUtil; public static final Log LOG = LogFactory.getLog(Client.class); + protected String zFactoryClassName; + /** * Construct a client to Zingg using provided arguments and spark master. @@ -36,10 +46,14 @@ public abstract class Client implements Serializable { * if issue connecting to master */ - public Client() {} + public Client(String zFactory) { + setZFactoryClassName(zFactory); + } - public Client(IArguments args, ClientOptions options) throws ZinggClientException { - setOptions(options); + public Client(IArguments args, ClientOptions options, String zFactory) throws ZinggClientException { + setZFactoryClassName(zFactory); + this.options = options; + setOptions(options); try { buildAndSetArguments(args, options); printAnalyticsBanner(arguments.getCollectMetrics()); @@ -51,14 +65,28 @@ public Client(IArguments args, ClientOptions options) throws ZinggClientExceptio } } - public Client(IArguments args, ClientOptions options, S s) throws ZinggClientException { - this(args, options); + + public String getZFactoryClassName() { + return zFactoryClassName; + } + + public void setZFactoryClassName(String s) { + this.zFactoryClassName = s; + } + + public Client(IArguments args, ClientOptions options, S s, String zFactory) throws ZinggClientException { + this(args, options, zFactory); this.session = s; LOG.debug("Session passed is " + s); if (session != null) zingg.setSession(session); } - public abstract IZinggFactory getZinggFactory() throws Exception;//(IZinggFactory) Class.forName("zingg.ZFactory").newInstance(); + + public IZinggFactory getZinggFactory() throws InstantiationException, IllegalAccessException, ClassNotFoundException{ + LOG.debug("z factory is " + getZFactoryClassName()); + return (IZinggFactory) Class.forName(getZFactoryClassName()).newInstance(); + } + @@ -70,9 +98,10 @@ public void setZingg(IArguments args, ClientOptions options) throws Exception{ catch(Exception e) { e.printStackTrace(); //set default - setZingg(zf.get(ZinggOptions.getByValue(ZinggOptions.PEEK_MODEL.getValue()))); + setZingg(zf.get(ZinggOptions.getByValue(ZinggOptions.PEEK_MODEL.getName()))); } } + public void setZingg(IZingg zingg) { this.zingg = zingg; @@ -120,7 +149,7 @@ else if (args.getJobId() != -1) { } public void printBanner() { - String versionStr = "0.4.0"; + String versionStr = "0.4.1-SNAPSHOT"; LOG.info(""); LOG.info("********************************************************"); LOG.info("* Zingg AI *"); @@ -155,23 +184,27 @@ public void printAnalyticsBanner(boolean collectMetrics) { } public abstract Client getClient(IArguments args, ClientOptions options) throws ZinggClientException; + + public ClientOptions getClientOptions(String ... args){ + return new ClientOptions(args); + } public void mainMethod(String... args) { printBanner(); Client client = null; ClientOptions options = null; try { + for (String a: args) LOG.debug("args " + a); - options = new ClientOptions(args); + options = getClientOptions(args); setOptions(options); - + if (options.has(options.HELP) || options.has(options.HELP1) || options.get(ClientOptions.PHASE) == null) { LOG.warn(options.getHelp()); System.exit(0); } String phase = options.get(ClientOptions.PHASE).value.trim(); ZinggOptions.verifyPhase(phase); - IArguments arguments = null; if (options.get(ClientOptions.CONF).value.endsWith("json")) { arguments = getArgsUtil().createArgumentsFromJSON(options.get(ClientOptions.CONF).value, phase); } @@ -184,6 +217,7 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) { client = getClient(arguments, options); client.init(); + // after setting arguments etc. as some of the listeners need it client.execute(); client.postMetrics(); LOG.warn("Zingg processing has completed"); @@ -212,6 +246,7 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) { } finally { try { + EventsListener.getInstance().fireEvent(new ZinggStopEvent()); if (client != null) { //client.postMetrics(); client.stop(); @@ -228,14 +263,12 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) { } public void init() throws ZinggClientException { - zingg.setClientOptions(options); - zingg.init(getArguments(), getLicense(options.get(ClientOptions.LICENSE).value.trim())); + zingg.init(getArguments(), getSession()); if (session != null) zingg.setSession(session); - + initializeListeners(); + EventsListener.getInstance().fireEvent(new ZinggStartEvent()); } - protected abstract IZinggLicense getLicense(String license) throws ZinggClientException ; - /** * Stop the Spark job running context */ @@ -305,5 +338,26 @@ protected ArgumentsUtil getArgsUtil() { } return argsUtil; } + + public void addListener(Class eventClass, IEventListener listener) { + EventsListener.getInstance().addListener(eventClass, listener); + } + + public void initializeListeners() { + addListener(ZinggStartEvent.class, new ZinggStartListener()); + addListener(ZinggStopEvent.class, new ZinggStopListener()); + } + + public abstract S getSession(); + + public void setSession(S s) { + this.session = s; + } + + public abstract PipeUtilBase getPipeUtil(); + + public void setPipeUtil(PipeUtilBase pipeUtil) { + this.pipeUtil = pipeUtil; + } } \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/ClientOptions.java b/common/client/src/main/java/zingg/common/client/ClientOptions.java index 8fb8dbf8f..cb1aa0929 100644 --- a/common/client/src/main/java/zingg/common/client/ClientOptions.java +++ b/common/client/src/main/java/zingg/common/client/ClientOptions.java @@ -12,6 +12,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.Util; public class ClientOptions { @@ -42,7 +43,7 @@ public class ClientOptions { protected String[] commandLineArgs; - protected static Map optionMaster = new HashMap(); + protected Map optionMaster = new HashMap(); /* * String optionName; //String alias; @@ -51,7 +52,7 @@ public class ClientOptions { boolean isExit; boolean isMandatory; */ - static { //This is the canonical list of Zingg options. + protected void loadOptions() { //This is the canonical list of Zingg options. optionMaster.put(CONF, new Option(CONF, true, "JSON configuration with data input output locations and field definitions", false, true)); optionMaster.put(PHASE, new Option(PHASE, true, Util.join(ZinggOptions.getAllZinggOptions(), "|"), false, true, ZinggOptions.getAllZinggOptions())); optionMaster.put(LICENSE, new Option(LICENSE, true, "location of license file", false, true)); @@ -75,13 +76,19 @@ public class ClientOptions { } protected Map options = new HashMap (); + + public ClientOptions(){ + loadOptions(); + } public ClientOptions(String... args) { + this(); this.commandLineArgs = args; parse(Arrays.asList(args)); } public ClientOptions(List args) { + this(); this.commandLineArgs = args.toArray(new String[args.size()]); parse(args); } @@ -89,8 +96,15 @@ public ClientOptions(List args) { public String[] getCommandLineArgs() { return this.commandLineArgs; } + + public Map getOptionMaster(){ + return optionMaster; + } - + public void setOptionMaster(Map optionMaster) { + this.optionMaster = optionMaster; + } + /** * Parse a list of Zingg command line options. *

@@ -249,12 +263,13 @@ public final static String getHelp() { s.append("options\n"); int maxlo = 0; - for (Option o: optionMaster.values()){ + ClientOptions co = new ClientOptions(); + for (Option o: co.optionMaster.values()){ maxlo=Math.max(maxlo,o.optionName.length()); } int maxld = 0; - for (Option o: optionMaster.values()){ + for (Option o: co.optionMaster.values()){ maxld=Math.max(maxld,o.desc.length()); } @@ -262,7 +277,7 @@ public final static String getHelp() { formatBuilder.append("\t").append("%-").append(maxlo + 5).append("s").append(": ").append("%-").append(maxld + 5).append("s").append("\n"); String format = formatBuilder.toString(); - for (Option o: optionMaster.values()) { + for (Option o: co.optionMaster.values()) { s.append(String.format(format,o.optionName, o.desc)); } return s.toString(); @@ -284,9 +299,13 @@ public String getOptionValue(String a) { return get(a).getValue(); //throw new IllegalArgumentException("Wrong argument"); } - - - - + /** A helper that allows to modify ClientOptions by changing values */ + public void setOptionValue(String key, String value) { + if (has(key)) { + OptionWithVal optionWithVal = get(key); + optionWithVal.setValue(value); + options.put(key, optionWithVal); + } + } } diff --git a/common/client/src/main/java/zingg/common/client/FieldDefUtil.java b/common/client/src/main/java/zingg/common/client/FieldDefUtil.java new file mode 100644 index 000000000..c8b06a55f --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/FieldDefUtil.java @@ -0,0 +1,30 @@ +package zingg.common.client; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Collectors; + +/** + * + * Util methods related to FieldDefinition objects + * + */ +public class FieldDefUtil implements Serializable{ + + private static final long serialVersionUID = 1L; + + public List getFieldDefinitionDontUse(List fieldDefinition) { + return fieldDefinition.stream() + .filter(x->x.matchType.contains(MatchType.DONT_USE)) + .collect(Collectors.toList()); + } + + public List getFieldDefinitionToUse(List fieldDefinition) { + return fieldDefinition.stream() + .filter(x->!x.matchType.contains(MatchType.DONT_USE)) + .collect(Collectors.toList()); + } + + + +} diff --git a/common/client/src/main/java/zingg/common/client/FieldDefinition.java b/common/client/src/main/java/zingg/common/client/FieldDefinition.java index 314c6d868..676829a88 100644 --- a/common/client/src/main/java/zingg/common/client/FieldDefinition.java +++ b/common/client/src/main/java/zingg/common/client/FieldDefinition.java @@ -10,6 +10,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; @@ -22,6 +23,8 @@ import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import zingg.common.client.cols.Named; + /** * This class defines each field that we use in matching We can use this to @@ -30,7 +33,7 @@ * @author sgoyal * */ -public class FieldDefinition implements +public class FieldDefinition implements Named, Serializable { private static final long serialVersionUID = 1L; @@ -119,6 +122,21 @@ public void setFieldName(String fieldName) { this.fieldName = fieldName; } + @JsonIgnore + public boolean isDontUse() { + return (matchType != null && matchType.contains(MatchType.DONT_USE)); + } + + @Override + public String getName() { + return getFieldName(); + } + + @Override + public void setName(String name) { + setFieldName(name); + } + @Override public int hashCode() { final int prime = 31; diff --git a/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java b/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java index 89e2ae44f..6385bc7f0 100644 --- a/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java +++ b/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java @@ -8,7 +8,7 @@ public interface ILabelDataViewHelper { List getClusterIds(ZFrame lines); - List getDisplayColumns(ZFrame lines, IArguments args); +// List getDisplayColumns(ZFrame lines, IArguments args); ZFrame getCurrentPair(ZFrame lines, int index, List clusterIds, ZFrame clusterLines); diff --git a/common/client/src/main/java/zingg/common/client/IZingg.java b/common/client/src/main/java/zingg/common/client/IZingg.java index 5e77a04db..61bd8133e 100644 --- a/common/client/src/main/java/zingg/common/client/IZingg.java +++ b/common/client/src/main/java/zingg/common/client/IZingg.java @@ -1,17 +1,15 @@ package zingg.common.client; -import zingg.common.client.license.IZinggLicense; - public interface IZingg { - public void init(IArguments args, IZinggLicense license) + public void init(IArguments args, S session) throws ZinggClientException; public void execute() throws ZinggClientException; public void cleanup() throws ZinggClientException; - public ZinggOptions getZinggOptions(); + //public ZinggOptions getZinggOptions(); public String getName(); diff --git a/common/client/src/main/java/zingg/common/client/IZinggFactory.java b/common/client/src/main/java/zingg/common/client/IZinggFactory.java index 427cbf35d..02a4b8d9c 100644 --- a/common/client/src/main/java/zingg/common/client/IZinggFactory.java +++ b/common/client/src/main/java/zingg/common/client/IZinggFactory.java @@ -1,9 +1,9 @@ package zingg.common.client; -import zingg.common.client.IZingg; +import zingg.common.client.options.ZinggOption; public interface IZinggFactory { - public IZingg get(ZinggOptions z) throws InstantiationException, IllegalAccessException, ClassNotFoundException; + public IZingg get(ZinggOption z) throws InstantiationException, IllegalAccessException, ClassNotFoundException; } diff --git a/common/client/src/main/java/zingg/common/client/Samples.java b/common/client/src/main/java/zingg/common/client/Samples.java index 1a74c3874..c93fa249a 100644 --- a/common/client/src/main/java/zingg/common/client/Samples.java +++ b/common/client/src/main/java/zingg/common/client/Samples.java @@ -1,3 +1,7 @@ + + + + package zingg.common.client; import java.io.Serializable; diff --git a/common/client/src/main/java/zingg/common/client/ZFrame.java b/common/client/src/main/java/zingg/common/client/ZFrame.java index 1a0861917..b07a264c0 100644 --- a/common/client/src/main/java/zingg/common/client/ZFrame.java +++ b/common/client/src/main/java/zingg/common/client/ZFrame.java @@ -20,7 +20,7 @@ public interface ZFrame { public ZFrame selectExpr(String... col); public ZFrame distinct(); public List collectAsList(); - public List collectAsListOfStrings(); + public List collectFirstColumn(); public ZFrame toDF(String[] cols); public ZFrame toDF(String col1, String col2); @@ -81,6 +81,9 @@ public interface ZFrame { public ZFrame repartition(int num); public ZFrame repartition(int num, C c); + public ZFrame repartition(int num,scala.collection.Seq partitionExprs); + public ZFrame repartition(scala.collection.Seq partitionExprs); + public ZFrame sample(boolean repartition, float num); public ZFrame sample(boolean repartition, double num); @@ -170,5 +173,10 @@ public interface ZFrame { public ZFrame groupByCount(String groupByCol1, String groupByCol2, String countColName); - + public ZFrame intersect(ZFrame other); + + public C substr(C col, int startPos, int len); + + public C gt(C column1, C column2); + } diff --git a/common/client/src/main/java/zingg/common/client/ZSession.java b/common/client/src/main/java/zingg/common/client/ZSession.java deleted file mode 100644 index 1b778bad7..000000000 --- a/common/client/src/main/java/zingg/common/client/ZSession.java +++ /dev/null @@ -1,16 +0,0 @@ -package zingg.common.client; - -import zingg.common.client.license.IZinggLicense; - -public interface ZSession { - - public S getSession(); - - public void setSession(S session); - - public IZinggLicense getLicense(); - - public void setLicense(IZinggLicense license); - - -} diff --git a/common/client/src/main/java/zingg/common/client/ZinggOptions.java b/common/client/src/main/java/zingg/common/client/ZinggOptions.java deleted file mode 100644 index 8c3d32173..000000000 --- a/common/client/src/main/java/zingg/common/client/ZinggOptions.java +++ /dev/null @@ -1,57 +0,0 @@ -package zingg.common.client; - -import zingg.common.client.util.Util; - -public enum ZinggOptions { - - TRAIN("train"), - MATCH("match"), - TRAIN_MATCH("trainMatch"), - FIND_TRAINING_DATA("findTrainingData"), - LABEL("label"), - LINK("link"), - GENERATE_DOCS("generateDocs"), - RECOMMEND("recommend"), - UPDATE_LABEL("updateLabel"), - FIND_AND_LABEL("findAndLabel"), - ASSESS_MODEL("assessModel"), - PEEK_MODEL("peekModel"), - EXPORT_MODEL("exportModel"), - APPROVE_CLUSTERS("approveClusters"), - RUN_INCREMENTAL("runIncremental"); - - private String value; - - ZinggOptions(String s) { - this.value = s; - } - - public static String[] getAllZinggOptions() { - ZinggOptions[] zo = ZinggOptions.values(); - int i = 0; - String[] s = new String[zo.length]; - for (ZinggOptions z: zo) { - s[i++] = z.getValue(); - } - return s; - } - - public String getValue() { - return value; - } - - public static final ZinggOptions getByValue(String value){ - for (ZinggOptions zo: ZinggOptions.values()) { - if (zo.value.equals(value)) return zo; - } - return null; - } - - public static void verifyPhase(String phase) throws ZinggClientException { - if (getByValue(phase) == null) { - String message = "'" + phase + "' is not a valid phase. " - + "Valid phases are: " + Util.join(getAllZinggOptions(), "|"); - throw new ZinggClientException(message); - } - } -} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java b/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java new file mode 100644 index 000000000..af5f615a0 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java @@ -0,0 +1,44 @@ +package zingg.common.client.cols; + +import java.util.ArrayList; +import java.util.List; + +import zingg.common.client.FieldDefinition; +import zingg.common.client.MatchType; + +public class FieldDefSelectedCols extends SelectedCols { + + protected FieldDefSelectedCols() { + + } + + public FieldDefSelectedCols(List fieldDefs, boolean showConcise) { + List colList = getColList(fieldDefs, showConcise); + setCols(colList); + } + + protected List getColList(List fieldDefs) { + return getColList(fieldDefs,false); + } + + protected List getColList(List fieldDefs, boolean showConcise) { + List namedList = new ArrayList(); + + for (FieldDefinition fieldDef : fieldDefs) { + if (showConcise && fieldDef.isDontUse()) { + continue; + } + namedList.add(fieldDef); + } + List stringList = convertNamedListToStringList(namedList); + return stringList; + } + + protected List convertNamedListToStringList(List namedList) { + List stringList = new ArrayList(); + for (FieldDefinition named : namedList) { + stringList.add(named.getName()); + } + return stringList; + } +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/ISelectedCols.java b/common/client/src/main/java/zingg/common/client/cols/ISelectedCols.java new file mode 100644 index 000000000..1d48fc945 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/ISelectedCols.java @@ -0,0 +1,16 @@ +package zingg.common.client.cols; + +import java.util.List; + +public interface ISelectedCols { + + String[] getCols(List n); + + String[] getCols(); + + void setCols(List cols); + + void setNamedCols(List n); + + void setStringCols(List cols); +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/Named.java b/common/client/src/main/java/zingg/common/client/cols/Named.java new file mode 100644 index 000000000..1fbe2a0a6 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/Named.java @@ -0,0 +1,8 @@ +package zingg.common.client.cols; + +public interface Named { + + String getName(); + + void setName(String name); +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/PredictionColsSelector.java b/common/client/src/main/java/zingg/common/client/cols/PredictionColsSelector.java new file mode 100644 index 000000000..71baf980c --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/PredictionColsSelector.java @@ -0,0 +1,23 @@ +package zingg.common.client.cols; + +import java.util.ArrayList; +import java.util.List; + +import zingg.common.client.util.ColName; + +public class PredictionColsSelector extends SelectedCols { + + public PredictionColsSelector() { + + List cols = new ArrayList(); + cols.add(ColName.ID_COL); + cols.add(ColName.COL_PREFIX + ColName.ID_COL); + cols.add(ColName.PREDICTION_COL); + cols.add(ColName.SCORE_COL); + + setCols(cols); + + } + + +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/SelectedCols.java b/common/client/src/main/java/zingg/common/client/cols/SelectedCols.java new file mode 100644 index 000000000..106afa534 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/SelectedCols.java @@ -0,0 +1,37 @@ +package zingg.common.client.cols; + +import java.util.List; + +public class SelectedCols implements ISelectedCols { + + private String[] cols; + + @Override + public String[] getCols(List n) { + String[] result = new String[n.size()]; + for (int i = 0; i < n.size(); i++) { + result[i] = n.get(i).getName(); + } + return result; + } + + @Override + public String[] getCols() { + return cols; + } + + @Override + public void setCols(List strings) { + this.cols = strings.toArray(new String[0]); + } + + @Override + public void setNamedCols(List n) { + this.cols = getCols(n); + } + + @Override + public void setStringCols(List columnNames) { + this.cols = columnNames.toArray(new String[0]); + } +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java b/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java new file mode 100644 index 000000000..62f5aac70 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java @@ -0,0 +1,24 @@ +package zingg.common.client.cols; + +import java.util.List; + +import zingg.common.client.FieldDefinition; +import zingg.common.client.util.ColName; + +public class ZidAndFieldDefSelector extends FieldDefSelectedCols { + + public ZidAndFieldDefSelector(List fieldDefs) { + this(fieldDefs, true, false); + } + + public ZidAndFieldDefSelector(List fieldDefs, boolean includeZid, boolean showConcise) { + List colList = getColList(fieldDefs, showConcise); + + if (includeZid) colList.add(0, ColName.ID_COL); + + colList.add(ColName.SOURCE_COL); + + setCols(colList); + } + +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/event/events/DataCountEvent.java b/common/client/src/main/java/zingg/common/client/event/events/DataCountEvent.java new file mode 100644 index 000000000..667364863 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/events/DataCountEvent.java @@ -0,0 +1,6 @@ +package zingg.common.client.event.events; + +public class DataCountEvent extends IEvent{ + + public static final String INPUT_DATA_COUNT = "INPUT_DATA_COUNT"; +} diff --git a/common/client/src/main/java/zingg/common/client/event/events/IEvent.java b/common/client/src/main/java/zingg/common/client/event/events/IEvent.java new file mode 100644 index 000000000..6fe90d0f2 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/events/IEvent.java @@ -0,0 +1,25 @@ +package zingg.common.client.event.events; + +import java.util.HashMap; + +public class IEvent { + + protected HashMap eventDataProps; + + public IEvent() { + super(); + } + + public IEvent(HashMap eventDataProps) { + super(); + this.eventDataProps = eventDataProps; + } + + public HashMap getProps(){ + return eventDataProps; + } + + public void setProps(HashMap props){ + this.eventDataProps = props; + } +} diff --git a/common/client/src/main/java/zingg/common/client/event/events/ZinggStartEvent.java b/common/client/src/main/java/zingg/common/client/event/events/ZinggStartEvent.java new file mode 100644 index 000000000..40c15775f --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/events/ZinggStartEvent.java @@ -0,0 +1,5 @@ +package zingg.common.client.event.events; + +public class ZinggStartEvent extends IEvent{ + +} diff --git a/common/client/src/main/java/zingg/common/client/event/events/ZinggStopEvent.java b/common/client/src/main/java/zingg/common/client/event/events/ZinggStopEvent.java new file mode 100644 index 000000000..dedeb37bd --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/events/ZinggStopEvent.java @@ -0,0 +1,5 @@ +package zingg.common.client.event.events; + +public class ZinggStopEvent extends IEvent{ + +} diff --git a/common/client/src/main/java/zingg/common/client/event/listeners/EventsListener.java b/common/client/src/main/java/zingg/common/client/event/listeners/EventsListener.java new file mode 100644 index 000000000..df4bd73a6 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/listeners/EventsListener.java @@ -0,0 +1,39 @@ +package zingg.common.client.event.listeners; + +import java.util.List; + +import zingg.common.client.event.events.IEvent; +import zingg.common.client.util.ListMap; + +public class EventsListener { + private static EventsListener _eventsListener = new EventsListener(); + private final ListMap eventListenersList; + + private EventsListener() { + eventListenersList = new ListMap(); + } + + public static EventsListener getInstance() { + return _eventsListener; + } + + public void addListener(Class eventClass, IEventListener listener) { + eventListenersList.add(eventClass.getCanonicalName(), listener); + } + + public void fireEvent(IEvent event) { + listen(event); + } + + private void listen(IEvent event) { + Class eventClass = event.getClass(); + List listenerList = eventListenersList.get(eventClass.getCanonicalName()); + if (listenerList != null) { + for (IEventListener listener : listenerList) { + if (listener != null) { + listener.listen(event); + } + } + } + } +} diff --git a/common/client/src/main/java/zingg/common/client/event/listeners/IEventListener.java b/common/client/src/main/java/zingg/common/client/event/listeners/IEventListener.java new file mode 100644 index 000000000..5f45e5082 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/listeners/IEventListener.java @@ -0,0 +1,10 @@ +package zingg.common.client.event.listeners; + +import zingg.common.client.event.events.IEvent; + +public class IEventListener { + + public void listen(IEvent event) { + + } +} diff --git a/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStartListener.java b/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStartListener.java new file mode 100644 index 000000000..06ed396c7 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStartListener.java @@ -0,0 +1,11 @@ +package zingg.common.client.event.listeners; + +import zingg.common.client.event.events.IEvent; + +public class ZinggStartListener extends IEventListener { + + @Override + public void listen(IEvent event) { + } + +} diff --git a/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStopListener.java b/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStopListener.java new file mode 100644 index 000000000..9d161dfb9 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/event/listeners/ZinggStopListener.java @@ -0,0 +1,10 @@ +package zingg.common.client.event.listeners; + +import zingg.common.client.event.events.IEvent; + +public class ZinggStopListener extends IEventListener { + + @Override + public void listen(IEvent event) { + } +} diff --git a/common/client/src/main/java/zingg/common/client/license/ILicenseValidator.java b/common/client/src/main/java/zingg/common/client/license/ILicenseValidator.java deleted file mode 100644 index 92fa47a37..000000000 --- a/common/client/src/main/java/zingg/common/client/license/ILicenseValidator.java +++ /dev/null @@ -1,25 +0,0 @@ -package zingg.common.client.license; - -import java.util.Properties; - -public interface ILicenseValidator { - - public boolean validate(); - - public Properties getLicenseProps(); - - public void setLicenseProps(Properties licenseProps); - - public String getKey(); - - public void setKey(String key); - - public String getValToCheck(); - - public void setValToCheck(String valToCheck); - - public String getName(); - - public void setName(String name); - -} diff --git a/common/client/src/main/java/zingg/common/client/license/IZinggLicense.java b/common/client/src/main/java/zingg/common/client/license/IZinggLicense.java deleted file mode 100644 index 761b5aedb..000000000 --- a/common/client/src/main/java/zingg/common/client/license/IZinggLicense.java +++ /dev/null @@ -1,11 +0,0 @@ -package zingg.common.client.license; - -import java.util.Properties; - -public interface IZinggLicense { - - public ILicenseValidator getValidator(String name); - - public Properties getLicenseProps(); - -} diff --git a/common/client/src/main/java/zingg/common/client/options/ZinggOption.java b/common/client/src/main/java/zingg/common/client/options/ZinggOption.java new file mode 100644 index 000000000..2b3ba2999 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/options/ZinggOption.java @@ -0,0 +1,19 @@ +package zingg.common.client.options; + +public class ZinggOption { + String name; + + public ZinggOption(String name) { + this.name = name; + ZinggOptions.put(this); + } + + public String getName() { + return name; + } + + @Override + public String toString(){ + return name; + } +} diff --git a/common/client/src/main/java/zingg/common/client/options/ZinggOptions.java b/common/client/src/main/java/zingg/common/client/options/ZinggOptions.java new file mode 100644 index 000000000..d4c98ed1e --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/options/ZinggOptions.java @@ -0,0 +1,66 @@ +package zingg.common.client.options; + +import java.util.HashMap; +import java.util.Map; + +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.Util; + +public class ZinggOptions { + + public final static ZinggOption TRAIN = new ZinggOption("train"); + public final static ZinggOption MATCH = new ZinggOption("match"); + public final static ZinggOption TRAIN_MATCH = new ZinggOption("trainMatch"); + public final static ZinggOption FIND_TRAINING_DATA = new ZinggOption("findTrainingData"); + public final static ZinggOption LABEL = new ZinggOption("label"); + public final static ZinggOption LINK = new ZinggOption("link"); + public final static ZinggOption GENERATE_DOCS = new ZinggOption("generateDocs"); + public final static ZinggOption RECOMMEND = new ZinggOption("recommend"); + public final static ZinggOption UPDATE_LABEL = new ZinggOption("updateLabel"); + public final static ZinggOption FIND_AND_LABEL = new ZinggOption("findAndLabel"); + public final static ZinggOption ASSESS_MODEL = new ZinggOption("assessModel"); + public final static ZinggOption PEEK_MODEL = new ZinggOption("peekModel"); + public final static ZinggOption EXPORT_MODEL = new ZinggOption("exportModel"); + + public static Map allZinggOptions;// = new HashMap(); + + + + protected ZinggOptions() { + } + + public static final void put(ZinggOption o) { + if (allZinggOptions == null) { + allZinggOptions = new HashMap(); + } + allZinggOptions.put(o.getName(), o); + } + + + + public static String[] getAllZinggOptions() { + ZinggOption[] zo = allZinggOptions.values().toArray(new ZinggOption[allZinggOptions.size()]); + int i = 0; + String[] s = new String[zo.length]; + for (ZinggOption z: zo) { + s[i++] = z.getName(); + } + return s; + } + + + public static final ZinggOption getByValue(String value){ + for (ZinggOption zo: ZinggOptions.allZinggOptions.values()) { + if (zo.name.equals(value)) return zo; + } + return null; + } + + public static void verifyPhase(String phase) throws ZinggClientException { + if (getByValue(phase) == null) { + String message = "'" + phase + "' is not a valid phase. " + + "Valid phases are: " + Util.join(getAllZinggOptions(), "|"); + throw new ZinggClientException(message); + } + } +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/util/DFObjectUtil.java b/common/client/src/main/java/zingg/common/client/util/DFObjectUtil.java new file mode 100644 index 000000000..c0ae8bd89 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/DFObjectUtil.java @@ -0,0 +1,17 @@ +package zingg.common.client.util; + +import java.util.List; + +import zingg.common.client.ZFrame; + +public abstract class DFObjectUtil { + + protected final IWithSession iWithSession; + + protected DFObjectUtil(IWithSession iWithSession) { + this.iWithSession = iWithSession; + } + + public abstract ZFrame getDFFromObjectList(List objList, Class objClass) throws Exception; + +} diff --git a/common/core/src/main/java/zingg/common/core/util/DFReader.java b/common/client/src/main/java/zingg/common/client/util/DFReader.java similarity index 93% rename from common/core/src/main/java/zingg/common/core/util/DFReader.java rename to common/client/src/main/java/zingg/common/client/util/DFReader.java index 6bf84940b..89f867752 100644 --- a/common/core/src/main/java/zingg/common/core/util/DFReader.java +++ b/common/client/src/main/java/zingg/common/client/util/DFReader.java @@ -1,4 +1,4 @@ -package zingg.common.core.util; +package zingg.common.client.util; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; diff --git a/common/core/src/main/java/zingg/common/core/util/DFWriter.java b/common/client/src/main/java/zingg/common/client/util/DFWriter.java similarity index 87% rename from common/core/src/main/java/zingg/common/core/util/DFWriter.java rename to common/client/src/main/java/zingg/common/client/util/DFWriter.java index c41e97196..9ddbfc88f 100644 --- a/common/core/src/main/java/zingg/common/core/util/DFWriter.java +++ b/common/client/src/main/java/zingg/common/client/util/DFWriter.java @@ -1,4 +1,4 @@ -package zingg.common.core.util; +package zingg.common.client.util; public interface DFWriter { diff --git a/common/core/src/main/java/zingg/common/core/util/DSUtil.java b/common/client/src/main/java/zingg/common/client/util/DSUtil.java similarity index 97% rename from common/core/src/main/java/zingg/common/core/util/DSUtil.java rename to common/client/src/main/java/zingg/common/client/util/DSUtil.java index 6c6d0721b..f8d2f8108 100644 --- a/common/core/src/main/java/zingg/common/core/util/DSUtil.java +++ b/common/client/src/main/java/zingg/common/client/util/DSUtil.java @@ -1,4 +1,4 @@ -package zingg.common.core.util; +package zingg.common.client.util; import zingg.common.client.FieldDefinition; @@ -7,8 +7,6 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; import zingg.common.client.pipe.Pipe; -import zingg.common.client.util.ColName; -import zingg.common.client.util.ColValues; import java.util.ArrayList; import java.util.List; @@ -43,7 +41,12 @@ public static final String[] getPrefixedColumns(String[] cols) { } public ZFrame getPrefixedColumnsDS(ZFrame lines) { - return lines.toDF(getPrefixedColumns(lines.columns())); + try { + return lines.toDF(getPrefixedColumns(lines.columns())); + } catch (Exception e) { + LOG.error("Please ensure that the 'ftd' and 'label' processes are executed before initiating the training phase"); + throw e; + } } diff --git a/common/client/src/main/java/zingg/common/client/util/IWithSession.java b/common/client/src/main/java/zingg/common/client/util/IWithSession.java new file mode 100644 index 000000000..470405c38 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/IWithSession.java @@ -0,0 +1,9 @@ +package zingg.common.client.util; + +public interface IWithSession { + + public void setSession(S s); + + public S getSession(); + +} \ No newline at end of file diff --git a/common/client/src/main/java/zingg/common/client/util/JsonStringify.java b/common/client/src/main/java/zingg/common/client/util/JsonStringify.java new file mode 100644 index 000000000..848155e83 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/JsonStringify.java @@ -0,0 +1,27 @@ +package zingg.common.client.util; + +import java.io.IOException; +import java.io.StringWriter; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.ObjectMapper; + +import zingg.common.client.Arguments; +import zingg.common.client.ArgumentsUtil; + +public class JsonStringify { + public static String toString (Object o){ + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, true); + //mapper.configure(JsonParser.Feature.FAIL_ON_EMPTY_BEANS, true) + try { + StringWriter writer = new StringWriter(); + return mapper.writeValueAsString(o); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + return null; + } + } + +} diff --git a/common/core/src/main/java/zingg/common/core/util/PipeUtil.java b/common/client/src/main/java/zingg/common/client/util/PipeUtil.java similarity index 99% rename from common/core/src/main/java/zingg/common/core/util/PipeUtil.java rename to common/client/src/main/java/zingg/common/client/util/PipeUtil.java index b76f8a371..415a4e36a 100644 --- a/common/core/src/main/java/zingg/common/core/util/PipeUtil.java +++ b/common/client/src/main/java/zingg/common/client/util/PipeUtil.java @@ -1,4 +1,4 @@ -package zingg.common.core.util; +package zingg.common.client.util; import java.util.Arrays; import java.util.stream.Collectors; @@ -12,7 +12,6 @@ import zingg.common.client.pipe.FilePipe; //import zingg.common.client.pipe.InMemoryPipe; import zingg.common.client.pipe.Pipe; -import zingg.common.client.util.ColName; //import com.datastax.spark.connector.cql.*; //import org.elasticsearch.spark.sql.api.java.JavaEsSparkSQL; @@ -185,7 +184,7 @@ public ZFrame read(boolean addExtraCol, boolean addLineNo, int numPartit return rows; } - public void write(ZFrame toWriteOrig, IArguments args, + public void write(ZFrame toWriteOrig, Pipe... pipes) throws ZinggClientException { try { for (Pipe p: pipes) { diff --git a/common/core/src/main/java/zingg/common/core/util/PipeUtilBase.java b/common/client/src/main/java/zingg/common/client/util/PipeUtilBase.java similarity index 93% rename from common/core/src/main/java/zingg/common/core/util/PipeUtilBase.java rename to common/client/src/main/java/zingg/common/client/util/PipeUtilBase.java index bdb363a2b..b293d0b71 100644 --- a/common/core/src/main/java/zingg/common/core/util/PipeUtilBase.java +++ b/common/client/src/main/java/zingg/common/client/util/PipeUtilBase.java @@ -1,4 +1,4 @@ -package zingg.common.core.util; +package zingg.common.client.util; import zingg.common.client.IArguments; import zingg.common.client.ZFrame; @@ -29,7 +29,7 @@ public ZFrame read(boolean addLineNo, int numPartitions, public ZFrame read(boolean addExtraCol, boolean addLineNo, int numPartitions, boolean addSource, Pipe... pipes) throws ZinggClientException; - public void write(ZFrame toWriteOrig, IArguments args, Pipe... pipes) + public void write(ZFrame toWriteOrig, Pipe... pipes) throws ZinggClientException; diff --git a/common/client/src/main/java/zingg/common/client/util/PojoToArrayConverter.java b/common/client/src/main/java/zingg/common/client/util/PojoToArrayConverter.java new file mode 100644 index 000000000..a04e60b68 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/PojoToArrayConverter.java @@ -0,0 +1,40 @@ +package zingg.common.client.util; + +import java.lang.reflect.Field; + +public class PojoToArrayConverter { + + public static Object[] getObjectArray(Object object) throws IllegalAccessException { + Field[] fieldsInChildClass = object.getClass().getDeclaredFields(); + Field[] fieldsInParentClass = null; + + int fieldCountInChildClass = fieldsInChildClass.length; + int fieldCount = fieldCountInChildClass; + + if (object.getClass().getSuperclass() != null) { + fieldCount += object.getClass().getSuperclass().getDeclaredFields().length; + fieldsInParentClass = object.getClass().getSuperclass().getDeclaredFields(); + } + + //fieldCount = fieldCountChild + fieldCountParent + Object[] objArr = new Object[fieldCount]; + + int idx = 0; + + //iterate through child class fields + for (; idx < fieldCountInChildClass; idx++) { + Field field = fieldsInChildClass[idx]; + field.setAccessible(true); + objArr[idx] = field.get(object); + } + + //iterate through super class fields + for (; idx < fieldCount; idx++) { + Field field = fieldsInParentClass[idx - fieldCountInChildClass]; + field.setAccessible(true); + objArr[idx] = field.get(object); + } + + return objArr; + } +} diff --git a/common/client/src/main/java/zingg/common/client/util/StructTypeFromPojoClass.java b/common/client/src/main/java/zingg/common/client/util/StructTypeFromPojoClass.java new file mode 100644 index 000000000..4b3de89bb --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/StructTypeFromPojoClass.java @@ -0,0 +1,34 @@ +package zingg.common.client.util; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; + +public abstract class StructTypeFromPojoClass { + + public abstract ST getStructType(Class objClass) throws Exception; + + public List getFields(Class objClass) { + List structFields = new ArrayList(); + Field[] fields = objClass.getDeclaredFields(); + + //add child class fields in struct + for (Field f : fields) { + structFields.add(getStructField(f)); + } + + //add parent class fields in struct + if (objClass.getSuperclass() != null) { + Field[] fieldsSuper = objClass.getSuperclass().getDeclaredFields(); + for (Field f : fieldsSuper) { + structFields.add(getStructField(f)); + } + } + return structFields; + } + + public abstract SF getStructField(Field field); + + public abstract T getSFType(Class t); + +} diff --git a/common/client/src/main/java/zingg/common/client/util/WithSession.java b/common/client/src/main/java/zingg/common/client/util/WithSession.java new file mode 100644 index 000000000..e3d0612b9 --- /dev/null +++ b/common/client/src/main/java/zingg/common/client/util/WithSession.java @@ -0,0 +1,15 @@ +package zingg.common.client.util; + +public class WithSession implements IWithSession { + + S session; + @Override + public void setSession(S session) { + this.session = session; + } + + @Override + public S getSession() { + return session; + } +} diff --git a/common/client/src/test/java/zingg/common/client/TestArguments.java b/common/client/src/test/java/zingg/common/client/TestArguments.java index 231464c44..0a75b4d2d 100644 --- a/common/client/src/test/java/zingg/common/client/TestArguments.java +++ b/common/client/src/test/java/zingg/common/client/TestArguments.java @@ -9,6 +9,7 @@ import java.nio.file.Paths; import java.util.Arrays; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -25,6 +26,7 @@ public class TestArguments { public static final Log LOG = LogFactory.getLog(TestArguments.class); protected ArgumentsUtil argsUtil = new ArgumentsUtil(); + @Test public void testSubstituteVariablesWithAllEnvVarSet() { try { @@ -242,8 +244,33 @@ public void testMatchTypeWrong() { } - - - + + @Test + public void testJsonStringify(){ + IArguments argsFromJsonFile; + try{ + //Converting to JSON using toString() + argsFromJsonFile = argsUtil.createArgumentsFromJSON(getClass().getResource("../../../testArguments/configWithMultipleMatchTypesUnsupported.json").getFile(), "test"); + String strFromJsonFile = argsFromJsonFile.toString(); + + IArguments argsFullCycle = argsUtil.createArgumentsFromJSONString(strFromJsonFile, ""); + + assertEquals(argsFullCycle.getFieldDefinition().get(0), argsFromJsonFile.getFieldDefinition().get(0)); + assertEquals(argsFullCycle.getFieldDefinition().get(2), argsFromJsonFile.getFieldDefinition().get(2)); + assertEquals(argsFullCycle.getModelId(), argsFromJsonFile.getModelId()); + assertEquals(argsFullCycle.getZinggModelDir(), argsFromJsonFile.getZinggModelDir()); + assertEquals(argsFullCycle.getNumPartitions(), argsFromJsonFile.getNumPartitions()); + assertEquals(argsFullCycle.getLabelDataSampleSize() ,argsFromJsonFile.getLabelDataSampleSize()); + assertEquals(argsFullCycle.getTrainingSamples(),argsFromJsonFile.getTrainingSamples()); + assertEquals(argsFullCycle.getOutput(),argsFromJsonFile.getOutput()); + assertEquals(argsFullCycle.getData(),argsFromJsonFile.getData()); + assertEquals(argsFullCycle.getZinggDir(),argsFromJsonFile.getZinggDir()); + assertEquals(argsFullCycle.getJobId(),argsFromJsonFile.getJobId()); + + } catch (Exception | ZinggClientException e) { + e.printStackTrace(); + } + + } } diff --git a/common/client/src/test/java/zingg/common/client/TestClient.java b/common/client/src/test/java/zingg/common/client/TestClient.java index e22ff3c21..5a3befd85 100644 --- a/common/client/src/test/java/zingg/common/client/TestClient.java +++ b/common/client/src/test/java/zingg/common/client/TestClient.java @@ -6,6 +6,8 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; +import zingg.common.client.options.ZinggOptions; + public class TestClient { public static final Log LOG = LogFactory.getLog(TestClient.class); diff --git a/common/client/src/test/java/zingg/common/client/TestFieldDefUtil.java b/common/client/src/test/java/zingg/common/client/TestFieldDefUtil.java new file mode 100644 index 000000000..3d78d4618 --- /dev/null +++ b/common/client/src/test/java/zingg/common/client/TestFieldDefUtil.java @@ -0,0 +1,41 @@ +package zingg.common.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; + + +public class TestFieldDefUtil { + + public static final Log LOG = LogFactory.getLog(TestFieldDefUtil.class); + protected ArgumentsUtil argsUtil = new ArgumentsUtil(); + + protected FieldDefUtil fieldDefUtil = new FieldDefUtil(); + + @Test + public void testMatchTypeFilter() { + IArguments args; + try { + args = argsUtil.createArgumentsFromJSON(getClass().getResource("../../../testArguments/configTestDontUse.json").getFile(), "test"); + + List dontUseList = fieldDefUtil.getFieldDefinitionDontUse(args.getFieldDefinition()); + assertEquals(dontUseList.size(), 3); + + List matchList = fieldDefUtil.getFieldDefinitionToUse(args.getFieldDefinition()); + assertEquals(matchList.size(), 4); + + } catch (Exception | ZinggClientException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + fail("Could not read config"); + } + + } + + +} diff --git a/common/client/src/test/java/zingg/common/client/util/TestStringRedactor.java b/common/client/src/test/java/zingg/common/client/util/TestStringRedactor.java index 10220a5f1..07aff4f66 100644 --- a/common/client/src/test/java/zingg/common/client/util/TestStringRedactor.java +++ b/common/client/src/test/java/zingg/common/client/util/TestStringRedactor.java @@ -1,7 +1,5 @@ package zingg.common.client.util; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.stream.Stream; @@ -10,14 +8,10 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.Arguments; import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.junit.jupiter.api.Named.named; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; -import org.junit.jupiter.api.Test; - public class TestStringRedactor { @ParameterizedTest(name="{0}") diff --git a/common/client/src/test/resources/testArguments/configTestDontUse.json b/common/client/src/test/resources/testArguments/configTestDontUse.json new file mode 100644 index 000000000..f1f1ed225 --- /dev/null +++ b/common/client/src/test/resources/testArguments/configTestDontUse.json @@ -0,0 +1,70 @@ +{ + "fieldDefinition":[ + { + "fieldName" : "fname", + "matchType" : "fuzzy,null_or_blank", + "fields" : "fname", + "dataType": "string" + }, + { + "fieldName" : "lname", + "matchType" : "fuzzy", + "fields" : "lname", + "dataType": "string" + }, + { + "fieldName" : "stNo", + "matchType": "exact", + "fields" : "stNo", + "dataType": "string" + }, + { + "fieldName" : "add1", + "matchType": "fuzzy,dont_use", + "fields" : "add1", + "dataType": "string" + }, + { + "fieldName" : "add2", + "matchType": "dont_use", + "fields" : "add2", + "dataType": "string" + }, + { + "fieldName" : "city", + "matchType": "dont_use,fuzzy", + "fields" : "city", + "dataType": "string" + }, + { + "fieldName" : "state", + "matchType": "fuzzy", + "fields" : "state", + "dataType": "string" + } + ], + "output" : [{ + "name":"output", + "format":"csv", + "props": { + "location": "/tmp/zinggOutput", + "delimiter": ",", + "header":true + } + }], + "data" : [{ + "name":"test", + "format":"csv", + "props": { + "location": "examples/febrl/test.csv", + "delimiter": ",", + "header":false + }, + "schema": "id string, fname string, lname string, stNo string, add1 string, add2 string, city string, areacode string, state string, dob string, ssn string" + }], + "labelDataSampleSize" : 0.5, + "numPartitions":4, + "modelId": 100, + "zinggDir": "models" + +} diff --git a/common/core/pom.xml b/common/core/pom.xml index 926187b45..40d61e0c4 100644 --- a/common/core/pom.xml +++ b/common/core/pom.xml @@ -27,7 +27,41 @@ org.apache.httpcomponents httpclient - 4.5.2 + 4.5.14 + + org.junit.jupiter + junit-jupiter-engine + 5.8.1 + test + + + org.junit.jupiter + junit-jupiter-api + 5.8.1 + test + + + org.junit.jupiter + junit-jupiter-params + 5.8.1 + test + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.3.2 + + + + test-jar + + + + + + diff --git a/common/core/src/main/java/zingg/common/core/Context.java b/common/core/src/main/java/zingg/common/core/Context.java deleted file mode 100644 index d475708ee..000000000 --- a/common/core/src/main/java/zingg/common/core/Context.java +++ /dev/null @@ -1,58 +0,0 @@ -package zingg.common.core; - -import java.io.Serializable; - -import zingg.common.client.ZinggClientException; -import zingg.common.client.license.IZinggLicense; -import zingg.common.core.util.BlockingTreeUtil; -import zingg.common.core.util.DSUtil; -import zingg.common.core.util.GraphUtil; -import zingg.common.core.util.HashUtil; -import zingg.common.core.util.ModelUtil; -import zingg.common.core.util.PipeUtilBase; - -public interface Context extends Serializable { - - public HashUtil getHashUtil() ; - public void setHashUtil(HashUtil t) ; - public GraphUtil getGraphUtil() ; - - public void setGraphUtil(GraphUtil t) ; - - public void setModelUtil(ModelUtil t); - public void setBlockingTreeUtil(BlockingTreeUtil t) ; - - public ModelUtil getModelUtil(); - - public void setPipeUtil(PipeUtilBase pipeUtil); - public void setDSUtil(DSUtil pipeUtil); - public DSUtil getDSUtil() ; - public PipeUtilBase getPipeUtil(); - public BlockingTreeUtil getBlockingTreeUtil() ; - - public void init(IZinggLicense license) - throws ZinggClientException; - - public void cleanup(); - - /**convenience method to set all utils - * especially useful when you dont want to create the connection/spark context etc - * */ - public void setUtils(); - - public S getSession(); - - public void setSession(S session); - - - //public void initHashFns() throws ZinggClientException; - - - - - - } - - - - diff --git a/common/core/src/main/java/zingg/common/core/block/Block.java b/common/core/src/main/java/zingg/common/core/block/Block.java index 35bde6b54..0fdd3665b 100644 --- a/common/core/src/main/java/zingg/common/core/block/Block.java +++ b/common/core/src/main/java/zingg/common/core/block/Block.java @@ -13,10 +13,13 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ListMap; +import zingg.common.core.feature.FeatureFactory; import zingg.common.core.hash.HashFunction; public abstract class Block implements Serializable { + private static final long serialVersionUID = 1L; + public static final Log LOG = LogFactory.getLog(Block.class); protected ZFrame dupes; @@ -66,16 +69,13 @@ public void setDupes(ZFrame dupes) { /** * @return the types * - * public Class[] getTypes() { return types; } */ /** * @param types - * the types to set - * - * public void setTypes(Class[] types) { this.types = types; } + * the types to set * - * /** + * * @return the maxSize */ public long getMaxSize() { @@ -84,7 +84,7 @@ public long getMaxSize() { /** * @param maxSize - * the maxSize to set + * the maxSize to set */ public void setMaxSize(long maxSize) { this.maxSize = maxSize; @@ -102,10 +102,13 @@ protected void setFunctionsMap(ListMap> m) { this.functionsMap = m; } + protected Canopy getCanopy(){ + return new Canopy(); + } public CanopygetNodeFromCurrent(Canopynode, HashFunction function, FieldDefinition context) { - Canopytrial = new Canopy(); + Canopytrial = getCanopy(); trial = node.copyTo(trial); // node.training, node.dupeN, function, context); trial.function = function; @@ -113,23 +116,28 @@ protected void setFunctionsMap(ListMap> m) { return trial; } - public abstract T getDataTypeFromString(String t); + public void estimateElimCount(Canopy c, long elimCount) { + c.estimateElimCount(); + } public CanopygetBestNode(Tree> tree, Canopyparent, Canopynode, List fieldsOfInterest) throws Exception { long least = Long.MAX_VALUE; int maxElimination = 0; Canopybest = null; - for (FieldDefinition field : fieldsOfInterest) { - LOG.debug("Trying for " + field + " with data type " + field.getDataType() + " and real dt " - + getDataTypeFromString(field.getDataType())); + if (LOG.isDebugEnabled()){ + LOG.debug("Trying for " + field + " with data type " + field.getDataType() + " and real dt " + + getFeatureFactory().getDataTypeFromString(field.getDataType())); + } //Class type = FieldClass.getFieldClassClass(field.getFieldClass()); FieldDefinition context = field; if (least ==0) break;//how much better can it get? // applicable functions - List> functions = functionsMap.get(getDataTypeFromString(field.getDataType())); - LOG.debug("functions are " + functions); + List> functions = functionsMap.get(getFeatureFactory().getDataTypeFromString(field.getDataType())); + if (LOG.isDebugEnabled()){ + LOG.debug("functions are " + functions); + } if (functions != null) { @@ -140,11 +148,13 @@ protected void setFunctionsMap(ListMap> m) { //!childless.contains(function, field.fieldName) ) { - LOG.debug("Evaluating field " + field.fieldName + if (LOG.isDebugEnabled()){ + LOG.debug("Evaluating field " + field.fieldName + " and function " + function + " for " + field.dataType); + } Canopytrial = getNodeFromCurrent(node, function, context); - trial.estimateElimCount(); + estimateElimCount(trial, least); long elimCount = trial.getElimCount(); @@ -178,7 +188,9 @@ protected void setFunctionsMap(ListMap> m) { }*/ } else { - LOG.debug("No child " + function); + if (LOG.isDebugEnabled()){ + LOG.debug("No child " + function); + } //childless.add(function, field.fieldName); } @@ -392,7 +404,7 @@ public void printTree(Tree> tree, } } - + public abstract FeatureFactory getFeatureFactory(); } diff --git a/common/core/src/main/java/zingg/common/core/block/Canopy.java b/common/core/src/main/java/zingg/common/core/block/Canopy.java index 25f0d4124..09451c56d 100644 --- a/common/core/src/main/java/zingg/common/core/block/Canopy.java +++ b/common/core/src/main/java/zingg/common/core/block/Canopy.java @@ -20,19 +20,19 @@ public class Canopy implements Serializable { public static final Log LOG = LogFactory.getLog(Canopy.class); // created by function edge leading from parent to this node - HashFunction function; + protected HashFunction function; // aplied on field - FieldDefinition context; + protected FieldDefinition context; // list of duplicates passed from parent - List dupeN; + protected List dupeN; // number of duplicates eliminated after function applied on fn context - long elimCount; + protected long elimCount; // hash of canopy - Object hash; + protected Object hash; // training set - List training; + protected List training; // duplicates remaining after function is applied - List dupeRemaining; + protected List dupeRemaining; public Canopy() { } diff --git a/common/core/src/main/java/zingg/common/core/context/Context.java b/common/core/src/main/java/zingg/common/core/context/Context.java new file mode 100644 index 000000000..410e3ae3d --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/context/Context.java @@ -0,0 +1,89 @@ +package zingg.common.core.context; + +import java.io.Serializable; + +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.DSUtil; +import zingg.common.client.util.PipeUtilBase; +import zingg.common.core.util.BlockingTreeUtil; +import zingg.common.core.util.GraphUtil; +import zingg.common.core.util.HashUtil; +import zingg.common.core.util.ModelUtil; + +public abstract class Context implements Serializable { + protected S session; + protected PipeUtilBase pipeUtil; + protected HashUtil hashUtil; + protected DSUtil dsUtil; + protected GraphUtil graphUtil; + protected ModelUtil modelUtil; + protected BlockingTreeUtil blockingTreeUtil; + + public static final String hashFunctionFile = "hashFunctions.json"; + + public HashUtil getHashUtil() { + return this.hashUtil; + } + public void setHashUtil(HashUtil t) { + this.hashUtil = t; + } + public GraphUtil getGraphUtil() { + return this.graphUtil; + } + + public void setGraphUtil(GraphUtil t) { + this.graphUtil = t; + } + + public void setModelUtil(ModelUtil t){ + this.modelUtil = t; + } + public void setBlockingTreeUtil(BlockingTreeUtil t) { + this.blockingTreeUtil = t; + } + + public ModelUtil getModelUtil(){ + return this.modelUtil; + } + + public void setPipeUtil(PipeUtilBase pipeUtil){ + this.pipeUtil = pipeUtil; + } + public void setDSUtil(DSUtil d){ + this.dsUtil = d; + } + public DSUtil getDSUtil() { + return this.dsUtil; + } + public PipeUtilBase getPipeUtil(){ + return this.pipeUtil; + } + public BlockingTreeUtil getBlockingTreeUtil() { + return this.blockingTreeUtil; + } + + public abstract void init(S session) + throws ZinggClientException; + + public abstract void cleanup(); + + /**convenience method to set all utils + * especially useful when you dont want to create the connection/spark context etc + * */ + public abstract void setUtils(); + + public S getSession(){ + return session; + } + + public void setSession(S session){ + this.session = session; + } + + + + } + + + + diff --git a/common/core/src/main/java/zingg/common/core/documenter/DataColDocumenter.java b/common/core/src/main/java/zingg/common/core/documenter/DataColDocumenter.java index c227f5187..b69c32c80 100644 --- a/common/core/src/main/java/zingg/common/core/documenter/DataColDocumenter.java +++ b/common/core/src/main/java/zingg/common/core/documenter/DataColDocumenter.java @@ -6,7 +6,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public abstract class DataColDocumenter extends DocumenterBase { protected static String name = "zingg.DataColDocumenter"; diff --git a/common/core/src/main/java/zingg/common/core/documenter/DataDocumenter.java b/common/core/src/main/java/zingg/common/core/documenter/DataDocumenter.java index 71737064d..0d88b1424 100644 --- a/common/core/src/main/java/zingg/common/core/documenter/DataDocumenter.java +++ b/common/core/src/main/java/zingg/common/core/documenter/DataDocumenter.java @@ -12,7 +12,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public abstract class DataDocumenter extends DocumenterBase { protected static String name = "zingg.DataDocumenter"; diff --git a/common/core/src/main/java/zingg/common/core/documenter/DocumenterBase.java b/common/core/src/main/java/zingg/common/core/documenter/DocumenterBase.java index 0f891c839..59858bd0f 100644 --- a/common/core/src/main/java/zingg/common/core/documenter/DocumenterBase.java +++ b/common/core/src/main/java/zingg/common/core/documenter/DocumenterBase.java @@ -12,7 +12,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ColName; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.executor.ZinggBase; public abstract class DocumenterBase extends ZinggBase{ diff --git a/common/core/src/main/java/zingg/common/core/documenter/ModelColDocumenter.java b/common/core/src/main/java/zingg/common/core/documenter/ModelColDocumenter.java index 41d215e63..1bdfb2942 100644 --- a/common/core/src/main/java/zingg/common/core/documenter/ModelColDocumenter.java +++ b/common/core/src/main/java/zingg/common/core/documenter/ModelColDocumenter.java @@ -9,7 +9,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public abstract class ModelColDocumenter extends DocumenterBase { protected static String name = "zingg.ModelColDocumenter"; diff --git a/common/core/src/main/java/zingg/common/core/documenter/ModelDocumenter.java b/common/core/src/main/java/zingg/common/core/documenter/ModelDocumenter.java index 75363e71c..67c0a7ef5 100644 --- a/common/core/src/main/java/zingg/common/core/documenter/ModelDocumenter.java +++ b/common/core/src/main/java/zingg/common/core/documenter/ModelDocumenter.java @@ -10,12 +10,14 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import zingg.common.client.FieldDefUtil; +import zingg.common.client.FieldDefinition; import zingg.common.client.IArguments; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public abstract class ModelDocumenter extends DocumenterBase { @@ -30,10 +32,13 @@ public abstract class ModelDocumenter extends DocumenterBase modelColDoc; protected ZFrame markedRecords; protected ZFrame unmarkedRecords; + + protected FieldDefUtil fieldDefUtil; public ModelDocumenter(Context context, IArguments args) { super(context, args); markedRecords = getDSUtil().emptyDataFrame(); + fieldDefUtil = new FieldDefUtil(); } public void process() throws ZinggClientException { @@ -45,8 +50,9 @@ protected void createModelDocument() throws ZinggClientException { try { LOG.info("Model document generation starts"); - markedRecords = getMarkedRecords().sortAscending(ColName.CLUSTER_COLUMN); - unmarkedRecords = getUnmarkedRecords().sortAscending(ColName.CLUSTER_COLUMN); + // drop columns which are don't use if show concise is true + markedRecords = filterForConcise(getMarkedRecords().sortAscending(ColName.CLUSTER_COLUMN)); + unmarkedRecords = filterForConcise(getUnmarkedRecords().sortAscending(ColName.CLUSTER_COLUMN)); Map root = populateTemplateData(); writeModelDocument(root); @@ -82,8 +88,7 @@ protected Map populateTemplateData() { } else { // fields required to generate basic document - List columnList = args.getFieldDefinition().stream().map(fd -> fd.getFieldName()) - .collect(Collectors.toList()); + List columnList = getColumnList(); root.put(TemplateFields.NUM_COLUMNS, columnList.size()); root.put(TemplateFields.COLUMNS, columnList.toArray()); root.put(TemplateFields.CLUSTERS, Collections.emptyList()); @@ -94,6 +99,31 @@ protected Map populateTemplateData() { return root; } + protected ZFrame filterForConcise(ZFrame df) { + if (args.getShowConcise()) { + List dontUseFields = getFieldNames( + (List) fieldDefUtil.getFieldDefinitionDontUse(args.getFieldDefinition())); + if(!dontUseFields.isEmpty()) { + df = df.drop(dontUseFields.toArray(new String[dontUseFields.size()])); + } + } + return df; + } + + protected List getColumnList() { + List fieldList = args.getFieldDefinition(); + //drop columns which are don't use if show concise is true + if (args.getShowConcise()) { + fieldList = fieldDefUtil.getFieldDefinitionToUse(args.getFieldDefinition()); + } + return getFieldNames(fieldList); + } + + protected List getFieldNames(List fieldList) { + return fieldList.stream().map(fd -> fd.getFieldName()) + .collect(Collectors.toList()); + } + private void putSummaryCounts(Map root) { // Get the count if not empty ZFrame markedRecordsPairSummary = markedRecords.groupByCount(ColName.MATCH_FLAG_COL, PAIR_WISE_COUNT); diff --git a/common/core/src/main/java/zingg/common/core/executor/Documenter.java b/common/core/src/main/java/zingg/common/core/executor/Documenter.java index 6e80b8aa7..2841720e5 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Documenter.java +++ b/common/core/src/main/java/zingg/common/core/executor/Documenter.java @@ -4,7 +4,7 @@ import org.apache.commons.logging.LogFactory; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.options.ZinggOptions; import zingg.common.core.documenter.DataDocumenter; import zingg.common.core.documenter.ModelDocumenter; @@ -14,7 +14,7 @@ public abstract class Documenter extends ZinggBase { public static final Log LOG = LogFactory.getLog(Documenter.class); public Documenter() { - setZinggOptions(ZinggOptions.GENERATE_DOCS); + setZinggOption(ZinggOptions.GENERATE_DOCS); } public void execute() throws ZinggClientException { diff --git a/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java b/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java index e4e43109a..b8eb3eff0 100644 --- a/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java +++ b/common/core/src/main/java/zingg/common/core/executor/FindAndLabeller.java @@ -5,8 +5,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; public abstract class FindAndLabeller extends ZinggBase { private static final long serialVersionUID = 1L; @@ -17,14 +16,14 @@ public abstract class FindAndLabeller extends ZinggBase labeller; public FindAndLabeller() { - setZinggOptions(ZinggOptions.FIND_AND_LABEL); + setZinggOption(ZinggOptions.FIND_AND_LABEL); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - finder.init(args, license); - labeller.init(args, license); - super.init(args, license); + public void init(IArguments args, S s) throws ZinggClientException { + finder.init(args,s); + labeller.init(args,s); + super.init(args,s); } @Override diff --git a/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java b/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java index 0c6024621..d5bd5970d 100644 --- a/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java +++ b/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java @@ -6,14 +6,12 @@ import org.apache.commons.logging.LogFactory; import zingg.common.client.ClientOptions; -import zingg.common.client.IArguments; import zingg.common.client.ILabelDataViewHelper; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.util.LabelMatchType; public class LabelDataViewHelper extends ZinggBase implements ILabelDataViewHelper { @@ -21,9 +19,8 @@ public class LabelDataViewHelper extends ZinggBase imp private static final long serialVersionUID = 1L; public static final Log LOG = LogFactory.getLog(LabelDataViewHelper.class); - public LabelDataViewHelper(Context context, ZinggOptions zinggOptions, ClientOptions clientOptions) { + public LabelDataViewHelper(Context context, ClientOptions clientOptions) { setContext(context); - setZinggOptions(zinggOptions); setClientOptions(clientOptions); setName(this.getClass().getName()); } @@ -40,11 +37,11 @@ public List getClusterIds(ZFrame lines) { } - @Override - public List getDisplayColumns(ZFrame lines, IArguments args) { - return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); - } - +// @Override +// public List getDisplayColumns(ZFrame lines, IArguments args) { +// return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); +// } +// @Override public ZFrame getCurrentPair(ZFrame lines, int index, List clusterIds, ZFrame clusterLines) { @@ -127,5 +124,7 @@ public void execute() throws ZinggClientException { public ILabelDataViewHelper getLabelDataViewHelper() throws UnsupportedOperationException { return this; } + + } diff --git a/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java b/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java index 4e3365783..cb1fbe6e3 100644 --- a/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java +++ b/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java @@ -1,6 +1,5 @@ package zingg.common.core.executor; -import java.util.List; import java.util.Scanner; import org.apache.commons.logging.Log; @@ -8,7 +7,8 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.cols.ZidAndFieldDefSelector; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.pipe.Pipe; import zingg.common.client.util.ColName; import zingg.common.core.util.LabelMatchType; @@ -19,7 +19,7 @@ public abstract class LabelUpdater extends Labeller { public static final Log LOG = LogFactory.getLog(LabelUpdater.class); public LabelUpdater() { - setZinggOptions(ZinggOptions.UPDATE_LABEL); + setZinggOption(ZinggOptions.UPDATE_LABEL); } public void execute() throws ZinggClientException { @@ -125,14 +125,14 @@ protected ZFrame getUpdatedRecords(ZFrame updatedRecords, int } protected int getUserInput(ZFrame lines,ZFrame currentPair,String cluster_id) { - - List displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); - +// List displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); int matchFlag = currentPair.getAsInt(currentPair.head(),ColName.MATCH_FLAG_COL); String preMsg = String.format("\n\tThe record pairs belonging to the input cluster id %s are:", cluster_id); String matchType = LabelMatchType.get(matchFlag).msg; String postMsg = String.format("\tThe above pair is labeled as %s\n", matchType); - int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg); +// int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg); + int selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), preMsg, postMsg); getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT); getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT); getLabelDataViewHelper().printMarkedRecordsStat( @@ -154,4 +154,4 @@ protected Pipe getOutputPipe() { } protected abstract Pipe setSaveModeOnPipe(Pipe p); -} \ No newline at end of file +} diff --git a/common/core/src/main/java/zingg/common/core/executor/Labeller.java b/common/core/src/main/java/zingg/common/core/executor/Labeller.java index 7c9575c25..3c496445f 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Labeller.java +++ b/common/core/src/main/java/zingg/common/core/executor/Labeller.java @@ -10,7 +10,8 @@ import zingg.common.client.ITrainingDataModel; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.cols.ZidAndFieldDefSelector; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; public abstract class Labeller extends ZinggBase { @@ -24,7 +25,7 @@ public abstract class Labeller extends ZinggBase { protected ILabelDataViewHelper labelDataViewHelper; public Labeller() { - setZinggOptions(ZinggOptions.LABEL); + setZinggOption(ZinggOptions.LABEL); } public void execute() throws ZinggClientException { @@ -79,7 +80,8 @@ public ZFrame processRecordsCli(ZFrame lines) throws ZinggClientE ); lines = lines.cache(); - List displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args); +// List displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); //have to introduce as snowframe can not handle row.getAs with column //name and row and lines are out of order for the code to work properly //snow getAsString expects row to have same struc as dataframe which is @@ -104,7 +106,8 @@ public ZFrame processRecordsCli(ZFrame lines) throws ZinggClientE msg2 = getLabelDataViewHelper().getMsg2(prediction, score); //String msgHeader = msg1 + msg2; - selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2); +// selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2); + selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), msg1, msg2); getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT); getLabelDataViewHelper().printMarkedRecordsStat( getTrainingDataModel().getPositivePairsCount(), @@ -158,7 +161,7 @@ int readCliInput() { @Override public ITrainingDataModel getTrainingDataModel() { if (trainingDataModel==null) { - this.trainingDataModel = new TrainingDataModel(getContext(), getZinggOptions(), getClientOptions()); + this.trainingDataModel = new TrainingDataModel(getContext(), getClientOptions()); } return trainingDataModel; } @@ -170,7 +173,7 @@ public void setTrainingDataModel(ITrainingDataModel trainingDataMode @Override public ILabelDataViewHelper getLabelDataViewHelper() { if(labelDataViewHelper==null) { - labelDataViewHelper = new LabelDataViewHelper(getContext(), getZinggOptions(), getClientOptions()); + labelDataViewHelper = new LabelDataViewHelper(getContext(), getClientOptions()); } return labelDataViewHelper; } diff --git a/common/core/src/main/java/zingg/common/core/executor/Linker.java b/common/core/src/main/java/zingg/common/core/executor/Linker.java index 797bb59bc..c271a2161 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Linker.java +++ b/common/core/src/main/java/zingg/common/core/executor/Linker.java @@ -5,36 +5,46 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; -import zingg.common.client.util.ColValues; +import zingg.common.core.filter.PredictionFilter; +import zingg.common.core.pairs.SelfPairBuilderSourceSensitive; public abstract class Linker extends Matcher { + private static final long serialVersionUID = 1L; protected static String name = "zingg.Linker"; public static final Log LOG = LogFactory.getLog(Linker.class); public Linker() { - setZinggOptions(ZinggOptions.LINK); + setZinggOption(ZinggOptions.LINK); } - - protected ZFrame getBlocks(ZFrame blocked, ZFrame bAll) throws Exception{ - // THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE - LOG.info("in getBlocks, blocked count is " + blocked.count()); - return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache(); - } - - protected ZFrame selectColsFromBlocked(ZFrame blocked) { + + @Override + public ZFrame selectColsFromBlocked(ZFrame blocked) { return blocked; } + + @Override + protected ZFrame getActualDupes(ZFrame blocked, ZFrame testData) throws Exception, ZinggClientException{ + PredictionFilter predictionFilter = new PredictionFilter(); + SelfPairBuilderSourceSensitive iPairBuilder = getPairBuilderSourceSensitive(); + return getActualDupes(blocked, testData,predictionFilter, iPairBuilder, null); + } + protected SelfPairBuilderSourceSensitive getPairBuilderSourceSensitive() { + return new SelfPairBuilderSourceSensitive (getDSUtil(),args); + } + + @Override public void writeOutput(ZFrame sampleOrginal, ZFrame dupes) throws ZinggClientException { try { // input dupes are pairs /// pick ones according to the threshold by user - ZFrame dupesActual = getDupesActualForGraph(dupes); + PredictionFilter predictionFilter = new PredictionFilter(); + ZFrame dupesActual = predictionFilter.filter(dupes); // all clusters consolidated in one place if (args.getOutput() != null) { @@ -46,19 +56,11 @@ public void writeOutput(ZFrame sampleOrginal, ZFrame dupes) throws ZFramedupes2 = getDSUtil().alignLinked(dupesActual, args); dupes2 = getDSUtil().postprocessLinked(dupes2, sampleOrginal); LOG.debug("uncertain output schema is " + dupes2.showSchema()); - getPipeUtil().write(dupes2, args, args.getOutput()); + getPipeUtil().write(dupes2, args.getOutput()); } } catch (Exception e) { e.printStackTrace(); } } - protected ZFrame getDupesActualForGraph(ZFrame dupes) { - ZFrame dupesActual = dupes - .filter(dupes.equalTo(ColName.PREDICTION_COL, ColValues.IS_MATCH_PREDICTION)); - return dupesActual; - } - - - } diff --git a/common/core/src/main/java/zingg/common/core/executor/Matcher.java b/common/core/src/main/java/zingg/common/core/executor/Matcher.java index 95e6df68b..88a16cd10 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Matcher.java +++ b/common/core/src/main/java/zingg/common/core/executor/Matcher.java @@ -8,12 +8,17 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.cols.PredictionColsSelector; +import zingg.common.client.cols.ZidAndFieldDefSelector; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; -import zingg.common.client.util.ColValues; import zingg.common.core.block.Canopy; import zingg.common.core.block.Tree; +import zingg.common.core.filter.IFilter; +import zingg.common.core.filter.PredictionFilter; import zingg.common.core.model.Model; +import zingg.common.core.pairs.IPairBuilder; +import zingg.common.core.pairs.SelfPairBuilder; import zingg.common.core.preprocess.StopWordsRemover; import zingg.common.core.util.Analytics; import zingg.common.core.util.Metric; @@ -24,55 +29,32 @@ public abstract class Matcher extends ZinggBase{ protected static String name = "zingg.Matcher"; public static final Log LOG = LogFactory.getLog(Matcher.class); - public Matcher() { - setZinggOptions(ZinggOptions.MATCH); + setZinggOption(ZinggOptions.MATCH); } - protected ZFrame getTestData() throws ZinggClientException{ + public ZFrame getTestData() throws ZinggClientException{ ZFrame data = getPipeUtil().read(true, true, args.getNumPartitions(), true, args.getData()); return data; } - protected ZFrame getFieldDefColumnsDS(ZFrame testDataOriginal) { - return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true); + public ZFrame getFieldDefColumnsDS(ZFrame testDataOriginal) { + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition()); + return testDataOriginal.select(zidAndFieldDefSelector.getCols()); +// return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true); } - protected ZFrame getBlocked( ZFrame testData) throws Exception, ZinggClientException{ + public ZFrame getBlocked( ZFrame testData) throws Exception, ZinggClientException{ LOG.debug("Blocking model file location is " + args.getBlockFile()); Tree> tree = getBlockingTreeUtil().readBlockingTree(args); ZFrame blocked = getBlockingTreeUtil().getBlockHashes(testData, tree); - ZFrame blocked1 = blocked.repartition(args.getNumPartitions(), blocked.col(ColName.HASH_COL)); //.cache(); + ZFrame blocked1 = blocked.repartition(args.getNumPartitions(), blocked.col(ColName.HASH_COL)).cache(); return blocked1; } - - - protected ZFrame getBlocks(ZFrameblocked) throws Exception{ - return getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache(); - } - - protected ZFrame getBlocks(ZFrameblocked, ZFramebAll) throws Exception{ - ZFramejoinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache(); - /*ZFramejoinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL) - .selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid"); - */ - //joinH.show(); - joinH = joinH.filter(joinH.gt(ColName.ID_COL)); - LOG.warn("Num comparisons " + joinH.count()); - joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL)); - bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL)); - joinH = joinH.joinOnCol(bAll, ColName.ID_COL); - LOG.warn("Joining with actual values"); - //joinH.show(); - bAll = getDSUtil().getPrefixedColumnsDS(bAll); - //bAll.show(); - joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL)); - joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL); - LOG.warn("Joining again with actual values"); - //joinH.show(); - return joinH; + public ZFrame getPairs(ZFrameblocked, ZFramebAll, IPairBuilder iPairBuilder) throws Exception{ + return iPairBuilder.getPairs(blocked, bAll); } protected abstract Model getModel() throws ZinggClientException; @@ -94,11 +76,22 @@ protected ZFrame predictOnBlocks(ZFrameblocks) throws Exception, Z } protected ZFrame getActualDupes(ZFrame blocked, ZFrame testData) throws Exception, ZinggClientException{ - ZFrame blocks = getBlocks(selectColsFromBlocked(blocked), testData); - ZFramedupesActual = predictOnBlocks(blocks); - return getDupesActualForGraph(dupesActual); + PredictionFilter predictionFilter = new PredictionFilter(); + SelfPairBuilder iPairBuilder = new SelfPairBuilder (getDSUtil(),args); + return getActualDupes(blocked, testData,predictionFilter, iPairBuilder,new PredictionColsSelector()); } + protected ZFrame getActualDupes(ZFrame blocked, ZFrame testData, + IFilter predictionFilter, IPairBuilder iPairBuilder, PredictionColsSelector colsSelector) throws Exception, ZinggClientException{ + ZFrame blocks = getPairs(selectColsFromBlocked(blocked), testData, iPairBuilder); + ZFramedupesActual = predictOnBlocks(blocks); + ZFrame filteredData = predictionFilter.filter(dupesActual); + if(colsSelector!=null) { + filteredData = filteredData.select(colsSelector.getCols()); + } + return filteredData; + } + @Override public void execute() throws ZinggClientException { try { @@ -149,7 +142,7 @@ public void writeOutput( ZFrame blocked, ZFrame dupesActual) th //all clusters consolidated in one place if (args.getOutput() != null) { ZFrame graphWithScores = getOutput(blocked, dupesActual); - getPipeUtil().write(graphWithScores, args, args.getOutput()); + getPipeUtil().write(graphWithScores, args.getOutput()); } } catch(Exception e) { @@ -160,7 +153,7 @@ public void writeOutput( ZFrame blocked, ZFrame dupesActual) th - protected ZFrame getOutput(ZFrame blocked, ZFrame dupesActual) throws Exception { + protected ZFrame getOutput(ZFrame blocked, ZFrame dupesActual) throws ZinggClientException, Exception { //-1 is initial suggestion, 1 is add, 0 is deletion, 2 is unsure /*blocked = blocked.drop(ColName.HASH_COL); blocked = blocked.drop(ColName.SOURCE_COL); @@ -269,25 +262,7 @@ protected ZFrame getGraphWithScores(ZFrame graph, ZFrame getDupesActualForGraph(ZFramedupes) { - dupes = selectColsFromDupes(dupes); - LOG.debug("dupes al"); - if (LOG.isDebugEnabled()) dupes.show(); - return dupes.filter(dupes.equalTo(ColName.PREDICTION_COL,ColValues.IS_MATCH_PREDICTION)); - } - - protected ZFrame selectColsFromDupes(ZFramedupesActual) { - List cols = new ArrayList(); - cols.add(dupesActual.col(ColName.ID_COL)); - cols.add(dupesActual.col(ColName.COL_PREFIX + ColName.ID_COL)); - cols.add(dupesActual.col(ColName.PREDICTION_COL)); - cols.add(dupesActual.col(ColName.SCORE_COL)); - ZFrame dupesActual1 = dupesActual.select(cols); //.cache(); - return dupesActual1; - } - protected abstract StopWordsRemover getStopWords(); - } diff --git a/common/core/src/main/java/zingg/common/core/executor/Recommender.java b/common/core/src/main/java/zingg/common/core/executor/Recommender.java index 7119a1182..cc870c41a 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Recommender.java +++ b/common/core/src/main/java/zingg/common/core/executor/Recommender.java @@ -4,7 +4,7 @@ import org.apache.commons.logging.LogFactory; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.options.ZinggOptions; import zingg.common.core.recommender.StopWordsRecommender; public abstract class Recommender extends ZinggBase { @@ -13,7 +13,7 @@ public abstract class Recommender extends ZinggBase { public static final Log LOG = LogFactory.getLog(Recommender.class); public Recommender() { - setZinggOptions(ZinggOptions.RECOMMEND); + setZinggOption(ZinggOptions.RECOMMEND); } public void execute() throws ZinggClientException { diff --git a/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java b/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java index e6521ec21..b4fdfc97e 100644 --- a/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java +++ b/common/core/src/main/java/zingg/common/core/executor/TrainMatcher.java @@ -5,8 +5,7 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; public abstract class TrainMatcher extends ZinggBase{ @@ -18,15 +17,15 @@ public abstract class TrainMatcher extends ZinggBase{ protected Matcher matcher; public TrainMatcher() { - setZinggOptions(ZinggOptions.TRAIN_MATCH); + setZinggOption(ZinggOptions.TRAIN_MATCH); } @Override - public void init(IArguments args, IZinggLicense license) + public void init(IArguments args, S s) throws ZinggClientException { - trainer.init(args, license); - matcher.init(args, license); - super.init(args, license); + trainer.init(args,s); + matcher.init(args,s); + super.init(args,s); } @Override diff --git a/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java b/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java index 018db64cc..3c2919688 100644 --- a/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java +++ b/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java @@ -1,11 +1,14 @@ package zingg.common.core.executor; +import java.util.Arrays; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.cols.ZidAndFieldDefSelector; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.pipe.Pipe; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; @@ -22,7 +25,7 @@ public abstract class TrainingDataFinder extends ZinggBase public TrainingDataFinder() { - setZinggOptions(ZinggOptions.FIND_TRAINING_DATA); + setZinggOption(ZinggOptions.FIND_TRAINING_DATA); } public ZFrame getTraining() throws ZinggClientException { @@ -79,7 +82,7 @@ public void execute() throws ZinggClientException { if (negPairs!= null) negPairs = negPairs.cache(); //create random samples for blocking ZFrame sampleOrginal = data.sample(false, args.getLabelDataSampleSize()).repartition(args.getNumPartitions()).cache(); - sampleOrginal = getDSUtil().getFieldDefColumnsDS(sampleOrginal, args, true); + sampleOrginal = getFieldDefColumnsDS(sampleOrginal); LOG.info("Preprocessing DS for stopWords"); ZFrame sample = getStopWords().preprocessForStopWords(sampleOrginal); @@ -155,7 +158,7 @@ public void writeUncertain(ZFrame dupesActual, ZFrame sampleOrgina dupes1 = getDSUtil().postprocess(dupes1, sampleOrginal); ZFrame dupes2 = dupes1.orderBy(ColName.CLUSTER_COLUMN); //LOG.debug("uncertain output schema is " + dupes2.schema()); - getPipeUtil().write(dupes2 , args, getUnmarkedLocation()); + getPipeUtil().write(dupes2 , getUnmarkedLocation()); //PipeUtil.write(jdbc, massageForJdbc(dupes2.cache()) , args, ctx); } @@ -188,7 +191,7 @@ public ZFrame getPositiveSamples(ZFrame data) throws Exception { } ZFrame posSample = data.sample(false, args.getLabelDataSampleSize()); //select only those columns which are mentioned in the field definitions - posSample = getDSUtil().getFieldDefColumnsDS(posSample, args, true); + posSample = getFieldDefColumnsDS(posSample); if (LOG.isDebugEnabled()) { LOG.debug("Sampled " + posSample.count()); } @@ -202,8 +205,13 @@ public ZFrame getPositiveSamples(ZFrame data) throws Exception { return posPairs; } + protected ZFrame getFieldDefColumnsDS(ZFrame data) { + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition()); + String[] cols = zidAndFieldDefSelector.getCols(); + return data.select(cols); + //return getDSUtil().getFieldDefColumnsDS(data, args, true); + } + protected abstract StopWordsRemover getStopWords(); - - } diff --git a/common/core/src/main/java/zingg/common/core/executor/TrainingDataModel.java b/common/core/src/main/java/zingg/common/core/executor/TrainingDataModel.java index c11f75fc0..6cde5bce1 100644 --- a/common/core/src/main/java/zingg/common/core/executor/TrainingDataModel.java +++ b/common/core/src/main/java/zingg/common/core/executor/TrainingDataModel.java @@ -8,11 +8,11 @@ import zingg.common.client.ITrainingDataModel; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.pipe.Pipe; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public class TrainingDataModel extends ZinggBase implements ITrainingDataModel{ @@ -21,9 +21,8 @@ public class TrainingDataModel extends ZinggBase imple private long positivePairsCount, negativePairsCount, notSurePairsCount; private long totalCount; - public TrainingDataModel(Context context, ZinggOptions zinggOptions, ClientOptions clientOptions) { + public TrainingDataModel(Context context, ClientOptions clientOptions) { setContext(context); - setZinggOptions(zinggOptions); setClientOptions(clientOptions); setName(this.getClass().getName()); } @@ -79,7 +78,7 @@ public void writeLabelledOutput(ZFrame records, IArguments args, Pipe p) LOG.warn("No labelled records"); return; } - getPipeUtil().write(records, args,p); + getPipeUtil().write(records, p); } public Pipe getOutputPipe(IArguments args) { @@ -120,4 +119,5 @@ public long getTotalCount() { + } diff --git a/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java b/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java index efbd2fada..b9f07c32f 100644 --- a/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java +++ b/common/core/src/main/java/zingg/common/core/executor/ZinggBase.java @@ -13,19 +13,19 @@ import zingg.common.client.MatchType; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOption; +import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; -import zingg.common.core.Context; +import zingg.common.client.util.DSUtil; +import zingg.common.client.util.PipeUtilBase; +import zingg.common.core.context.Context; import zingg.common.core.util.Analytics; import zingg.common.core.util.BlockingTreeUtil; -import zingg.common.core.util.DSUtil; import zingg.common.core.util.GraphUtil; import zingg.common.core.util.HashUtil; import zingg.common.core.util.Metric; import zingg.common.core.util.ModelUtil; -import zingg.common.core.util.PipeUtilBase; public abstract class ZinggBase implements Serializable, IZingg { @@ -34,7 +34,7 @@ public abstract class ZinggBase implements Serializable, IZingg context; protected String name; - protected ZinggOptions zinggOptions; + protected ZinggOption zinggOption; protected long startTime; protected ClientOptions clientOptions; @@ -62,13 +62,12 @@ public ZinggBase() { } - - public void init(IArguments args, IZinggLicense license) + @Override + public void init(IArguments args, S session) throws ZinggClientException { startTime = System.currentTimeMillis(); this.args = args; - - } + } public void setSession(S s) { @@ -85,8 +84,10 @@ public void postMetrics() { collectMetrics); Analytics.track(Metric.DATA_FORMAT, getPipeUtil().getPipesAsString(args.getData()), collectMetrics); Analytics.track(Metric.OUTPUT_FORMAT, getPipeUtil().getPipesAsString(args.getOutput()), collectMetrics); + + Analytics.track(Metric.MODEL_ID, args.getModelId(), collectMetrics); - Analytics.track(Metric.ZINGG_VERSION, "0.4.0", collectMetrics); + Analytics.track(Metric.ZINGG_VERSION, "0.4.1-SNAPSHOT", collectMetrics); Analytics.trackEnvProp(Metric.DATABRICKS_RUNTIME_VERSION, collectMetrics); Analytics.trackEnvProp(Metric.DB_INSTANCE_TYPE, collectMetrics); Analytics.trackEnvProp(Metric.JAVA_HOME, collectMetrics); @@ -96,8 +97,8 @@ public void postMetrics() { //Analytics.trackEnvProp(Metric.USER_NAME, collectMetrics); //Analytics.trackEnvProp(Metric.USER_HOME, collectMetrics); Analytics.trackDomain(Metric.DOMAIN, collectMetrics); - Analytics.track(Metric.ZINGG_VERSION, "0.4.0", collectMetrics); - Analytics.postEvent(zinggOptions.getValue(), collectMetrics); + Analytics.track(Metric.ZINGG_VERSION, "0.4.1-SNAPSHOT", collectMetrics); + Analytics.postEvent(zinggOption.getName(), collectMetrics); } public IArguments getArgs() { @@ -121,17 +122,20 @@ public void setContext(Context source) { public void setName(String name) { this.name = name; } - public void setZinggOptions(ZinggOptions zinggOptions) { - this.zinggOptions = zinggOptions; + + public void setZinggOption(ZinggOption zinggOptions) { + this.zinggOption = zinggOptions; } + public String getName() { return name; } + /* public ZinggOptions getZinggOptions() { return zinggOptions; - } + }*/ public ZFrame getMarkedRecords() { try { diff --git a/common/core/src/main/java/zingg/common/core/feature/DateFeature.java b/common/core/src/main/java/zingg/common/core/feature/DateFeature.java index 7809c3b6f..230d81972 100644 --- a/common/core/src/main/java/zingg/common/core/feature/DateFeature.java +++ b/common/core/src/main/java/zingg/common/core/feature/DateFeature.java @@ -4,10 +4,14 @@ import zingg.common.client.FieldDefinition; import zingg.common.client.MatchType; +import zingg.common.core.similarity.function.CheckNullFunction; import zingg.common.core.similarity.function.DateSimilarityFunction; +import zingg.common.core.similarity.function.SimilarityFunctionExact; public class DateFeature extends BaseFeature { + private static final long serialVersionUID = 1L; + public DateFeature() { } @@ -28,6 +32,12 @@ public void init(FieldDefinition f) { if (f.getMatchType().contains(MatchType.FUZZY)) { addSimFunction(new DateSimilarityFunction()); } + if (f.getMatchType().contains(MatchType.EXACT)) { + addSimFunction(new SimilarityFunctionExact("DateSimilarityFunctionExact")); + } + if (f.getMatchType().contains(MatchType.NULL_OR_BLANK)) { + addSimFunction(new CheckNullFunction("CheckNullFunctionDate")); + } } } diff --git a/common/core/src/main/java/zingg/common/core/feature/IntFeature.java b/common/core/src/main/java/zingg/common/core/feature/IntFeature.java index a178ba5ea..a28fa2833 100644 --- a/common/core/src/main/java/zingg/common/core/feature/IntFeature.java +++ b/common/core/src/main/java/zingg/common/core/feature/IntFeature.java @@ -2,9 +2,13 @@ import zingg.common.client.FieldDefinition; import zingg.common.client.MatchType; +import zingg.common.core.similarity.function.CheckNullFunction; import zingg.common.core.similarity.function.IntegerSimilarityFunction; +import zingg.common.core.similarity.function.SimilarityFunctionExact; public class IntFeature extends BaseFeature { + private static final long serialVersionUID = 1L; + public IntFeature() { } @@ -14,6 +18,12 @@ public void init(FieldDefinition newParam) { if (newParam.getMatchType().contains(MatchType.FUZZY)) { addSimFunction(new IntegerSimilarityFunction()); } + if (newParam.getMatchType().contains(MatchType.EXACT)) { + addSimFunction(new SimilarityFunctionExact("IntegerSimilarityFunctionExact")); + } + if (newParam.getMatchType().contains(MatchType.NULL_OR_BLANK)) { + addSimFunction(new CheckNullFunction("CheckNullFunctionInt")); + } } } diff --git a/common/core/src/main/java/zingg/common/core/feature/LongFeature.java b/common/core/src/main/java/zingg/common/core/feature/LongFeature.java index 8c3a3c5b4..81bf7261a 100644 --- a/common/core/src/main/java/zingg/common/core/feature/LongFeature.java +++ b/common/core/src/main/java/zingg/common/core/feature/LongFeature.java @@ -2,7 +2,9 @@ import zingg.common.client.FieldDefinition; import zingg.common.client.MatchType; +import zingg.common.core.similarity.function.CheckNullFunction; import zingg.common.core.similarity.function.LongSimilarityFunction; +import zingg.common.core.similarity.function.SimilarityFunctionExact; public class LongFeature extends BaseFeature { private static final long serialVersionUID = 1L; @@ -16,6 +18,12 @@ public void init(FieldDefinition newParam) { if (newParam.getMatchType().contains(MatchType.FUZZY)) { addSimFunction(new LongSimilarityFunction()); } + if (newParam.getMatchType().contains(MatchType.EXACT)) { + addSimFunction(new SimilarityFunctionExact("LongSimilarityFunctionExact")); + } + if (newParam.getMatchType().contains(MatchType.NULL_OR_BLANK)) { + addSimFunction(new CheckNullFunction("CheckNullFunctionLong")); + } } } diff --git a/common/core/src/main/java/zingg/common/core/filter/IFilter.java b/common/core/src/main/java/zingg/common/core/filter/IFilter.java new file mode 100644 index 000000000..70d6b8eec --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/filter/IFilter.java @@ -0,0 +1,9 @@ +package zingg.common.core.filter; + +import zingg.common.client.ZFrame; + +public interface IFilter { + + public ZFrame filter(ZFrame df); + +} diff --git a/common/core/src/main/java/zingg/common/core/filter/PredictionFilter.java b/common/core/src/main/java/zingg/common/core/filter/PredictionFilter.java new file mode 100644 index 000000000..8affb1f76 --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/filter/PredictionFilter.java @@ -0,0 +1,29 @@ +package zingg.common.core.filter; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.ZFrame; +import zingg.common.client.util.ColName; +import zingg.common.client.util.ColValues; + +public class PredictionFilter implements IFilter { + + public static final Log LOG = LogFactory.getLog(PredictionFilter.class); + + public PredictionFilter() { + super(); + } + + @Override + public ZFrame filter(ZFrame dupes) { + dupes = filterMatches(dupes); + return dupes; + } + + protected ZFrame filterMatches(ZFrame dupes) { + dupes = dupes.filter(dupes.equalTo(ColName.PREDICTION_COL,ColValues.IS_MATCH_PREDICTION)); + return dupes; + } + +} diff --git a/common/core/src/main/java/zingg/common/core/hash/FirstChars.java b/common/core/src/main/java/zingg/common/core/hash/FirstChars.java index 116b67cc9..78ad3042d 100644 --- a/common/core/src/main/java/zingg/common/core/hash/FirstChars.java +++ b/common/core/src/main/java/zingg/common/core/hash/FirstChars.java @@ -32,7 +32,7 @@ public String call(String field) { r = field.trim().substring(0, endIndex); } } - LOG.debug("Applying " + this.getName() + " on " + field + " and returning " + r); + //LOG.debug("Applying " + this.getName() + " on " + field + " and returning " + r); return r; } diff --git a/common/core/src/main/java/zingg/common/core/model/Model.java b/common/core/src/main/java/zingg/common/core/model/Model.java index ef086514b..461a4103d 100644 --- a/common/core/src/main/java/zingg/common/core/model/Model.java +++ b/common/core/src/main/java/zingg/common/core/model/Model.java @@ -4,31 +4,59 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import zingg.common.client.FieldDefinition; + +import zingg.common.client.Arguments; import zingg.common.client.ZFrame; -import zingg.common.core.feature.Feature; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.ColName; public abstract class Model implements Serializable { public static final Log LOG = LogFactory.getLog(Model.class); - //private Map featurers; + protected List columnsAdded = new ArrayList(); + protected S session; + + public void setSession(S s){ + this.session = s; + } - public Model() { + public S getSession(){ + return session; + } + + public Model(S s){ + this.session = s; } - public abstract void register(S spark) ; + public abstract void register() ; + + + protected String getColumnName(String fieldName, String fnName, int count) { + return ColName.SIM_COL + count; + } + + + public List getColumnsAdded() { + return columnsAdded; + } + + + public void setColumnsAdded(List columnsAdded) { + this.columnsAdded = columnsAdded; + } + public static double[] getGrid(double begin, double end, double jump, boolean isMultiple) { List alphaList = new ArrayList(); if (isMultiple) { for (double alpha =begin; alpha <= end; alpha *= jump) { alphaList.add(alpha); } + } else { for (double alpha =begin; alpha <= end; alpha += jump) { @@ -42,18 +70,31 @@ public static double[] getGrid(double begin, double end, double jump, boolean is return retArr; } - public abstract void fit(ZFrame pos, ZFrame neg); + public abstract void fit(ZFrame pos, ZFrame neg) throws ZinggClientException; public abstract void load(String path); + + protected abstract ZFrame fitCore(ZFrame pos, ZFrame neg); + + public abstract ZFrame predict(ZFrame data) throws ZinggClientException; + public abstract ZFrame predict(ZFrame data, boolean isDrop) throws ZinggClientException ; - public abstract ZFrame predict(ZFrame data); - - public abstract ZFrame predict(ZFrame data, boolean isDrop) ; + //this will do the prediction but not drop the columns + protected abstract ZFrame predictCore(ZFrame data); public abstract void save(String path) throws IOException; public abstract ZFrame transform(ZFrame input); + + public ZFrame dropFeatureCols(ZFrame predictWithFeatures, boolean isDrop){ + if (isDrop) { + ZFrame returnDS = predictWithFeatures.drop(columnsAdded.toArray(new String[columnsAdded.size()])); + //LOG.debug("Return schema after dropping additional columns is " + returnDS.schema()); + return returnDS; //new SparkFrame(returnDS); + } + return predictWithFeatures; + } } diff --git a/common/core/src/main/java/zingg/common/core/pairs/IPairBuilder.java b/common/core/src/main/java/zingg/common/core/pairs/IPairBuilder.java new file mode 100644 index 000000000..235483818 --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/pairs/IPairBuilder.java @@ -0,0 +1,9 @@ +package zingg.common.core.pairs; + +import zingg.common.client.ZFrame; + +public interface IPairBuilder { + + public ZFrame getPairs(ZFrameblocked, ZFramebAll) throws Exception; + +} diff --git a/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilder.java b/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilder.java new file mode 100644 index 000000000..2e9e261db --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilder.java @@ -0,0 +1,55 @@ +package zingg.common.core.pairs; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.IArguments; +import zingg.common.client.ZFrame; +import zingg.common.client.util.ColName; +import zingg.common.client.util.DSUtil; + +public class SelfPairBuilder implements IPairBuilder { + + protected DSUtil dsUtil; + public static final Log LOG = LogFactory.getLog(SelfPairBuilder.class); + protected IArguments args; + + public SelfPairBuilder(DSUtil dsUtil, IArguments args) { + this.dsUtil = dsUtil; + this.args = args; + } + + @Override + public ZFrame getPairs(ZFrameblocked, ZFramebAll) throws Exception { + ZFramejoinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache(); + /*ZFramejoinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL) + .selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid"); + */ + //joinH.show(); + joinH = joinH.filter(joinH.gt(ColName.ID_COL)); + if (LOG.isDebugEnabled()) LOG.debug("Num comparisons " + joinH.count()); + joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL)); + bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL)); + joinH = joinH.joinOnCol(bAll, ColName.ID_COL); + LOG.warn("Joining with actual values"); + //joinH.show(); + bAll = getDSUtil().getPrefixedColumnsDS(bAll); + //bAll.show(); + joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL)); + joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL); + LOG.warn("Joining again with actual values"); + //joinH.show(); + return joinH; + } + + public DSUtil getDSUtil() { + return dsUtil; + } + + public void setDSUtil(DSUtil dsUtil) { + this.dsUtil = dsUtil; + } + + + +} diff --git a/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilderSourceSensitive.java b/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilderSourceSensitive.java new file mode 100644 index 000000000..293eb162c --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/pairs/SelfPairBuilderSourceSensitive.java @@ -0,0 +1,26 @@ +package zingg.common.core.pairs; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.IArguments; +import zingg.common.client.ZFrame; +import zingg.common.client.util.ColName; +import zingg.common.client.util.DSUtil; + +public class SelfPairBuilderSourceSensitive extends SelfPairBuilder { + + public static final Log LOG = LogFactory.getLog(SelfPairBuilderSourceSensitive.class); + + public SelfPairBuilderSourceSensitive(DSUtil dsUtil, IArguments args) { + super(dsUtil, args); + } + + @Override + public ZFrame getPairs(ZFrame blocked, ZFrame bAll) throws Exception{ + // THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE + LOG.info("in getBlocks, blocked count is " + blocked.count()); + return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache(); + } + +} diff --git a/common/core/src/main/java/zingg/common/core/preprocess/StopWords.java b/common/core/src/main/java/zingg/common/core/preprocess/StopWords.java index ea42b7401..8e1511489 100644 --- a/common/core/src/main/java/zingg/common/core/preprocess/StopWords.java +++ b/common/core/src/main/java/zingg/common/core/preprocess/StopWords.java @@ -7,7 +7,7 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ColName; -import zingg.common.core.util.PipeUtilBase; +import zingg.common.client.util.PipeUtilBase; public class StopWords { diff --git a/common/core/src/main/java/zingg/common/core/preprocess/StopWordsRemover.java b/common/core/src/main/java/zingg/common/core/preprocess/StopWordsRemover.java index b45c6d250..9742426c4 100644 --- a/common/core/src/main/java/zingg/common/core/preprocess/StopWordsRemover.java +++ b/common/core/src/main/java/zingg/common/core/preprocess/StopWordsRemover.java @@ -13,8 +13,8 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ColName; -import zingg.common.core.Context; -import zingg.common.core.util.PipeUtilBase; +import zingg.common.client.util.PipeUtilBase; +import zingg.common.core.context.Context; public abstract class StopWordsRemover implements Serializable{ @@ -65,7 +65,7 @@ protected String getStopWordColumnName(ZFrame stopWords) { } protected List getWordList(ZFrame stopWords, String stopWordColumn) { - return stopWords.select(stopWordColumn).collectAsListOfStrings(); + return stopWords.select(stopWordColumn).collectFirstColumn(); } /** diff --git a/common/core/src/main/java/zingg/common/core/recommender/StopWordsRecommender.java b/common/core/src/main/java/zingg/common/core/recommender/StopWordsRecommender.java index 9e29c6ba7..b09ac7556 100644 --- a/common/core/src/main/java/zingg/common/core/recommender/StopWordsRecommender.java +++ b/common/core/src/main/java/zingg/common/core/recommender/StopWordsRecommender.java @@ -10,7 +10,7 @@ import zingg.common.client.ZinggClientException; import zingg.common.client.ZinggClientException; import zingg.common.client.util.ColName; -import zingg.common.core.Context; +import zingg.common.core.context.Context; public abstract class StopWordsRecommender { private static final String REGEX_WHITESPACE = "\\s+"; @@ -49,9 +49,9 @@ public void createStopWordsDocuments(ZFrame data, String fieldName) throw if(Arrays.asList(data.columns()).contains(args.getColumn())) { String filenameCSV = args.getStopWordsDir() + fieldName; data = findStopWords(data, fieldName); - context.getPipeUtil().write(data, args, context.getPipeUtil().getStopWordsPipe(args, filenameCSV)); + context.getPipeUtil().write(data, context.getPipeUtil().getStopWordsPipe(args, filenameCSV)); } else { - LOG.info("An invalid column name - " + args.getColumn() + " entered. Please provide valid column name."); + LOG.info("An invalid column name - " + args.getColumn() + " entered. Please provide valid column name, as per the config (they are case sensitive)."); } } else { LOG.info("Please provide '--column ' option at command line to generate stop words for that column."); diff --git a/common/core/src/main/java/zingg/common/core/similarity/function/CheckNullFunction.java b/common/core/src/main/java/zingg/common/core/similarity/function/CheckNullFunction.java new file mode 100644 index 000000000..9a5ffc7f4 --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/similarity/function/CheckNullFunction.java @@ -0,0 +1,26 @@ +package zingg.common.core.similarity.function; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public class CheckNullFunction extends SimFunction { + + private static final long serialVersionUID = 1L; + public static final Log LOG = LogFactory + .getLog(CheckNullFunction.class); + + public CheckNullFunction(String name) { + super(name); + } + + @Override + public Double call(T first, T second) { + if (first != null && second != null) { + return 1d; + } + return 0d; + } + + + +} diff --git a/common/core/src/main/java/zingg/common/core/similarity/function/SimilarityFunctionExact.java b/common/core/src/main/java/zingg/common/core/similarity/function/SimilarityFunctionExact.java new file mode 100644 index 000000000..af1100eec --- /dev/null +++ b/common/core/src/main/java/zingg/common/core/similarity/function/SimilarityFunctionExact.java @@ -0,0 +1,21 @@ +package zingg.common.core.similarity.function; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public class SimilarityFunctionExact extends SimFunction { + private static final long serialVersionUID = 1L; + public static final Log LOG = LogFactory + .getLog(SimilarityFunctionExact.class); + + public SimilarityFunctionExact(String name) { + super(name); + } + + @Override + public Double call(T first, T second) { + if (first == null || second == null) return 1d; + double score = first.equals(second) ? 1d : 0d; + return score; + } +} diff --git a/common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java b/common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java index d2bb54eb9..bab1f04fa 100644 --- a/common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java +++ b/common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java @@ -12,6 +12,7 @@ import zingg.common.client.ZinggClientException; import zingg.common.client.ZFrame; import zingg.common.client.util.ListMap; +import zingg.common.client.util.PipeUtilBase; import zingg.common.client.util.Util; import zingg.common.core.block.Block; import zingg.common.core.block.Canopy; @@ -90,7 +91,7 @@ public Tree> createBlockingTreeFromSample(ZFrame testData, public void writeBlockingTree(Tree> blockingTree, IArguments args) throws Exception, ZinggClientException { byte[] byteArray = Util.convertObjectIntoByteArray(blockingTree); PipeUtilBase pu = getPipeUtil(); - pu.write(getTreeDF(byteArray), args, pu.getBlockingTreePipe(args)); + pu.write(getTreeDF(byteArray), pu.getBlockingTreePipe(args)); } public abstract ZFrame getTreeDF(byte[] tree) ; diff --git a/common/core/src/main/java/zingg/common/core/util/GraphUtil.java b/common/core/src/main/java/zingg/common/core/util/GraphUtil.java index d91b59bbd..69d72db30 100644 --- a/common/core/src/main/java/zingg/common/core/util/GraphUtil.java +++ b/common/core/src/main/java/zingg/common/core/util/GraphUtil.java @@ -1,10 +1,11 @@ package zingg.common.core.util; import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; public interface GraphUtil { - public ZFrame buildGraph(ZFrame vertices, ZFrameedges) ; + public ZFrame buildGraph(ZFrame vertices, ZFrameedges) throws ZinggClientException ; /* diff --git a/common/core/src/main/java/zingg/common/core/util/ModelUtil.java b/common/core/src/main/java/zingg/common/core/util/ModelUtil.java index ed8d0951a..15b1c20b8 100644 --- a/common/core/src/main/java/zingg/common/core/util/ModelUtil.java +++ b/common/core/src/main/java/zingg/common/core/util/ModelUtil.java @@ -22,6 +22,10 @@ public abstract class ModelUtil { public static final Log LOG = LogFactory.getLog(ModelUtil.class); protected Map> featurers; protected S session; + + public ModelUtil(S s) { + this.session = s; + } public abstract FeatureFactory getFeatureFactory(); @@ -72,7 +76,7 @@ public Model createModel(ZFrame positives, + negLabeledPointsWithLabel.count()); } Model model = getModel(isLabel, args); - model.register(session); + model.register(); model.fit(posLabeledPointsWithLabel, negLabeledPointsWithLabel); return model; } diff --git a/common/core/src/test/java/zingg/common/core/block/TestBlockBase.java b/common/core/src/test/java/zingg/common/core/block/TestBlockBase.java new file mode 100644 index 000000000..ecceb6201 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/block/TestBlockBase.java @@ -0,0 +1,100 @@ +package zingg.common.core.block; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +import zingg.common.client.ArgumentsUtil; +import zingg.common.client.FieldDefinition; +import zingg.common.client.IArguments; +import zingg.common.client.MatchType; +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.DFObjectUtil; +import zingg.common.core.util.BlockingTreeUtil; +import zingg.common.core.util.HashUtil; +import zingg.common.core.model.Event; +import zingg.common.core.model.EventPair; +import zingg.common.core.data.EventTestData; + +public abstract class TestBlockBase { + + public ArgumentsUtil argumentsUtil = new ArgumentsUtil(); + public final DFObjectUtil dfObjectUtil; + public final HashUtil hashUtil; + public final BlockingTreeUtil blockingTreeUtil; + + public TestBlockBase(DFObjectUtil dfObjectUtil, HashUtil hashUtil, BlockingTreeUtil blockingTreeUtil) { + this.dfObjectUtil = dfObjectUtil; + this.hashUtil = hashUtil; + this.blockingTreeUtil = blockingTreeUtil; + } + + @Test + public void testTree() throws Throwable { + + // form tree + ZFrame zFrameEvent = dfObjectUtil.getDFFromObjectList(EventTestData.createSampleEventData(), Event.class); + ZFrame zFrameEventCluster = dfObjectUtil.getDFFromObjectList(EventTestData.createSampleClusterEventData(), EventPair.class); + IArguments args = getArguments(); + + Tree> blockingTree = blockingTreeUtil.createBlockingTreeFromSample(zFrameEvent, zFrameEventCluster, 0.5, -1, + args, hashUtil.getHashFunctionList()); + + // primary deciding is unique year so identityInteger should have been picked + Canopy head = blockingTree.getHead(); + assertEquals("identityInteger", head.getFunction().getName()); + blockingTree.toString(); + } + + private IArguments getArguments() throws ZinggClientException { + String configFilePath = Objects.requireNonNull(getClass().getResource("../../../../testFebrl/config.json")).getFile(); + + IArguments args = argumentsUtil.createArgumentsFromJSON(configFilePath, "trainMatch"); + + List fdList = getFieldDefList(); + + args.setFieldDefinition(fdList); + return args; + } + + private List getFieldDefList() { + List fdList = new ArrayList(4); + + FieldDefinition idFD = new FieldDefinition(); + idFD.setDataType("integer"); + idFD.setFieldName("id"); + ArrayList matchTypelistId = new ArrayList(); + matchTypelistId.add(MatchType.DONT_USE); + idFD.setMatchType(matchTypelistId); + fdList.add(idFD); + + ArrayList matchTypelistFuzzy = new ArrayList(); + matchTypelistFuzzy.add(MatchType.FUZZY); + + + FieldDefinition yearFD = new FieldDefinition(); + yearFD.setDataType("integer"); + yearFD.setFieldName("year"); + yearFD.setMatchType(matchTypelistFuzzy); + fdList.add(yearFD); + + FieldDefinition eventFD = new FieldDefinition(); + eventFD.setDataType("string"); + eventFD.setFieldName("event"); + eventFD.setMatchType(matchTypelistFuzzy); + fdList.add(eventFD); + + FieldDefinition commentFD = new FieldDefinition(); + commentFD.setDataType("string"); + commentFD.setFieldName("comment"); + commentFD.setMatchType(matchTypelistFuzzy); + fdList.add(commentFD); + return fdList; + } + +} diff --git a/common/core/src/test/java/zingg/block/TestTree.java b/common/core/src/test/java/zingg/common/core/block/TestTree.java similarity index 79% rename from common/core/src/test/java/zingg/block/TestTree.java rename to common/core/src/test/java/zingg/common/core/block/TestTree.java index 81d5044b6..93898c105 100644 --- a/common/core/src/test/java/zingg/block/TestTree.java +++ b/common/core/src/test/java/zingg/common/core/block/TestTree.java @@ -1,11 +1,7 @@ -package zingg.block; +package zingg.common.core.block; import org.junit.jupiter.api.*; -import zingg.common.core.block.Tree; - -import static org.junit.jupiter.api.Assertions.*; - public class TestTree { @Test diff --git a/common/core/src/test/java/zingg/common/core/data/EventTestData.java b/common/core/src/test/java/zingg/common/core/data/EventTestData.java new file mode 100644 index 000000000..9531b6772 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/data/EventTestData.java @@ -0,0 +1,264 @@ +package zingg.common.core.data; + +import zingg.common.core.model.Event; +import zingg.common.core.model.EventPair; +import zingg.common.core.model.Statement; +import zingg.common.core.model.PostStopWordProcess; +import zingg.common.core.model.PriorStopWordProcess; + +import java.util.ArrayList; +import java.util.List; + +public class EventTestData { + public static List createSampleEventData() { + + int row_id = 1; + List sample = new ArrayList(); + sample.add(new Event(row_id++, 1942, "quit India", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disob", "India")); + sample.add(new Event(row_id++, 1942, "quit India", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidience", "India")); + sample.add(new Event(row_id++, 1942, "quit Hindustan", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JW", "Amritsar")); + sample.add(new Event(row_id++, 1930, "Civil Dis", "India")); + sample.add(new Event(row_id++, 1942, "quit Nation", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add((new Event(row_id++, 1942, "quit N", "Mahatma"))); + sample.add((new Event(row_id++, 1919, "JallianWal", "Punjb"))); + sample.add(new Event(row_id++, 1942, "quit ", "Mahatm")); + sample.add(new Event(row_id++, 1942, "quit Ntn", "Mahama")); + sample.add(new Event(row_id++, 1942, "quit Natin", "Mahaatma")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disob", "India")); + sample.add(new Event(row_id++, 1942, "quit India", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidience", "India")); + sample.add(new Event(row_id++, 1942, "Quit Bharat", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidence", "India")); + sample.add(new Event(row_id++, 1942, "quit Hindustan", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JW", "Amritsar")); + sample.add(new Event(row_id++, 1930, "Civil Dis", "India")); + sample.add(new Event(row_id++, 1942, "quit Nation", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit N", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit ", "Mahatm")); + sample.add(new Event(row_id++, 1942, "quit Ntn", "Mahama")); + sample.add(new Event(row_id++, 1942, "quit Natin", "Mahaatma")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disob", "India")); + sample.add(new Event(row_id++, 1942, "quit India", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidience", "India")); + sample.add(new Event(row_id++, 1942, "Quit Bharat", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidence", "India")); + sample.add(new Event(row_id++, 1942, "quit Hindustan", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JW", "Amritsar")); + sample.add(new Event(row_id++, 1930, "Civil Dis", "India")); + sample.add(new Event(row_id++, 1942, "quit Nation", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit N", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit ", "Mahatm")); + sample.add(new Event(row_id++, 1942, "quit Ntn", "Mahama")); + sample.add(new Event(row_id++, 1942, "quit Natin", "Mahaatma")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disob", "India")); + sample.add(new Event(row_id++, 1942, "quit India", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidience", "India")); + sample.add(new Event(row_id++, 1942, "Quit Bharat", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JallianWala", "Punjab")); + sample.add(new Event(row_id++, 1930, "Civil Disobidence", "India")); + sample.add(new Event(row_id++, 1942, "quit Hindustan", "Mahatma Gandhi")); + sample.add(new Event(row_id++, 1919, "JW", "Amritsar")); + sample.add(new Event(row_id++, 1930, "Civil Dis", "India")); + sample.add(new Event(row_id++, 1942, "quit Nation", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit N", "Mahatma")); + sample.add(new Event(row_id++, 1919, "JallianWal", "Punjb")); + sample.add(new Event(row_id++, 1942, "quit ", "Mahatm")); + sample.add(new Event(row_id++, 1942, "quit Ntn", "Mahama")); + sample.add(new Event(row_id, 1942, "quit Natin", "Mahaatma")); + + return sample; + } + + public static List createSampleClusterEventData() { + + int row_id = 1; + List sample = new ArrayList(); + sample.add(new EventPair(row_id++, 1942, "quit Nation", "Mahatma",1942, "quit Nation", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma",1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit N", "Mahatma", 1942, "quit N", "Mahatma", 1L)); + sample.add(new EventPair(row_id++, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + sample.add(new EventPair(row_id++, 1942, "quit ", "Mahatm", 1942, "quit ", "Mahatm", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Ntn", "Mahama", 1942, "quit Ntn", "Mahama", 1L)); + sample.add(new EventPair(row_id++, 1942, "quit Natin", "Mahaatma", 1942, "quit Natin", "Mahaatma", 1L)); + sample.add(new EventPair(row_id, 1919, "JallianWal", "Punjb", 1919, "JallianWal", "Punjb", 2L)); + + return sample; + } + + public static List getData1Original() { + + List sample = new ArrayList(); + sample.add(new Statement("The zingg is a Spark application")); + sample.add(new Statement("It is very popular in data Science")); + sample.add(new Statement("It is written in Java and Scala")); + sample.add(new Statement("Best of luck to zingg")); + + return sample; + } + + public static List getData1Expected() { + + List sample = new ArrayList(); + sample.add(new Statement("zingg spark application")); + sample.add(new Statement("very popular in data science")); + sample.add(new Statement("written in java and scala")); + sample.add(new Statement("best luck to zingg")); + + return sample; + } + + public static List getData2Original() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "The zingg is a spark application", "two", + "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "It is very popular in Data Science", "Three", "true indeed", + "test")); + sample.add(new PriorStopWordProcess("30", "It is written in java and scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "Best of luck to zingg Mobile/T-Mobile", "Five", "thank you", "test")); + + return sample; + } + + public static List getData2Expected() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "zingg spark application", "two", "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "very popular data science", "Three", "true indeed", "test")); + sample.add(new PriorStopWordProcess("30", "written java scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "best luck to zingg ", "Five", "thank you", "test")); + + return sample; + } + + public static List getData3Original() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "The zingg is a spark application", "two", + "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "It is very popular in Data Science", "Three", "true indeed", + "test")); + sample.add(new PriorStopWordProcess("30", "It is written in java and scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "Best of luck to zingg Mobile/T-Mobile", "Five", "thank you", "test")); + + return sample; + } + + public static List getData3Expected() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "zingg spark application", "two", "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "very popular data science", "Three", "true indeed", "test")); + sample.add(new PriorStopWordProcess("30", "written java scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "best luck to zingg ", "Five", "thank you", "test")); + + return sample; + } + + public static List getData4original() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "The zingg is a spark application", "two", + "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "It is very popular in data science", "Three", "true indeed", + "test")); + sample.add(new PriorStopWordProcess("30", "It is written in java and scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "Best of luck to zingg", "Five", "thank you", "test")); + + return sample; + } + + public static List getData4Expected() { + + List sample = new ArrayList(); + sample.add(new PostStopWordProcess("1648811730857:10", "10", "1.0", "0.555555", "-1", + "The zingg spark application", "two", "Yes. good application", "test")); + sample.add(new PostStopWordProcess("1648811730857:20", "20", "1.0", "1.0", "-1", + "It very popular data science", "Three", "true indeed", "test")); + sample.add(new PostStopWordProcess("1648811730857:30", "30", "1.0", "0.999995", "-1", + "It written java scala", "four", "", "test")); + sample.add(new PostStopWordProcess("1648811730857:40", "40", "1.0", "1.0", "-1", "Best luck zingg", "Five", + "thank", "test")); + + return sample; + } + + public static List getData5Original() { + + List sample = new ArrayList(); + sample.add(new PriorStopWordProcess("10", "The zingg is a spark application", "two", + "Yes. a good application", "test")); + sample.add(new PriorStopWordProcess("20", "It is very popular in data science", "Three", "true indeed", + "test")); + sample.add(new PriorStopWordProcess("30", "It is written in java and scala", "four", "", "test")); + sample.add(new PriorStopWordProcess("40", "Best of luck to zingg", "Five", "thank you", "test")); + + return sample; + } + + public static List getData5Actual() { + + List sample = new ArrayList(); + sample.add(new PostStopWordProcess("1648811730857:10", "10", "1.0", "0.555555", "-1", + "The zingg spark application", "two", "Yes. good application", "test")); + sample.add(new PostStopWordProcess("1648811730857:20", "20", "1.0", "1.0", "-1", + "It very popular data science", "Three", "true indeed", "test")); + sample.add(new PostStopWordProcess("1648811730857:30", "30", "1.0", "0.999995", "-1", + "It written java scala", "four", "", "test")); + sample.add(new PostStopWordProcess("1648811730857:40", "40", "1.0", "1.0", "-1", "Best luck zingg", "Five", + "thank", "test")); + + return sample; + } +} diff --git a/common/core/src/test/java/zingg/common/core/executor/ExecutorTester.java b/common/core/src/test/java/zingg/common/core/executor/ExecutorTester.java new file mode 100644 index 000000000..8addea3f8 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/ExecutorTester.java @@ -0,0 +1,24 @@ +package zingg.common.core.executor; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.ZinggClientException; + +public abstract class ExecutorTester { + + public static final Log LOG = LogFactory.getLog(ExecutorTester.class); + + public ZinggBase executor; + + public ExecutorTester(ZinggBase executor) { + this.executor = executor; + } + + public void execute() throws ZinggClientException { + executor.execute(); + } + + public abstract void validateResults() throws ZinggClientException; + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/JunitLabeller.java b/common/core/src/test/java/zingg/common/core/executor/JunitLabeller.java new file mode 100644 index 000000000..ecdba92f4 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/JunitLabeller.java @@ -0,0 +1,60 @@ +package zingg.common.core.executor; + +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.options.ZinggOptions; +import zingg.common.client.util.ColName; +import zingg.common.client.util.ColValues; +import zingg.common.core.context.Context; + +public class JunitLabeller extends Labeller { + + private static final long serialVersionUID = 1L; + + public JunitLabeller(Context context) { + setZinggOption(ZinggOptions.LABEL); + setContext(context); + } + + @Override + public ZFrame processRecordsCli(ZFrame lines) + throws ZinggClientException { + + // now get a list of all those rows which have same cluster and match due to fname => mark match + ZFrame lines2 = getDSUtil().getPrefixedColumnsDS(lines); + + // construct AND condition + C clusterCond = getJoinCondForCol(lines, lines2, ColName.CLUSTER_COLUMN,true); + C fnameCond = getJoinCondForCol(lines, lines2, "FNAME",true); + C idCond = getJoinCondForCol(lines, lines2, "ID",false); + C filterCond = lines2.and(lines2.and(clusterCond,idCond),fnameCond); + + ZFrame filtered = lines.joinOnCol(lines2, filterCond).cache(); + + ZFrame matches = filtered.select(ColName.CLUSTER_COLUMN).distinct().withColumn(ColName.MATCH_FLAG_COL, ColValues.IS_MATCH_PREDICTION).cache(); + + ZFrame nonMatches = lines.select(ColName.CLUSTER_COLUMN).except(matches.select(ColName.CLUSTER_COLUMN)).distinct().withColumn(ColName.MATCH_FLAG_COL, ColValues.IS_NOT_A_MATCH_PREDICTION).cache(); + + ZFrame all = matches.unionAll(nonMatches); + + ZFrame linesMatched = lines; + linesMatched = linesMatched.drop(ColName.MATCH_FLAG_COL); + linesMatched = linesMatched.joinOnCol(all, ColName.CLUSTER_COLUMN); + linesMatched = linesMatched.select(lines.columns()); // make same order + + return linesMatched; + } + + private C getJoinCondForCol(ZFrame df1, ZFrame dfToJoin,String colName, boolean equal) { + C column = df1.col(colName); + C columnWithPrefix = dfToJoin.col(ColName.COL_PREFIX + colName); + C equalTo = df1.equalTo(column,columnWithPrefix); + if (equal) { + return equalTo; + } else { + return df1.not(equalTo); + } + } + + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/LabellerTester.java b/common/core/src/test/java/zingg/common/core/executor/LabellerTester.java new file mode 100644 index 000000000..d522a26b6 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/LabellerTester.java @@ -0,0 +1,36 @@ +package zingg.common.core.executor; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.ColName; + +public class LabellerTester extends ExecutorTester { + + public static final Log LOG = LogFactory.getLog(LabellerTester.class); + + public LabellerTester(Labeller executor) { + super(executor); + } + + @Override + public void validateResults() throws ZinggClientException { + // check that marked data has at least 1 match row and 1 unmatch row + ZFrame dfMarked = executor.getContext().getPipeUtil(). + read(false, false, executor.getContext().getPipeUtil().getTrainingDataMarkedPipe(executor.getArgs())); + + C matchCond = dfMarked.equalTo(ColName.MATCH_FLAG_COL, 1); + C notMatchCond = dfMarked.equalTo(ColName.MATCH_FLAG_COL, 0); + + long matchCount = dfMarked.filter(matchCond).count(); + assertTrue(matchCount > 1); + long unmatchCount = dfMarked.filter(notMatchCond).count(); + assertTrue(unmatchCount > 1); + LOG.info("matchCount : "+ matchCount + ", unmatchCount : " + unmatchCount); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/MatcherTester.java b/common/core/src/test/java/zingg/common/core/executor/MatcherTester.java new file mode 100644 index 000000000..24500fe3f --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/MatcherTester.java @@ -0,0 +1,78 @@ +package zingg.common.core.executor; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.ColName; + +public class MatcherTester extends ExecutorTester { + + public static final Log LOG = LogFactory.getLog(MatcherTester.class); + + public MatcherTester(Matcher executor) { + super(executor); + } + + @Override + public void validateResults() throws ZinggClientException { + assessAccuracy(); + } + + public String getClusterColName() { + return ColName.CLUSTER_COLUMN; + } + + protected void assessAccuracy() throws ZinggClientException { + ZFrame df = getOutputData(); + + df = df.withColumn("fnameId",df.concat(df.col("fname"), df.col("id"))); + df = df.select("fnameId", getClusterColName()); + df = df.withColumn("dupeFnameId",df.substr(df.col("fnameId"),0,8)).cache(); + ZFrame df1 = df.withColumnRenamed("fnameId", "fnameId1").withColumnRenamed("dupeFnameId", "dupeFnameId1") + .withColumnRenamed(getClusterColName(), getClusterColName() + "1").cache(); + + + ZFrame gold = joinAndFilter("dupeFnameId", df, df1).cache(); + ZFrame result = joinAndFilter(getClusterColName(), df, df1).cache(); + + ZFrame fn = gold.except(result); + ZFrame tp = gold.intersect(result); + ZFrame fp = result.except(gold); + + long fnCount = fn.count(); + long tpCount = tp.count(); + long fpCount = fp.count(); + double score1 = tpCount*1.0d/(tpCount+fpCount); + double score2 = tpCount*1.0d/(tpCount+fnCount); + + LOG.info("False negative " + fnCount); + LOG.info("True positive " + tpCount); + LOG.info("False positive " + fpCount); + LOG.info("precision " + score1); + LOG.info("recall " + tpCount + " denom " + (tpCount+fnCount) + " overall " + score2); + + System.out.println("precision score1 " + score1); + + System.out.println("recall score2 " + score2); + + assertTrue(0.8 <= score1); + assertTrue(0.8 <= score2); + } + + public ZFrame getOutputData() throws ZinggClientException { + ZFrame output = executor.getContext().getPipeUtil().read(false, false, executor.getArgs().getOutput()[0]); + return output; + } + + protected ZFrame joinAndFilter(String colName, ZFrame df, ZFrame df1){ + C col1 = df.col(colName); + C col2 = df1.col(colName+"1"); + ZFrame joined = df.joinOnCol(df1, df.equalTo(col1, col2)); + return joined.filter(joined.gt(joined.col("fnameId"), joined.col("fnameId1"))); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java b/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java new file mode 100644 index 000000000..6de3c9813 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/TestExecutorsGeneric.java @@ -0,0 +1,106 @@ +package zingg.common.core.executor; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; + +import zingg.common.client.ArgumentsUtil; +import zingg.common.client.IArguments; +import zingg.common.client.ZinggClientException; + +public abstract class TestExecutorsGeneric { + + public static final Log LOG = LogFactory.getLog(TestExecutorsGeneric.class); + + protected IArguments args; + + + protected S session; + + public TestExecutorsGeneric() { + + } + + public TestExecutorsGeneric(S s) throws ZinggClientException, IOException { + init(s); + } + + public void init(S s) throws ZinggClientException, IOException { + this.session = s; + // set up args + setupArgs(); + } + + public String setupArgs() throws ZinggClientException, IOException { + String configFile = getClass().getClassLoader().getResource(getConfigFile()).getFile(); + args = new ArgumentsUtil().createArgumentsFromJSON( + configFile, + "findTrainingData"); + return configFile; + } + + public abstract String getConfigFile(); + + + @Test + public void testExecutors() throws ZinggClientException { + List> executorTesterList = new ArrayList>(); + + TrainingDataFinder trainingDataFinder = getTrainingDataFinder(); + trainingDataFinder.init(args,session); + TrainingDataFinderTester tdft = new TrainingDataFinderTester(trainingDataFinder); + executorTesterList.add(tdft); + + Labeller labeller = getLabeller(); + labeller.init(args,session); + LabellerTester lt = new LabellerTester(labeller); + executorTesterList.add(lt); + + // training and labelling needed twice to get sufficient data + TrainingDataFinder trainingDataFinder2 = getTrainingDataFinder(); + trainingDataFinder2.init(args,session); + TrainingDataFinderTester tdft2 = new TrainingDataFinderTester(trainingDataFinder2); + executorTesterList.add(tdft2); + + Labeller labeller2 = getLabeller(); + labeller2.init(args,session); + LabellerTester lt2 = new LabellerTester(labeller2); + executorTesterList.add(lt2); + + Trainer trainer = getTrainer(); + trainer.init(args,session); + TrainerTester tt = getTrainerTester(trainer); + executorTesterList.add(tt); + + Matcher matcher = getMatcher(); + matcher.init(args,session); + MatcherTester mt = new MatcherTester(matcher); + executorTesterList.add(mt); + + testExecutors(executorTesterList); + } + + protected abstract TrainerTester getTrainerTester(Trainer trainer); + + + public void testExecutors(List> executorTesterList) throws ZinggClientException { + for (ExecutorTester executorTester : executorTesterList) { + executorTester.execute(); + executorTester.validateResults(); + } + } + + public abstract void tearDown(); + + protected abstract TrainingDataFinder getTrainingDataFinder() throws ZinggClientException; + + protected abstract Labeller getLabeller() throws ZinggClientException; + + protected abstract Trainer getTrainer() throws ZinggClientException; + + protected abstract Matcher getMatcher() throws ZinggClientException; + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/TrainerTester.java b/common/core/src/test/java/zingg/common/core/executor/TrainerTester.java new file mode 100644 index 000000000..b5f0cbbd9 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/TrainerTester.java @@ -0,0 +1,19 @@ +package zingg.common.core.executor; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.IArguments; + +public abstract class TrainerTester extends ExecutorTester { + + public static final Log LOG = LogFactory.getLog(TrainerTester.class); + + protected IArguments args; + + public TrainerTester(Trainer executor,IArguments args) { + super(executor); + this.args = args; + } + +} diff --git a/common/core/src/test/java/zingg/common/core/executor/TrainingDataFinderTester.java b/common/core/src/test/java/zingg/common/core/executor/TrainingDataFinderTester.java new file mode 100644 index 000000000..945be8ed0 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/executor/TrainingDataFinderTester.java @@ -0,0 +1,29 @@ +package zingg.common.core.executor; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; + +public class TrainingDataFinderTester extends ExecutorTester { + + public static final Log LOG = LogFactory.getLog(TrainingDataFinderTester.class); + + public TrainingDataFinderTester(TrainingDataFinder executor) { + super(executor); + } + + @Override + public void validateResults() throws ZinggClientException { + // check that unmarked data has at least 10 rows + ZFrame df = executor.getContext().getPipeUtil().read(false, false, executor.getContext().getPipeUtil().getTrainingDataUnmarkedPipe(executor.getArgs())); + + long trainingDataCount = df.count(); + assertTrue(trainingDataCount > 10); + LOG.info("trainingDataCount : "+ trainingDataCount); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/model/Event.java b/common/core/src/test/java/zingg/common/core/model/Event.java new file mode 100644 index 000000000..d4ef977bc --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/Event.java @@ -0,0 +1,8 @@ +package zingg.common.core.model; + +public class Event extends EventBase{ + + public Event(Integer id, Integer year, String event, String comment) { + super(id, year, event, comment); + } +} diff --git a/common/core/src/test/java/zingg/common/core/model/EventBase.java b/common/core/src/test/java/zingg/common/core/model/EventBase.java new file mode 100644 index 000000000..018fdf486 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/EventBase.java @@ -0,0 +1,15 @@ +package zingg.common.core.model; + +public class EventBase { + public final Integer id; + public final Integer year; + public final String event; + public final String comment; + + public EventBase(Integer id, Integer year, String event, String comment) { + this.id = id; + this.year = year; + this.event = event; + this.comment = comment; + } +} diff --git a/common/core/src/test/java/zingg/common/core/model/EventPair.java b/common/core/src/test/java/zingg/common/core/model/EventPair.java new file mode 100644 index 000000000..97be67d1f --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/EventPair.java @@ -0,0 +1,16 @@ +package zingg.common.core.model; + +public class EventPair extends EventBase{ + public final Integer z_year; + public final String z_event; + public final String z_comment; + public final Long z_zid; + + public EventPair(Integer id, Integer year, String event, String comment, Integer z_year, String z_event, String z_comment, Long z_zid) { + super(id, year, event, comment); + this.z_year = z_year; + this.z_event = z_event; + this.z_comment = z_comment; + this.z_zid = z_zid; + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/common/core/model/PostStopWordProcess.java b/common/core/src/test/java/zingg/common/core/model/PostStopWordProcess.java new file mode 100644 index 000000000..c137fa559 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/PostStopWordProcess.java @@ -0,0 +1,26 @@ +package zingg.common.core.model; + +public class PostStopWordProcess { + public final String z_cluster; + public final String z_zid; + public final String z_prediction; + public final String z_score; + public final String z_isMatch; + public final String field1; + public final String field2; + public final String field3; + public final String z_zsource; + + public PostStopWordProcess(String z_cluster, String z_zid, String z_prediction, String z_score, String z_isMatch, + String field1, String field2, String field3, String z_zsource) { + this.z_cluster = z_cluster; + this.z_zid = z_zid; + this.z_prediction = z_prediction; + this.z_score = z_score; + this.z_isMatch = z_isMatch; + this.field1 = field1; + this.field2 = field2; + this.field3 = field3; + this.z_zsource = z_zsource; + } +} diff --git a/common/core/src/test/java/zingg/common/core/model/PriorStopWordProcess.java b/common/core/src/test/java/zingg/common/core/model/PriorStopWordProcess.java new file mode 100644 index 000000000..502a8ef2b --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/PriorStopWordProcess.java @@ -0,0 +1,17 @@ +package zingg.common.core.model; + +public class PriorStopWordProcess { + public final String z_zid; + public final String field1; + public final String field2; + public final String field3; + public final String z_zsource; + + public PriorStopWordProcess(String z_zid, String field1, String field2, String field3, String z_zsource) { + this.z_zid = z_zid; + this.field1 = field1; + this.field2 = field2; + this.field3 = field3; + this.z_zsource = z_zsource; + } +} diff --git a/common/core/src/test/java/zingg/common/core/model/Statement.java b/common/core/src/test/java/zingg/common/core/model/Statement.java new file mode 100644 index 000000000..1fabf51ef --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/Statement.java @@ -0,0 +1,9 @@ +package zingg.common.core.model; + +public class Statement { + public final String statement; + + public Statement(String statement) { + this.statement = statement; + } +} diff --git a/common/core/src/test/java/zingg/common/core/model/TestModel.java b/common/core/src/test/java/zingg/common/core/model/TestModel.java new file mode 100644 index 000000000..4396897a5 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/model/TestModel.java @@ -0,0 +1,22 @@ +package zingg.common.core.model; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import org.junit.jupiter.api.Test; + +public class TestModel { + + @Test + public void testGetGrid() { + double[] result = Model.getGrid(1.0, 10.0, 2.0, false); + double[] expected = {1.0, 3.0, 5.0, 7.0, 9.0}; + assertArrayEquals(expected, result, 0.0); + } + + @Test + public void testGetGridForMultiples() { + double[] result = Model.getGrid(1.0, 10.0, 2.0, true); + double[] expected = {1.0, 2.0, 4.0, 8.0}; + assertArrayEquals(expected, result, 0.0); + } +} diff --git a/common/core/src/test/java/zingg/common/core/preprocess/TestStopWordsBase.java b/common/core/src/test/java/zingg/common/core/preprocess/TestStopWordsBase.java new file mode 100644 index 000000000..aff0fd439 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/preprocess/TestStopWordsBase.java @@ -0,0 +1,115 @@ +package zingg.common.core.preprocess; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.ColName; +import zingg.common.client.util.DFObjectUtil; +import zingg.common.core.context.Context; +import zingg.common.core.data.EventTestData; +import zingg.common.core.model.Statement; +import zingg.common.core.model.PostStopWordProcess; +import zingg.common.core.model.PriorStopWordProcess; +import zingg.common.core.util.StopWordRemoverUtility; + +public abstract class TestStopWordsBase { + + public static final Log LOG = LogFactory.getLog(TestStopWordsBase.class); + private final DFObjectUtil dfObjectUtil; + private final StopWordRemoverUtility stopWordRemoverUtility; + private final Context context; + + + public TestStopWordsBase(DFObjectUtil dfObjectUtil, StopWordRemoverUtility stopWordRemoverUtility, Context context) { + this.dfObjectUtil = dfObjectUtil; + this.stopWordRemoverUtility = stopWordRemoverUtility; + this.context = context; + } + + @DisplayName ("Test Stop Words removal from Single column dataset") + @Test + public void testStopWordsSingleColumn() throws ZinggClientException, Exception { + + List> stopWordsRemovers = getStopWordsRemovers(); + String stopWords = "\\b(a|an|the|is|It|of|yes|no|I|has|have|you)\\b\\s?".toLowerCase(); + + ZFrame zFrameOriginal = dfObjectUtil.getDFFromObjectList(EventTestData.getData1Original(), Statement.class); + ZFrame zFrameExpected = dfObjectUtil.getDFFromObjectList(EventTestData.getData1Expected(), Statement.class); + + StopWordsRemover stopWordsRemover = stopWordsRemovers.get(0); + + stopWordsRemover.preprocessForStopWords(zFrameOriginal); + ZFrame newZFrame = stopWordsRemover.removeStopWordsFromDF(zFrameOriginal,"statement",stopWords); + + assertTrue(zFrameExpected.except(newZFrame).isEmpty()); + assertTrue(newZFrame.except(zFrameExpected).isEmpty()); + } + + @Test + public void testRemoveStopWordsFromDataset() throws ZinggClientException, Exception { + + List> stopWordsRemovers = getStopWordsRemovers(); + ZFrame zFrameOriginal = dfObjectUtil.getDFFromObjectList(EventTestData.getData2Original(), PriorStopWordProcess.class); + ZFrame zFrameExpected = dfObjectUtil.getDFFromObjectList(EventTestData.getData2Expected(), PriorStopWordProcess.class); + + StopWordsRemover stopWordsRemover = stopWordsRemovers.get(1); + ZFrame newZFrame = stopWordsRemover.preprocessForStopWords(zFrameOriginal); + + assertTrue(zFrameExpected.except(newZFrame).isEmpty()); + assertTrue(newZFrame.except(zFrameExpected).isEmpty()); + } + + @Test + public void testStopWordColumnMissingFromStopWordFile() throws ZinggClientException, Exception { + + List> stopWordsRemovers = getStopWordsRemovers(); + + ZFrame zFrameOriginal = dfObjectUtil.getDFFromObjectList(EventTestData.getData3Original(), PriorStopWordProcess.class); + ZFrame zFrameExpected = dfObjectUtil.getDFFromObjectList(EventTestData.getData3Expected(), PriorStopWordProcess.class); + + StopWordsRemover stopWordsRemover = stopWordsRemovers.get(2); + ZFrame newZFrame = stopWordsRemover.preprocessForStopWords(zFrameOriginal); + + assertTrue(zFrameExpected.except(newZFrame).isEmpty()); + assertTrue(newZFrame.except(zFrameExpected).isEmpty()); + } + + + @Test + public void testForOriginalDataAfterPostProcess() throws Exception { + + ZFrame zFrameOriginal = dfObjectUtil.getDFFromObjectList(EventTestData.getData4original(), PriorStopWordProcess.class); + ZFrame zFrameExpected = dfObjectUtil.getDFFromObjectList(EventTestData.getData4Expected(), PostStopWordProcess.class); + + ZFrame newZFrame = context.getDSUtil().postprocess(zFrameExpected, zFrameOriginal); + + assertTrue(newZFrame.select(ColName.ID_COL, "field1", "field2", "field3", ColName.SOURCE_COL).except(zFrameOriginal).isEmpty()); + assertTrue(zFrameOriginal.except(newZFrame.select(ColName.ID_COL, "field1", "field2", "field3", ColName.SOURCE_COL)).isEmpty()); + } + + @Test + public void testOriginalDataAfterPostProcessLinked() throws Exception { + + ZFrame zFrameOriginal = dfObjectUtil.getDFFromObjectList(EventTestData.getData5Original(), PriorStopWordProcess.class); + ZFrame zFrameExpected = dfObjectUtil.getDFFromObjectList(EventTestData.getData5Actual(), PostStopWordProcess.class); + + ZFrame newZFrame = context.getDSUtil().postprocessLinked(zFrameExpected, zFrameOriginal); + + assertTrue(newZFrame.select("field1", "field2", "field3").except(zFrameOriginal.select("field1", "field2", "field3")).isEmpty()); + assertTrue(zFrameOriginal.select("field1", "field2", "field3").except(newZFrame.select("field1", "field2", "field3")).isEmpty()); + } + + private List> getStopWordsRemovers() throws ZinggClientException { + stopWordRemoverUtility.buildStopWordRemovers(); + return stopWordRemoverUtility.getStopWordsRemovers(); + } + +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionDate.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionDate.java new file mode 100644 index 000000000..c886b76ed --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionDate.java @@ -0,0 +1,34 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Date; + +import org.junit.jupiter.api.Test; +public class TestCheckNullFunctionDate { + + @Test + public void testFirstNull() { + assertEquals(0d, simFunc().call(null, new Date(2))); + } + + @Test + public void testSecondNull() { + assertEquals(0d, simFunc().call(new Date(1), null)); + } + + @Test + public void testBothNull() { + assertEquals(0d, simFunc().call(null, null)); + } + + @Test + public void testBothNotNull() { + assertEquals(1d, simFunc().call(new Date(1), new Date(2))); + } + + protected CheckNullFunction simFunc() { + return new CheckNullFunction("CheckNullFunctionDate"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionInt.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionInt.java new file mode 100644 index 000000000..144fc5fa6 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionInt.java @@ -0,0 +1,34 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TestCheckNullFunctionInt { + + + @Test + public void testFirstNull() { + assertEquals(0d, simFunc().call(null, 2)); + } + + @Test + public void testSecondNull() { + assertEquals(0d, simFunc().call(1, null)); + } + + @Test + public void testBothNull() { + assertEquals(0d, simFunc().call(null, null)); + } + + @Test + public void testBothNotNull() { + assertEquals(1d, simFunc().call(1, 2)); + } + + protected CheckNullFunction simFunc() { + return new CheckNullFunction("CheckNullFunctionInt"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionLong.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionLong.java new file mode 100644 index 000000000..a7712d074 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestCheckNullFunctionLong.java @@ -0,0 +1,34 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TestCheckNullFunctionLong { + + + @Test + public void testFirstNull() { + assertEquals(0d, simFunc().call(null, 2l)); + } + + @Test + public void testSecondNull() { + assertEquals(0d, simFunc().call(1l, null)); + } + + @Test + public void testBothNull() { + assertEquals(0d, simFunc().call(null, null)); + } + + @Test + public void testBothNotNull() { + assertEquals(1d, simFunc().call(1l, 2l)); + } + + protected CheckNullFunction simFunc() { + return new CheckNullFunction("CheckNullFunctionLong"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestDateSimilarityFunctionExact.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestDateSimilarityFunctionExact.java new file mode 100644 index 000000000..56f815291 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestDateSimilarityFunctionExact.java @@ -0,0 +1,42 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Date; + +import org.junit.jupiter.api.Test; + +public class TestDateSimilarityFunctionExact { + + + @Test + public void testFirstNull() { + assertEquals(1d, simFunc().call(null, new Date(2))); + } + + + @Test + public void testSecondNull() { + assertEquals(1d, simFunc().call(new Date(1), null)); + } + + @Test + public void testBothNull() { + assertEquals(1d, simFunc().call(null, null)); + } + + @Test + public void testNotEqual() { + assertEquals(0d, simFunc().call(new Date(101), new Date(102))); + } + + @Test + public void testEqual() { + assertEquals(1d, simFunc().call(new Date(101), new Date(101))); + } + + protected SimilarityFunctionExact simFunc() { + return new SimilarityFunctionExact("DateSimilarityFunctionExact"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestIntegerSimilarityFunctionExact.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestIntegerSimilarityFunctionExact.java new file mode 100644 index 000000000..37e1b415b --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestIntegerSimilarityFunctionExact.java @@ -0,0 +1,38 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TestIntegerSimilarityFunctionExact { + + @Test + public void testFirstNull() { + assertEquals(1d, simFunc().call(null, 2)); + } + + @Test + public void testSecondNull() { + assertEquals(1d, simFunc().call(1, null)); + } + + @Test + public void testBothNull() { + assertEquals(1d, simFunc().call(null, null)); + } + + @Test + public void testNotEqual() { + assertEquals(0d, simFunc().call(101, 102)); + } + + @Test + public void testEqual() { + assertEquals(1d, simFunc().call(101, 101)); + } + + protected SimilarityFunctionExact simFunc() { + return new SimilarityFunctionExact("IntegerSimilarityFunctionExact"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/similarity/function/TestLongSimilarityFunctionExact.java b/common/core/src/test/java/zingg/common/core/similarity/function/TestLongSimilarityFunctionExact.java new file mode 100644 index 000000000..ee8808259 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/similarity/function/TestLongSimilarityFunctionExact.java @@ -0,0 +1,39 @@ +package zingg.common.core.similarity.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TestLongSimilarityFunctionExact { + + + @Test + public void testFirstNull() { + assertEquals(1d, simFunc().call(null, 2l)); + } + + @Test + public void testSecondNull() { + assertEquals(1d, simFunc().call(1l, null)); + } + + @Test + public void testBothNull() { + assertEquals(1d, simFunc().call(null, null)); + } + + @Test + public void testNotEqual() { + assertEquals(0d, simFunc().call(101l, 102l)); + } + + @Test + public void testEqual() { + assertEquals(1d, simFunc().call(101l, 101l)); + } + + protected SimilarityFunctionExact simFunc() { + return new SimilarityFunctionExact("LongSimilarityFunctionExact"); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/sink/TestTableOutput.java b/common/core/src/test/java/zingg/common/core/sink/TestTableOutput.java new file mode 100644 index 000000000..97f2fc34e --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/sink/TestTableOutput.java @@ -0,0 +1,45 @@ +package zingg.common.core.sink; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.sink.TableOutput; + + +public class TestTableOutput { + + private TableOutput getInstance() { + return new TableOutput(3, 234456L, 87654L, "Company X"); + } + + @Test + public void testGetMethods() { + String ans = "Company X"; + TableOutput value = getInstance(); + assertEquals(3, value.getJobId()); + assertEquals(234456L, value.getTimestamp()); + assertEquals(87654L, value.getClusterId()); + assertEquals(ans, value.getRecord()); + } + + @Test + public void testSetMethods() { + TableOutput value = getInstance(); + int newJobId = 5; + long newTimestamp = 778899L; + long newClusterId = 9876L; + String newRecord = "Company Y"; + + value.setJobId(newJobId); + value.setTimestamp(newTimestamp); + value.setClusterId(newClusterId); + value.setRecord(newRecord); + + assertEquals(5, value.getJobId()); + assertEquals(778899L, value.getTimestamp()); + assertEquals(9876L, value.getClusterId()); + assertEquals(newRecord, value.getRecord()); + } + +} diff --git a/common/core/src/test/java/zingg/common/core/util/CsvReader.java b/common/core/src/test/java/zingg/common/core/util/CsvReader.java new file mode 100644 index 000000000..c700d6fe2 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/util/CsvReader.java @@ -0,0 +1,28 @@ +package zingg.common.core.util; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.ArrayList; +import java.util.List; +import java.util.Scanner; + +public class CsvReader { + protected List records; + IFromCsv creator; + + public CsvReader(IFromCsv creator){ + records = new ArrayList(); + this.creator = creator; + } + + public List getRecords(String file, boolean skipHeader) throws FileNotFoundException{ + int lineno = 0; + try (Scanner scanner = new Scanner(new File(file))) { + while (scanner.hasNextLine()) { + records.add(creator.fromCsv(scanner.nextLine())); + } + } + return records; + } + +} diff --git a/common/core/src/test/java/zingg/common/core/util/IFromCsv.java b/common/core/src/test/java/zingg/common/core/util/IFromCsv.java new file mode 100644 index 000000000..574da836b --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/util/IFromCsv.java @@ -0,0 +1,7 @@ +package zingg.common.core.util; + +public interface IFromCsv { + + C fromCsv(String s); + +} diff --git a/common/infra/src/main/java/zingg/common/infra/util/PojoToArrayConverter.java b/common/core/src/test/java/zingg/common/core/util/PojoToArrayConverter.java similarity index 96% rename from common/infra/src/main/java/zingg/common/infra/util/PojoToArrayConverter.java rename to common/core/src/test/java/zingg/common/core/util/PojoToArrayConverter.java index a3e04b4b0..e1a0ccf80 100644 --- a/common/infra/src/main/java/zingg/common/infra/util/PojoToArrayConverter.java +++ b/common/core/src/test/java/zingg/common/core/util/PojoToArrayConverter.java @@ -1,4 +1,4 @@ -package zingg.common.infraForTest.util; +package zingg.common.core.util; import java.lang.reflect.*; import java.security.NoSuchAlgorithmException; diff --git a/common/core/src/test/java/zingg/common/core/util/StopWordRemoverUtility.java b/common/core/src/test/java/zingg/common/core/util/StopWordRemoverUtility.java new file mode 100644 index 000000000..611c36700 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/util/StopWordRemoverUtility.java @@ -0,0 +1,65 @@ +package zingg.common.core.util; + +import zingg.common.client.Arguments; +import zingg.common.client.FieldDefinition; +import zingg.common.client.IArguments; +import zingg.common.client.MatchType; +import zingg.common.client.ZinggClientException; +import zingg.common.core.preprocess.StopWordsRemover; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public abstract class StopWordRemoverUtility { + + protected final List> stopWordsRemovers; + + public StopWordRemoverUtility() throws ZinggClientException { + this.stopWordsRemovers = new ArrayList>();; + } + + public void buildStopWordRemovers() throws ZinggClientException { + + //add first stopWordRemover + List fdList = new ArrayList(4); + ArrayList matchTypelistFuzzy = new ArrayList(); + matchTypelistFuzzy.add(MatchType.FUZZY); + FieldDefinition eventFD = new FieldDefinition(); + eventFD.setDataType("string"); + eventFD.setFieldName("statement"); + eventFD.setMatchType(matchTypelistFuzzy); + fdList.add(eventFD); + IArguments stmtArgs = new Arguments(); + stmtArgs.setFieldDefinition(fdList); + addStopWordRemover(stmtArgs); + + //add second stopWordRemover + String stopWordsFileName1 = Objects.requireNonNull( + StopWordRemoverUtility.class.getResource("../../../../preProcess/stopWords.csv")).getFile(); + FieldDefinition fieldDefinition1 = new FieldDefinition(); + fieldDefinition1.setStopWords(stopWordsFileName1); + fieldDefinition1.setFieldName("field1"); + List fieldDefinitionList1 = List.of(fieldDefinition1); + stmtArgs = new Arguments(); + stmtArgs.setFieldDefinition(fieldDefinitionList1); + addStopWordRemover(stmtArgs); + + //add third stopWordRemover + String stopWordsFileName2 = Objects.requireNonNull( + StopWordRemoverUtility.class.getResource("../../../../preProcess/stopWordsWithoutHeader.csv")).getFile(); + FieldDefinition fieldDefinition2 = new FieldDefinition(); + fieldDefinition2.setStopWords(stopWordsFileName2); + fieldDefinition2.setFieldName("field1"); + List fieldDefinitionList2 = List.of(fieldDefinition2); + stmtArgs = new Arguments(); + stmtArgs.setFieldDefinition(fieldDefinitionList2); + addStopWordRemover(stmtArgs); + } + + public List> getStopWordsRemovers() { + return this.stopWordsRemovers; + } + + public abstract void addStopWordRemover(IArguments iArguments); +} diff --git a/common/core/src/test/java/zingg/common/core/zFrame/TestZFrameBase.java b/common/core/src/test/java/zingg/common/core/zFrame/TestZFrameBase.java new file mode 100644 index 000000000..b4bbbb2d9 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/TestZFrameBase.java @@ -0,0 +1,644 @@ +package zingg.common.core.zFrame; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import zingg.common.client.ZFrame; +import zingg.common.client.util.ColName; +import zingg.common.client.util.DFObjectUtil; +import zingg.common.core.zFrame.model.ClusterZScore; +import zingg.common.core.zFrame.model.InputWithZidAndSource; +import zingg.common.core.zFrame.model.PairPartOne; +import zingg.common.core.zFrame.model.PairPartTwo; +import zingg.common.core.zFrame.model.Person; +import zingg.common.core.zFrame.model.PersonMixed; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static zingg.common.core.zFrame.data.TestData.createEmptySampleData; +import static zingg.common.core.zFrame.data.TestData.createSampleDataCluster; +import static zingg.common.core.zFrame.data.TestData.createSampleDataClusterWithNull; +import static zingg.common.core.zFrame.data.TestData.createSampleDataInput; +import static zingg.common.core.zFrame.data.TestData.createSampleDataList; +import static zingg.common.core.zFrame.data.TestData.createSampleDataListDistinct; +import static zingg.common.core.zFrame.data.TestData.createSampleDataListWithDistinctSurnameAndPostcode; +import static zingg.common.core.zFrame.data.TestData.createSampleDataListWithMixedDataType; +import static zingg.common.core.zFrame.data.TestData.createSampleDataZScore; + +public abstract class TestZFrameBase { + + public static final Log LOG = LogFactory.getLog(TestZFrameBase.class); + public static final String NEW_COLUMN = "newColumn"; + public static final String STR_RECID = "recid"; + private final DFObjectUtil dfObjectUtil; + + public TestZFrameBase(DFObjectUtil dfObjectUtil) { + this.dfObjectUtil = dfObjectUtil; + } + + @Test + public void testAliasOfZFrame() throws Exception { + List sampleDataSet = createSampleDataList(); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + String aliasName = "AnotherName"; + zFrame.as(aliasName); + assertTrueCheckingExceptOutput(zFrame.as(aliasName), zFrame, "Dataframe and its alias are not same"); + } + + + @Test + public void testCreateZFrameAndGetDF() throws Exception { + List sampleDataSet = createSampleDataList(); + + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + //assert rows + List rows = zFrame.collectAsList(); + List fields = List.of(Person.class.getDeclaredFields()); + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + for (Field column : fields) { + String columnName = column.getName(); + assertEquals(column.get(sampleDataSet.get(idx)).toString(), zFrame.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } + } + } + + @Test + public void testColumnsNamesAndCount() throws Exception { + List sampleDataSet = createSampleDataList(); + + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + //assert on fields + List fieldsInTestData = new ArrayList(); + List fieldsInZFrame = new ArrayList(); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + Arrays.stream(zFrame.fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + assertEquals(fieldsInTestData, fieldsInZFrame, + "Columns of sample data and zFrame are not equal"); + } + + @Test + public void testSelectWithSingleColumnName() throws Exception { + List sampleDataSet = createSampleDataList(); //List + + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + String colName = "recid"; + List rows = zFrame.select(colName).collectAsList(); + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + assertEquals(sampleDataSet.get(idx).recid, zFrame.getAsString(row, colName), + "value in ZFrame and sample input is not same"); + } + } + + /* + list of string can not be cast to list of C + zFrame select does not have an interface method for List + */ + @Disabled + @Test + public void testSelectWithColumnList() throws Exception { + List sampleDataSet = createSampleDataList(); //List + + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + List columnList = (List) Arrays.asList("recid", "surname", "postcode"); + List rows = zFrame.select(columnList).collectAsList(); + + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + Assertions.assertEquals(zFrame.getAsString(row, "recid"), sampleDataSet.get(idx).recid, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "surname"), sampleDataSet.get(idx).surname, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "postcode"), sampleDataSet.get(idx).postcode, + "value from zFrame and sampleData doesn't match"); + } + } + + /* + string can not be cast to C + zFrame doesn't have an interface method for C[] + */ + @Disabled + @Test + public void testSelectWithColumnArray() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + C[] columnArray = (C[]) new Object[3]; + columnArray[0] = (C) "recid"; + columnArray[1] = (C) "surname"; + columnArray[2] = (C) "postcode"; + + List rows = zFrame.select(columnArray).collectAsList(); + + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + Assertions.assertEquals(zFrame.getAsString(row, "recid"), sampleDataSet.get(idx).recid, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "surname"), sampleDataSet.get(idx).surname, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "postcode"), sampleDataSet.get(idx).postcode, + "value from zFrame and sampleData doesn't match"); + } + } + + @Test + public void testSelectWithMultipleColumnNamesAsString() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + List rows = zFrame.select("recid", "surname", "postcode").collectAsList(); + + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + Assertions.assertEquals(zFrame.getAsString(row, "recid"), sampleDataSet.get(idx).recid, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "surname"), sampleDataSet.get(idx).surname, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "postcode"), sampleDataSet.get(idx).postcode, + "value from zFrame and sampleData doesn't match"); + } + } + + @Test + public void testSelectExprByPassingColumnStringsAsInSQLStatement() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + List rows = zFrame.selectExpr("recid as RecordId", "surname as FamilyName", + "postcode as Pin").collectAsList(); + + for (int idx = 0; idx < sampleDataSet.size(); idx++) { + R row = rows.get(idx); + Assertions.assertEquals(zFrame.getAsString(row, "RecordId"), sampleDataSet.get(idx).recid, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "FamilyName"), sampleDataSet.get(idx).surname, + "value from zFrame and sampleData doesn't match"); + Assertions.assertEquals(zFrame.getAsString(row, "Pin"), sampleDataSet.get(idx).postcode, + "value from zFrame and sampleData doesn't match"); + } + } + + @Test + public void testDropSingleColumn() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + List fieldsInZFrame = new ArrayList(); + List fieldsInTestData = new ArrayList(); + Arrays.stream(zFrame.drop("recid").fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.remove("recid"); + + assertEquals(fieldsInTestData, fieldsInZFrame, "Fields in zFrame and sample data doesn't match"); + } + + @Test + public void testDropColumnsAsStringArray() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + List fieldsInZFrame = new ArrayList(); + List fieldsInTestData = new ArrayList(); + Arrays.stream(zFrame.drop("recid", "surname", "postcode").fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.remove("recid"); + fieldsInTestData.remove("surname"); + fieldsInTestData.remove("postcode"); + + assertEquals(fieldsInTestData, fieldsInZFrame, + "Fields in zFrame and sample data doesn't match"); + } + + @Test + public void testLimit() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + int len = 5; + List rows = zFrame.limit(len).collectAsList(); + + assertEquals(rows.size(), len, "Size is not equal"); + + //assert on rows + List fields = List.of(Person.class.getDeclaredFields()); + for (int idx = 0; idx < len; idx++) { + R row = rows.get(idx); + for (Field column : fields) { + String columnName = column.getName(); + assertEquals(column.get(sampleDataSet.get(idx)).toString(), zFrame.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } + } + } + + @Test + public void testHead() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + R row = zFrame.head(); + List fields = List.of(Person.class.getDeclaredFields()); + for (Field column : fields) { + String columnName = column.getName(); + assertEquals(column.get(sampleDataSet.get(0)).toString(), zFrame.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } + } + + @Test + public void testGetAsInt() throws Exception { + List sampleDataSet = createSampleDataListWithMixedDataType(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, PersonMixed.class); + + R row = zFrame.head(); + assertTrue(zFrame.getAsInt(row, "recid") == sampleDataSet.get(0).recid, + "row.getAsInt(col) hasn't returned correct int value"); + } + + @Test + public void testGetAsString() throws Exception { + List sampleDataSet = createSampleDataListWithMixedDataType(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, PersonMixed.class); + + R row = zFrame.head(); + assertEquals(zFrame.getAsString(row, "surname"), sampleDataSet.get(0).surname, + "row.getAsString(col) hasn't returned correct string value"); + } + + @Test + public void testGetAsDouble() throws Exception { + List sampleDataSet = createSampleDataListWithMixedDataType(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, PersonMixed.class); + + R row = zFrame.head(); + assertEquals(zFrame.getAsDouble(row, "cost"), sampleDataSet.get(0).cost, + "row.getAsDouble(col) hasn't returned correct double value"); + } + + @Test + public void testWithColumnForIntegerValue() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + + String newCol = NEW_COLUMN; + int newColVal = 36; + ZFrame zFrameWithAddedColumn = zFrame.withColumn(newCol, newColVal); + + List fieldsInTestData = new ArrayList(); + List fieldsInZFrame = new ArrayList(); + Arrays.stream(zFrameWithAddedColumn.fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.add(newCol); + + //Assert on columns + assertEquals(fieldsInTestData, fieldsInZFrame, + "Columns of sample data and zFrame are not equal"); + + //Assert on first row + R row = zFrameWithAddedColumn.head(); + Assertions.assertEquals(zFrame.getAsInt(row, newCol), Integer.valueOf(newColVal), + "value of added column is not as expected"); + } + + @Test + public void testWithColumnForDoubleValue() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + String newCol = NEW_COLUMN; + double newColVal = 3.14; + ZFrame zFrameWithAddedColumn = zFrame.withColumn(newCol, newColVal); + + List fieldsInTestData = new ArrayList(); + List fieldsInZFrame = new ArrayList(); + Arrays.stream(zFrameWithAddedColumn.fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.add(newCol); + + //Assert on columns + assertEquals(fieldsInTestData, fieldsInZFrame, + "Columns of sample data and zFrame are not equal"); + + //Assert on first row + R row = zFrameWithAddedColumn.head(); + Assertions.assertEquals(zFrame.getAsDouble(row, newCol), Double.valueOf(newColVal), + "value of added column is not as expected"); + } + + @Test + public void testWithColumnForStringValue() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + String newCol = NEW_COLUMN; + String newColVal = "zingg"; + ZFrame zFrameWithAddedColumn = zFrame.withColumn(newCol, newColVal); + + List fieldsInTestData = new ArrayList(); + List fieldsInZFrame = new ArrayList(); + Arrays.stream(zFrameWithAddedColumn.fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.add(newCol); + + //Assert on columns + assertEquals(fieldsInTestData, fieldsInZFrame, + "Columns of sample data and zFrame are not equal"); + + //Assert on first row + R row = zFrameWithAddedColumn.head(); + Assertions.assertEquals(zFrame.getAsString(row, newCol), newColVal, + "value of added column is not as expected"); + } + + @Test + public void testWithColumnForAnotherColumn() throws Exception { + List sampleDataSet = createSampleDataList(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, Person.class); + String oldCol = STR_RECID; + String newCol = NEW_COLUMN; + ZFrame zFrameWithAddedColumn = zFrame.withColumn(newCol, zFrame.col(oldCol)); + + List fieldsInTestData = new ArrayList(); + List fieldsInZFrame = new ArrayList(); + Arrays.stream(zFrameWithAddedColumn.fields()).iterator().forEachRemaining(fieldZ -> fieldsInZFrame.add(fieldZ.getName())); + Arrays.stream(Person.class.getFields()).sequential().forEach(fieldS -> fieldsInTestData.add(fieldS.getName())); + fieldsInTestData.add(newCol); + + //Assert on columns + assertEquals(fieldsInTestData, fieldsInZFrame, + "Columns of sample data and zFrame are not equal"); + + //Assert on first row + R row = zFrameWithAddedColumn.head(); + assertEquals(Optional.of(zFrameWithAddedColumn.getAsString(row, newCol)), Optional.of(zFrameWithAddedColumn.getAsString(row, oldCol)), + "value of added column is not as expected"); + } + + @Test + public void testGetMaxVal() throws Exception { + List sampleDataSet = createSampleDataZScore(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, ClusterZScore.class); + + assertEquals("400", zFrame.getMaxVal(ColName.CLUSTER_COLUMN), + "Max value is not as expected"); + } + + @Test + public void testGroupByMinMax() throws Exception { + List sampleDataSet = createSampleDataZScore(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, ClusterZScore.class); + + ZFrame groupByDF = zFrame.groupByMinMaxScore(zFrame.col(ColName.ID_COL)); + + List assertionRows = groupByDF.collectAsList(); + for (R row : assertionRows) { + if (groupByDF.getAsLong(row, "z_zid") == 1.0) { + assertEquals(1001.0, groupByDF.getAsDouble(row, "z_minScore"), + "z_minScore is not as expected"); + assertEquals(2002.0, groupByDF.getAsDouble(row, "z_maxScore"), + "z_maxScore is not as expected"); + } + } + } + + @Test + public void testGroupByMinMax2() throws Exception { + List sampleDataSet = createSampleDataZScore(); //List + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSet, ClusterZScore.class); + + ZFrame groupByDF = zFrame.groupByMinMaxScore(zFrame.col(ColName.CLUSTER_COLUMN)); + + List assertionRows = groupByDF.collectAsList(); + for (R row : assertionRows) { + if ("100".equals(groupByDF.getAsString(row, "z_cluster"))) { + assertEquals(900.0, groupByDF.getAsDouble(row, "z_minScore"), + "z_minScore is not as expected"); + assertEquals(9002.0, groupByDF.getAsDouble(row, "z_maxScore"), + "z_maxScore is not as expected"); + } + } + } + + @Test + public void testRightJoinMultiCol() throws Exception { + List sampleDataSetInput = createSampleDataInput(); //List + ZFrame zFrameInput = dfObjectUtil.getDFFromObjectList(sampleDataSetInput, InputWithZidAndSource.class); + List sampleDataSetCluster = createSampleDataCluster(); //List + ZFrame zFrameCluster = dfObjectUtil.getDFFromObjectList(sampleDataSetCluster, PairPartOne.class); + + ZFrame joinedData = zFrameCluster.join(zFrameInput, ColName.ID_COL, ColName.SOURCE_COL, ZFrame.RIGHT_JOIN); + assertEquals(10, joinedData.count()); + } + + @Test + public void testFilterInCond() throws Exception { + List sampleDataSetInput = createSampleDataInput(); //List + ZFrame zFrameInput = dfObjectUtil.getDFFromObjectList(sampleDataSetInput, InputWithZidAndSource.class); + List sampleDataSetCluster = createSampleDataClusterWithNull(); //List + ZFrame zFrameCluster = dfObjectUtil.getDFFromObjectList(sampleDataSetCluster, PairPartTwo.class); + ZFrame filteredData = zFrameInput.filterInCond(ColName.ID_COL, zFrameCluster, ColName.COL_PREFIX + ColName.ID_COL); + assertEquals(5, filteredData.count()); + } + + @Test + public void testFilterNotNullCond() throws Exception { + List sampleDataSetCluster = createSampleDataClusterWithNull(); //List + ZFrame zFrameCluster = dfObjectUtil.getDFFromObjectList(sampleDataSetCluster, PairPartTwo.class); + + ZFrame filteredData = zFrameCluster.filterNotNullCond(ColName.SOURCE_COL); + assertEquals(3, filteredData.count()); + } + + @Test + public void testFilterNullCond() throws Exception { + List sampleDataSetCluster = createSampleDataClusterWithNull(); //List + ZFrame zFrameCluster = dfObjectUtil.getDFFromObjectList(sampleDataSetCluster, PairPartTwo.class); + + ZFrame filteredData = zFrameCluster.filterNullCond(ColName.SOURCE_COL); + assertEquals(2, filteredData.count()); + } + + @Test + public void testDropDuplicatesConsideringGivenColumnsAsStringArray() throws Exception { + List sampleData = createSampleDataList(); + List sampleDataWithDistinctSurnameAndPostCode = createSampleDataListWithDistinctSurnameAndPostcode(); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleData, Person.class); + + String[] columnArray = new String[]{"surname", "postcode"}; + ZFrame zFrameDeDuplicated = zFrame.dropDuplicates(columnArray); + + List rows = zFrameDeDuplicated.collectAsList(); + + List fields = List.of(Person.class.getDeclaredFields()); + int matchedCount = 0; + for (Person schema : sampleDataWithDistinctSurnameAndPostCode) { + for (R row : rows) { + boolean rowMatched = true; + for (Field column : fields) { + String columnName = column.getName(); + if (!column.get(schema).toString(). + equals(zFrame.getAsString(row, columnName))) { + rowMatched = false; + break; + } + } + if (rowMatched) { + matchedCount++; + break; + } + } + } + + + assertEquals(rows.size(), matchedCount, + "rows count is not as expected"); + assertEquals(sampleDataWithDistinctSurnameAndPostCode.size(), matchedCount, + "rows count is not as expected"); + } + + @Test + public void testDropDuplicatesConsideringGivenIndividualColumnsAsString() throws Exception { + List sampleDataSetCluster = createSampleDataList(); + List sampleDataWithDistinctSurnameAndPostCode = createSampleDataListWithDistinctSurnameAndPostcode(); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleDataSetCluster, Person.class); + ZFrame zFrameDeDuplicated = zFrame.dropDuplicates("surname", "postcode"); + + List rows = zFrameDeDuplicated.collectAsList(); + List fields = List.of(Person.class.getDeclaredFields()); + int matchedCount = 0; + for (Person person : sampleDataWithDistinctSurnameAndPostCode) { + for (R row : rows) { + boolean rowMatched = true; + for (Field column : fields) { + String columnName = column.getName(); + if (!column.get(person).toString(). + equals(zFrame.getAsString(row, columnName))) { + rowMatched = false; + break; + } + } + if (rowMatched) { + matchedCount++; + break; + } + } + } + + + assertEquals(rows.size(), matchedCount, + "rows count is not as expected"); + assertEquals(sampleDataWithDistinctSurnameAndPostCode.size(), matchedCount, + "rows count is not as expected"); + } + + @Test + public void testSortDescending() throws Exception { + List sampleData = createSampleDataListWithMixedDataType(); + sampleData.sort((a, b) -> a.recid > b.recid ? -1 : 1); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleData, PersonMixed.class); + + String col = STR_RECID; + ZFrame zFrameSortedDesc = zFrame.sortDescending(col); + List rows = zFrameSortedDesc.collectAsList(); + + List fields = List.of(PersonMixed.class.getDeclaredFields()); + for (int idx = 0; idx < sampleData.size(); idx++) { + R row = rows.get(idx); + for (Field column : fields) { + String columnName = column.getName(); + if (column.getType() == String.class) { + assertEquals(column.get(sampleData.get(idx)), zFrameSortedDesc.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Integer.class) { + assertEquals(column.get(sampleData.get(idx)), zFrameSortedDesc.getAsInt(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Double.class) { + assertEquals(column.get(sampleData.get(idx)), zFrameSortedDesc.getAsDouble(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Long.class) { + assertEquals(column.get(sampleData.get(idx)), zFrameSortedDesc.getAsLong(row, columnName), + "value in ZFrame and sample input is not same"); + } else { + throw new Exception("Not a valid data type"); + } + } + } + } + + @Test + public void testSortAscending() throws Exception { + List sampleData = createSampleDataListWithMixedDataType(); + sampleData.sort((a, b) -> a.recid < b.recid ? -1 : 1); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleData, PersonMixed.class); + + String col = STR_RECID; + ZFrame zFrameSortedAsc = zFrame.sortAscending(col); + List rows = zFrameSortedAsc.collectAsList(); + + List fields = List.of(PersonMixed.class.getDeclaredFields()); + for (int idx = 0; idx < sampleData.size(); idx++) { + R row = rows.get(idx); + for (Field column : fields) { + String columnName = column.getName(); + if (column.getType() == String.class) { + assertEquals(column.get(sampleData.get(idx)).toString(), zFrame.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Integer.class) { + assertEquals(column.get(sampleData.get(idx)), zFrame.getAsInt(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Double.class) { + assertEquals(column.get(sampleData.get(idx)), zFrame.getAsDouble(row, columnName), + "value in ZFrame and sample input is not same"); + } else if (column.getType() == Long.class) { + assertEquals(column.get(sampleData.get(idx)), zFrame.getAsLong(row, columnName), + "value in ZFrame and sample input is not same"); + } else { + throw new Exception("Not a valid data type"); + } + } + } + } + + @Test + public void testIsEmpty() throws Exception { + List emptySampleData = createEmptySampleData(); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(emptySampleData, Person.class); + + assertTrue(zFrame.isEmpty(), "zFrame is not empty"); + } + + @Test + public void testDistinct() throws Exception { + List sampleData = createSampleDataList(); + List sampleDataDistinct = createSampleDataListDistinct(); + ZFrame zFrame = dfObjectUtil.getDFFromObjectList(sampleData, Person.class); + + List rows = zFrame.distinct().collectAsList(); + + List fields = List.of(Person.class.getDeclaredFields()); + for (int idx = 0; idx < sampleDataDistinct.size(); idx++) { + R row = rows.get(idx); + for (Field column : fields) { + String columnName = column.getName(); + assertEquals(column.get(sampleDataDistinct.get(idx)).toString(), zFrame.getAsString(row, columnName), + "value in ZFrame and sample input is not same"); + } + } + } + + protected void assertTrueCheckingExceptOutput(ZFrame sf1, ZFrame sf2, String message) { + assertTrue(sf1.except(sf2).isEmpty(), message); + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/common/core/zFrame/data/TestData.java b/common/core/src/test/java/zingg/common/core/zFrame/data/TestData.java new file mode 100644 index 000000000..3bf63f254 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/data/TestData.java @@ -0,0 +1,149 @@ +package zingg.common.core.zFrame.data; + + +import zingg.common.core.zFrame.model.ClusterZScore; +import zingg.common.core.zFrame.model.InputWithZidAndSource; +import zingg.common.core.zFrame.model.PairPartOne; +import zingg.common.core.zFrame.model.PairPartTwo; +import zingg.common.core.zFrame.model.Person; +import zingg.common.core.zFrame.model.PersonMixed; + +import java.util.ArrayList; +import java.util.List; + +public class TestData { + + //sample data classes to be used for testing + public static List createEmptySampleData() { + + return new ArrayList(); + } + + public static List createSampleDataList() { + List sample = new ArrayList(); + sample.add(new Person("07317257", "erjc", "henson", "hendersonville", "2873g")); + sample.add(new Person("07317257", "erjc", "henson", "hendersonville", "2873g")); + sample.add(new Person("03102490", "jhon", "kozak", "henders0nville", "28792")); + sample.add(new Person("02890805", "david", "pisczek", "durham", "27717")); + sample.add(new Person("04437063", "e5in", "bbrown", "greenville", "27858")); + sample.add(new Person("03211564", "susan", "jones", "greenjboro", "274o7")); + + sample.add(new Person("04155808", "jerome", "wilkins", "battleborn", "2780g")); + sample.add(new Person("05723231", "clarinw", "pastoreus", "elizabeth city", "27909")); + sample.add(new Person("06087743", "william", "craven", "greenshoro", "27405")); + sample.add(new Person("00538491", "marh", "jackdon", "greensboro", "27406")); + sample.add(new Person("01306702", "vonnell", "palmer", "siler sity", "273q4")); + + return sample; + } + + public static List createSampleDataListDistinct() { + List sample = new ArrayList(); + sample.add(new Person("07317257", "erjc", "henson", "hendersonville", "2873g")); + sample.add(new Person("03102490", "jhon", "kozak", "henders0nville", "28792")); + sample.add(new Person("02890805", "david", "pisczek", "durham", "27717")); + sample.add(new Person("04437063", "e5in", "bbrown", "greenville", "27858")); + sample.add(new Person("03211564", "susan", "jones", "greenjboro", "274o7")); + + sample.add(new Person("04155808", "jerome", "wilkins", "battleborn", "2780g")); + sample.add(new Person("05723231", "clarinw", "pastoreus", "elizabeth city", "27909")); + sample.add(new Person("06087743", "william", "craven", "greenshoro", "27405")); + sample.add(new Person("00538491", "marh", "jackdon", "greensboro", "27406")); + sample.add(new Person("01306702", "vonnell", "palmer", "siler sity", "273q4")); + + return sample; + } + + public static List createSampleDataListWithDistinctSurnameAndPostcode() { + List sample = new ArrayList(); + sample.add(new Person("07317257", "erjc", "henson", "hendersonville", "2873g")); + sample.add(new Person("03102490", "jhon", "kozak", "henders0nville", "28792")); + sample.add(new Person("02890805", "david", "pisczek", "durham", "27717")); + sample.add(new Person("04437063", "e5in", "bbrown", "greenville", "27858")); + sample.add(new Person("03211564", "susan", "jones", "greenjboro", "274o7")); + + sample.add(new Person("04155808", "jerome", "wilkins", "battleborn", "2780g")); + sample.add(new Person("05723231", "clarinw", "pastoreus", "elizabeth city", "27909")); + sample.add(new Person("06087743", "william", "craven", "greenshoro", "27405")); + sample.add(new Person("00538491", "marh", "jackdon", "greensboro", "27406")); + sample.add(new Person("01306702", "vonnell", "palmer", "siler sity", "273q4")); + + return sample; + } + + public static List createSampleDataListWithMixedDataType() { + List sample = new ArrayList(); + sample.add(new PersonMixed(7317257, "erjc", "henson", 10.021, 2873)); + sample.add(new PersonMixed(3102490, "jhon", "kozak", 3.2434, 28792)); + sample.add(new PersonMixed(2890805, "david", "pisczek", 5436.0232, 27717)); + sample.add(new PersonMixed(4437063, "e5in", "bbrown", 67.0, 27858)); + sample.add(new PersonMixed(3211564, "susan", "jones", 7343.2324, 2747)); + + sample.add(new PersonMixed(4155808, "jerome", "wilkins", 50.34, 2780)); + sample.add(new PersonMixed(5723231, "clarinw", "pastoreus", 87.2323, 27909)); + sample.add(new PersonMixed(6087743, "william", "craven", 834.123, 27405)); + sample.add(new PersonMixed(538491, "marh", "jackdon", 123.123, 27406)); + sample.add(new PersonMixed(1306702, "vonnell", "palmer", 83.123, 2734)); + + return sample; + } + + public static List createSampleDataZScore() { + + List sample = new ArrayList(); + sample.add(new ClusterZScore(0L, "100", 900.0)); + sample.add(new ClusterZScore(1L, "100", 1001.0)); + sample.add(new ClusterZScore(1L, "100", 1002.0)); + sample.add(new ClusterZScore(1L, "100", 2001.0)); + sample.add(new ClusterZScore(1L, "100", 2002.0)); + sample.add(new ClusterZScore(11L, "100", 9002.0)); + sample.add(new ClusterZScore(3L, "300", 3001.0)); + sample.add(new ClusterZScore(3L, "300", 3002.0)); + sample.add(new ClusterZScore(3L, "400", 4001.0)); + sample.add(new ClusterZScore(4L, "400", 4002.0)); + + return sample; + } + + public static List createSampleDataCluster() { + + List sample = new ArrayList(); + sample.add(new PairPartOne(1L, "100", 1001.0, "b")); + sample.add(new PairPartOne(2L, "100", 1002.0, "a")); + sample.add(new PairPartOne(3L, "100", 2001.0, "b")); + sample.add(new PairPartOne(4L, "900", 2002.0, "c")); + sample.add(new PairPartOne(5L, "111", 9002.0, "c")); + + return sample; + } + + public static List createSampleDataClusterWithNull() { + + List sample = new ArrayList(); + sample.add(new PairPartTwo(1L, "100", 1001.0, "b")); + sample.add(new PairPartTwo(2L, "100", 1002.0, "a")); + sample.add(new PairPartTwo(3L, "100", 2001.0, null)); + sample.add(new PairPartTwo(4L, "900", 2002.0, "c")); + sample.add(new PairPartTwo(5L, "111", 9002.0, null)); + + return sample; + } + + public static List createSampleDataInput() { + + List sample = new ArrayList(); + sample.add(new InputWithZidAndSource(1L, "fname1", "b")); + sample.add(new InputWithZidAndSource(2L, "fname", "a")); + sample.add(new InputWithZidAndSource(3L, "fna", "b")); + sample.add((new InputWithZidAndSource(4L, "x", "c"))); + sample.add(new InputWithZidAndSource(5L, "y", "c")); + sample.add(new InputWithZidAndSource(11L, "new1", "b")); + sample.add(new InputWithZidAndSource(22L, "new12", "a")); + sample.add(new InputWithZidAndSource(33L, "new13", "b")); + sample.add(new InputWithZidAndSource(44L, "new14", "c")); + sample.add(new InputWithZidAndSource(55L, "new15", "c")); + + return sample; + } + +} diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/ClusterZScore.java b/common/core/src/test/java/zingg/common/core/zFrame/model/ClusterZScore.java new file mode 100644 index 000000000..e10788395 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/ClusterZScore.java @@ -0,0 +1,13 @@ +package zingg.common.core.zFrame.model; + +public class ClusterZScore { + public final Long z_zid; + public final String z_cluster; + public final Double z_score; + + public ClusterZScore(Long z_zid, String z_cluster, Double z_score) { + this.z_zid = z_zid; + this.z_cluster = z_cluster; + this.z_score = z_score; + } +} diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/InputWithZidAndSource.java b/common/core/src/test/java/zingg/common/core/zFrame/model/InputWithZidAndSource.java new file mode 100644 index 000000000..a370e9643 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/InputWithZidAndSource.java @@ -0,0 +1,13 @@ +package zingg.common.core.zFrame.model; + +public class InputWithZidAndSource { + public final Long z_zid; + public final String fname; + public final String z_zsource; + + public InputWithZidAndSource(Long z_zid, String fname, String z_zsource) { + this.z_zid = z_zid; + this.fname = fname; + this.z_zsource = z_zsource; + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartOne.java b/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartOne.java new file mode 100644 index 000000000..3f4ef6adc --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartOne.java @@ -0,0 +1,15 @@ +package zingg.common.core.zFrame.model; + +public class PairPartOne { + public final Long z_zid; + public final String z_cluster; + public final Double z_score; + public final String z_zsource; + + public PairPartOne(Long z_zid, String z_cluster, Double z_score, String z_zsource) { + this.z_zid = z_zid; + this.z_cluster = z_cluster; + this.z_score = z_score; + this.z_zsource = z_zsource; + } +} diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartTwo.java b/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartTwo.java new file mode 100644 index 000000000..44e1ccd17 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/PairPartTwo.java @@ -0,0 +1,15 @@ +package zingg.common.core.zFrame.model; + +public class PairPartTwo { + public final Long z_z_zid; + public final String z_cluster; + public final Double z_score; + public final String z_zsource; + + public PairPartTwo(Long z_z_zid, String z_cluster, Double z_score, String z_zsource) { + this.z_z_zid = z_z_zid; + this.z_cluster = z_cluster; + this.z_score = z_score; + this.z_zsource = z_zsource; + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/Person.java b/common/core/src/test/java/zingg/common/core/zFrame/model/Person.java new file mode 100644 index 000000000..d1ea21612 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/Person.java @@ -0,0 +1,17 @@ +package zingg.common.core.zFrame.model; + +public class Person { + public final String recid; + public final String givenname; + public final String surname; + public final String suburb; + public final String postcode; + + public Person(String recid, String givename, String surname, String suburb, String postcode) { + this.recid = recid; + this.givenname = givename; + this.surname = surname; + this.suburb = suburb; + this.postcode = postcode; + } +} diff --git a/common/core/src/test/java/zingg/common/core/zFrame/model/PersonMixed.java b/common/core/src/test/java/zingg/common/core/zFrame/model/PersonMixed.java new file mode 100644 index 000000000..a200c4f49 --- /dev/null +++ b/common/core/src/test/java/zingg/common/core/zFrame/model/PersonMixed.java @@ -0,0 +1,17 @@ +package zingg.common.core.zFrame.model; + +public class PersonMixed { + public final Integer recid; + public final String givenname; + public final String surname; + public final Double cost; + public final Integer postcode; + + public PersonMixed(Integer recid, String givename, String surname, Double cost, Integer postcode) { + this.recid = recid; + this.givenname = givename; + this.surname = surname; + this.cost = cost; + this.postcode = postcode; + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/hash/TestHashFnFromConf.java b/common/core/src/test/java/zingg/hash/TestHashFnFromConf.java new file mode 100644 index 000000000..c99179c5a --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestHashFnFromConf.java @@ -0,0 +1,25 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.HashFnFromConf; + +@JsonInclude(Include.NON_NULL) +public class TestHashFnFromConf { + @Test + public void testHashFnFromConf() { + HashFnFromConf hashFnFromConf = new HashFnFromConf(); + hashFnFromConf.setName("Micheal"); + assertEquals("Micheal", hashFnFromConf.getName()); + } + + @Test + public void testHashFnFromConf1() { + HashFnFromConf hashFnFromConf = new HashFnFromConf(); + hashFnFromConf.setName(null); + assertEquals(null, hashFnFromConf.getName()); + } +} \ No newline at end of file diff --git a/common/core/src/test/java/zingg/hash/TestHashFunction.java b/common/core/src/test/java/zingg/hash/TestHashFunction.java new file mode 100644 index 000000000..1e46142a8 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestHashFunction.java @@ -0,0 +1,150 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.client.ZFrame; +import zingg.common.core.hash.HashFunction; + + +public class TestHashFunction { + @Test + public void testGetName() { + HashFunction hashFunction = new HashFunction("initialName") { + @Override + public ZFrame apply(ZFrame ds, String column, String newColumn) { + return null; + } + + @Override + public Object getAs(Integer integer, String column) { + return null; + } + + @Override + public Object getAs(String s, Integer integer, String column) { + return null; + } + + @Override + public Object apply(Integer integer, String column) { + return null; + } + + @Override + public Object apply(String s, Integer integer, String column) { + return null; + } + }; + + String expectedName = "hashFunction"; + hashFunction.setName(expectedName); + assertEquals(expectedName, hashFunction.getName()); + } + @Test + public void testGetReturnType() { + HashFunction hashFunction = new HashFunction("Name", 999L, 888L) { + @Override + public ZFrame apply(ZFrame ds, String column, String newColumn) { + return null; + } + + @Override + public Object getAs(Integer integer, String column) { + return null; + } + + @Override + public Object getAs(String s, Integer integer, String column) { + return null; + } + + @Override + public Object apply(Integer integer, String column) { + return null; + } + + @Override + public Object apply(String s, Integer integer, String column) { + return null; + } + }; + + long returnType = 9999L; + hashFunction.setReturnType(returnType); + assertEquals(returnType, hashFunction.getReturnType()); + + long dataType = 888L; + hashFunction.setDataType(dataType); + assertEquals(dataType, hashFunction.getDataType()); + } + + @Test + public void testIsUdf() { + HashFunction hashFunction = new HashFunction("Name", 999L, 888L, true) { + @Override + public ZFrame apply(ZFrame ds, String column, String newColumn) { + return null; + } + + @Override + public Object getAs(Integer integer, String column) { + return null; + } + + @Override + public Object getAs(String s, Integer integer, String column) { + return null; + } + + @Override + public Object apply(Integer integer, String column) { + return null; + } + + @Override + public Object apply(String s, Integer integer, String column) { + return null; + } + }; + + Boolean isUdf = false; + hashFunction.setUdf(isUdf); + assertEquals(false, hashFunction.isUdf()); + } + + @Test + public void testGetAs() { + HashFunction hashFunction = new HashFunction() { + @Override + public ZFrame apply(ZFrame ds, String column, String newColumn) { + return null; + } + + @Override + public Object getAs(Integer integer, String column) { + return null; + } + + @Override + public Object getAs(String s, Integer integer, String column) { + return null; + } + + @Override + public Object apply(Integer integer, String column) { + return null; + } + + @Override + public Object apply(String s, Integer integer, String column) { + return null; + } + }; + Integer value = 10; + String column = "inputColumn"; + assertEquals(null, hashFunction.getAs(value, column)); + } + +} diff --git a/common/core/src/test/java/zingg/hash/TestIdentityLong.java b/common/core/src/test/java/zingg/hash/TestIdentityLong.java new file mode 100644 index 000000000..03615df9a --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestIdentityLong.java @@ -0,0 +1,27 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.IdentityLong; + +public class TestIdentityLong { + + @Test + public void testIdentityLong() { + IdentityLong value = getInstance(); + assertEquals(12345L, value.call(12345L)); + } + + @Test + public void testNullValue() { + IdentityLong value = getInstance(); + assertEquals(null, value.call(null)); + } + + private IdentityLong getInstance() { + return new IdentityLong(); + } + +} diff --git a/common/core/src/test/java/zingg/hash/TestLessThanZeroFloat.java b/common/core/src/test/java/zingg/hash/TestLessThanZeroFloat.java new file mode 100644 index 000000000..63bdb5bf3 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestLessThanZeroFloat.java @@ -0,0 +1,40 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.LessThanZeroFloat; + +public class TestLessThanZeroFloat { + + @Test + public void testLessThanZeroFloatForValueZero() { + LessThanZeroFloat value = getInstance(); + assertFalse(value.call(0.0f)); + } + + @Test + public void testLessThanZeroFloatForValueNull() { + LessThanZeroFloat value = getInstance(); + assertFalse(value.call(null)); + } + + @Test + public void testLessThanZeroFloatNegativeValue() { + LessThanZeroFloat value = getInstance(); + assertTrue(value.call(-5435.45f)); + } + + @Test + public void testLessThanZeroFloatPositiveValue() { + LessThanZeroFloat value = getInstance(); + assertFalse(value.call(876.457f)); + } + + private LessThanZeroFloat getInstance() { + LessThanZeroFloat value = new LessThanZeroFloat(); + return value; + } +} diff --git a/common/core/src/test/java/zingg/hash/TestLessThanZeroLong.java b/common/core/src/test/java/zingg/hash/TestLessThanZeroLong.java new file mode 100644 index 000000000..44d161752 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestLessThanZeroLong.java @@ -0,0 +1,40 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.LessThanZeroLong; + +public class TestLessThanZeroLong { + + @Test + public void testLessThanZeroLongForValueZero() { + LessThanZeroLong value = getInstance(); + assertFalse(value.call(0L)); + } + + @Test + public void testLessThanZeroLongForValueNull() { + LessThanZeroLong value = getInstance(); + assertFalse(value.call(null)); + } + + @Test + public void testLessThanZeroLongNegativeValue() { + LessThanZeroLong value = getInstance(); + assertTrue(value.call(-543545L)); + } + + @Test + public void testLessThanZeroLongPositiveValue() { + LessThanZeroLong value = getInstance(); + assertFalse(value.call(876457L)); + } + + private LessThanZeroLong getInstance() { + LessThanZeroLong value = new LessThanZeroLong(); + return value; + } +} diff --git a/common/core/src/test/java/zingg/hash/TestRangeBetween0And10Float.java b/common/core/src/test/java/zingg/hash/TestRangeBetween0And10Float.java new file mode 100644 index 000000000..2b2dfe246 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestRangeBetween0And10Float.java @@ -0,0 +1,72 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.RangeFloat; + +public class TestRangeBetween0And10Float { + + private RangeFloat getInstance() { + return new RangeFloat(0,10); + } + + @Test + public void testRangeForValueZero() { + RangeFloat value = getInstance(); + assertEquals(1, value.call(0f)); + } + + @Test + public void testRangeForNegativeValue() { + Float input = -100f; + RangeFloat value = getInstance(); + assertEquals(0, value.call(input)); + } + + @Test + public void testRangeForVeryHighValue() { + Float input = 99999f; + RangeFloat value = getInstance(); + assertEquals(0, value.call(input)); + } + + @Test + public void testRangeForValue8() { + RangeFloat value = getInstance(); + assertEquals(1, value.call(8f)); + } + + @Test + public void testRangeForValue65() { + RangeFloat value = getInstance(); + assertEquals(0, value.call(65f)); + } + + @Test + public void testRangeForValue867() { + RangeFloat value = getInstance(); + assertEquals(0, value.call(867f)); + } + @Test + public void testRangeForValue8637() { + RangeFloat value = getInstance(); + assertEquals(0, value.call(8637f)); + } + @Test + public void testRangeForNull() { + RangeFloat value = getInstance(); + assertEquals(0, value.call(null)); + } + @Test + public void testRangeForUpperLimit() { + RangeFloat value = getInstance(); + assertEquals(10, value.getUpperLimit()); + } + @Test + public void testRangeForLowerLimit() { + RangeFloat value = getInstance(); + assertEquals(0, value.getLowerLimit()); + } +} diff --git a/common/core/src/test/java/zingg/hash/TestRangeBetween100And1000Long.java b/common/core/src/test/java/zingg/hash/TestRangeBetween100And1000Long.java new file mode 100644 index 000000000..4c3b9cfea --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestRangeBetween100And1000Long.java @@ -0,0 +1,71 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.RangeLong; + +public class TestRangeBetween100And1000Long { + + private RangeLong getInstance() { + return new RangeLong(100L,1000L); + } + + @Test + public void testRangeForValueZero() { + RangeLong value = getInstance(); + assertEquals(0, value.call(0L)); + } + + @Test + public void testRangeForNegativeValue() { + RangeLong value = getInstance(); + assertEquals(0, value.call(-100L)); + } + + @Test + public void testRangeForVeryHighValue() { + RangeLong value = getInstance(); + assertEquals(0, value.call(999999L)); + } + + @Test + public void testRangeForValue8() { + RangeLong value = getInstance(); + assertEquals(0, value.call(8L)); + } + + @Test + public void testRangeForValue65() { + RangeLong value = getInstance(); + assertEquals(0, value.call(65L)); + } + + @Test + public void testRangeForValue867() { + RangeLong value = getInstance(); + assertEquals(1, value.call(867L)); + } + @Test + public void testRangeForValue8637() { + RangeLong value = getInstance(); + assertEquals(0, value.call(8637L)); + } + @Test + public void testRangeForNull() { + RangeLong value = getInstance(); + assertEquals(0, value.call(null)); + } + @Test + public void testRangeForUpperLimit() { + RangeLong value = getInstance(); + assertEquals(1000, value.getUpperLimit()); + } + @Test + public void testRangeForLowerLimit() { + RangeLong value = getInstance(); + assertEquals(100, value.getLowerLimit()); + } + +} diff --git a/common/core/src/test/java/zingg/hash/TestTrimLastDigitsFloat.java b/common/core/src/test/java/zingg/hash/TestTrimLastDigitsFloat.java new file mode 100644 index 000000000..2676dfde4 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestTrimLastDigitsFloat.java @@ -0,0 +1,75 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.TrimLastDigitsFloat; + +public class TestTrimLastDigitsFloat { + + @Test + public void testTrimLast1DigitFloat() { + TrimLastDigitsFloat value = getInstance(1); + assertEquals(54353f, value.call(543534.677f)); + } + + @Test + public void testTrimLast2DigitsFloat() { + TrimLastDigitsFloat value = getInstance(2); + assertEquals(5435f, value.call(543534.677f)); + } + + @Test + public void testTrimLast3DigitsFloat() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(543f, value.call(543534.677f)); + } + + @Test + public void testTrimLast1DigitNegativeFloat() { + TrimLastDigitsFloat value = getInstance(1); + assertEquals(-54354f, value.call(-543534.677f)); + } + + @Test + public void testTrimLast2DigitsNegativeFloat() { + TrimLastDigitsFloat value = getInstance(2); + assertEquals(-5436f, value.call(-543534.677f)); + } + + @Test + public void testTrimLast3DigitsNegativeFloat() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(-544f, value.call(-543534.677f)); + } + + @Test + public void testTrimLast3DigitsFloatNaNValue() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(Float.NaN, value.call(Float.NaN)); + } + + @Test + public void testTrimLast3DigitsFloatNullValue() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(null, value.call(null)); + } + + @Test + public void testTrimLast3DigitsNegativeFloatNaNValue() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(Float.NaN, value.call(Float.NaN)); + } + + @Test + public void testTrimLast3DigitsNegativeFloatNullValue() { + TrimLastDigitsFloat value = getInstance(3); + assertEquals(null, value.call(null)); + } + + private TrimLastDigitsFloat getInstance(int num) { + return new TrimLastDigitsFloat(num); + } + +} diff --git a/common/core/src/test/java/zingg/hash/TestTrimLastDigitsLong.java b/common/core/src/test/java/zingg/hash/TestTrimLastDigitsLong.java new file mode 100644 index 000000000..a8aefc628 --- /dev/null +++ b/common/core/src/test/java/zingg/hash/TestTrimLastDigitsLong.java @@ -0,0 +1,39 @@ +package zingg.hash; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import zingg.common.core.hash.TrimLastDigitsLong; + +public class TestTrimLastDigitsLong { + + @Test + public void testTrimLast1Digit() { + TrimLastDigitsLong value = getInstance(1); + assertEquals(54353L, value.call(543534L)); + } + + @Test + public void testTrimLast2DigitsInt() { + TrimLastDigitsLong value = getInstance(2); + assertEquals(5435L, value.call(543534L)); + } + + @Test + public void testTrimLast3DigitsInt() { + TrimLastDigitsLong value = getInstance(3); + assertEquals(543L, value.call(543534L)); + } + + @Test + public void testTrimLast3DigitsIntNullValue() { + TrimLastDigitsLong value = getInstance(3); + assertEquals(null, value.call(null)); + } + + private TrimLastDigitsLong getInstance(int num) { + return new TrimLastDigitsLong(num); + } + +} \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 8a849e159..b610e9f13 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,7 @@ description: Hope you find us useful :-) This is the latest documentation for Zingg. Release wise documentation can be accessed through: +* [v0.4.1 ](https://docs.zingg.ai/zingg0.4.1/) * [v0.4.0 ](https://docs.zingg.ai/zingg0.4.0/) * [v0.3.4 ](https://docs.zingg.ai/zingg0.3.4/) * [v0.3.3](https://docs.zingg.ai/zingg0.3.3/) diff --git a/docs/connectors/amazons3.md b/docs/connectors/amazons3.md deleted file mode 100644 index 5ee47236f..000000000 --- a/docs/connectors/amazons3.md +++ /dev/null @@ -1,25 +0,0 @@ -# S3 - -1. Set a bucket e.g. zingg28032023 and a folder inside it e.g. zingg - -2. Create aws access key and export via env vars (ensure that the user with below keys has read/write access to above): - -export AWS_ACCESS_KEY_ID= -export AWS_SECRET_ACCESS_KEY= - -(if mfa is enabled AWS_SESSION_TOKEN env var would also be needed ) - -3. Download hadoop-aws-3.1.0.jar and aws-java-sdk-bundle-1.11.271.jar via maven - -4. Set above in zingg.conf : -spark.jars=//hadoop-aws-3.1.0.jar,//aws-java-sdk-bundle-1.11.271.jar - -5. Run using: - - ./scripts/zingg.sh --phase findTrainingData --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg - ./scripts/zingg.sh --phase label --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg - ./scripts/zingg.sh --phase train --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg - ./scripts/zingg.sh --phase match --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg - -6. Models etc. would get saved in -Amazon S3 > Buckets > zingg28032023 >zingg > 100 diff --git a/docs/connectors/aws-s3.md b/docs/connectors/aws-s3.md index f4be12b78..b263139bc 100644 --- a/docs/connectors/aws-s3.md +++ b/docs/connectors/aws-s3.md @@ -1,2 +1,25 @@ # AWS S3 +1. Set a bucket e.g. zingg28032023 and a folder inside it e.g. zingg + +2. Create aws access key and export via env vars (ensure that the user with below keys has read/write access to above): + +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= + +(if mfa is enabled AWS_SESSION_TOKEN env var would also be needed ) + +3. Download hadoop-aws-3.1.0.jar and aws-java-sdk-bundle-1.11.271.jar via maven + +4. Set above in zingg.conf : +spark.jars=//hadoop-aws-3.1.0.jar,//aws-java-sdk-bundle-1.11.271.jar + +5. Run using: + + ./scripts/zingg.sh --phase findTrainingData --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg + ./scripts/zingg.sh --phase label --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg + ./scripts/zingg.sh --phase train --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg + ./scripts/zingg.sh --phase match --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg + +6. Models etc. would get saved in +Amazon S3 > Buckets > zingg28032023 >zingg > 100 diff --git a/docs/dataSourcesAndSinks/amazonS3.md b/docs/dataSourcesAndSinks/amazonS3.md index 5ee47236f..7ccf1f728 100644 --- a/docs/dataSourcesAndSinks/amazonS3.md +++ b/docs/dataSourcesAndSinks/amazonS3.md @@ -1,25 +1,30 @@ # S3 -1. Set a bucket e.g. zingg28032023 and a folder inside it e.g. zingg +Zingg can use AWS S3 as a source and sink -2. Create aws access key and export via env vars (ensure that the user with below keys has read/write access to above): +## Steps to run zingg on S3 -export AWS_ACCESS_KEY_ID= -export AWS_SECRET_ACCESS_KEY= +* Set a bucket e.g. zingg28032023 and a folder inside it e.g. zingg -(if mfa is enabled AWS_SESSION_TOKEN env var would also be needed ) +* Create aws access key and export via env vars (ensure that the user with below keys has read/write access to above) + export AWS_ACCESS_KEY_ID= + export AWS_SECRET_ACCESS_KEY= + (if mfa is enabled AWS_SESSION_TOKEN env var would also be needed ) -3. Download hadoop-aws-3.1.0.jar and aws-java-sdk-bundle-1.11.271.jar via maven +* Download hadoop-aws-3.1.0.jar and aws-java-sdk-bundle-1.11.271.jar via maven -4. Set above in zingg.conf : -spark.jars=//hadoop-aws-3.1.0.jar,//aws-java-sdk-bundle-1.11.271.jar +* Set above in zingg.conf + spark.jars=//hadoop-aws-3.1.0.jar,//aws-java-sdk-bundle-1.11.271.jar -5. Run using: +* Run using below commands +```bash ./scripts/zingg.sh --phase findTrainingData --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg ./scripts/zingg.sh --phase label --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg ./scripts/zingg.sh --phase train --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg ./scripts/zingg.sh --phase match --properties-file config/zingg.conf --conf examples/febrl/config.json --zinggDir s3a://zingg28032023/zingg + ``` -6. Models etc. would get saved in -Amazon S3 > Buckets > zingg28032023 >zingg > 100 + ## Model location + Models etc. would get saved in + Amazon S3 > Buckets > zingg28032023 >zingg > 100 diff --git a/docs/faq.md b/docs/faq.md index 1663a624f..9f30aaa9e 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -22,11 +22,13 @@ Very much! Zingg uses Spark and ML under the hood so that you don't have to worr ## Is Zingg an MDM? -No, Zingg is not an MDM. An MDM is the system of record, it has its own store where linked and mastered records are saved. Zingg enables MDM but is not a system of record. You can build an MDM in a data store of your choice using Zingg however. +An MDM is the system of record, it has its own store where linked and mastered records are saved. Zingg Community Version is not a complete MDM but it can be sed to build an MDM. You can build an MDM in a data store of your choice using Zingg Community Version. Zingg Enterprise Version is a lakehouse/warehouse native MDM. ## Is Zingg a CDP ? -No, Zingg is not a CDP, as it does not stream events or customer data through different channels. Zingg does overlap with the CDPs identity resolution and building customer 360 views. Here is an [article](https://hightouch.com/blog/warehouse-identity-resolution/) describing how you can build your own CDP on the warehouse with Zingg. +No, Zingg is not a CDP, as it does not stream events or customer data through different channels. However, if you want to base your customer platform off your warehouse or datalake, Zing gis a great fit. You can leverage existing ETL, observability and other tools which are part of your data stack and use Zingg for identity. +Zingg Comminity Version can be used to build a composable CDP by identity resolution natively on the warehouse and datalake and building customer 360 views. Zingg's identity resolution is far more powerful than what is offered by any out of the box CDP. +Zingg Enterprise's probabilistic and deterministic matching take this further beyond. Here is an [article](https://hightouch.com/blog/warehouse-identity-resolution/) describing how you can build your own CDP on the warehouse with Zingg. ## I can do Entity Resolution using a graph database like TigerGraph/Neo4J, why do I need Zingg ? diff --git a/docs/generatingdocumentation.md b/docs/generatingdocumentation.md index 11fb09921..90f7e21d9 100644 --- a/docs/generatingdocumentation.md +++ b/docs/generatingdocumentation.md @@ -3,7 +3,7 @@ Zingg generates readable documentation about the training data, including those marked as matches, as well as non-matches. The documentation is written to the zinggDir/modelId folder and can be built using the following ``` -./scripts/zingg.sh --phase generateDocs --conf +./scripts/zingg.sh --phase generateDocs --conf ``` The generated documentation file can be viewed in a browser and looks like as below. diff --git a/docs/settingUpZingg.md b/docs/settingUpZingg.md index 9f1f3982a..9bfef4453 100644 --- a/docs/settingUpZingg.md +++ b/docs/settingUpZingg.md @@ -8,7 +8,7 @@ sudo apt update **** -_**Step 0 : Install Ubuntu on WSL2 on Windows**_ +**Step 0 : Install Ubuntu on WSL2 on Windows** * Install wsl: Type the following command in **Windows PowerShell**. ``` @@ -24,31 +24,31 @@ sudo apt update **** -_**Step 1 : Clone the Zingg Repository**_ +**Step 1 : Clone the Zingg Repository** * Install and SetUp Git: **sudo apt install git** * Verify : **git --version** * Set up Git by following the [tutorial](https://www.digitalocean.com/community/tutorials/how-to-install-git-on-ubuntu-20-04). * Clone the Zingg Repository: **git clone https://github.com/zinggAI/zingg.git** -_**Note :-**_ It is suggested to fork the repository to your account and then clone the repository. +**Note :-** It is suggested to fork the repository to your account and then clone the repository. **** -_**Step 2 : Install JDK 1.8 (Java Development Kit)**_ +**Step 2 : Install JDK 11 (Java Development Kit)** -* Follow this [tutorial](https://linuxize.com/post/install-java-on-ubuntu-20-04/) to install Java8 JDK1.8 in Ubuntu. +* Follow this [tutorial](https://linuxize.com/post/install-java-on-ubuntu-20-04/) to install Java11 JDK11 in Ubuntu. * For example: ``` -sudo apt install openjdk-8-jdk openjdk-8-jre +sudo apt install openjdk-11-jdk openjdk-11-jre javac -version java -version ``` **** -_**Step 3 : Install Apache Spark -**_ +**Step 3 : Install Apache Spark -** * Download Apache Spark - from the [Apache Spark Official Website](https://spark.apache.org/downloads.html). * Install downloaded Apache Spark - on your Ubuntu by following [this tutorial](https://computingforgeeks.com/how-to-install-apache-spark-on-ubuntu-debian/). @@ -63,11 +63,11 @@ sudo mv spark-3.5.0-bin-hadoop3 /opt/spark Make sure that spark version you have installed is compatible with java you have installed, and Zingg is supporting those versions. -_**Note :-**_ Zingg supports Spark 3.5 and the corresponding Java version. +**Note :-** Zingg supports Spark 3.5 and the corresponding Java version. **** -_**Step 4 : Install Apache Maven**_ +**Step 4 : Install Apache Maven** * Install the latest maven package. @@ -79,66 +79,97 @@ rm -rf apache-maven-3.8.8-bin.tar.gz cd apache-maven-3.8.8/ cd bin ./mvn --version + +Make sure that mvn -version should display correct java version as well(JAVA 11) +Apache Maven 3.8.7 +Maven home: /usr/share/maven +Java version: 11.0.23, vendor: Ubuntu, runtime: /usr/lib/jvm/java-11-openjdk-amd64 ``` **** -_**Step 5 : Update Env Variables**_ +**Step 5 : Update Env Variables** -Open .bashrc and add env variables at end of file +* Open .bashrc and add env variables at end of file ``` vim ~/.bashrc - export SPARK_HOME=/opt/spark export SPARK_MASTER=local[\*] export MAVEN_HOME=/home/ubuntu/apache-maven-3.8.8 -export PATH=$PATH:$SPARK_HOME/bin:$SPARK_HOME/sbin:$MAVEN_HOME/bin -export ZINGG_HOME=/zingg/assembly/target -export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 +export ZINGG_HOME=/assembly/target +export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 +export PATH=$PATH:$SPARK_HOME/bin:$SPARK_HOME/sbin:$JAVA_HOME/bin + +``` +\ will be a directory where you clone the repository of the Zingg. Similarly, if you have installed spark on a different directory you can set **SPARK\_HOME** accordingly. -Save/exit and do source .bashrc so that they reflect +**Note :-** Skip exporting MAVEN_HOME if multiple maven version are not required +* Save/exit and do source .bashrc so that they reflect +``` source ~/.bashrc +``` -Verify: +* Verify: +``` echo $PATH mvn --version -``` -where \ will be a directory where you clone the repository of the Zingg. Similarly, if you have installed spark on a different directory you can set **SPARK\_HOME** accordingly. +``` -_**Note :-**_ If you have already set up **JAVA\_HOME** and **SPARK\_HOME** in the steps before you don't need to do this again. +**Note :-** If you have already set up **JAVA\_HOME** and **SPARK\_HOME** in the steps before you don't need to do this again. **** -_**Step 6 : Compile the Zingg Repository**_ +**Step 6 : Compile the Zingg Repository** -* Run the following to Compile the Zingg Repository - +* Ensure you are on main branch ``` git branch -(Ensure you are on main branch) + +``` + +* Run the following to Compile the Zingg Repository +``` mvn initialize -* Run the following to Compile the Zingg Repository - **mvn initialize** and -* **mvn clean compile package -Dspark=sparkVer** +mvn clean compile package -Dspark=sparkVer ``` -_**Note :-**_ Replace the **sparkVer** with the version of spark you installed, For example, **-Dspark=3.5** and if still facing error, include **-Dmaven.test.skip=true** with the above command. +* Run the following to Compile while skipping tests +``` +mvn initialize +mvn clean compile package -Dspark=sparkVer -Dmaven.test.skip=true +``` +**Note :-** Replace the **sparkVer** with the version of spark you installed, For example, **-Dspark=3.5** and if still facing error, exclude tests while compiling. -_**Note :-**_ substitute 3.3 with profile of the spark version you have installed. This is based on profiles specified in pom.xml + +**Note :-** substitute 3.3 with profile of the spark version you have installed. This is based on profiles specified in pom.xml **** -_**Step 7 : If had any issue with 'SPARK\_LOCAL\_IP'**_ +**Step 7 : If had any issue with 'SPARK\_LOCAL\_IP'** + +* Install **net-tools** +``` +sudo apt-get install -y net-tools +``` + +* Run command in the terminal to get IP address +``` +ifconfig +``` -* Install **net-tools** using **sudo apt-get install -y net-tools** -* Run command in the terminal **ifconfig**, find the **IP address** and paste the same in **/opt/hosts** IP address of your Pc-Name +* Paste the IP in **/opt/hosts** IP address of your Pc-Name **** -_**Step 8 : Run Zingg to Find Training Data**_ +**Step 8 : Run Zingg to Find Training Data** -* Run this Script in terminal opened in zingg clones directory - **./scripts/zingg.sh --phase findTrainingData --conf examples/febrl/config.json** +* Run this Script in terminal opened in zingg clones directory - +``` +./scripts/zingg.sh --phase findTrainingData --conf examples/febrl/config.json +``` **** -**If everything is right, it should show Zingg Icon.** +**If everything is right, it should show Zingg banner.** diff --git a/docs/setup/match.md b/docs/setup/match.md index 050d95fa4..0e1fd35d7 100644 --- a/docs/setup/match.md +++ b/docs/setup/match.md @@ -14,4 +14,4 @@ As can be seen in the image below, matching records are given the same z_cluster ![Match results](/assets/match.gif) -If records across multiple sources have to be matched, the [link phase](./link.md) should be used. \ No newline at end of file +If records across multiple sources have to be matched, the [link phase](./link.md) should be used. diff --git a/docs/stepbystep/configuration/field-definitions.md b/docs/stepbystep/configuration/field-definitions.md index d2bd5c291..6c0376983 100644 --- a/docs/stepbystep/configuration/field-definitions.md +++ b/docs/stepbystep/configuration/field-definitions.md @@ -32,12 +32,12 @@ Type of the column - string, integer, double, etc. | Match Type | Description | Can be applied to | | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------- | -| FUZZY | Broad matches with typos, abbreviations, and other variations. | string, integer, double, date | -| EXACT | No tolerance with variations, Preferable for country codes, pin codes, and other categorical variables where you expect no variations. | string | +| FUZZY | Broad matches with typos, abbreviations, and other variations. | string, integer, long, double, date | +| EXACT | No tolerance with variations, Preferable for country codes, pin codes, and other categorical variables where you expect no variations. | string, integer, long, date | | DONT\_USE | Appears in the output but no computation is done on these. Helpful for fields like ids that are required in the output. DONT\_USE fields are not shown to the user while labeling, if [showConcise](field-definitions.md#showconcise) is set to true. | any | | EMAIL | Matches only the id part of the email before the @ character | any | | PINCODE | Matches pin codes like xxxxx-xxxx with xxxxx | string | -| NULL\_OR\_BLANK | By default Zingg treats nulls as matches, but if we add this to a field which has other match type like FUZZY, Zingg will build a feature for null values and learn | string | +| NULL\_OR\_BLANK | By default Zingg treats nulls as matches, but if we add this to a field which has other match type like FUZZY, Zingg will build a feature for null values and learn | string, integer, long, date | | TEXT | Compares words overlap between two strings. Good for descriptive fields without much typos | string | | NUMERIC | extracts numbers from strings and compares how many of them are same across both strings, for example apartment numbers. | string | | NUMERIC\_WITH\_UNITS | extracts product codes or numbers with units, for example 16gb from strings and compares how many are same across both strings | string | diff --git a/docs/stepbystep/installation/compiling-from-source.md b/docs/stepbystep/installation/compiling-from-source.md index 1de4f32fa..d59f15e3b 100644 --- a/docs/stepbystep/installation/compiling-from-source.md +++ b/docs/stepbystep/installation/compiling-from-source.md @@ -7,7 +7,7 @@ description: For a different Spark version or compiling latest code If you need to compile the latest code or build for a different Spark version, you can clone this repo and * Install maven -* Install JDK 1.8 +* Install JDK 11 * Set JAVA\_HOME to JDK base directory * Run the following: `mvn initialize` and then `mvn clean compile package` diff --git a/docs/stepbystep/installation/docker/README.md b/docs/stepbystep/installation/docker/README.md index 13a9f1b30..afce0e9ac 100644 --- a/docs/stepbystep/installation/docker/README.md +++ b/docs/stepbystep/installation/docker/README.md @@ -9,8 +9,12 @@ description: From pre-built Docker image with all dependencies included The easiest way to get started is to pull the Docker image with the last release of Zingg. ``` -docker pull zingg/zingg:0.4.0 -docker run -it zingg/zingg:0.4.0 bash +docker pull zingg/zingg:0.4.1-SNAPSHOT +docker run -it zingg/zingg:0.4.1-SNAPSHOT bash +``` +In case of permission denied, try mapping /tmp of docker with user's machine /tmp +``` +docker run -v /tmp:/tmp -it zingg/zingg:0.4.0 bash ``` To know more about Docker, please refer to the official [docker documentation](https://docs.docker.com/). diff --git a/docs/stepbystep/installation/docker/file-read-write-permissions.md b/docs/stepbystep/installation/docker/file-read-write-permissions.md index 0c5ec43f0..a422f3b6a 100644 --- a/docs/stepbystep/installation/docker/file-read-write-permissions.md +++ b/docs/stepbystep/installation/docker/file-read-write-permissions.md @@ -9,5 +9,5 @@ A docker image is preferred to run with a non-root user. By default, the Zingg c ``` $ id uid=1000(abc) gid=1000(abc) groups=1000(abc) -$ docker run -u -it zingg/zingg:0.4.0 bash +$ docker run -u -it zingg/zingg:0.4.1-SNAPSHOT bash ``` diff --git a/docs/stepbystep/installation/docker/sharing-custom-data-and-config-files.md b/docs/stepbystep/installation/docker/sharing-custom-data-and-config-files.md index c85a615a7..c81fd3c95 100644 --- a/docs/stepbystep/installation/docker/sharing-custom-data-and-config-files.md +++ b/docs/stepbystep/installation/docker/sharing-custom-data-and-config-files.md @@ -7,7 +7,7 @@ description: Using custom data to save data files on host machine However, note that once the docker container is stopped, all the work done in that session is lost. If we want to use custom data or persist the generated model or data files, we have to use **Volumes** or **Bind mount** to share files between the two. ``` -docker run -v : -it zingg/zingg:0.4.0 bash +docker run -v : -it zingg/zingg:0.4.1-SNAPSHOT bash ``` The **\** directory from host will get mounted inside container at **\**. Any file written inside this directory will persist on the host machine and can be reused in a new container instance later. diff --git a/docs/stepbystep/installation/installing-from-release/README.md b/docs/stepbystep/installation/installing-from-release/README.md index 66180845a..41df9e958 100644 --- a/docs/stepbystep/installation/installing-from-release/README.md +++ b/docs/stepbystep/installation/installing-from-release/README.md @@ -8,7 +8,7 @@ Zingg is prebuilt for common Spark versions so that you can use those directly. ## Prerequisites -A) Java JDK - version "1.8.0\_131" +A) Java JDK - version "11.0.23" B) Apache Spark - version spark-3.5.0-bin-hadoop3 diff --git a/docs/stepbystep/installation/installing-from-release/installing-zingg.md b/docs/stepbystep/installation/installing-from-release/installing-zingg.md index 79bf1fc70..3308f6a47 100644 --- a/docs/stepbystep/installation/installing-from-release/installing-zingg.md +++ b/docs/stepbystep/installation/installing-from-release/installing-zingg.md @@ -6,13 +6,13 @@ description: Downloading and setting things up Download the tar zingg-version.tar.gz from the [Zingg releases page](https://github.com/zinggAI/zingg/releases) to a folder of your choice and run the following: -> gzip -d zingg-0.4.0.tar.gz ; tar xvf zingg-0.4.0.tar +> gzip -d zingg-0.4.1-SNAPSHOT.tar.gz ; tar xvf zingg-0.4.1-SNAPSHOT.tar -This will create a folder zingg-0.4.0 under the chosen folder. +This will create a folder zingg-0.4.1-SNAPSHOT under the chosen folder. Move the above folder to zingg. -> mv zingg-0.4.0 \~/zingg +> mv zingg-0.4.1-SNAPSHOT \~/zingg > export ZINGG\_HOME=path to zingg diff --git a/examples/febrl120k/config120k.json b/examples/febrl120k/config120k.json index 235738655..e16f979d7 100644 --- a/examples/febrl120k/config120k.json +++ b/examples/febrl120k/config120k.json @@ -2,7 +2,7 @@ "fieldDefinition":[ { "fieldName" : "fname", - "matchType" : "email", + "matchType" : "fuzzy", "fields" : "fname", "dataType": "string" }, diff --git a/log4j2.properties b/log4j2.properties index 6a7dbc16a..f007411fd 100644 --- a/log4j2.properties +++ b/log4j2.properties @@ -49,3 +49,7 @@ logger.zingg.name = zingg logger.zingg.level = info logger.zingg_analytics.name = zingg.common.core.util.Analytics logger.zingg_analytics.level = off +logger.codegen.name = org.apache.spark.sql.catalyst.expressions +logger.codegen.level = OFF +logger.codehaus.name = org.codehaus +logger.codehaus.level = OFF diff --git a/models/100/model/block/zingg.block/.part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet.crc b/models/100/model/block/zingg.block/.part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet.crc new file mode 100644 index 000000000..b29cacf2b Binary files /dev/null and b/models/100/model/block/zingg.block/.part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet.crc differ diff --git a/models/100/model/block/zingg.block/part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet b/models/100/model/block/zingg.block/part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet new file mode 100644 index 000000000..29c166347 Binary files /dev/null and b/models/100/model/block/zingg.block/part-00000-dd2a5cbc-5ce2-4e52-bbde-6d1f12a78595-c000.snappy.parquet differ diff --git a/models/100/model/classifier/best.model/bestModel/metadata/.part-00000.crc b/models/100/model/classifier/best.model/bestModel/metadata/.part-00000.crc index 279e53973..4b35f364c 100644 Binary files a/models/100/model/classifier/best.model/bestModel/metadata/.part-00000.crc and b/models/100/model/classifier/best.model/bestModel/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/metadata/part-00000 b/models/100/model/classifier/best.model/bestModel/metadata/part-00000 index 555df45c8..2925f2a56 100644 --- a/models/100/model/classifier/best.model/bestModel/metadata/part-00000 +++ b/models/100/model/classifier/best.model/bestModel/metadata/part-00000 @@ -1 +1,9 @@ +<<<<<<< HEAD +<<<<<<< HEAD +{"class":"org.apache.spark.ml.PipelineModel","timestamp":1700828463622,"sparkVersion":"3.3.2","uid":"pipeline_222983df78ea","paramMap":{"stageUids":["vecAssembler_af13a6e17960","poly_6c75dcbeb3d8","logreg_ebb48ef03274"]},"defaultParamMap":{}} +======= +{"class":"org.apache.spark.ml.PipelineModel","timestamp":1701962009518,"sparkVersion":"3.5.0","uid":"pipeline_7a94093bd54d","paramMap":{"stageUids":["vecAssembler_2d811a17d67b","poly_8172f8362a50","logreg_48335424eea0"]},"defaultParamMap":{}} +>>>>>>> 0.4.1 +======= {"class":"org.apache.spark.ml.PipelineModel","timestamp":1704103113892,"sparkVersion":"3.5.0","uid":"pipeline_0e5a2963bd3a","paramMap":{"stageUids":["vecAssembler_a54372e1f087","poly_ae0630949c53","logreg_3a0c5548d511"]},"defaultParamMap":{}} +>>>>>>> 0.4.1 diff --git a/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc new file mode 100644 index 000000000..597920eeb Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/_SUCCESS b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 new file mode 100644 index 000000000..f02ea04e1 --- /dev/null +++ b/models/100/model/classifier/best.model/bestModel/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.feature.VectorAssembler","timestamp":1700828463719,"sparkVersion":"3.3.2","uid":"vecAssembler_af13a6e17960","paramMap":{"inputCols":["z_sim0","z_sim1","z_sim2","z_sim3","z_sim4","z_sim5","z_sim6","z_sim7","z_sim8","z_sim9","z_sim10","z_sim11","z_sim12","z_sim13","z_sim14","z_sim15","z_sim16","z_sim17","z_sim18","z_sim19"],"outputCol":"z_featurevector"},"defaultParamMap":{"handleInvalid":"error","outputCol":"vecAssembler_af13a6e17960__output"}} diff --git a/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc new file mode 100644 index 000000000..26c11214c Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/_SUCCESS b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 new file mode 100644 index 000000000..72fca43f0 --- /dev/null +++ b/models/100/model/classifier/best.model/bestModel/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.feature.PolynomialExpansion","timestamp":1700828463830,"sparkVersion":"3.3.2","uid":"poly_6c75dcbeb3d8","paramMap":{"outputCol":"z_feature","inputCol":"z_featurevector","degree":3},"defaultParamMap":{"outputCol":"poly_6c75dcbeb3d8__output","degree":2}} diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/._SUCCESS.crc b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/.part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet.crc b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/.part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet.crc new file mode 100644 index 000000000..2abc57f6e Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/.part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/_SUCCESS b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet new file mode 100644 index 000000000..e6f265613 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/data/part-00000-be6a3806-a5f2-413d-83a5-a023c61fe823-c000.snappy.parquet differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc new file mode 100644 index 000000000..d0912c8c1 Binary files /dev/null and b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/_SUCCESS b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/part-00000 b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/part-00000 new file mode 100644 index 000000000..6b674e1e4 --- /dev/null +++ b/models/100/model/classifier/best.model/bestModel/stages/2_logreg_ebb48ef03274/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.classification.LogisticRegressionModel","timestamp":1700828463929,"sparkVersion":"3.3.2","uid":"logreg_ebb48ef03274","paramMap":{"regParam":1.0E-4,"maxIter":100,"threshold":0.4,"featuresCol":"z_feature","predictionCol":"z_prediction","labelCol":"z_isMatch","fitIntercept":true,"probabilityCol":"z_probability"},"defaultParamMap":{"regParam":0.0,"maxIter":100,"threshold":0.5,"family":"auto","featuresCol":"features","predictionCol":"prediction","standardization":true,"labelCol":"label","fitIntercept":true,"probabilityCol":"probability","rawPredictionCol":"rawPrediction","maxBlockSizeInMB":0.0,"elasticNetParam":0.0,"aggregationDepth":2,"tol":1.0E-6}} diff --git a/models/100/model/classifier/best.model/estimator/metadata/.part-00000.crc b/models/100/model/classifier/best.model/estimator/metadata/.part-00000.crc index 6d197441c..9c43b4e28 100644 Binary files a/models/100/model/classifier/best.model/estimator/metadata/.part-00000.crc and b/models/100/model/classifier/best.model/estimator/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/estimator/metadata/part-00000 b/models/100/model/classifier/best.model/estimator/metadata/part-00000 index 67adc8d2f..e0bbdfeac 100644 --- a/models/100/model/classifier/best.model/estimator/metadata/part-00000 +++ b/models/100/model/classifier/best.model/estimator/metadata/part-00000 @@ -1 +1,9 @@ +<<<<<<< HEAD +<<<<<<< HEAD +{"class":"org.apache.spark.ml.Pipeline","timestamp":1700828463006,"sparkVersion":"3.3.2","uid":"pipeline_222983df78ea","paramMap":{"stageUids":["vecAssembler_af13a6e17960","poly_6c75dcbeb3d8","logreg_ebb48ef03274"]},"defaultParamMap":{}} +======= +{"class":"org.apache.spark.ml.Pipeline","timestamp":1701962008875,"sparkVersion":"3.5.0","uid":"pipeline_7a94093bd54d","paramMap":{"stageUids":["vecAssembler_2d811a17d67b","poly_8172f8362a50","logreg_48335424eea0"]},"defaultParamMap":{}} +>>>>>>> 0.4.1 +======= {"class":"org.apache.spark.ml.Pipeline","timestamp":1704103113279,"sparkVersion":"3.5.0","uid":"pipeline_0e5a2963bd3a","paramMap":{"stageUids":["vecAssembler_a54372e1f087","poly_ae0630949c53","logreg_3a0c5548d511"]},"defaultParamMap":{}} +>>>>>>> 0.4.1 diff --git a/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc new file mode 100644 index 000000000..5df4e78e9 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/_SUCCESS b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 new file mode 100644 index 000000000..13c941ae9 --- /dev/null +++ b/models/100/model/classifier/best.model/estimator/stages/0_vecAssembler_af13a6e17960/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.feature.VectorAssembler","timestamp":1700828463176,"sparkVersion":"3.3.2","uid":"vecAssembler_af13a6e17960","paramMap":{"inputCols":["z_sim0","z_sim1","z_sim2","z_sim3","z_sim4","z_sim5","z_sim6","z_sim7","z_sim8","z_sim9","z_sim10","z_sim11","z_sim12","z_sim13","z_sim14","z_sim15","z_sim16","z_sim17","z_sim18","z_sim19"],"outputCol":"z_featurevector"},"defaultParamMap":{"handleInvalid":"error","outputCol":"vecAssembler_af13a6e17960__output"}} diff --git a/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc new file mode 100644 index 000000000..e1a72b969 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/_SUCCESS b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 new file mode 100644 index 000000000..bfd94d63b --- /dev/null +++ b/models/100/model/classifier/best.model/estimator/stages/1_poly_6c75dcbeb3d8/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.feature.PolynomialExpansion","timestamp":1700828463350,"sparkVersion":"3.3.2","uid":"poly_6c75dcbeb3d8","paramMap":{"outputCol":"z_feature","degree":3,"inputCol":"z_featurevector"},"defaultParamMap":{"outputCol":"poly_6c75dcbeb3d8__output","degree":2}} diff --git a/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc new file mode 100644 index 000000000..3b7b04493 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/._SUCCESS.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc new file mode 100644 index 000000000..ceb655f13 Binary files /dev/null and b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/.part-00000.crc differ diff --git a/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/_SUCCESS b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/_SUCCESS new file mode 100644 index 000000000..e69de29bb diff --git a/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/part-00000 b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/part-00000 new file mode 100644 index 000000000..817a088a4 --- /dev/null +++ b/models/100/model/classifier/best.model/estimator/stages/2_logreg_ebb48ef03274/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.classification.LogisticRegression","timestamp":1700828463513,"sparkVersion":"3.3.2","uid":"logreg_ebb48ef03274","paramMap":{"maxIter":100,"featuresCol":"z_feature","predictionCol":"z_prediction","labelCol":"z_isMatch","fitIntercept":true,"probabilityCol":"z_probability"},"defaultParamMap":{"regParam":0.0,"maxIter":100,"threshold":0.5,"family":"auto","featuresCol":"features","predictionCol":"prediction","standardization":true,"labelCol":"label","fitIntercept":true,"probabilityCol":"probability","rawPredictionCol":"rawPrediction","maxBlockSizeInMB":0.0,"elasticNetParam":0.0,"aggregationDepth":2,"tol":1.0E-6}} diff --git a/pom.xml b/pom.xml index d4370e1e2..f1d519e9c 100644 --- a/pom.xml +++ b/pom.xml @@ -55,7 +55,7 @@ - 3.5.0 + 3.5.2 2.12.10 3.5 2.12 @@ -64,11 +64,11 @@ - 0.4.0 + 0.4.1-SNAPSHOT false false - 8 - 8 + 11 + 11 UTF-8 2.10 2.5.2 @@ -89,7 +89,11 @@ SparkPackagesRepo https://repos.spark-packages.org/ - + + + Apache Snapshots + https://repository.apache.org/snapshots/ + @@ -152,6 +156,18 @@ + + org.apache.maven.plugins + maven-jar-plugin + ${maven-jar-plugin.version} + + + + test-jar + + + + maven-compiler-plugin ${maven-compiler-plugin.version} @@ -247,7 +263,6 @@ - - + diff --git a/protobuf/connect_plugins.proto b/protobuf/connect_plugins.proto new file mode 100644 index 000000000..94382746b --- /dev/null +++ b/protobuf/connect_plugins.proto @@ -0,0 +1,93 @@ +syntax = 'proto3'; + +option java_multiple_files = true; +option java_package = "zingg.spark.connect.proto"; + +message SubmitZinggJob { + Arguments argumnets = 1; + ClientOptions cli_options = 2; + // The next message is a serialized LogicalPlan + optional bytes in_memory_date = 3; +} + +enum MatchType { + MT_FUZZY = 0; + MT_EXACT = 1; + MT_DONT_USE = 2; + MT_EMAIL = 3; + MT_PINCODE = 4; + MT_NULL_OR_BLANK = 5; + MT_TEXT = 6; + MT_NUMERIC = 7; + MT_NUMERIC_WITH_UNITS = 8; + MT_ONLY_ALPHABETS_EXACT = 9; + MT_ONLY_ALPHABETS_FUZZY = 10; +} + +enum DataFormat { + DF_CSV = 0; + DF_PARQUET = 1; + DF_JSON = 2; + DF_TEXT = 3; + DF_XLS = 4; + DF_AVRO = 5; + DF_JDBC = 6; + DF_CASSANDRA = 7; + DF_SNOWFLAKE = 8; + DF_ELASTIC = 9; + DF_EXACOL = 10; + DF_BIGQUEY = 11; + DF_INMEMORY = 12; +} + +message FieldDefinition { + repeated MatchType match_type = 1; + string data_type = 2; + string field_name = 3; + string fields = 4; + optional string stop_words = 5; + optional string abbreviations = 6; +} + +message Pipe { + string name = 1; + DataFormat format = 2; + map props = 3; + optional string schema_field = 4; + optional string mode = 5; +} + +message Arguments { + repeated Pipe output = 1; + repeated Pipe data = 2; + string zingg_dir = 3; + repeated Pipe training_samples = 4; + repeated FieldDefinition fiield_definition = 5; + int32 num_partitions = 6; + float label_data_sample_size = 7; + string model_id = 8; + float threshold = 9; + int32 job_id = 10; + bool collect_metrics = 11; + bool show_concise = 12; + float stop_words_cutoff = 13; + int64 block_size = 14; + optional string column = 15; +} + +message ClientOptions { + optional string phase = 1; + optional string license = 2; + optional string email = 3; + optional string conf = 4; + optional string preprocess = 5; + optional string job_id = 6; + optional string format = 7; + optional string zingg_dir = 8; + optional string model_id = 9; + optional string collect_metrics = 10; + optional string show_concise = 11; + optional string location = 12; + optional string column = 13; + optional string remote = 14; +} diff --git a/python/PKG-INFO b/python/PKG-INFO index fd12b2c4a..dfff445f2 100644 --- a/python/PKG-INFO +++ b/python/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: zingg -Version: 0.4.0 +Version: 0.4.1 Summary: Zingg.ai Entity Resolution Home-page: www.zingg.ai Author: Zingg.AI diff --git a/python/docs/Makefile b/python/docs/Makefile index d4bb2cbb9..9847005dd 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -17,4 +17,6 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile + export ZINGG_DRY_RUN=1 @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + unset ZINGG_DRY_RUN diff --git a/python/docs/conf.py b/python/docs/conf.py index 6be49c6eb..0b6880647 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -23,7 +23,7 @@ author = 'Zingg.AI' # The full version, including alpha/beta/rc tags -release = '0.4.0' +release = '0.4.1-SNAPSHOT' # -- General configuration --------------------------------------------------- diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 000000000..6ce24f6ce --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "zingg" +dynamic = ["version"] +dependencies = [ + "pandas", + "seaborn", + "matplotlib", + "sphinx", + "sphinx-rtd-theme", + "pyspark>=3.5", + "pydantic", +] +readme = "README.md" +requires-python = ">=3.11" + +[tool.ruff] +line-length = 110 diff --git a/python/requirements.txt b/python/requirements.txt index 281558fd4..5cad85992 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -2,7 +2,7 @@ pandas seaborn matplotlib sphinx -sphinx_rtd_theme -pyspark +sphinx-rtd-theme +pyspark[connect]>=3.5.2 +pydantic numpy - diff --git a/python/test_spark_connect.py b/python/test_spark_connect.py new file mode 100644 index 000000000..1fdf00606 --- /dev/null +++ b/python/test_spark_connect.py @@ -0,0 +1,11 @@ +from zingg_v2.client import Zingg, Arguments, ClientOptions +from pyspark.sql.connect.session import SparkSession + + +if __name__ == "__main__": + spark = SparkSession.builder.remote("sc://localhost").getOrCreate() + print(hasattr(spark, "_jvm")) + opts = ClientOptions(None) + args = Arguments.createArgumentsFromJSON(fileName="../examples/febrl/config.json", phase="peekModel") + zingg = Zingg(args=args, options=opts) + zingg.execute() diff --git a/python/version.py b/python/version.py index 405432f9a..774acba9d 100644 --- a/python/version.py +++ b/python/version.py @@ -1,4 +1,4 @@ #!/usr/bin/env python3 -__version__: str = "0.4.0" \ No newline at end of file +__version__: str = "0.4.1" \ No newline at end of file diff --git a/python/zingg/client.py b/python/zingg/client.py index 3c748a978..7061002fd 100644 --- a/python/zingg/client.py +++ b/python/zingg/client.py @@ -4,108 +4,182 @@ This module is the main entry point of the Zingg Python API """ -import logging +from __future__ import annotations import argparse -import pandas as pd -from pyspark.sql import DataFrame +import logging +import os +from typing import Any -from pyspark import SparkConf, SparkContext, SQLContext +import pandas as pd +from pyspark import SparkContext, SQLContext +from pyspark.sql import DataFrame, SparkSession -from py4j.java_collections import SetConverter, MapConverter, ListConverter -from pyspark.sql import SparkSession -import os LOG = logging.getLogger("zingg") _spark_ctxt = None _sqlContext = None _spark = None -_zingg_jar = 'zingg-0.4.0.jar' +_zingg_jar = 'zingg-0.4.1-SNAPSHOT.jar' def initSparkClient(): global _spark_ctxt global _sqlContext - global _spark + global _spark _spark_ctxt = SparkContext.getOrCreate() _sqlContext = SQLContext(_spark_ctxt) _spark = SparkSession.builder.getOrCreate() return 1 + def initDataBricksConectClient(): global _spark_ctxt global _sqlContext - global _spark - jar_path = os.getenv('ZINGG_HOME')+'/'+_zingg_jar - _spark = SparkSession.builder.config('spark.jars', jar_path).getOrCreate() + global _spark + jar_path = os.getenv("ZINGG_HOME") + "/" + _zingg_jar + _spark = SparkSession.builder.config("spark.jars", jar_path).getOrCreate() _spark_ctxt = _spark.sparkContext _sqlContext = SQLContext(_spark_ctxt) return 1 + def initClient(): global _spark_ctxt global _sqlContext - global _spark + global _spark if _spark_ctxt is None: - DATABRICKS_CONNECT = os.getenv('DATABRICKS_CONNECT') - if DATABRICKS_CONNECT=='Y' or DATABRICKS_CONNECT=='y': + DATABRICKS_CONNECT = os.getenv("DATABRICKS_CONNECT") + if DATABRICKS_CONNECT == "Y" or DATABRICKS_CONNECT == "y": return initDataBricksConectClient() else: return initSparkClient() else: return 1 + def getSparkContext(): if _spark_ctxt is None: initClient() return _spark_ctxt + def getSparkSession(): if _spark is None: initClient() return _spark + def getSqlContext(): if _sqlContext is None: initClient() return _sqlContext + def getJVM(): + # TODO: Document this environ variable + is_dry_run = os.environ.get("ZINGG_DRY_RUN", 0) + if is_dry_run: + + class Dummy: + """Dummy class for handling JVM-magick without actual starting of Java""" + + def __init__(self, attrs: dict[str, Any]): + for k, v in attrs.items(): + self.__setattr__(k, v) + + # TODO: replace this magic by Context-like implementation + return Dummy( + { + "org": Dummy({"apache": Dummy({"spark": Dummy({"sql": Dummy({"types": Dummy({"StructType": None})})})})}), + "zingg": Dummy( + { + "common": Dummy( + { + "client": Dummy( + { + "util": Dummy({"ColName": None}), + "MatchType": None, + "ClientOptions": Dummy( + { + "PHASE": None, + "CONF": None, + "LICENSE": None, + "EMAIL": None, + "LOCATION": None, + "REMOTE": None, + "ZINGG_DIR": None, + "MODEL_ID": None, + "COLUMN": None, + } + ), + "ZinggOptions": None, + "pipe": Dummy( + { + "FilePipe": None, + } + ), + } + ), + "core": Dummy({"util": Dummy({"LabelMatchType": None})}), + } + ), + "spark": Dummy( + { + "client": Dummy( + { + "pipe": Dummy( + { + "SparkPipe": None, + } + ) + } + ) + } + ), + } + ), + } + ) return getSparkContext()._jvm + def getGateway(): return getSparkContext()._gateway + ColName = getJVM().zingg.common.client.util.ColName MatchType = getJVM().zingg.common.client.MatchType ClientOptions = getJVM().zingg.common.client.ClientOptions ZinggOptions = getJVM().zingg.common.client.ZinggOptions LabelMatchType = getJVM().zingg.common.core.util.LabelMatchType -UpdateLabelMode = 'Overwrite' +UpdateLabelMode = "Overwrite" + def getDfFromDs(data): - """ Method to convert spark dataset to dataframe + """Method to convert spark dataset to dataframe :param data: provide spark dataset :type data: DataSet :return: converted spark dataframe - :rtype: DataFrame + :rtype: DataFrame """ return DataFrame(data.df(), getSqlContext()) + def getPandasDfFromDs(data): - """ Method to convert spark dataset to pandas dataframe + """Method to convert spark dataset to pandas dataframe :param data: provide spark dataset :type data: DataSet :return: converted pandas dataframe - :rtype: DataFrame + :rtype: DataFrame """ df = getDfFromDs(data) return pd.DataFrame(df.collect(), columns=df.columns) class Zingg: - """ This class is the main point of interface with the Zingg matching product. Construct a client to Zingg using provided arguments and spark master. If running locally, set the master to local. + """This class is the main point of interface with the Zingg matching product. Construct a client to Zingg using provided arguments and spark master. If running locally, set the master to local. :param args: arguments for training and matching :type args: Arguments @@ -118,25 +192,25 @@ def __init__(self, args, options): self.inpArgs = args self.inpOptions = options self.client = getJVM().zingg.spark.client.SparkClient(args.getArgs(), options.getClientOptions()) - + def init(self): - """ Method to initialize zingg client by reading internal configurations and functions """ + """Method to initialize zingg client by reading internal configurations and functions""" self.client.init() def execute(self): - """ Method to execute this class object """ + """Method to execute this class object""" self.client.execute() - + def initAndExecute(self): - """ Method to run both init and execute methods consecutively """ + """Method to run both init and execute methods consecutively""" self.client.init() - DATABRICKS_CONNECT = os.getenv('DATABRICKS_CONNECT') - if DATABRICKS_CONNECT=='Y' or DATABRICKS_CONNECT=='y': + DATABRICKS_CONNECT = os.getenv("DATABRICKS_CONNECT") + if DATABRICKS_CONNECT == "Y" or DATABRICKS_CONNECT == "y": options = self.client.getOptions() inpPhase = options.get(ClientOptions.PHASE).getValue() - if (inpPhase==ZinggOptions.LABEL.getValue()): + if inpPhase == ZinggOptions.LABEL.getValue(): self.executeLabel() - elif (inpPhase==ZinggOptions.UPDATE_LABEL.getValue()): + elif inpPhase == ZinggOptions.UPDATE_LABEL.getValue(): self.executeLabelUpdate() else: self.client.execute() @@ -144,43 +218,48 @@ def initAndExecute(self): self.client.execute() def executeLabel(self): - """ Method to run label phase """ + """Method to run label phase""" self.client.getTrainingDataModel().setMarkedRecordsStat(self.getMarkedRecords()) unmarkedRecords = self.getUnmarkedRecords() - updatedRecords = self.processRecordsCli(unmarkedRecords,self.inpArgs) - self.writeLabelledOutput(updatedRecords,self.inpArgs) + updatedRecords = self.processRecordsCli(unmarkedRecords, self.inpArgs) + self.writeLabelledOutput(updatedRecords, self.inpArgs) def executeLabelUpdate(self): - """ Method to run label update phase """ - self.processRecordsCliLabelUpdate(self.getMarkedRecords(),self.inpArgs) + """Method to run label update phase""" + self.processRecordsCliLabelUpdate(self.getMarkedRecords(), self.inpArgs) def getMarkedRecords(self): - """ Method to get marked record dataset from the inputpipe + """Method to get marked record dataset from the inputpipe :return: spark dataset containing marked records - :rtype: Dataset + :rtype: Dataset """ return self.client.getMarkedRecords() def getUnmarkedRecords(self): - """ Method to get unmarked record dataset from the inputpipe + """Method to get unmarked record dataset from the inputpipe :return: spark dataset containing unmarked records - :rtype: Dataset + :rtype: Dataset """ return self.client.getUnmarkedRecords() - def processRecordsCli(self,unmarkedRecords,args): - """ Method to get user input on unmarked records + def processRecordsCli(self, unmarkedRecords, args): + """Method to get user input on unmarked records :return: spark dataset containing updated records - :rtype: Dataset + :rtype: Dataset """ trainingDataModel = self.client.getTrainingDataModel() labelDataViewHelper = self.client.getLabelDataViewHelper() if unmarkedRecords is not None and unmarkedRecords.count() > 0: - labelDataViewHelper.printMarkedRecordsStat(trainingDataModel.getPositivePairsCount(),trainingDataModel.getNegativePairsCount(),trainingDataModel.getNotSurePairsCount(),trainingDataModel.getTotalCount()) + labelDataViewHelper.printMarkedRecordsStat( + trainingDataModel.getPositivePairsCount(), + trainingDataModel.getNegativePairsCount(), + trainingDataModel.getNotSurePairsCount(), + trainingDataModel.getTotalCount(), + ) unmarkedRecords = unmarkedRecords.cache() displayCols = labelDataViewHelper.getDisplayColumns(unmarkedRecords, args.getArgs()) clusterIdZFrame = labelDataViewHelper.getClusterIdsFrame(unmarkedRecords) @@ -195,37 +274,53 @@ def processRecordsCli(self,unmarkedRecords,args): msg1 = labelDataViewHelper.getMsg1(index, totalPairs) msg2 = labelDataViewHelper.getMsg2(prediction, score) - labelDataViewHelper.displayRecords(labelDataViewHelper.getDSUtil().select(currentPair, displayCols), msg1, msg2) + labelDataViewHelper.displayRecords( + labelDataViewHelper.getDSUtil().select(currentPair, displayCols), + msg1, + msg2, + ) selected_option = input() - while int(selected_option) not in [0,1,2,9]: - print('Please enter valid option') + while int(selected_option) not in [0, 1, 2, 9]: + print("Please enter valid option") selected_option = input("Enter choice: ") if int(selected_option) == 9: print("User has quit in the middle. Updating the records.") - break + break trainingDataModel.updateLabellerStat(int(selected_option), 1) - labelDataViewHelper.printMarkedRecordsStat(trainingDataModel.getPositivePairsCount(),trainingDataModel.getNegativePairsCount(),trainingDataModel.getNotSurePairsCount(),trainingDataModel.getTotalCount()) - updatedRecords = trainingDataModel.updateRecords(int(selected_option), currentPair, updatedRecords) + labelDataViewHelper.printMarkedRecordsStat( + trainingDataModel.getPositivePairsCount(), + trainingDataModel.getNegativePairsCount(), + trainingDataModel.getNotSurePairsCount(), + trainingDataModel.getTotalCount(), + ) + updatedRecords = trainingDataModel.updateRecords(int(selected_option), currentPair, updatedRecords) print("Processing finished.") return updatedRecords else: - print("It seems there are no unmarked records at this moment. Please run findTrainingData job to build some pairs to be labelled and then run this labeler.") + print( + "It seems there are no unmarked records at this moment. Please run findTrainingData job to build some pairs to be labelled and then run this labeler." + ) return None - - def processRecordsCliLabelUpdate(self,lines,args): + + def processRecordsCliLabelUpdate(self, lines, args): trainingDataModel = self.client.getTrainingDataModel() labelDataViewHelper = self.client.getLabelDataViewHelper() - if (lines is not None and lines.count() > 0): + if lines is not None and lines.count() > 0: trainingDataModel.setMarkedRecordsStat(lines) - labelDataViewHelper.printMarkedRecordsStat(trainingDataModel.getPositivePairsCount(),trainingDataModel.getNegativePairsCount(),trainingDataModel.getNotSurePairsCount(),trainingDataModel.getTotalCount()) + labelDataViewHelper.printMarkedRecordsStat( + trainingDataModel.getPositivePairsCount(), + trainingDataModel.getNegativePairsCount(), + trainingDataModel.getNotSurePairsCount(), + trainingDataModel.getTotalCount(), + ) displayCols = labelDataViewHelper.getDSUtil().getFieldDefColumns(lines, args.getArgs(), False, args.getArgs().getShowConcise()) updatedRecords = None recordsToUpdate = lines selectedOption = -1 - while (str(selectedOption) != '9'): + while str(selectedOption) != "9": cluster_id = input("\n\tPlease enter the cluster id (or 9 to exit): ") - if str(cluster_id) == '9': + if str(cluster_id) == "9": print("User has exit in the middle. Updating the records.") break currentPair = lines.filter(lines.equalTo(ColName.CLUSTER_COLUMN, cluster_id)) @@ -233,23 +328,32 @@ def processRecordsCliLabelUpdate(self,lines,args): print("\tInvalid cluster id. Enter '9' to exit") continue - matchFlag = currentPair.getAsInt(currentPair.head(),ColName.MATCH_FLAG_COL) - preMsg = "\n\tThe record pairs belonging to the input cluster id "+cluster_id+" are:" - postMsg = "\tThe above pair is labeled as "+str(matchFlag)+"\n" - labelDataViewHelper.displayRecords(labelDataViewHelper.getDSUtil().select(currentPair, displayCols), preMsg, postMsg) + matchFlag = currentPair.getAsInt(currentPair.head(), ColName.MATCH_FLAG_COL) + preMsg = "\n\tThe record pairs belonging to the input cluster id " + cluster_id + " are:" + postMsg = "\tThe above pair is labeled as " + str(matchFlag) + "\n" + labelDataViewHelper.displayRecords( + labelDataViewHelper.getDSUtil().select(currentPair, displayCols), + preMsg, + postMsg, + ) selectedOption = input() trainingDataModel.updateLabellerStat(int(selectedOption), 1) trainingDataModel.updateLabellerStat(matchFlag, -1) - labelDataViewHelper.printMarkedRecordsStat(trainingDataModel.getPositivePairsCount(),trainingDataModel.getNegativePairsCount(),trainingDataModel.getNotSurePairsCount(),trainingDataModel.getTotalCount()) - - if (str(selectedOption) == '9'): + labelDataViewHelper.printMarkedRecordsStat( + trainingDataModel.getPositivePairsCount(), + trainingDataModel.getNegativePairsCount(), + trainingDataModel.getNotSurePairsCount(), + trainingDataModel.getTotalCount(), + ) + + if str(selectedOption) == "9": print("User has quit in the middle. Updating the records.") break - recordsToUpdate = recordsToUpdate.filter(recordsToUpdate.notEqual(ColName.CLUSTER_COLUMN,cluster_id)) + recordsToUpdate = recordsToUpdate.filter(recordsToUpdate.notEqual(ColName.CLUSTER_COLUMN, cluster_id)) - if (updatedRecords is not None): - updatedRecords = updatedRecords.filter(updatedRecords.notEqual(ColName.CLUSTER_COLUMN,cluster_id)) + if updatedRecords is not None: + updatedRecords = updatedRecords.filter(updatedRecords.notEqual(ColName.CLUSTER_COLUMN, cluster_id)) updatedRecords = trainingDataModel.updateRecords(int(selectedOption), currentPair, updatedRecords) @@ -259,32 +363,32 @@ def processRecordsCliLabelUpdate(self,lines,args): outPipe = trainingDataModel.getOutputPipe(args.getArgs()) outPipe.setMode(UpdateLabelMode) - trainingDataModel.writeLabelledOutput(updatedRecords,args.getArgs(),outPipe) + trainingDataModel.writeLabelledOutput(updatedRecords, args.getArgs(), outPipe) print("Processing finished.") return updatedRecords else: print("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.") return None - - def writeLabelledOutput(self,updatedRecords,args): - """ Method to write updated records after user input - """ + def writeLabelledOutput(self, updatedRecords, args): + """Method to write updated records after user input""" trainingDataModel = self.client.getTrainingDataModel() if updatedRecords is not None: - trainingDataModel.writeLabelledOutput(updatedRecords,args.getArgs()) + trainingDataModel.writeLabelledOutput(updatedRecords, args.getArgs()) - def writeLabelledOutputFromPandas(self,candidate_pairs_pd,args): - """ Method to write updated records (as pandas df) after user input - """ + def writeLabelledOutputFromPandas(self, candidate_pairs_pd, args): + """Method to write updated records (as pandas df) after user input""" markedRecordsAsDS = (getSparkSession().createDataFrame(candidate_pairs_pd))._jdf # pands df gives z_isMatch as long so needs to be cast - markedRecordsAsDS = markedRecordsAsDS.withColumn(ColName.MATCH_FLAG_COL,markedRecordsAsDS.col(ColName.MATCH_FLAG_COL).cast("int")) + markedRecordsAsDS = markedRecordsAsDS.withColumn( + ColName.MATCH_FLAG_COL, + markedRecordsAsDS.col(ColName.MATCH_FLAG_COL).cast("int"), + ) updatedRecords = getJVM().zingg.spark.client.SparkFrame(markedRecordsAsDS) - self.writeLabelledOutput(updatedRecords,args) + self.writeLabelledOutput(updatedRecords, args) def setArguments(self, args): - """ Method to set Arguments + """Method to set Arguments :param args: provide arguments for this class object :type args: Arguments @@ -292,7 +396,7 @@ def setArguments(self, args): self.client.setArguments() def getArguments(self): - """ Method to get atguments of this class object + """Method to get atguments of this class object :return: The pointer containing address of the Arguments object of this class object :rtype: pointer(Arguments) @@ -300,7 +404,7 @@ def getArguments(self): return self.client.getArguments() def getOptions(self): - """ Method to get client options of this class object + """Method to get client options of this class object :return: The pointer containing the address of the ClientOptions object of this class object :rtype: pointer(ClientOptions) @@ -308,56 +412,55 @@ def getOptions(self): return self.client.getOptions() def setOptions(self, options): - """ Method to set atguments of this class object + """Method to set atguments of this class object :param options: provide client options for this class object :type options: ClientOptions :return: The pointer containing address of the ClientOptions object of this class object - :rtype: pointer(ClientOptions) + :rtype: pointer(ClientOptions) """ return self.client.setOptions(options) def getMarkedRecordsStat(self, markedRecords, value): - """ Method to get No. of records that is marked + """Method to get No. of records that is marked :param markedRecords: spark dataset containing marked records :type markedRecords: Dataset :param value: flag value to check if markedRecord is initially matched or not :type value: long :return: The no. of marked records - :rtype: int + :rtype: int """ return self.client.getMarkedRecordsStat(markedRecords, value) def getMatchedMarkedRecordsStat(self): - """ Method to get No. of records that are marked and matched + """Method to get No. of records that are marked and matched :return: The bo. of matched marked records - :rtype: int + :rtype: int """ return self.client.getMatchedMarkedRecordsStat(self.getMarkedRecords()) def getUnmatchedMarkedRecordsStat(self): - """ Method to get No. of records that are marked and unmatched + """Method to get No. of records that are marked and unmatched :return: The no. of unmatched marked records - :rtype: int + :rtype: int """ return self.client.getUnmatchedMarkedRecordsStat(self.getMarkedRecords()) def getUnsureMarkedRecordsStat(self): - """ Method to get No. of records that are marked and Not Sure if its matched or not + """Method to get No. of records that are marked and Not Sure if its matched or not :return: The no. of Not Sure marked records - :rtype: int + :rtype: int """ return self.client.getUnsureMarkedRecordsStat(self.getMarkedRecords()) - class ZinggWithSpark(Zingg): - """ This class is the main point of interface with the Zingg matching product. Construct a client to Zingg using provided arguments and spark master. If running locally, set the master to local. + """This class is the main point of interface with the Zingg matching product. Construct a client to Zingg using provided arguments and spark master. If running locally, set the master to local. :param args: arguments for training and matching :type args: Arguments @@ -369,9 +472,9 @@ class ZinggWithSpark(Zingg): def __init__(self, args, options): self.client = getJVM().zingg.spark.client.SparkClient(args.getArgs(), options.getClientOptions(), getSparkSession()._jsparkSession) - + class Arguments: - """ This class helps supply match arguments to Zingg. There are 3 basic steps in any match process. + """This class helps supply match arguments to Zingg. There are 3 basic steps in any match process. :Defining: specifying information about data location, fields, and our notion of similarity. :training: making Zingg learn the matching rules @@ -382,7 +485,7 @@ def __init__(self): self.args = getJVM().zingg.common.client.Arguments() def setFieldDefinition(self, fieldDef): - """ Method convert python objects to java FieldDefinition objects and set the field definitions associated with this client + """Method convert python objects to java FieldDefinition objects and set the field definitions associated with this client :param fieldDef: python FieldDefinition object list :type fieldDef: List(FieldDefinition) @@ -393,16 +496,16 @@ def setFieldDefinition(self, fieldDef): self.args.setFieldDefinition(javaFieldDef) def getArgs(self): - """ Method to get pointer address of this class + """Method to get pointer address of this class :return: The pointer containing the address of this class object :rtype: pointer(Arguments) - + """ return self.args def setArgs(self, argumentsObj): - """ Method to set this class object + """Method to set this class object :param argumentsObj: Argument object to set this object :type argumentsObj: pointer(Arguments) @@ -410,7 +513,7 @@ def setArgs(self, argumentsObj): self.args = argumentsObj def setData(self, *pipes): - """ Method to set the file path of the file to be matched. + """Method to set the file path of the file to be matched. :param pipes: input data pipes separated by comma e.g. (pipe1,pipe2,..) :type pipes: Pipe[] @@ -421,7 +524,7 @@ def setData(self, *pipes): self.args.setData(dataPipe) def setOutput(self, *pipes): - """ Method to set the output directory where the match result will be saved + """Method to set the output directory where the match result will be saved :param pipes: output data pipes separated by comma e.g. (pipe1,pipe2,..) :type pipes: Pipe[] @@ -430,33 +533,33 @@ def setOutput(self, *pipes): for idx, pipe in enumerate(pipes): outputPipe[idx] = pipe.getPipe() self.args.setOutput(outputPipe) - + def getZinggBaseModelDir(self): return self.args.getZinggBaseModelDir() def getZinggModelDir(self): return self.args.getZinggModelDir() - + def getZinggBaseTrainingDataDir(self): - """ Method to get the location of the folder where Zingg - saves the training data found by findTrainingData + """Method to get the location of the folder where Zingg + saves the training data found by findTrainingData """ return self.args.getZinggBaseTrainingDataDir() def getZinggTrainingDataUnmarkedDir(self): - """ Method to get the location of the folder where Zingg - saves the training data found by findTrainingData + """Method to get the location of the folder where Zingg + saves the training data found by findTrainingData """ return self.args.getZinggTrainingDataUnmarkedDir() - + def getZinggTrainingDataMarkedDir(self): - """ Method to get the location of the folder where Zingg - saves the marked training data labeled by the user + """Method to get the location of the folder where Zingg + saves the marked training data labeled by the user """ return self.args.getZinggTrainingDataMarkedDir() - + def setTrainingSamples(self, *pipes): - """ Method to set existing training samples to be matched. + """Method to set existing training samples to be matched. :param pipes: input training data pipes separated by comma e.g. (pipe1,pipe2,..) :type pipes: Pipe[] @@ -467,19 +570,18 @@ def setTrainingSamples(self, *pipes): self.args.setTrainingSamples(dataPipe) def setModelId(self, id): - """ Method to set the output directory where the match output will be saved + """Method to set the output directory where the match output will be saved - :param id: model id value + :param id: model id value :type id: String """ self.args.setModelId(id) - + def getModelId(self): return self.args.getModelId() - def setZinggDir(self, f): - """ Method to set the location for Zingg to save its internal computations and models. Please set it to a place where the program has to write access. + """Method to set the location for Zingg to save its internal computations and models. Please set it to a place where the program has to write access. :param f: Zingg directory name of the models :type f: String @@ -487,7 +589,7 @@ def setZinggDir(self, f): self.args.setZinggDir(f) def setNumPartitions(self, numPartitions): - """ Method to set NumPartitions parameter value + """Method to set NumPartitions parameter value Sample size to use for seeding labeled data We don't want to run over all the data, as we want a quick way to seed some labeled data that we can manually edit :param numPartitions: number of partitions for given data pipes @@ -496,7 +598,7 @@ def setNumPartitions(self, numPartitions): self.args.setNumPartitions(numPartitions) def setLabelDataSampleSize(self, labelDataSampleSize): - """ Method to set labelDataSampleSize parameter value + """Method to set labelDataSampleSize parameter value Set the fraction of data to be used from the complete data set to be used for seeding the labeled data Labelling is costly and we want a fast approximate way of looking at a small sample of the records and identifying expected matches and nonmatches :param labelDataSampleSize: value between 0.0 and 1.0 denoting portion of dataset to use in generating seed samples @@ -505,7 +607,7 @@ def setLabelDataSampleSize(self, labelDataSampleSize): self.args.setLabelDataSampleSize(labelDataSampleSize) def writeArgumentsToJSON(self, fileName): - """ Method to write JSON file from the object of this class + """Method to write JSON file from the object of this class :param fileName: The CONF parameter value of ClientOption object or file address of json file :type fileName: String @@ -513,16 +615,16 @@ def writeArgumentsToJSON(self, fileName): getJVM().zingg.common.client.ArgumentsUtil().writeArgumentsToJSON(fileName, self.args) def setStopWordsCutoff(self, stopWordsCutoff): - """ Method to set stopWordsCutoff parameter value + """Method to set stopWordsCutoff parameter value By default, Zingg extracts 10% of the high frequency unique words from a dataset. If user wants different selection, they should set up StopWordsCutoff property :param stopWordsCutoff: The stop words cutoff parameter value of ClientOption object or file address of json file :type stopWordsCutoff: float """ self.args.setStopWordsCutoff(stopWordsCutoff) - + def setColumn(self, column): - """ Method to set stopWordsCutoff parameter value + """Method to set stopWordsCutoff parameter value By default, Zingg extracts 10% of the high frequency unique words from a dataset. If user wants different selection, they should set up StopWordsCutoff property :param stopWordsCutoff: The stop words cutoff parameter value of ClientOption object or file address of json file @@ -532,8 +634,8 @@ def setColumn(self, column): @staticmethod def createArgumentsFromJSON(fileName, phase): - """ Method to create an object of this class from the JSON file and phase parameter value. - + """Method to create an object of this class from the JSON file and phase parameter value. + :param fileName: The CONF parameter value of ClientOption object :type fileName: String :param phase: The PHASE parameter value of ClientOption object @@ -544,11 +646,10 @@ def createArgumentsFromJSON(fileName, phase): obj = Arguments() obj.args = getJVM().zingg.common.client.ArgumentsUtil().createArgumentsFromJSON(fileName, phase) return obj - - + def writeArgumentsToJSONString(self): - """ Method to create an object of this class from the JSON file and phase parameter value. - + """Method to create an object of this class from the JSON file and phase parameter value. + :param fileName: The CONF parameter value of ClientOption object :type fileName: String :param phase: The PHASE parameter value of ClientOption object @@ -557,30 +658,27 @@ def writeArgumentsToJSONString(self): :rtype: pointer(Arguments) """ return getJVM().zingg.common.client.ArgumentsUtil().writeArgumentstoJSONString(self.args) - + @staticmethod def createArgumentsFromJSONString(jsonArgs, phase): obj = Arguments() obj.args = getJVM().zingg.common.client.ArgumentsUtil().createArgumentsFromJSONString(jsonArgs, phase) return obj - - + def copyArgs(self, phase): argsString = self.writeArgumentsToJSONString() return self.createArgumentsFromJSONString(argsString, phase) - - - class ClientOptions: - """ Class that contains Client options for Zingg object + """Class that contains Client options for Zingg object :param phase: trainMatch, train, match, link, findAndLabel, findTrainingData, recommend etc :type phase: String :param args: Parse a list of Zingg command line options parameter values e.g. "--location" etc. optional argument for initializing this class. :type args: List(String) or None """ - PHASE = getJVM().zingg.common.client.ClientOptions.PHASE + + PHASE = getJVM().zingg.common.client.ClientOptions.PHASE """:PHASE: phase parameter for this class""" CONF = getJVM().zingg.common.client.ClientOptions.CONF """:CONF: conf parameter for this class""" @@ -601,28 +699,27 @@ class ClientOptions: def __init__(self, argsSent=None): print(argsSent) - if(argsSent == None): + if argsSent == None: args = [] else: args = argsSent.copy() - if (not (self.PHASE in args)): + if self.PHASE not in args: args.append(self.PHASE) args.append("peekModel") - if (not (self.LICENSE in args)): + if self.LICENSE not in args: args.append(self.LICENSE) args.append("zinggLic.txt") - if (not (self.EMAIL in args)): + if self.EMAIL not in args: args.append(self.EMAIL) args.append("zingg@zingg.ai") - if (not (self.CONF in args)): + if self.CONF not in args: args.append(self.CONF) args.append("dummyConf.json") - print("arguments for client options are ", args) + print("arguments for client options are ", args) self.co = getJVM().zingg.common.client.ClientOptions(args) - - + def getClientOptions(self): - """ Method to get pointer address of this class + """Method to get pointer address of this class :return: The pointer containing address of the this class object :rtype: pointer(ClientOptions) @@ -630,17 +727,17 @@ def getClientOptions(self): return self.co def getOptionValue(self, option): - """ Method to get value for the key option + """Method to get value for the key option :param option: key to geting the value :type option: String - :return: The value which is mapped for given key - :rtype: String + :return: The value which is mapped for given key + :rtype: String """ return self.co.getOptionValue(option) def setOptionValue(self, option, value): - """ Method to map option key to the given value + """Method to map option key to the given value :param option: key that is mapped with value :type option: String @@ -650,53 +747,53 @@ def setOptionValue(self, option, value): self.co.get(option).setValue(value) def getPhase(self): - """ Method to get PHASE value + """Method to get PHASE value :return: The PHASE parameter value - :rtype: String + :rtype: String """ return self.co.get(ClientOptions.PHASE).getValue() def setPhase(self, newValue): - """ Method to set PHASE value + """Method to set PHASE value :param newValue: name of the phase :type newValue: String :return: The pointer containing address of the this class object after seting phase - :rtype: pointer(ClientOptions) + :rtype: pointer(ClientOptions) """ self.co.get(ClientOptions.PHASE).setValue(newValue) def getConf(self): - """ Method to get CONF value + """Method to get CONF value :return: The CONF parameter value - :rtype: String + :rtype: String """ return self.co.get(ClientOptions.CONF).getValue() def hasLocation(self): - """ Method to check if this class has LOCATION parameter set as None or not + """Method to check if this class has LOCATION parameter set as None or not :return: The boolean value if LOCATION parameter is present or not - :rtype: Bool + :rtype: Bool """ - if(self.co.get(ClientOptions.LOCATION)==None): + if self.co.get(ClientOptions.LOCATION) == None: return False else: return True def getLocation(self): - """ Method to get LOCATION value + """Method to get LOCATION value :return: The LOCATION parameter value - :rtype: String + :rtype: String """ return self.co.get(ClientOptions.LOCATION).getValue() class FieldDefinition: - """ This class defines each field that we use in matching We can use this to configure the properties of each field we use for matching in Zingg. + """This class defines each field that we use in matching We can use this to configure the properties of each field we use for matching in Zingg. :param name: name of the field :type name: String @@ -712,9 +809,9 @@ def __init__(self, name, dataType, *matchType): self.fd.setDataType(self.stringify(dataType)) self.fd.setMatchType(matchType) self.fd.setFields(name) - + def setStopWords(self, stopWords): - """ Method to add stopwords to this class object + """Method to add stopwords to this class object :param stopWords: The stop Words containing csv file's location :type stopWords: String @@ -722,7 +819,7 @@ def setStopWords(self, stopWords): self.fd.setStopWords(stopWords) def getFieldDefinition(self): - """ Method to get pointer address of this class + """Method to get pointer address of this class :return: The pointer containing the address of this class object :rtype: pointer(FieldDefinition) @@ -731,31 +828,33 @@ def getFieldDefinition(self): # should be stringify'ed before it is set in fd object def stringify(self, str): - """ Method to stringify'ed the dataType before it is set in FieldDefinition object - + """Method to stringify'ed the dataType before it is set in FieldDefinition object + :param str: dataType of the FieldDefinition :type str: String :return: The stringify'ed value of the dataType :rtype: String """ - + return str - + def parseArguments(argv): - """ This method is used for checking mandatory arguments and creating an arguments list from Command line arguments + """This method is used for checking mandatory arguments and creating an arguments list from Command line arguments :param argv: Values that are passed during the calling of the program along with the calling statement. :type argv: List :return: a list containing necessary arguments to run any phase :rtype: List """ - parser = argparse.ArgumentParser(description='Zingg\'s python APIs') - mandatoryOptions = parser.add_argument_group('mandatory arguments') - mandatoryOptions.add_argument('--phase', required=True, - help='python phase e.g. assessModel') - mandatoryOptions.add_argument('--conf', required=True, - help='JSON configuration with data input output locations and field definitions') + parser = argparse.ArgumentParser(description="Zingg's python APIs") + mandatoryOptions = parser.add_argument_group("mandatory arguments") + mandatoryOptions.add_argument("--phase", required=True, help="python phase e.g. assessModel") + mandatoryOptions.add_argument( + "--conf", + required=True, + help="JSON configuration with data input output locations and field definitions", + ) args, remaining_args = parser.parse_known_args(argv) LOG.debug("args: ", args) diff --git a/python/zingg/databricks.py b/python/zingg/databricks.py index 14c21b28e..54404f989 100644 --- a/python/zingg/databricks.py +++ b/python/zingg/databricks.py @@ -29,7 +29,7 @@ "job_cluster_key": "_cluster", "libraries": [ { - "whl":"dbfs:/FileStore/py/zingg-0.4.0-py2.py3-none-any.whl" + "whl":"dbfs:/FileStore/py/zingg-0.4.1-py2.py3-none-any.whl" }, { "pypi": { @@ -37,7 +37,7 @@ } }, { - "jar": "dbfs:/FileStore/jars/zingg_0_4_0.jar" + "jar": "dbfs:/FileStore/jars/zingg_0_4_1_SNAPSHOT.jar" } ], "timeout_seconds": 0, diff --git a/python/zingg_v2/__init__.py b/python/zingg_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/zingg_v2/client.py b/python/zingg_v2/client.py new file mode 100644 index 000000000..32edda143 --- /dev/null +++ b/python/zingg_v2/client.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import json +import os +import warnings +from dataclasses import fields +from typing import Optional, Sequence, Union + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame +from pyspark.sql.connect.session import SparkSession as ConnectSession + +from zingg_v2 import models as models_v2 +from zingg_v2.connect import ZinggJob +from zingg_v2.errors import ZinggArgumentsValidationError, ZinggSparkConnectEmptySession +from zingg_v2.pipes import Pipe + + +class Zingg: + def __init__(self, args: Arguments, options: ClientOptions) -> None: + self.args = args + self.options = options + if os.environ["ZINGG_SPARK_CONNECT"]: + self.spark = ConnectSession.getActiveSession() + if self.spark is None: + _err_msg = "SparkConnect mode was choosen but spark session was not created!" + _err_msg += "\nYou have to initialize SparkConnectSession before creating Zingg!" + raise ZinggSparkConnectEmptySession(_err_msg) + else: + self.spark = SparkSession.getActiveSession() + if self.spark is None: + _warn_msg = "Spark Session is not initialized in the current thread!" + _warn_msg += " It is strongly reccomend to init SparkSession manually!" + warnings.warn(_warn_msg) + self.spark = SparkSession.builder.getOrCreate() + + def execute(self) -> Zingg: + # TODO: implement it + # java_args: arguments in form of string + # that is pairs of --key value + java_args = self.options.getClientOptions() + + # java_job_definition is JSON definition of Zingg Job + java_job_definition = self.args.writeArgumentsToJSONString() + + spark_connect = not hasattr(self.spark, "_jvm") + + if spark_connect: + _log_msg = "Submitting a Zingg Job\n" + _log_msg += f"Arguments: {java_args}\n\n" + _log_msg += java_job_definition + _log_msg += "\n\n" + print(java_job_definition) + df = ConnectDataFrame.withPlan( + ZinggJob(zingg_args=java_args, zingg_job=java_job_definition), self.spark + ) + output = df.collect()[0].asDict() + status: str = output["status"] + new_args: str = output["newArgs"] + + else: + # There are errors that should be fixed! :TODO + # TODO: Put that logic into Java by creating an entry point for Python API? + j_options = self.spark._jvm.zingg.common.client.ClientOptions(java_args) + j_args = self.spark._jvm.zingg.common.client.ArgumentsUtil.createArgumentsFromJSONString( + java_job_definition, + self.options.getPhase(), + ) + client = self.spark._jvm.zingg.spark.client( + j_args, + j_options, + self.spark._jsci, + ) + client.init() + client.execute() + client.postMetrics() + + status = "SUCCESS" + new_args: str = self.spark._jvm.zingg.client.ArgumentsUtil.writeArgumentstoJSONString( + client.getArguments() + ) + + print(f"Zingg Job output status: {status}") + + return Zingg( + Arguments.createArgumentsFromJSONString(new_args, self.options.getPhase()), + self.options.make_copy(), + ) + + def executeLabel(self) -> None: + raise NotImplementedError() + + def executeLabelUpdate(self) -> None: + raise NotImplementedError() + + def getMarkedRecords(self) -> Union[DataFrame, ConnectDataFrame]: + marked_path = self.args.getZinggTrainingDataMarkedDir() + marked = self.spark.read.parquet(marked_path) + return marked + + def getUnmarkedRecords(self) -> Union[DataFrame, ConnectDataFrame]: + unmarked_path = self.args.getZinggTrainingDataUnmarkedDir() + unmarked = self.spark.read.parquet(unmarked_path) + return unmarked + + def processRecordsCli( + self, unmarkedRecords: Union[DataFrame, ConnectDataFrame], args: Arguments + ) -> Union[DataFrame, ConnectDataFrame]: + raise NotImplementedError() + + def processRecordsCliLabelUpdate(self, lines, args): + raise NotImplementedError() + + def writeLabelledOutput(self, updatedRecords, args): + raise NotImplementedError() + + def writeLabelledOutputFromPandas(self, candidate_pairs_pd, args): + raise NotImplementedError() + + def setArguments(self, args: Arguments) -> None: + self.args = args + + def getArguments(self) -> Arguments: + return self.args + + def getOptions(self) -> ClientOptions: + return self.options + + def setOptions(self, options: ClientOptions) -> None: + self.options = options + + def getMarkedRecordsStat(self, markedRecords, value): + raise NotImplementedError() + + def getMatchedMarkedRecordsStat(self): + raise NotImplementedError() + + def getUnmatchedMarkedRecordsStat(self): + raise NotImplementedError() + + def getUnsureMarkedRecordsStat(self): + raise NotImplementedError() + + +class FieldDefinition: + def __init__(self, name: str, dataType: str, *matchType: Union[str, models_v2.MatchType]) -> None: + match_types = [] + for mt in matchType: + if not isinstance(mt, models_v2.MatchType): + mt = models_v2.MatchType(mt) + + self._model_v2 = models_v2.FieldDefinition( + fieldName=name, fields=name, dataType=dataType, matchType=match_types + ) + + def setStopWords(self, stopWords: str) -> None: + self._model_v2.stopWords = stopWords + + def getFieldDefinition(self) -> str: + return self._model_v2.model_dump_json() + + def to_v2(self) -> models_v2.FieldDefinition: + return self._model_v2 + + +class ClientOptions: + def __init__(self, argsSent: Optional[Sequence[str]]) -> None: + if argsSent is None: + args = [] + self._opt_v2 = models_v2.ClientOptions() + else: + args = [a for a in argsSent] + self._opt_v2 = models_v2.ClientOptions(**{k: v for k, v in zip(args[:-1], args[1:])}) + print("arguments for client options are ", self._opt_v2.to_java_args()) + + def getClientOptions(self) -> str: + return " ".join(self._opt_v2.to_java_args()) + + def getOptionValue(self, option: str) -> str: + if option.startswith("--"): + option = option[2:] + + if not hasattr(self._opt_v2, option): + _msg = "Wrong option; possible options are: " + _msg += ", ".join(f.name for f in fields(self._opt_v2)) + raise KeyError(_msg) + + return getattr(self._opt_v2, option) + + def setOptionValue(self, option: str, value: str) -> None: + if option.startswith("--"): + option = option[2:] + + if not hasattr(self._opt_v2, option): + _msg = "Wrong option; possible options are: " + _msg += ", ".join(f.name for f in fields(self._opt_v2)) + raise KeyError(_msg) + + setattr(self._opt_v2, option, value) + + def getPhase(self) -> str: + return self._opt_v2.phase + + def setPhase(self, newValue: str) -> None: + self._opt_v2.phase = newValue + + def getConf(self) -> str: + return self._opt_v2.conf + + def hasLocation(self) -> bool: + return self._opt_v2.location is None + + def getLocation(self) -> Optional[str]: + return self._opt_v2.location + + def to_v2(self) -> models_v2.ClientOptions: + return self._opt_v2 + + def make_copy(self) -> ClientOptions: + return ClientOptions(self._opt_v2.to_java_args()) + + +class Arguments: + def __init__(self): + self._args_v2 = models_v2.Arguments() + + @staticmethod + def _from_v2(arguments_v2: models_v2.Arguments) -> "Arguments": + new_obj = Arguments() + new_obj._args_v2 = arguments_v2 + return new_obj + + def setFieldDefinition(self, fieldDef: list[FieldDefinition]) -> None: + self._args_v2.fieldDefinition = [fd.to_v2() for fd in fieldDef] + + def setData(self, *pipes: Pipe) -> None: + self._args_v2.data = [pp.to_v2() for pp in pipes] + + def setOutput(self, *pipes: Pipe) -> None: + self._args_v2.output = [pp.to_v2() for pp in pipes] + + def getZinggBaseModelDir(self) -> str: + if isinstance(self._args_v2.modelId, int): + model_id = str(self._args_v2.modelId) + else: + model_id = self._args_v2.modelId + + return os.path.join( + self._args_v2.zinggDir, + model_id, + ) + + def getZinggModelDir(self) -> str: + return os.path.join(self.getZinggBaseModelDir(), "model") + + def getZinggBaseTrainingDataDir(self): + return os.path.join( + self.getZinggBaseModelDir(), + "trainingData", + ) + + def getZinggTrainingDataUnmarkedDir(self) -> str: + return os.path.join( + self.getZinggBaseTrainingDataDir(), + "unmarked", + ) + + def getZinggTrainingDataMarkedDir(self) -> str: + return os.path.join( + self.getZinggBaseTrainingDataDir(), + "marked", + ) + + def setTrainingSamples(self, *pipes: Pipe) -> None: + self._args_v2.trainingSamples = [pp.to_v2() for pp in pipes] + + def setModelId(self, id: str) -> None: + self._args_v2.modelId = id + + def getModelId(self): + return self._args_v2.modelId + + def setZinggDir(self, f: str) -> None: + self._args_v2.zinggDir = f + + def setNumPartitions(self, numPartitions: int) -> None: + self._args_v2.numPartitions = numPartitions + + def setLabelDataSampleSize(self, labelDataSampleSize: float) -> None: + self._args_v2.labelDataSampleSize = labelDataSampleSize + + def writeArgumentsToJSON(self, fileName: str) -> None: + with open(fileName, "w") as f_: + json.dump( + self._args_v2.model_dump_json(), + f_, + ) + + def setStopWordsCutoff(self, stopWordsCutoff: float) -> None: + self._args_v2.stopWordsCutoff = stopWordsCutoff + + def setColumn(self, column: str): + self._args_v2.column = column + + @staticmethod + def createArgumentsFromJSON(fileName: str, phase: str) -> "Arguments": + with open(fileName, "r") as f_: + json_string = json.load(f_) + + return Arguments.createArgumentsFromJSONString(json_string, phase) + + def writeArgumentsToJSONString(self) -> str: + return self._args_v2.model_dump_json() + + @staticmethod + def createArgumentsFromJSONString(jsonArgs: str, phase: str): + args_v2 = models_v2.Arguments.model_validate(jsonArgs) + + if not args_v2.validate_phase(phase): + raise ZinggArgumentsValidationError("Wrong args for the given phase") + + return Arguments._from_v2(args_v2) + + def copyArgs(self, phase): + argsString = self.writeArgumentsToJSONString() + return self.createArgumentsFromJSONString(argsString, phase) + + def to_v2(self) -> models_v2.Arguments: + return self._args_v2 diff --git a/python/zingg_v2/connect.py b/python/zingg_v2/connect.py new file mode 100644 index 000000000..69071e0f3 --- /dev/null +++ b/python/zingg_v2/connect.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pyspark.sql.connect import proto +from pyspark.sql.connect.plan import LogicalPlan + +if TYPE_CHECKING: + from pyspark.sql.connect.client import SparkConnectClient + +from zingg_v2.proto.connect_plugins_pb2 import SubmitZinggJob + + +class ZinggJob(LogicalPlan): + def __init__(self, zingg_args: str, zingg_job: str) -> None: + super().__init__(None) + self._args = zingg_args + self._job_json = zingg_job + + def plan(self, session: SparkConnectClient) -> proto.Relation: + plan = self._create_proto_relation() + zingg_submit = SubmitZinggJob(args=self._args, options=self._job_json) + plan.extension.Pack(zingg_submit) + + return plan diff --git a/python/zingg_v2/errors.py b/python/zingg_v2/errors.py new file mode 100644 index 000000000..4ba7794ff --- /dev/null +++ b/python/zingg_v2/errors.py @@ -0,0 +1,8 @@ +class ZinggArgumentsValidationError(ValueError): + pass + +class ZinggParameterIsNotSet(ValueError): + pass + +class ZinggSparkConnectEmptySession(ValueError): + pass diff --git a/python/zingg_v2/models.py b/python/zingg_v2/models.py new file mode 100644 index 000000000..2b7e9519d --- /dev/null +++ b/python/zingg_v2/models.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import itertools +from dataclasses import asdict, dataclass +from enum import StrEnum, auto +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field, ValidationError, field_validator + + +class MatchType(StrEnum): + FUZZY = auto() + EXACT = auto() + DONT_USE = auto() + EMAIL = auto() + PINCODE = auto() + NULL_OR_BLANK = auto() + TEXT = auto() + NUMERIC = auto() + NUMERIC_WITH_UNITS = auto() + ONLY_ALPHABETS_EXACT = auto() + ONLY_ALPHABETS_FUZZY = auto() + + +class DataFormat(StrEnum): + CSV = auto() + PARQUET = auto() + JSON = auto() + TEXT = auto() + XLS = "com.crealytics.spark.excel" + AVRO = auto() + JDBC = auto() + CASSANDRA = "org.apache.spark.sql.cassandra" + SNOWFLAKE = "net.snowflake.spark.snowflake" + ELASTIC = "org.elasticsearch.spark.sql" + EXACOL = "com.exasol.spark" + BIGQUERY = auto() + INMEMORY = auto() + + +class FieldDefinition(BaseModel): + matchType: Union[MatchType, list[MatchType]] + dataType: str + fieldName: str + fields: str + stopWords: Optional[str] = None + abbreviations: Optional[str] = None + + +class Pipe(BaseModel): + name: str + format: DataFormat + props: dict[str, Any] = {} + # "schema" is a built in attribute of BaseModel + # that is why we need that alias: + schema_field: Optional[str] = Field(default=None, alias="schema") + mode: Optional[str] = None + + +class Arguments(BaseModel): + output: Optional[list[Pipe]] = None + data: Optional[list[Pipe]] = None + zinggDir: str = "/tmp/zingg" + trainingSamples: Optional[list[Pipe]] = None + fieldDefinition: Optional[list[FieldDefinition]] = None + numPartitions: int = 10 + labelDataSampleSize: float = 0.01 + modelId: Union[str, int] = "1" + threshold: float = 0.5 + jobId: int = 1 + collectMetrics: bool = True + showConcise: bool = False + stopWordsCutoff: float = 0.1 + blockSize: int = 100 + column: Optional[str] = None + + @field_validator("numPartitions") + @classmethod + def validate_num_partitions(cls, v: int) -> int: + if (v != -1) and (v <= 0): + _err_msg = "Number of partitions can be greater than 0 for user specified partitioning or equal to -1 for system decided partitioning" + raise ValidationError(_err_msg) + + return v + + @field_validator("labelDataSampleSize", "stopWordsCutoff") + @classmethod + def validate_relative_size(cls, v: float) -> float: + if (v > 1) or (v < 0): + _err_msg = "Label Data Sample Size should be between 0 and 1" + raise ValidationError(_err_msg) + + return v + + def validate_phase(self, phase: str) -> bool: + is_valid = True + if phase in ["train", "match", "trainMatch", "link"]: + is_valid &= self.trainingSamples is not None + is_valid &= self.data is not None + is_valid &= self.numPartitions is not None + is_valid &= self.fieldDefinition is not None + + elif phase in ["seed", "seedDB"]: + is_valid &= self.data is not None + is_valid &= self.numPartitions is not None + is_valid &= self.fieldDefinition is not None + + elif phase != "WEB": + is_valid &= self.data is not None + is_valid &= self.numPartitions is not None + + return is_valid + + +@dataclass +class ClientOptions: + phase: str = "peekModel" + license: str = "zinggLic.txt" + email: str = "zingg@zingg.ai" + conf: str = "dummyConf.json" + preprocess: Optional[str] = None + jobId: Optional[str] = None + format: Optional[str] = None + zinggDir: Optional[str] = None + modelId: Optional[str] = None + collectMetrics: Optional[str] = None + showConcise: Optional[str] = None + location: Optional[str] = None + column: Optional[str] = None + remote: Optional[str] = None + + def to_java_args(self) -> list[str]: + return list( + itertools.chain.from_iterable( + [[f"--{key}", value] for key, value in asdict(self).items() if value is not None] + ) + ) diff --git a/python/zingg_v2/pipes.py b/python/zingg_v2/pipes.py new file mode 100644 index 000000000..60135fe1e --- /dev/null +++ b/python/zingg_v2/pipes.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import json +import warnings +from typing import Optional, Union + +from pandas import DataFrame as PDataFrame +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType + +from zingg_v2 import models as models_v2 + + +class Pipe: + def __init__(self, name: str, format: Union[str, models_v2.DataFormat]) -> None: + if not isinstance(format, models_v2.DataFormat): + format = models_v2.DataFormat(format) + self._pipe_v2 = models_v2.Pipe(name=name, format=format) + + def getPipe(self) -> str: + return self.toString() + + def addProperty(self, name: str, value: str) -> None: + self._pipe_v2.props[name] = value + + def setSchema(self, schema: str) -> None: + self.schema = schema + + def toString(self) -> str: + return json.dumps(self._pipe_v2.model_dump_json()) + + def to_v2(self) -> models_v2.Pipe: + return self._pipe_v2 + + +class CsvPipe(Pipe): + def __init__(self, name: str, location: Optional[str] = None, schema: Optional[str] = None) -> None: + super().__init__(name, models_v2.DataFormat.CSV) + if schema is not None: + self.setSchema(schema) + if location is not None: + self.addProperty("location", location) + + def setDelimiter(self, delimiter: str) -> None: + self.addProperty("delimiter", delimiter) + + def setLocation(self, location: str) -> None: + self.addProperty("location", location) + + def setHeader(self, header: str) -> None: + self.addProperty("header", header) + + +class BigQueryPipe(Pipe): + def __init__(self, name: str) -> None: + super().__init__(name, models_v2.DataFormat.BIGQUERY) + + def setCredentialFile(self, credentials_file: str) -> None: + self.addProperty("credentialsFile", credentials_file) + + def setTable(self, table: str) -> None: + self.addProperty("table", table) + + def setTemporaryGcsBucket(self, bucket: str) -> None: + self.addProperty("temporaryGcsBucket", bucket) + + def setViewsEnabled(self, isEnabled: bool) -> None: + self.addProperty("viewsEnabled", "true" if isEnabled else "false") + + +class SnowflakePipe(Pipe): + def __init__(self, name: str) -> None: + super().__init__(name, models_v2.DataFormat.SNOWFLAKE) + self.addProperty("application", "zinggai_zingg") + + def setUrl(self, url: str) -> None: + self.addProperty("sfUrl", url) + + def setUser(self, user: str) -> None: + self.addProperty("sfUser", user) + + def setPassword(self, passwd: str) -> None: + self.addProperty("sfPassword", passwd) + + def setDatabase(self, db: str) -> None: + self.addProperty("sfDatabase", db) + + def setSFSchema(self, schema: str) -> None: + self.addProperty("sfSchema", schema) + + def setWarehouse(self, warehouse: str) -> None: + self.addProperty("sfWarehouse", warehouse) + + def setDbTable(self, dbtable: str) -> None: + self.addProperty("dbtable", dbtable) + + +class InMemoryPipe(Pipe): + def __init__(self, name: str, df: Optional[Union[DataFrame, PDataFrame]] = None) -> None: + super().__init__(name, models_v2.DataFormat.INMEMORY) + self.df: Optional[DataFrame] = None + if df is not None: + self.setDataset(df) + + def setDataset(self, df: Union[DataFrame, PDataFrame]) -> None: + if isinstance(df, PDataFrame): + spark = SparkSession.getActiveSession() + if spark is None: + warnings.warn("No active Session Found!") + spark = SparkSession.builder.getOrCreate() + + if self.schema is None: + df = spark.createDataFrame(df) + else: + df = spark.createDataFrame(df, schema=StructType.fromJson(json.loads(self.schema))) + + self.df = df + + def getDataset(self) -> DataFrame: + if self.df is None: + raise ValueError("DataFrame is not set!") + + return self.df diff --git a/python/zingg_v2/proto/connect_plugins_pb2.py b/python/zingg_v2/proto/connect_plugins_pb2.py new file mode 100644 index 000000000..18cdd505a --- /dev/null +++ b/python/zingg_v2/proto/connect_plugins_pb2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: connect_plugins.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x63onnect_plugins.proto\"\xa9\x01\n\x0eSubmitZinggJob\x12(\n\targumnets\x18\x01 \x01(\x0b\x32\n.ArgumentsR\targumnets\x12/\n\x0b\x63li_options\x18\x02 \x01(\x0b\x32\x0e.ClientOptionsR\ncliOptions\x12)\n\x0ein_memory_date\x18\x03 \x01(\x0cH\x00R\x0cinMemoryDate\x88\x01\x01\x42\x11\n\x0f_in_memory_date\"\x80\x02\n\x0f\x46ieldDefinition\x12)\n\nmatch_type\x18\x01 \x03(\x0e\x32\n.MatchTypeR\tmatchType\x12\x1b\n\tdata_type\x18\x02 \x01(\tR\x08\x64\x61taType\x12\x1d\n\nfield_name\x18\x03 \x01(\tR\tfieldName\x12\x16\n\x06\x66ields\x18\x04 \x01(\tR\x06\x66ields\x12\"\n\nstop_words\x18\x05 \x01(\tH\x00R\tstopWords\x88\x01\x01\x12)\n\rabbreviations\x18\x06 \x01(\tH\x01R\rabbreviations\x88\x01\x01\x42\r\n\x0b_stop_wordsB\x10\n\x0e_abbreviations\"\xfc\x01\n\x04Pipe\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12#\n\x06\x66ormat\x18\x02 \x01(\x0e\x32\x0b.DataFormatR\x06\x66ormat\x12&\n\x05props\x18\x03 \x03(\x0b\x32\x10.Pipe.PropsEntryR\x05props\x12&\n\x0cschema_field\x18\x04 \x01(\tH\x00R\x0bschemaField\x88\x01\x01\x12\x17\n\x04mode\x18\x05 \x01(\tH\x01R\x04mode\x88\x01\x01\x1a\x38\n\nPropsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0f\n\r_schema_fieldB\x07\n\x05_mode\"\xbe\x04\n\tArguments\x12\x1d\n\x06output\x18\x01 \x03(\x0b\x32\x05.PipeR\x06output\x12\x19\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x05.PipeR\x04\x64\x61ta\x12\x1b\n\tzingg_dir\x18\x03 \x01(\tR\x08zinggDir\x12\x30\n\x10training_samples\x18\x04 \x03(\x0b\x32\x05.PipeR\x0ftrainingSamples\x12=\n\x11\x66iield_definition\x18\x05 \x03(\x0b\x32\x10.FieldDefinitionR\x10\x66iieldDefinition\x12%\n\x0enum_partitions\x18\x06 \x01(\x05R\rnumPartitions\x12\x33\n\x16label_data_sample_size\x18\x07 \x01(\x02R\x13labelDataSampleSize\x12\x19\n\x08model_id\x18\x08 \x01(\tR\x07modelId\x12\x1c\n\tthreshold\x18\t \x01(\x02R\tthreshold\x12\x15\n\x06job_id\x18\n \x01(\x05R\x05jobId\x12\'\n\x0f\x63ollect_metrics\x18\x0b \x01(\x08R\x0e\x63ollectMetrics\x12!\n\x0cshow_concise\x18\x0c \x01(\x08R\x0bshowConcise\x12*\n\x11stop_words_cutoff\x18\r \x01(\x02R\x0fstopWordsCutoff\x12\x1d\n\nblock_size\x18\x0e \x01(\x03R\tblockSize\x12\x1b\n\x06\x63olumn\x18\x0f \x01(\tH\x00R\x06\x63olumn\x88\x01\x01\x42\t\n\x07_column\"\xff\x04\n\rClientOptions\x12\x19\n\x05phase\x18\x01 \x01(\tH\x00R\x05phase\x88\x01\x01\x12\x1d\n\x07license\x18\x02 \x01(\tH\x01R\x07license\x88\x01\x01\x12\x19\n\x05\x65mail\x18\x03 \x01(\tH\x02R\x05\x65mail\x88\x01\x01\x12\x17\n\x04\x63onf\x18\x04 \x01(\tH\x03R\x04\x63onf\x88\x01\x01\x12#\n\npreprocess\x18\x05 \x01(\tH\x04R\npreprocess\x88\x01\x01\x12\x1a\n\x06job_id\x18\x06 \x01(\tH\x05R\x05jobId\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x07 \x01(\tH\x06R\x06\x66ormat\x88\x01\x01\x12 \n\tzingg_dir\x18\x08 \x01(\tH\x07R\x08zinggDir\x88\x01\x01\x12\x1e\n\x08model_id\x18\t \x01(\tH\x08R\x07modelId\x88\x01\x01\x12,\n\x0f\x63ollect_metrics\x18\n \x01(\tH\tR\x0e\x63ollectMetrics\x88\x01\x01\x12&\n\x0cshow_concise\x18\x0b \x01(\tH\nR\x0bshowConcise\x88\x01\x01\x12\x1f\n\x08location\x18\x0c \x01(\tH\x0bR\x08location\x88\x01\x01\x12\x1b\n\x06\x63olumn\x18\r \x01(\tH\x0cR\x06\x63olumn\x88\x01\x01\x12\x1b\n\x06remote\x18\x0e \x01(\tH\rR\x06remote\x88\x01\x01\x42\x08\n\x06_phaseB\n\n\x08_licenseB\x08\n\x06_emailB\x07\n\x05_confB\r\n\x0b_preprocessB\t\n\x07_job_idB\t\n\x07_formatB\x0c\n\n_zingg_dirB\x0b\n\t_model_idB\x12\n\x10_collect_metricsB\x0f\n\r_show_conciseB\x0b\n\t_locationB\t\n\x07_columnB\t\n\x07_remote*\xde\x01\n\tMatchType\x12\x0c\n\x08MT_FUZZY\x10\x00\x12\x0c\n\x08MT_EXACT\x10\x01\x12\x0f\n\x0bMT_DONT_USE\x10\x02\x12\x0c\n\x08MT_EMAIL\x10\x03\x12\x0e\n\nMT_PINCODE\x10\x04\x12\x14\n\x10MT_NULL_OR_BLANK\x10\x05\x12\x0b\n\x07MT_TEXT\x10\x06\x12\x0e\n\nMT_NUMERIC\x10\x07\x12\x19\n\x15MT_NUMERIC_WITH_UNITS\x10\x08\x12\x1b\n\x17MT_ONLY_ALPHABETS_EXACT\x10\t\x12\x1b\n\x17MT_ONLY_ALPHABETS_FUZZY\x10\n*\xcc\x01\n\nDataFormat\x12\n\n\x06\x44\x46_CSV\x10\x00\x12\x0e\n\nDF_PARQUET\x10\x01\x12\x0b\n\x07\x44\x46_JSON\x10\x02\x12\x0b\n\x07\x44\x46_TEXT\x10\x03\x12\n\n\x06\x44\x46_XLS\x10\x04\x12\x0b\n\x07\x44\x46_AVRO\x10\x05\x12\x0b\n\x07\x44\x46_JDBC\x10\x06\x12\x10\n\x0c\x44\x46_CASSANDRA\x10\x07\x12\x10\n\x0c\x44\x46_SNOWFLAKE\x10\x08\x12\x0e\n\nDF_ELASTIC\x10\t\x12\r\n\tDF_EXACOL\x10\n\x12\x0e\n\nDF_BIGQUEY\x10\x0b\x12\x0f\n\x0b\x44\x46_INMEMORY\x10\x0c\x42\x1d\n\x19zingg.spark.connect.protoP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'connect_plugins_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\031zingg.spark.connect.protoP\001' + _globals['_PIPE_PROPSENTRY']._options = None + _globals['_PIPE_PROPSENTRY']._serialized_options = b'8\001' + _globals['_MATCHTYPE']._serialized_start=1931 + _globals['_MATCHTYPE']._serialized_end=2153 + _globals['_DATAFORMAT']._serialized_start=2156 + _globals['_DATAFORMAT']._serialized_end=2360 + _globals['_SUBMITZINGGJOB']._serialized_start=26 + _globals['_SUBMITZINGGJOB']._serialized_end=195 + _globals['_FIELDDEFINITION']._serialized_start=198 + _globals['_FIELDDEFINITION']._serialized_end=454 + _globals['_PIPE']._serialized_start=457 + _globals['_PIPE']._serialized_end=709 + _globals['_PIPE_PROPSENTRY']._serialized_start=627 + _globals['_PIPE_PROPSENTRY']._serialized_end=683 + _globals['_ARGUMENTS']._serialized_start=712 + _globals['_ARGUMENTS']._serialized_end=1286 + _globals['_CLIENTOPTIONS']._serialized_start=1289 + _globals['_CLIENTOPTIONS']._serialized_end=1928 +# @@protoc_insertion_point(module_scope) diff --git a/python/zingg_v2/proto/connect_plugins_pb2.pyi b/python/zingg_v2/proto/connect_plugins_pb2.pyi new file mode 100644 index 000000000..b42895889 --- /dev/null +++ b/python/zingg_v2/proto/connect_plugins_pb2.pyi @@ -0,0 +1,346 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _MatchType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _MatchTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_MatchType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + MT_FUZZY: _MatchType.ValueType # 0 + MT_EXACT: _MatchType.ValueType # 1 + MT_DONT_USE: _MatchType.ValueType # 2 + MT_EMAIL: _MatchType.ValueType # 3 + MT_PINCODE: _MatchType.ValueType # 4 + MT_NULL_OR_BLANK: _MatchType.ValueType # 5 + MT_TEXT: _MatchType.ValueType # 6 + MT_NUMERIC: _MatchType.ValueType # 7 + MT_NUMERIC_WITH_UNITS: _MatchType.ValueType # 8 + MT_ONLY_ALPHABETS_EXACT: _MatchType.ValueType # 9 + MT_ONLY_ALPHABETS_FUZZY: _MatchType.ValueType # 10 + +class MatchType(_MatchType, metaclass=_MatchTypeEnumTypeWrapper): ... + +MT_FUZZY: MatchType.ValueType # 0 +MT_EXACT: MatchType.ValueType # 1 +MT_DONT_USE: MatchType.ValueType # 2 +MT_EMAIL: MatchType.ValueType # 3 +MT_PINCODE: MatchType.ValueType # 4 +MT_NULL_OR_BLANK: MatchType.ValueType # 5 +MT_TEXT: MatchType.ValueType # 6 +MT_NUMERIC: MatchType.ValueType # 7 +MT_NUMERIC_WITH_UNITS: MatchType.ValueType # 8 +MT_ONLY_ALPHABETS_EXACT: MatchType.ValueType # 9 +MT_ONLY_ALPHABETS_FUZZY: MatchType.ValueType # 10 +global___MatchType = MatchType + +class _DataFormat: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _DataFormatEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_DataFormat.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + DF_CSV: _DataFormat.ValueType # 0 + DF_PARQUET: _DataFormat.ValueType # 1 + DF_JSON: _DataFormat.ValueType # 2 + DF_TEXT: _DataFormat.ValueType # 3 + DF_XLS: _DataFormat.ValueType # 4 + DF_AVRO: _DataFormat.ValueType # 5 + DF_JDBC: _DataFormat.ValueType # 6 + DF_CASSANDRA: _DataFormat.ValueType # 7 + DF_SNOWFLAKE: _DataFormat.ValueType # 8 + DF_ELASTIC: _DataFormat.ValueType # 9 + DF_EXACOL: _DataFormat.ValueType # 10 + DF_BIGQUEY: _DataFormat.ValueType # 11 + DF_INMEMORY: _DataFormat.ValueType # 12 + +class DataFormat(_DataFormat, metaclass=_DataFormatEnumTypeWrapper): ... + +DF_CSV: DataFormat.ValueType # 0 +DF_PARQUET: DataFormat.ValueType # 1 +DF_JSON: DataFormat.ValueType # 2 +DF_TEXT: DataFormat.ValueType # 3 +DF_XLS: DataFormat.ValueType # 4 +DF_AVRO: DataFormat.ValueType # 5 +DF_JDBC: DataFormat.ValueType # 6 +DF_CASSANDRA: DataFormat.ValueType # 7 +DF_SNOWFLAKE: DataFormat.ValueType # 8 +DF_ELASTIC: DataFormat.ValueType # 9 +DF_EXACOL: DataFormat.ValueType # 10 +DF_BIGQUEY: DataFormat.ValueType # 11 +DF_INMEMORY: DataFormat.ValueType # 12 +global___DataFormat = DataFormat + +@typing_extensions.final +class SubmitZinggJob(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ARGUMNETS_FIELD_NUMBER: builtins.int + CLI_OPTIONS_FIELD_NUMBER: builtins.int + IN_MEMORY_DATE_FIELD_NUMBER: builtins.int + @property + def argumnets(self) -> global___Arguments: ... + @property + def cli_options(self) -> global___ClientOptions: ... + in_memory_date: builtins.bytes + """The next message is a serialized LogicalPlan""" + def __init__( + self, + *, + argumnets: global___Arguments | None = ..., + cli_options: global___ClientOptions | None = ..., + in_memory_date: builtins.bytes | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_in_memory_date", b"_in_memory_date", "argumnets", b"argumnets", "cli_options", b"cli_options", "in_memory_date", b"in_memory_date"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_in_memory_date", b"_in_memory_date", "argumnets", b"argumnets", "cli_options", b"cli_options", "in_memory_date", b"in_memory_date"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["_in_memory_date", b"_in_memory_date"]) -> typing_extensions.Literal["in_memory_date"] | None: ... + +global___SubmitZinggJob = SubmitZinggJob + +@typing_extensions.final +class FieldDefinition(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MATCH_TYPE_FIELD_NUMBER: builtins.int + DATA_TYPE_FIELD_NUMBER: builtins.int + FIELD_NAME_FIELD_NUMBER: builtins.int + FIELDS_FIELD_NUMBER: builtins.int + STOP_WORDS_FIELD_NUMBER: builtins.int + ABBREVIATIONS_FIELD_NUMBER: builtins.int + @property + def match_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___MatchType.ValueType]: ... + data_type: builtins.str + field_name: builtins.str + fields: builtins.str + stop_words: builtins.str + abbreviations: builtins.str + def __init__( + self, + *, + match_type: collections.abc.Iterable[global___MatchType.ValueType] | None = ..., + data_type: builtins.str = ..., + field_name: builtins.str = ..., + fields: builtins.str = ..., + stop_words: builtins.str | None = ..., + abbreviations: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_abbreviations", b"_abbreviations", "_stop_words", b"_stop_words", "abbreviations", b"abbreviations", "stop_words", b"stop_words"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_abbreviations", b"_abbreviations", "_stop_words", b"_stop_words", "abbreviations", b"abbreviations", "data_type", b"data_type", "field_name", b"field_name", "fields", b"fields", "match_type", b"match_type", "stop_words", b"stop_words"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_abbreviations", b"_abbreviations"]) -> typing_extensions.Literal["abbreviations"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_stop_words", b"_stop_words"]) -> typing_extensions.Literal["stop_words"] | None: ... + +global___FieldDefinition = FieldDefinition + +@typing_extensions.final +class Pipe(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing_extensions.final + class PropsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]) -> None: ... + + NAME_FIELD_NUMBER: builtins.int + FORMAT_FIELD_NUMBER: builtins.int + PROPS_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_FIELD_NUMBER: builtins.int + MODE_FIELD_NUMBER: builtins.int + name: builtins.str + format: global___DataFormat.ValueType + @property + def props(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... + schema_field: builtins.str + mode: builtins.str + def __init__( + self, + *, + name: builtins.str = ..., + format: global___DataFormat.ValueType = ..., + props: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + schema_field: builtins.str | None = ..., + mode: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_mode", b"_mode", "_schema_field", b"_schema_field", "mode", b"mode", "schema_field", b"schema_field"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_mode", b"_mode", "_schema_field", b"_schema_field", "format", b"format", "mode", b"mode", "name", b"name", "props", b"props", "schema_field", b"schema_field"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_mode", b"_mode"]) -> typing_extensions.Literal["mode"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_schema_field", b"_schema_field"]) -> typing_extensions.Literal["schema_field"] | None: ... + +global___Pipe = Pipe + +@typing_extensions.final +class Arguments(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + OUTPUT_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + ZINGG_DIR_FIELD_NUMBER: builtins.int + TRAINING_SAMPLES_FIELD_NUMBER: builtins.int + FIIELD_DEFINITION_FIELD_NUMBER: builtins.int + NUM_PARTITIONS_FIELD_NUMBER: builtins.int + LABEL_DATA_SAMPLE_SIZE_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + THRESHOLD_FIELD_NUMBER: builtins.int + JOB_ID_FIELD_NUMBER: builtins.int + COLLECT_METRICS_FIELD_NUMBER: builtins.int + SHOW_CONCISE_FIELD_NUMBER: builtins.int + STOP_WORDS_CUTOFF_FIELD_NUMBER: builtins.int + BLOCK_SIZE_FIELD_NUMBER: builtins.int + COLUMN_FIELD_NUMBER: builtins.int + @property + def output(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Pipe]: ... + @property + def data(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Pipe]: ... + zingg_dir: builtins.str + @property + def training_samples(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Pipe]: ... + @property + def fiield_definition(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___FieldDefinition]: ... + num_partitions: builtins.int + label_data_sample_size: builtins.float + model_id: builtins.str + threshold: builtins.float + job_id: builtins.int + collect_metrics: builtins.bool + show_concise: builtins.bool + stop_words_cutoff: builtins.float + block_size: builtins.int + column: builtins.str + def __init__( + self, + *, + output: collections.abc.Iterable[global___Pipe] | None = ..., + data: collections.abc.Iterable[global___Pipe] | None = ..., + zingg_dir: builtins.str = ..., + training_samples: collections.abc.Iterable[global___Pipe] | None = ..., + fiield_definition: collections.abc.Iterable[global___FieldDefinition] | None = ..., + num_partitions: builtins.int = ..., + label_data_sample_size: builtins.float = ..., + model_id: builtins.str = ..., + threshold: builtins.float = ..., + job_id: builtins.int = ..., + collect_metrics: builtins.bool = ..., + show_concise: builtins.bool = ..., + stop_words_cutoff: builtins.float = ..., + block_size: builtins.int = ..., + column: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_column", b"_column", "column", b"column"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_column", b"_column", "block_size", b"block_size", "collect_metrics", b"collect_metrics", "column", b"column", "data", b"data", "fiield_definition", b"fiield_definition", "job_id", b"job_id", "label_data_sample_size", b"label_data_sample_size", "model_id", b"model_id", "num_partitions", b"num_partitions", "output", b"output", "show_concise", b"show_concise", "stop_words_cutoff", b"stop_words_cutoff", "threshold", b"threshold", "training_samples", b"training_samples", "zingg_dir", b"zingg_dir"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["_column", b"_column"]) -> typing_extensions.Literal["column"] | None: ... + +global___Arguments = Arguments + +@typing_extensions.final +class ClientOptions(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PHASE_FIELD_NUMBER: builtins.int + LICENSE_FIELD_NUMBER: builtins.int + EMAIL_FIELD_NUMBER: builtins.int + CONF_FIELD_NUMBER: builtins.int + PREPROCESS_FIELD_NUMBER: builtins.int + JOB_ID_FIELD_NUMBER: builtins.int + FORMAT_FIELD_NUMBER: builtins.int + ZINGG_DIR_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + COLLECT_METRICS_FIELD_NUMBER: builtins.int + SHOW_CONCISE_FIELD_NUMBER: builtins.int + LOCATION_FIELD_NUMBER: builtins.int + COLUMN_FIELD_NUMBER: builtins.int + REMOTE_FIELD_NUMBER: builtins.int + phase: builtins.str + license: builtins.str + email: builtins.str + conf: builtins.str + preprocess: builtins.str + job_id: builtins.str + format: builtins.str + zingg_dir: builtins.str + model_id: builtins.str + collect_metrics: builtins.str + show_concise: builtins.str + location: builtins.str + column: builtins.str + remote: builtins.str + def __init__( + self, + *, + phase: builtins.str | None = ..., + license: builtins.str | None = ..., + email: builtins.str | None = ..., + conf: builtins.str | None = ..., + preprocess: builtins.str | None = ..., + job_id: builtins.str | None = ..., + format: builtins.str | None = ..., + zingg_dir: builtins.str | None = ..., + model_id: builtins.str | None = ..., + collect_metrics: builtins.str | None = ..., + show_concise: builtins.str | None = ..., + location: builtins.str | None = ..., + column: builtins.str | None = ..., + remote: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_collect_metrics", b"_collect_metrics", "_column", b"_column", "_conf", b"_conf", "_email", b"_email", "_format", b"_format", "_job_id", b"_job_id", "_license", b"_license", "_location", b"_location", "_model_id", b"_model_id", "_phase", b"_phase", "_preprocess", b"_preprocess", "_remote", b"_remote", "_show_concise", b"_show_concise", "_zingg_dir", b"_zingg_dir", "collect_metrics", b"collect_metrics", "column", b"column", "conf", b"conf", "email", b"email", "format", b"format", "job_id", b"job_id", "license", b"license", "location", b"location", "model_id", b"model_id", "phase", b"phase", "preprocess", b"preprocess", "remote", b"remote", "show_concise", b"show_concise", "zingg_dir", b"zingg_dir"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_collect_metrics", b"_collect_metrics", "_column", b"_column", "_conf", b"_conf", "_email", b"_email", "_format", b"_format", "_job_id", b"_job_id", "_license", b"_license", "_location", b"_location", "_model_id", b"_model_id", "_phase", b"_phase", "_preprocess", b"_preprocess", "_remote", b"_remote", "_show_concise", b"_show_concise", "_zingg_dir", b"_zingg_dir", "collect_metrics", b"collect_metrics", "column", b"column", "conf", b"conf", "email", b"email", "format", b"format", "job_id", b"job_id", "license", b"license", "location", b"location", "model_id", b"model_id", "phase", b"phase", "preprocess", b"preprocess", "remote", b"remote", "show_concise", b"show_concise", "zingg_dir", b"zingg_dir"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_collect_metrics", b"_collect_metrics"]) -> typing_extensions.Literal["collect_metrics"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_column", b"_column"]) -> typing_extensions.Literal["column"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_conf", b"_conf"]) -> typing_extensions.Literal["conf"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_email", b"_email"]) -> typing_extensions.Literal["email"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_format", b"_format"]) -> typing_extensions.Literal["format"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_job_id", b"_job_id"]) -> typing_extensions.Literal["job_id"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_license", b"_license"]) -> typing_extensions.Literal["license"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_location", b"_location"]) -> typing_extensions.Literal["location"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_model_id", b"_model_id"]) -> typing_extensions.Literal["model_id"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_phase", b"_phase"]) -> typing_extensions.Literal["phase"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_preprocess", b"_preprocess"]) -> typing_extensions.Literal["preprocess"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_remote", b"_remote"]) -> typing_extensions.Literal["remote"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_show_concise", b"_show_concise"]) -> typing_extensions.Literal["show_concise"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_zingg_dir", b"_zingg_dir"]) -> typing_extensions.Literal["zingg_dir"] | None: ... + +global___ClientOptions = ClientOptions diff --git a/python/zingg_v2/proto/connect_plugins_pb2_grpc.py b/python/zingg_v2/proto/connect_plugins_pb2_grpc.py new file mode 100644 index 000000000..2daafffeb --- /dev/null +++ b/python/zingg_v2/proto/connect_plugins_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/ruleset.xml b/ruleset.xml new file mode 100644 index 000000000..71d1534a3 --- /dev/null +++ b/ruleset.xml @@ -0,0 +1,10 @@ + + + lafaspot PMD rules. + + 1 + + 1 + + + diff --git a/scripts/get-spark-connect-local.sh b/scripts/get-spark-connect-local.sh new file mode 100644 index 000000000..43e878102 --- /dev/null +++ b/scripts/get-spark-connect-local.sh @@ -0,0 +1,4 @@ +#!/usr/bin/bash + +wget -q https://archive.apache.org/dist/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz +tar -xvf spark-3.5.1-bin-hadoop3.tgz diff --git a/scripts/run-spark-connect-local.sh b/scripts/run-spark-connect-local.sh new file mode 100644 index 000000000..87572659a --- /dev/null +++ b/scripts/run-spark-connect-local.sh @@ -0,0 +1,7 @@ +#!/usr/bin/bash + + spark-3.5.1-bin-hadoop3/sbin/start-connect-server.sh --wait \ + --verbose \ + --jars assembly/target/zingg-0.4.0.jar \ + --conf spark.connect.extensions.relation.classes=zingg.spark.connect.ZinggConnectPlugin \ + --packages org.apache.spark:spark-connect_2.12:3.5.1 diff --git a/scripts/zingg.sh b/scripts/zingg.sh index 53b841ad2..2ac64f852 100755 --- a/scripts/zingg.sh +++ b/scripts/zingg.sh @@ -1,6 +1,6 @@ #!/bin/bash #ZINGG_HOME=./assembly/target -ZINGG_JARS=$ZINGG_HOME/zingg-0.4.0.jar +ZINGG_JARS=$ZINGG_HOME/zingg-0.4.1-SNAPSHOT.jar EMAIL=zingg@zingg.ai LICENSE=zinggLicense.txt log4j_setting="-Dlog4j2.configurationFile=file:log4j2.properties" diff --git a/spark/client/pom.xml b/spark/client/pom.xml index 418abd6e8..3b86a0a93 100644 --- a/spark/client/pom.xml +++ b/spark/client/pom.xml @@ -8,8 +8,8 @@ zingg-spark-client jar - 2.12.6 - 2.12.6.1 + 2.15.2 + 2.15.2 @@ -39,14 +39,47 @@ jackson-annotations ${fasterxml.jackson.version} + + zingg + zingg-common-client + ${zingg.version} + tests + test-jar + test + + + net.alchim31.maven + scala-maven-plugin + 4.9.2 + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + ${scala.version} + + org.apache.maven.plugins maven-javadoc-plugin - 2.9.1 + 3.5.0 ${basedir}/src/main/java/zingg/client diff --git a/spark/client/src/main/java/zingg/spark/client/SparkClient.java b/spark/client/src/main/java/zingg/spark/client/SparkClient.java index 211ea279f..ae61414c4 100644 --- a/spark/client/src/main/java/zingg/spark/client/SparkClient.java +++ b/spark/client/src/main/java/zingg/spark/client/SparkClient.java @@ -1,42 +1,39 @@ package zingg.spark.client; -import java.io.Serializable; - import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; import zingg.common.client.Client; import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; -import zingg.common.client.IZinggFactory; import zingg.common.client.ZinggClientException; -import zingg.common.client.license.IZinggLicense; - +import zingg.common.client.util.PipeUtilBase; +import zingg.spark.client.util.SparkPipeUtil; /** * This is the main point of interface with the Zingg matching product. * * @author sgoyal * */ -public class SparkClient extends Client, Row, Column, DataType> implements Serializable { +public class SparkClient extends Client, Row, Column, DataType> { private static final long serialVersionUID = 1L; + protected static final String zFactoryClassName = "zingg.spark.core.executor.SparkZFactory"; public SparkClient(IArguments args, ClientOptions options) throws ZinggClientException { - super(args, options); + super(args, options, zFactoryClassName); } - - public SparkClient(IArguments args, ClientOptions options, ZSparkSession s) throws ZinggClientException { - super(args, options, s); - } + + public SparkClient(IArguments args, ClientOptions options, SparkSession s) throws ZinggClientException { - this(args, options, new ZSparkSession(s,null)); + super(args, options, s, zFactoryClassName); } + public SparkClient() { /*SparkSession session = SparkSession @@ -45,18 +42,15 @@ public SparkClient() { .getOrCreate(); JavaSparkContext ctx = JavaSparkContext.fromSparkContext(session.sparkContext()); JavaSparkContext.jarOfClass(IZingg.class); + */ + super(zFactoryClassName); } - @Override - public IZinggFactory getZinggFactory() throws InstantiationException, IllegalAccessException, ClassNotFoundException{ - return (IZinggFactory) Class.forName("zingg.spark.core.executor.SparkZFactory").newInstance(); - } - @Override - public Client, Row, Column, DataType> getClient(IArguments args, + public Client, Row, Column, DataType> getClient(IArguments args, ClientOptions options) throws ZinggClientException { // TODO Auto-generated method stub SparkClient client = null; @@ -77,9 +71,29 @@ public static void main(String... args) { } @Override - protected IZinggLicense getLicense(String license) throws ZinggClientException { - return null; + public SparkSession getSession() { + if (session!=null) { + return session; + } else { + SparkSession s = SparkSession + .builder() + .appName("Zingg") + .getOrCreate(); + setSession(s); + return s; + } + } + @Override + public PipeUtilBase, Row, Column> getPipeUtil() { + if (pipeUtil!=null) { + return pipeUtil; + } else { + PipeUtilBase, Row, Column> p = new SparkPipeUtil(session); + setPipeUtil(p); + return p; + } + } } \ No newline at end of file diff --git a/spark/client/src/main/java/zingg/spark/client/SparkFrame.java b/spark/client/src/main/java/zingg/spark/client/SparkFrame.java index 18957ac67..7add25352 100644 --- a/spark/client/src/main/java/zingg/spark/client/SparkFrame.java +++ b/spark/client/src/main/java/zingg/spark/client/SparkFrame.java @@ -87,7 +87,7 @@ public List collectAsList() { return df.collectAsList(); } - public List collectAsListOfStrings() { + public List collectFirstColumn() { return df.as(Encoders.STRING()).collectAsList(); } @@ -219,6 +219,14 @@ public ZFrame, Row, Column> repartition(int nul, Column c){ return new SparkFrame(df.repartition(nul, c)); } + public ZFrame, Row, Column> repartition(int num,scala.collection.Seq partitionExprs){ + return new SparkFrame(df.repartition(num, partitionExprs)); + } + + public ZFrame, Row, Column> repartition(scala.collection.Seq partitionExprs){ + return new SparkFrame(df.repartition(partitionExprs)); + } + @Override public Column gt(String c) { return gt(this,c); @@ -448,7 +456,20 @@ public ZFrame, Row, Column> groupByCount(String groupByCol1, String } + @Override + public ZFrame, Row, Column> intersect(ZFrame, Row, Column> other) { + return new SparkFrame(df.intersect(other.df())); + } + @Override + public Column substr(Column col, int startPos, int len) { + return col.substr(startPos, len); + } + + @Override + public Column gt(Column column1, Column column2) { + return column1.gt(column2); + } } diff --git a/spark/client/src/main/java/zingg/spark/client/ZSparkSession.java b/spark/client/src/main/java/zingg/spark/client/ZSparkSession.java deleted file mode 100644 index 756b7e574..000000000 --- a/spark/client/src/main/java/zingg/spark/client/ZSparkSession.java +++ /dev/null @@ -1,39 +0,0 @@ -package zingg.spark.client; -import org.apache.spark.sql.SparkSession; - -import zingg.common.client.ZSession; -import zingg.common.client.license.IZinggLicense; - -public class ZSparkSession implements ZSession { - - private SparkSession session; - - private IZinggLicense license; - - public ZSparkSession(SparkSession session, IZinggLicense license) { - super(); - this.session = session; - this.license = license; - } - - @Override - public SparkSession getSession() { - return session; - } - - @Override - public void setSession(SparkSession session) { - this.session = session; - } - - @Override - public IZinggLicense getLicense() { - return license; - } - - @Override - public void setLicense(IZinggLicense license) { - this.license = license; - } - -} diff --git a/spark/client/src/main/java/zingg/spark/client/util/RowsFromObjectList.java b/spark/client/src/main/java/zingg/spark/client/util/RowsFromObjectList.java new file mode 100644 index 000000000..cb1a635be --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/client/util/RowsFromObjectList.java @@ -0,0 +1,18 @@ +package zingg.spark.client.util; + +import java.util.List; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + +import zingg.common.client.util.PojoToArrayConverter; + +public class RowsFromObjectList { + + public static Row[] getRows(List t) throws Exception{ + Row[] rows = new Row[t.size()]; + for (int i=0; i < t.size(); ++i){ + rows[i] = RowFactory.create(PojoToArrayConverter.getObjectArray(t.get(i))); + } + return rows; + } +} diff --git a/spark/client/src/main/java/zingg/spark/client/util/SparkDFObjectUtil.java b/spark/client/src/main/java/zingg/spark/client/util/SparkDFObjectUtil.java new file mode 100644 index 000000000..f5d185ae3 --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkDFObjectUtil.java @@ -0,0 +1,35 @@ +package zingg.spark.client.util; + +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructType; + +import zingg.common.client.ZFrame; +import zingg.common.client.util.DFObjectUtil; +import zingg.common.client.util.IWithSession; +import zingg.spark.client.SparkFrame; + +public class SparkDFObjectUtil extends DFObjectUtil, Row, Column> { + + + public SparkDFObjectUtil(IWithSession withSparkSession) { + super(withSparkSession); + } + + @Override + public ZFrame, Row, Column> getDFFromObjectList(List objList, Class objClass) throws Exception { + if(objList == null || objClass == null) return null; + + SparkStructTypeFromPojoClass stpc = new SparkStructTypeFromPojoClass(); + + List rows = Arrays.asList(RowsFromObjectList.getRows(objList)); + StructType structType = stpc.getStructType(objClass); + return new SparkFrame(iWithSession.getSession().createDataFrame(rows, structType)); + } + +} diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkDFReader.java b/spark/client/src/main/java/zingg/spark/client/util/SparkDFReader.java similarity index 84% rename from spark/core/src/main/java/zingg/spark/core/util/SparkDFReader.java rename to spark/client/src/main/java/zingg/spark/client/util/SparkDFReader.java index 127d29e2a..cf67ff1f6 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkDFReader.java +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkDFReader.java @@ -1,4 +1,4 @@ -package zingg.spark.core.util; +package zingg.spark.client.util; import org.apache.spark.sql.Column; import org.apache.spark.sql.DataFrameReader; @@ -8,18 +8,18 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.core.util.DFReader; +import zingg.common.client.util.DFReader; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; public class SparkDFReader implements DFReader, Row, Column> { - private ZSparkSession session; + private SparkSession session; private DataFrameReader reader; - public SparkDFReader(ZSparkSession s) { + public SparkDFReader(SparkSession s) { this.session = s; - this.reader = s.getSession().read(); + this.reader = s.read(); } public DFReader, Row, Column> getReader() { diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkDFWriter.java b/spark/client/src/main/java/zingg/spark/client/util/SparkDFWriter.java similarity index 92% rename from spark/core/src/main/java/zingg/spark/core/util/SparkDFWriter.java rename to spark/client/src/main/java/zingg/spark/client/util/SparkDFWriter.java index 714c022f0..023a90fb6 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkDFWriter.java +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkDFWriter.java @@ -1,4 +1,4 @@ -package zingg.spark.core.util; +package zingg.spark.client.util; import org.apache.spark.sql.Column; import org.apache.spark.sql.DataFrameWriter; @@ -7,7 +7,7 @@ import org.apache.spark.sql.SaveMode; import zingg.common.client.ZFrame; -import zingg.common.core.util.DFWriter; +import zingg.common.client.util.DFWriter; public class SparkDFWriter implements DFWriter, Row, Column>{ private DataFrameWriter writer; diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkDSUtil.java b/spark/client/src/main/java/zingg/spark/client/util/SparkDSUtil.java similarity index 66% rename from spark/core/src/main/java/zingg/spark/core/util/SparkDSUtil.java rename to spark/client/src/main/java/zingg/spark/client/util/SparkDSUtil.java index 1ed2cf265..ec7df7128 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkDSUtil.java +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkDSUtil.java @@ -1,4 +1,4 @@ -package zingg.spark.core.util; +package zingg.spark.client.util; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -7,14 +7,14 @@ import org.apache.spark.sql.Row; import zingg.common.client.ZFrame; -import zingg.common.core.util.DSUtil; +import zingg.common.client.util.DSUtil; import zingg.scala.DFUtil; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; -public class SparkDSUtil extends DSUtil, Row, Column>{ +public class SparkDSUtil extends DSUtil, Row, Column>{ - public SparkDSUtil(ZSparkSession s) { + public SparkDSUtil(SparkSession s) { super(s); //TODO Auto-generated constructor stub } @@ -28,16 +28,16 @@ public SparkDSUtil(ZSparkSession s) { @Override public ZFrame, Row, Column> addClusterRowNumber(ZFrame, Row, Column> ds) { - ZSparkSession zSparkSession = getSession(); - return new SparkFrame(DFUtil.addClusterRowNumber(((Dataset)ds.df()), zSparkSession.getSession())); + SparkSession sparkSession = getSession(); + return new SparkFrame(DFUtil.addClusterRowNumber(((Dataset)ds.df()), sparkSession)); } @Override public ZFrame, Row, Column> addRowNumber(ZFrame, Row, Column> ds) { - ZSparkSession zSparkSession = getSession(); - return new SparkFrame(DFUtil.addRowNumber(((Dataset)ds.df()), zSparkSession.getSession())); + SparkSession SparkSession = getSession(); + return new SparkFrame(DFUtil.addRowNumber(((Dataset)ds.df()), getSession())); } diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkPipeUtil.java b/spark/client/src/main/java/zingg/spark/client/util/SparkPipeUtil.java similarity index 77% rename from spark/core/src/main/java/zingg/spark/core/util/SparkPipeUtil.java rename to spark/client/src/main/java/zingg/spark/client/util/SparkPipeUtil.java index 999f549c9..51530e7d3 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkPipeUtil.java +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkPipeUtil.java @@ -1,4 +1,4 @@ -package zingg.spark.core.util; +package zingg.spark.client.util; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -10,33 +10,33 @@ import zingg.common.client.ZFrame; //import zingg.common.client.pipe.InMemoryPipe; import zingg.common.client.pipe.Pipe; -import zingg.common.core.util.DFReader; -import zingg.common.core.util.DFWriter; -import zingg.common.core.util.PipeUtil; +import zingg.common.client.util.DFReader; +import zingg.common.client.util.DFWriter; +import zingg.common.client.util.PipeUtil; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; //import com.datastax.spark.connector.cql.*; //import org.elasticsearch.spark.sql.api.java.JavaEsSparkSQL; //import zingg.scala.DFUtil; -public class SparkPipeUtil extends PipeUtil, Row, Column>{ +public class SparkPipeUtil extends PipeUtil, Row, Column>{ public final Log LOG = LogFactory.getLog(SparkPipeUtil.class); //private SparkDFReader reader; - public SparkPipeUtil(ZSparkSession spark) { + public SparkPipeUtil(SparkSession spark) { super(spark); } - public ZSparkSession getSession(){ + public SparkSession getSession(){ return this.session; } - public void setSession(ZSparkSession session){ + public void setSession(SparkSession session){ this.session = session; } diff --git a/spark/client/src/main/java/zingg/spark/client/util/SparkStructTypeFromPojoClass.java b/spark/client/src/main/java/zingg/spark/client/util/SparkStructTypeFromPojoClass.java new file mode 100644 index 000000000..3032907f4 --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/client/util/SparkStructTypeFromPojoClass.java @@ -0,0 +1,48 @@ +package zingg.spark.client.util; + +import java.lang.reflect.Field; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import zingg.common.client.util.StructTypeFromPojoClass; + +public class SparkStructTypeFromPojoClass extends StructTypeFromPojoClass { + + public StructType getStructType(Class objClass) + throws NoSuchAlgorithmException, IllegalArgumentException, IllegalAccessException { + List structFields = getFields(objClass); + return new StructType(structFields.toArray(new StructField[structFields.size()])); + } + + public StructField getStructField(Field field) { + field.setAccessible(true); + return new StructField(field.getName(), getSFType(field.getType()), true, Metadata.empty()); + } + + public DataType getSFType(Class t) { + if (t.getCanonicalName().contains("String")) { + return DataTypes.StringType; + } else if (t.getCanonicalName().contains("Integer")) { + return DataTypes.IntegerType; + } else if (t.getCanonicalName().contains("Long")) { + return DataTypes.LongType; + } else if (t.getCanonicalName().contains("Float")) { + return DataTypes.FloatType; + } else if (t.getCanonicalName().contains("Double")) { + return DataTypes.DoubleType; + } else if (t.getCanonicalName().contains("Date")) { + return DataTypes.DateType; + } else if (t.getCanonicalName().contains("Timestamp")) { + return DataTypes.TimestampType; + } + + return null; + } + +} \ No newline at end of file diff --git a/spark/client/src/main/java/zingg/spark/connect/ZinggConnectPlugin.java b/spark/client/src/main/java/zingg/spark/connect/ZinggConnectPlugin.java new file mode 100644 index 000000000..5a60df46c --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/connect/ZinggConnectPlugin.java @@ -0,0 +1,222 @@ +package zingg.spark.connect; + +import java.util.List; +import java.util.Map; + +import com.google.protobuf.Any; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connect.planner.SparkConnectPlanner; +import org.apache.spark.sql.connect.plugin.RelationPlugin; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import scala.Option; +import zingg.common.client.Arguments; +import zingg.common.client.ArgumentsUtil; +import zingg.common.client.ClientOptions; +import zingg.common.client.FieldDefinition; +import zingg.common.client.IArguments; +import zingg.common.client.pipe.Pipe; +import zingg.spark.client.pipe.SparkPipe; +import zingg.spark.connect.proto.DataFormat; +import zingg.spark.connect.proto.MatchType; +import zingg.spark.connect.proto.SubmitZinggJob; + +public class ZinggConnectPlugin implements RelationPlugin { + private SparkPipe parsePipe(zingg.spark.connect.proto.Pipe protoPipe) { + SparkPipe sparkPipe = new SparkPipe(); + sparkPipe.setName(protoPipe.getName()); + + // Parse DataFormat from proto + DataFormat dataFormatProto = protoPipe.getFormat(); + if (dataFormatProto == DataFormat.DF_AVRO) { + sparkPipe.setFormat(Pipe.FORMAT_AVRO); + } else if (dataFormatProto == DataFormat.DF_BIGQUEY) { + sparkPipe.setFormat(Pipe.FORMAT_BIGQUERY); + } else if (dataFormatProto == DataFormat.DF_CASSANDRA) { + sparkPipe.setFormat(Pipe.FORMAT_CASSANDRA); + } else if (dataFormatProto == DataFormat.DF_CSV) { + sparkPipe.setFormat(Pipe.FORMAT_CSV); + } else if (dataFormatProto == DataFormat.DF_ELASTIC) { + sparkPipe.setFormat(Pipe.FORMAT_ELASTIC); + } else if (dataFormatProto == DataFormat.DF_EXACOL) { + sparkPipe.setFormat(Pipe.FORMAT_EXASOL); + } else if (dataFormatProto == DataFormat.DF_INMEMORY) { + sparkPipe.setFormat(Pipe.FORMAT_INMEMORY); + } else if (dataFormatProto == DataFormat.DF_JDBC) { + sparkPipe.setFormat(Pipe.FORMAT_JDBC); + } else if (dataFormatProto == DataFormat.DF_JSON) { + sparkPipe.setFormat(Pipe.FORMAT_JSON); + } else if (dataFormatProto == DataFormat.DF_PARQUET) { + sparkPipe.setFormat(Pipe.FORMAT_PARQUET); + } else if (dataFormatProto == DataFormat.DF_SNOWFLAKE) { + sparkPipe.setFormat(Pipe.FORMAT_SNOWFLAKE); + } else if (dataFormatProto == DataFormat.DF_TEXT) { + sparkPipe.setFormat(Pipe.FORMAT_TEXT); + } else if (dataFormatProto == DataFormat.DF_XLS) { + sparkPipe.setFormat(Pipe.FORMAT_XLS); + } else { + throw new RuntimeException(String.format("Unknown format %s", dataFormatProto.name())); + } + + // Parse tags + for (Map.Entry kv : protoPipe.getPropsMap().entrySet()) { + sparkPipe.setProp(kv.getKey(), kv.getValue()); + } + + if (protoPipe.hasSchemaField()) { + sparkPipe.setSchema(protoPipe.getSchemaField()); + } + + if (protoPipe.hasMode()) { + sparkPipe.setMode(protoPipe.getMode()); + } + + return sparkPipe; + } + + private SparkPipe[] parsePipes(List protoPipes) { + return protoPipes.stream().map(protoPipe -> parsePipe(protoPipe)).toArray(SparkPipe[]::new); + } + + private FieldDefinition parseFieldDefinition(zingg.spark.connect.proto.FieldDefinition fieldDefinitionProto) { + FieldDefinition fieldDefinition = new FieldDefinition(); + fieldDefinition.setMatchType(fieldDefinitionProto.getMatchTypeList().stream().map(mt -> { + if (mt == MatchType.MT_FUZZY) { + return zingg.common.client.MatchType.FUZZY; + } else if (mt == MatchType.MT_EXACT) { + return zingg.common.client.MatchType.EXACT; + } else if (mt == MatchType.MT_DONT_USE) { + return zingg.common.client.MatchType.DONT_USE; + } else if (mt == MatchType.MT_EMAIL) { + return zingg.common.client.MatchType.EMAIL; + } else if (mt == MatchType.MT_PINCODE) { + return zingg.common.client.MatchType.PINCODE; + } else if (mt == MatchType.MT_NULL_OR_BLANK) { + return zingg.common.client.MatchType.NULL_OR_BLANK; + } else if (mt == MatchType.MT_TEXT) { + return zingg.common.client.MatchType.TEXT; + } else if (mt == MatchType.MT_NUMERIC) { + return zingg.common.client.MatchType.NUMERIC; + } else if (mt == MatchType.MT_NUMERIC_WITH_UNITS) { + return zingg.common.client.MatchType.NUMERIC_WITH_UNITS; + } else if (mt == MatchType.MT_ONLY_ALPHABETS_EXACT) { + return zingg.common.client.MatchType.ONLY_ALPHABETS_EXACT; + } else if (mt == MatchType.MT_ONLY_ALPHABETS_FUZZY) { + return zingg.common.client.MatchType.ONLY_ALPHABETS_FUZZY; + } else { + throw new RuntimeException(String.format("Unknown type %s", mt.name())); + } + }).toList()); + + fieldDefinition.setDataType(fieldDefinitionProto.getDataType()); + fieldDefinition.setFieldName(fieldDefinitionProto.getFieldName()); + fieldDefinition.setFields(fieldDefinitionProto.getFields()); + if (fieldDefinitionProto.hasStopWords()) { + fieldDefinition.setStopWords(fieldDefinitionProto.getStopWords()); + } + if (fieldDefinitionProto.hasAbbreviations()) { + fieldDefinition.setAbbreviations(fieldDefinitionProto.getAbbreviations()); + } + + return fieldDefinition; + } + + // 3.5.2 behaviour + // Because of shading rules this method may be marked as wrongly overriden + @Override + public Option transform(Any relation, SparkConnectPlanner planner) { + if (relation.is(SubmitZinggJob.class)) { + SubmitZinggJob zinggJobProto = relation.unpack(SubmitZinggJob.class); + // It is expected that the session exisits! + SparkSession spark = planner.sessionHolder().session(); + IArguments arguments = new Arguments(); + // Parse arguments + + // Output pipes + arguments.setOutput(parsePipes(zinggJobProto.getArgumnets().getOutputList())); + // Data pipes + arguments.setData(parsePipes(zinggJobProto.getArgumnets().getDataList())); + // Training samples + arguments.setTrainingSamples(parsePipes(zinggJobProto.getArgumnets().getTrainingSamplesList())); + // Field definitions + arguments.setFieldDefinition(zinggJobProto.getArgumnets().getFiieldDefinitionList().stream() + .map(fd -> parseFieldDefinition(fd)).toList()); + + // Arguments + arguments.setZinggDir(zinggJobProto.getArgumnets().getZinggDir()); + arguments.setNumPartitions(zinggJobProto.getArgumnets().getNumPartitions()); + arguments.setLabelDataSampleSize(zinggJobProto.getArgumnets().getLabelDataSampleSize()); + arguments.setModelId(zinggJobProto.getArgumnets().getModelId()); + arguments.setThreshold(zinggJobProto.getArgumnets().getThreshold()); + arguments.setJobId(zinggJobProto.getArgumnets().getJobId()); + arguments.setCollectMetrics(zinggJobProto.getArgumnets().getCollectMetrics()); + arguments.setShowConcise(zinggJobProto.getArgumnets().getShowConcise()); + arguments.setStopWordsCutoff(zinggJobProto.getArgumnets().getStopWordsCutoff()); + arguments.setBlockSize(zinggJobProto.getArgumnets().getBlockSize()); + if (zinggJobProto.getArgumnets().hasColumn()) { + arguments.setColumn(zinggJobProto.getArgumnets().getColumn()); + } + + // Options + zingg.spark.connect.proto.ClientOptions clientOptionsProto = zinggJobProto.getCliOptions(); + ClientOptions clientOptions = new ClientOptions(); + + if (clientOptionsProto.hasPhase()) { + clientOptions.setOptionValue(ClientOptions.PHASE, clientOptionsProto.getPhase()); + } + if (clientOptionsProto.hasLicense()) { + clientOptions.setOptionValue(ClientOptions.LICENSE, clientOptionsProto.getLicense()); + } + if (clientOptionsProto.hasEmail()) { + clientOptions.setOptionValue(ClientOptions.EMAIL, clientOptionsProto.getEmail()); + } + if (clientOptionsProto.hasConf()) { + clientOptions.setOptionValue(ClientOptions.CONF, clientOptionsProto.getConf()); + } + if (clientOptionsProto.hasPreprocess()) { + clientOptions.setOptionValue(ClientOptions.PREPROCESS, clientOptionsProto.getPreprocess()); + } + if (clientOptionsProto.hasJobId()) { + clientOptions.setOptionValue(ClientOptions.JOBID, clientOptionsProto.getJobId()); + } + if (clientOptionsProto.hasFormat()) { + clientOptions.setOptionValue(ClientOptions.FORMAT, clientOptionsProto.getFormat()); + } + if (clientOptionsProto.hasZinggDir()) { + clientOptions.setOptionValue(ClientOptions.ZINGG_DIR, clientOptionsProto.getZinggDir()); + } + if (clientOptionsProto.hasModelId()) { + clientOptions.setOptionValue(ClientOptions.MODEL_ID, clientOptionsProto.getModelId()); + } + if (clientOptionsProto.hasCollectMetrics()) { + clientOptions.setOptionValue(ClientOptions.COLLECT_METRICS, clientOptionsProto.getCollectMetrics()); + } + if (clientOptionsProto.hasShowConcise()) { + clientOptions.setOptionValue(ClientOptions.SHOW_CONCISE, clientOptionsProto.getShowConcise()); + } + if (clientOptionsProto.hasLocation()) { + clientOptions.setOptionValue(ClientOptions.LOCATION, clientOptionsProto.getLocation()); + } + if (clientOptionsProto.hasColumn()) { + clientOptions.setOptionValue(ClientOptions.COLUMN, clientOptionsProto.getColumn()); + } + if (clientOptionsProto.hasRemote()) { + clientOptions.setOptionValue(ClientOptions.REMOTE, clientOptionsProto.getRemote()); + } + + Dataset outDF = spark.createDataFrame( + List.of(RowFactory.create(new ArgumentsUtil().writeArgumentstoJSONString(arguments), + String.join(" ", clientOptions.getCommandLineArgs()))), + new StructType(new StructField[] { DataTypes.createStructField("args", DataTypes.StringType, false), + DataTypes.createStructField("cliopts", DataTypes.StringType, false) })); + return Option.apply(outDF.logicalPlan()); + } + } +} diff --git a/spark/client/src/main/java/zingg/spark/connect/proto/ConnectPlugins.java b/spark/client/src/main/java/zingg/spark/connect/proto/ConnectPlugins.java new file mode 100644 index 000000000..07d04e315 --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/connect/proto/ConnectPlugins.java @@ -0,0 +1,161 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: connect_plugins.proto + +// Protobuf Java Version: 3.25.3 +package zingg.spark.connect.proto; + +public final class ConnectPlugins { + private ConnectPlugins() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + static final com.google.protobuf.Descriptors.Descriptor + internal_static_SubmitZinggJob_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_SubmitZinggJob_fieldAccessorTable; + static final com.google.protobuf.Descriptors.Descriptor + internal_static_FieldDefinition_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_FieldDefinition_fieldAccessorTable; + static final com.google.protobuf.Descriptors.Descriptor + internal_static_Pipe_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_Pipe_fieldAccessorTable; + static final com.google.protobuf.Descriptors.Descriptor + internal_static_Pipe_PropsEntry_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_Pipe_PropsEntry_fieldAccessorTable; + static final com.google.protobuf.Descriptors.Descriptor + internal_static_Arguments_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_Arguments_fieldAccessorTable; + static final com.google.protobuf.Descriptors.Descriptor + internal_static_ClientOptions_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_ClientOptions_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\025connect_plugins.proto\"\251\001\n\016SubmitZinggJ" + + "ob\022(\n\targumnets\030\001 \001(\0132\n.ArgumentsR\targum" + + "nets\022/\n\013cli_options\030\002 \001(\0132\016.ClientOption" + + "sR\ncliOptions\022)\n\016in_memory_date\030\003 \001(\014H\000R" + + "\014inMemoryDate\210\001\001B\021\n\017_in_memory_date\"\200\002\n\017" + + "FieldDefinition\022)\n\nmatch_type\030\001 \003(\0162\n.Ma" + + "tchTypeR\tmatchType\022\033\n\tdata_type\030\002 \001(\tR\010d" + + "ataType\022\035\n\nfield_name\030\003 \001(\tR\tfieldName\022\026" + + "\n\006fields\030\004 \001(\tR\006fields\022\"\n\nstop_words\030\005 \001" + + "(\tH\000R\tstopWords\210\001\001\022)\n\rabbreviations\030\006 \001(" + + "\tH\001R\rabbreviations\210\001\001B\r\n\013_stop_wordsB\020\n\016" + + "_abbreviations\"\374\001\n\004Pipe\022\022\n\004name\030\001 \001(\tR\004n" + + "ame\022#\n\006format\030\002 \001(\0162\013.DataFormatR\006format" + + "\022&\n\005props\030\003 \003(\0132\020.Pipe.PropsEntryR\005props" + + "\022&\n\014schema_field\030\004 \001(\tH\000R\013schemaField\210\001\001" + + "\022\027\n\004mode\030\005 \001(\tH\001R\004mode\210\001\001\0328\n\nPropsEntry\022" + + "\020\n\003key\030\001 \001(\tR\003key\022\024\n\005value\030\002 \001(\tR\005value:" + + "\0028\001B\017\n\r_schema_fieldB\007\n\005_mode\"\276\004\n\tArgume" + + "nts\022\035\n\006output\030\001 \003(\0132\005.PipeR\006output\022\031\n\004da" + + "ta\030\002 \003(\0132\005.PipeR\004data\022\033\n\tzingg_dir\030\003 \001(\t" + + "R\010zinggDir\0220\n\020training_samples\030\004 \003(\0132\005.P" + + "ipeR\017trainingSamples\022=\n\021fiield_definitio" + + "n\030\005 \003(\0132\020.FieldDefinitionR\020fiieldDefinit" + + "ion\022%\n\016num_partitions\030\006 \001(\005R\rnumPartitio" + + "ns\0223\n\026label_data_sample_size\030\007 \001(\002R\023labe" + + "lDataSampleSize\022\031\n\010model_id\030\010 \001(\tR\007model" + + "Id\022\034\n\tthreshold\030\t \001(\002R\tthreshold\022\025\n\006job_" + + "id\030\n \001(\005R\005jobId\022\'\n\017collect_metrics\030\013 \001(\010" + + "R\016collectMetrics\022!\n\014show_concise\030\014 \001(\010R\013" + + "showConcise\022*\n\021stop_words_cutoff\030\r \001(\002R\017" + + "stopWordsCutoff\022\035\n\nblock_size\030\016 \001(\003R\tblo" + + "ckSize\022\033\n\006column\030\017 \001(\tH\000R\006column\210\001\001B\t\n\007_" + + "column\"\377\004\n\rClientOptions\022\031\n\005phase\030\001 \001(\tH" + + "\000R\005phase\210\001\001\022\035\n\007license\030\002 \001(\tH\001R\007license\210" + + "\001\001\022\031\n\005email\030\003 \001(\tH\002R\005email\210\001\001\022\027\n\004conf\030\004 " + + "\001(\tH\003R\004conf\210\001\001\022#\n\npreprocess\030\005 \001(\tH\004R\npr" + + "eprocess\210\001\001\022\032\n\006job_id\030\006 \001(\tH\005R\005jobId\210\001\001\022" + + "\033\n\006format\030\007 \001(\tH\006R\006format\210\001\001\022 \n\tzingg_di" + + "r\030\010 \001(\tH\007R\010zinggDir\210\001\001\022\036\n\010model_id\030\t \001(\t" + + "H\010R\007modelId\210\001\001\022,\n\017collect_metrics\030\n \001(\tH" + + "\tR\016collectMetrics\210\001\001\022&\n\014show_concise\030\013 \001" + + "(\tH\nR\013showConcise\210\001\001\022\037\n\010location\030\014 \001(\tH\013" + + "R\010location\210\001\001\022\033\n\006column\030\r \001(\tH\014R\006column\210" + + "\001\001\022\033\n\006remote\030\016 \001(\tH\rR\006remote\210\001\001B\010\n\006_phas" + + "eB\n\n\010_licenseB\010\n\006_emailB\007\n\005_confB\r\n\013_pre" + + "processB\t\n\007_job_idB\t\n\007_formatB\014\n\n_zingg_" + + "dirB\013\n\t_model_idB\022\n\020_collect_metricsB\017\n\r" + + "_show_conciseB\013\n\t_locationB\t\n\007_columnB\t\n" + + "\007_remote*\336\001\n\tMatchType\022\014\n\010MT_FUZZY\020\000\022\014\n\010" + + "MT_EXACT\020\001\022\017\n\013MT_DONT_USE\020\002\022\014\n\010MT_EMAIL\020" + + "\003\022\016\n\nMT_PINCODE\020\004\022\024\n\020MT_NULL_OR_BLANK\020\005\022" + + "\013\n\007MT_TEXT\020\006\022\016\n\nMT_NUMERIC\020\007\022\031\n\025MT_NUMER" + + "IC_WITH_UNITS\020\010\022\033\n\027MT_ONLY_ALPHABETS_EXA" + + "CT\020\t\022\033\n\027MT_ONLY_ALPHABETS_FUZZY\020\n*\314\001\n\nDa" + + "taFormat\022\n\n\006DF_CSV\020\000\022\016\n\nDF_PARQUET\020\001\022\013\n\007" + + "DF_JSON\020\002\022\013\n\007DF_TEXT\020\003\022\n\n\006DF_XLS\020\004\022\013\n\007DF" + + "_AVRO\020\005\022\013\n\007DF_JDBC\020\006\022\020\n\014DF_CASSANDRA\020\007\022\020" + + "\n\014DF_SNOWFLAKE\020\010\022\016\n\nDF_ELASTIC\020\t\022\r\n\tDF_E" + + "XACOL\020\n\022\016\n\nDF_BIGQUEY\020\013\022\017\n\013DF_INMEMORY\020\014" + + "B\035\n\031zingg.spark.connect.protoP\001b\006proto3" + }; + descriptor = com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + }); + internal_static_SubmitZinggJob_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_SubmitZinggJob_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_SubmitZinggJob_descriptor, + new java.lang.String[] { "Argumnets", "CliOptions", "InMemoryDate", }); + internal_static_FieldDefinition_descriptor = + getDescriptor().getMessageTypes().get(1); + internal_static_FieldDefinition_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_FieldDefinition_descriptor, + new java.lang.String[] { "MatchType", "DataType", "FieldName", "Fields", "StopWords", "Abbreviations", }); + internal_static_Pipe_descriptor = + getDescriptor().getMessageTypes().get(2); + internal_static_Pipe_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_Pipe_descriptor, + new java.lang.String[] { "Name", "Format", "Props", "SchemaField", "Mode", }); + internal_static_Pipe_PropsEntry_descriptor = + internal_static_Pipe_descriptor.getNestedTypes().get(0); + internal_static_Pipe_PropsEntry_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_Pipe_PropsEntry_descriptor, + new java.lang.String[] { "Key", "Value", }); + internal_static_Arguments_descriptor = + getDescriptor().getMessageTypes().get(3); + internal_static_Arguments_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_Arguments_descriptor, + new java.lang.String[] { "Output", "Data", "ZinggDir", "TrainingSamples", "FiieldDefinition", "NumPartitions", "LabelDataSampleSize", "ModelId", "Threshold", "JobId", "CollectMetrics", "ShowConcise", "StopWordsCutoff", "BlockSize", "Column", }); + internal_static_ClientOptions_descriptor = + getDescriptor().getMessageTypes().get(4); + internal_static_ClientOptions_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_ClientOptions_descriptor, + new java.lang.String[] { "Phase", "License", "Email", "Conf", "Preprocess", "JobId", "Format", "ZinggDir", "ModelId", "CollectMetrics", "ShowConcise", "Location", "Column", "Remote", }); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJob.java b/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJob.java new file mode 100644 index 000000000..fdf0377fc --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJob.java @@ -0,0 +1,897 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: connect_plugins.proto + +// Protobuf Java Version: 3.25.3 +package zingg.spark.connect.proto; + +/** + * Protobuf type {@code SubmitZinggJob} + */ +public final class SubmitZinggJob extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:SubmitZinggJob) + SubmitZinggJobOrBuilder { +private static final long serialVersionUID = 0L; + // Use SubmitZinggJob.newBuilder() to construct. + private SubmitZinggJob(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private SubmitZinggJob() { + inMemoryDate_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SubmitZinggJob(); + } + + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return zingg.spark.connect.proto.ConnectPlugins.internal_static_SubmitZinggJob_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return zingg.spark.connect.proto.ConnectPlugins.internal_static_SubmitZinggJob_fieldAccessorTable + .ensureFieldAccessorsInitialized( + zingg.spark.connect.proto.SubmitZinggJob.class, zingg.spark.connect.proto.SubmitZinggJob.Builder.class); + } + + private int bitField0_; + public static final int ARGUMNETS_FIELD_NUMBER = 1; + private zingg.spark.connect.proto.Arguments argumnets_; + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return Whether the argumnets field is set. + */ + @java.lang.Override + public boolean hasArgumnets() { + return ((bitField0_ & 0x00000001) != 0); + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return The argumnets. + */ + @java.lang.Override + public zingg.spark.connect.proto.Arguments getArgumnets() { + return argumnets_ == null ? zingg.spark.connect.proto.Arguments.getDefaultInstance() : argumnets_; + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + @java.lang.Override + public zingg.spark.connect.proto.ArgumentsOrBuilder getArgumnetsOrBuilder() { + return argumnets_ == null ? zingg.spark.connect.proto.Arguments.getDefaultInstance() : argumnets_; + } + + public static final int CLI_OPTIONS_FIELD_NUMBER = 2; + private zingg.spark.connect.proto.ClientOptions cliOptions_; + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return Whether the cliOptions field is set. + */ + @java.lang.Override + public boolean hasCliOptions() { + return ((bitField0_ & 0x00000002) != 0); + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return The cliOptions. + */ + @java.lang.Override + public zingg.spark.connect.proto.ClientOptions getCliOptions() { + return cliOptions_ == null ? zingg.spark.connect.proto.ClientOptions.getDefaultInstance() : cliOptions_; + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + @java.lang.Override + public zingg.spark.connect.proto.ClientOptionsOrBuilder getCliOptionsOrBuilder() { + return cliOptions_ == null ? zingg.spark.connect.proto.ClientOptions.getDefaultInstance() : cliOptions_; + } + + public static final int IN_MEMORY_DATE_FIELD_NUMBER = 3; + private com.google.protobuf.ByteString inMemoryDate_ = com.google.protobuf.ByteString.EMPTY; + /** + *

+   * The next message is a serialized LogicalPlan
+   * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return Whether the inMemoryDate field is set. + */ + @java.lang.Override + public boolean hasInMemoryDate() { + return ((bitField0_ & 0x00000004) != 0); + } + /** + *
+   * The next message is a serialized LogicalPlan
+   * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return The inMemoryDate. + */ + @java.lang.Override + public com.google.protobuf.ByteString getInMemoryDate() { + return inMemoryDate_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) != 0)) { + output.writeMessage(1, getArgumnets()); + } + if (((bitField0_ & 0x00000002) != 0)) { + output.writeMessage(2, getCliOptions()); + } + if (((bitField0_ & 0x00000004) != 0)) { + output.writeBytes(3, inMemoryDate_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) != 0)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(1, getArgumnets()); + } + if (((bitField0_ & 0x00000002) != 0)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, getCliOptions()); + } + if (((bitField0_ & 0x00000004) != 0)) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(3, inMemoryDate_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof zingg.spark.connect.proto.SubmitZinggJob)) { + return super.equals(obj); + } + zingg.spark.connect.proto.SubmitZinggJob other = (zingg.spark.connect.proto.SubmitZinggJob) obj; + + if (hasArgumnets() != other.hasArgumnets()) return false; + if (hasArgumnets()) { + if (!getArgumnets() + .equals(other.getArgumnets())) return false; + } + if (hasCliOptions() != other.hasCliOptions()) return false; + if (hasCliOptions()) { + if (!getCliOptions() + .equals(other.getCliOptions())) return false; + } + if (hasInMemoryDate() != other.hasInMemoryDate()) return false; + if (hasInMemoryDate()) { + if (!getInMemoryDate() + .equals(other.getInMemoryDate())) return false; + } + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasArgumnets()) { + hash = (37 * hash) + ARGUMNETS_FIELD_NUMBER; + hash = (53 * hash) + getArgumnets().hashCode(); + } + if (hasCliOptions()) { + hash = (37 * hash) + CLI_OPTIONS_FIELD_NUMBER; + hash = (53 * hash) + getCliOptions().hashCode(); + } + if (hasInMemoryDate()) { + hash = (37 * hash) + IN_MEMORY_DATE_FIELD_NUMBER; + hash = (53 * hash) + getInMemoryDate().hashCode(); + } + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + public static zingg.spark.connect.proto.SubmitZinggJob parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + + public static zingg.spark.connect.proto.SubmitZinggJob parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static zingg.spark.connect.proto.SubmitZinggJob parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(zingg.spark.connect.proto.SubmitZinggJob prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code SubmitZinggJob} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:SubmitZinggJob) + zingg.spark.connect.proto.SubmitZinggJobOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return zingg.spark.connect.proto.ConnectPlugins.internal_static_SubmitZinggJob_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return zingg.spark.connect.proto.ConnectPlugins.internal_static_SubmitZinggJob_fieldAccessorTable + .ensureFieldAccessorsInitialized( + zingg.spark.connect.proto.SubmitZinggJob.class, zingg.spark.connect.proto.SubmitZinggJob.Builder.class); + } + + // Construct using zingg.spark.connect.proto.SubmitZinggJob.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + getArgumnetsFieldBuilder(); + getCliOptionsFieldBuilder(); + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + bitField0_ = 0; + argumnets_ = null; + if (argumnetsBuilder_ != null) { + argumnetsBuilder_.dispose(); + argumnetsBuilder_ = null; + } + cliOptions_ = null; + if (cliOptionsBuilder_ != null) { + cliOptionsBuilder_.dispose(); + cliOptionsBuilder_ = null; + } + inMemoryDate_ = com.google.protobuf.ByteString.EMPTY; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return zingg.spark.connect.proto.ConnectPlugins.internal_static_SubmitZinggJob_descriptor; + } + + @java.lang.Override + public zingg.spark.connect.proto.SubmitZinggJob getDefaultInstanceForType() { + return zingg.spark.connect.proto.SubmitZinggJob.getDefaultInstance(); + } + + @java.lang.Override + public zingg.spark.connect.proto.SubmitZinggJob build() { + zingg.spark.connect.proto.SubmitZinggJob result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public zingg.spark.connect.proto.SubmitZinggJob buildPartial() { + zingg.spark.connect.proto.SubmitZinggJob result = new zingg.spark.connect.proto.SubmitZinggJob(this); + if (bitField0_ != 0) { buildPartial0(result); } + onBuilt(); + return result; + } + + private void buildPartial0(zingg.spark.connect.proto.SubmitZinggJob result) { + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) != 0)) { + result.argumnets_ = argumnetsBuilder_ == null + ? argumnets_ + : argumnetsBuilder_.build(); + to_bitField0_ |= 0x00000001; + } + if (((from_bitField0_ & 0x00000002) != 0)) { + result.cliOptions_ = cliOptionsBuilder_ == null + ? cliOptions_ + : cliOptionsBuilder_.build(); + to_bitField0_ |= 0x00000002; + } + if (((from_bitField0_ & 0x00000004) != 0)) { + result.inMemoryDate_ = inMemoryDate_; + to_bitField0_ |= 0x00000004; + } + result.bitField0_ |= to_bitField0_; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof zingg.spark.connect.proto.SubmitZinggJob) { + return mergeFrom((zingg.spark.connect.proto.SubmitZinggJob)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(zingg.spark.connect.proto.SubmitZinggJob other) { + if (other == zingg.spark.connect.proto.SubmitZinggJob.getDefaultInstance()) return this; + if (other.hasArgumnets()) { + mergeArgumnets(other.getArgumnets()); + } + if (other.hasCliOptions()) { + mergeCliOptions(other.getCliOptions()); + } + if (other.hasInMemoryDate()) { + setInMemoryDate(other.getInMemoryDate()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + input.readMessage( + getArgumnetsFieldBuilder().getBuilder(), + extensionRegistry); + bitField0_ |= 0x00000001; + break; + } // case 10 + case 18: { + input.readMessage( + getCliOptionsFieldBuilder().getBuilder(), + extensionRegistry); + bitField0_ |= 0x00000002; + break; + } // case 18 + case 26: { + inMemoryDate_ = input.readBytes(); + bitField0_ |= 0x00000004; + break; + } // case 26 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + private int bitField0_; + + private zingg.spark.connect.proto.Arguments argumnets_; + private com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.Arguments, zingg.spark.connect.proto.Arguments.Builder, zingg.spark.connect.proto.ArgumentsOrBuilder> argumnetsBuilder_; + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return Whether the argumnets field is set. + */ + public boolean hasArgumnets() { + return ((bitField0_ & 0x00000001) != 0); + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return The argumnets. + */ + public zingg.spark.connect.proto.Arguments getArgumnets() { + if (argumnetsBuilder_ == null) { + return argumnets_ == null ? zingg.spark.connect.proto.Arguments.getDefaultInstance() : argumnets_; + } else { + return argumnetsBuilder_.getMessage(); + } + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public Builder setArgumnets(zingg.spark.connect.proto.Arguments value) { + if (argumnetsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + argumnets_ = value; + } else { + argumnetsBuilder_.setMessage(value); + } + bitField0_ |= 0x00000001; + onChanged(); + return this; + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public Builder setArgumnets( + zingg.spark.connect.proto.Arguments.Builder builderForValue) { + if (argumnetsBuilder_ == null) { + argumnets_ = builderForValue.build(); + } else { + argumnetsBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000001; + onChanged(); + return this; + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public Builder mergeArgumnets(zingg.spark.connect.proto.Arguments value) { + if (argumnetsBuilder_ == null) { + if (((bitField0_ & 0x00000001) != 0) && + argumnets_ != null && + argumnets_ != zingg.spark.connect.proto.Arguments.getDefaultInstance()) { + getArgumnetsBuilder().mergeFrom(value); + } else { + argumnets_ = value; + } + } else { + argumnetsBuilder_.mergeFrom(value); + } + if (argumnets_ != null) { + bitField0_ |= 0x00000001; + onChanged(); + } + return this; + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public Builder clearArgumnets() { + bitField0_ = (bitField0_ & ~0x00000001); + argumnets_ = null; + if (argumnetsBuilder_ != null) { + argumnetsBuilder_.dispose(); + argumnetsBuilder_ = null; + } + onChanged(); + return this; + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public zingg.spark.connect.proto.Arguments.Builder getArgumnetsBuilder() { + bitField0_ |= 0x00000001; + onChanged(); + return getArgumnetsFieldBuilder().getBuilder(); + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + public zingg.spark.connect.proto.ArgumentsOrBuilder getArgumnetsOrBuilder() { + if (argumnetsBuilder_ != null) { + return argumnetsBuilder_.getMessageOrBuilder(); + } else { + return argumnets_ == null ? + zingg.spark.connect.proto.Arguments.getDefaultInstance() : argumnets_; + } + } + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + private com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.Arguments, zingg.spark.connect.proto.Arguments.Builder, zingg.spark.connect.proto.ArgumentsOrBuilder> + getArgumnetsFieldBuilder() { + if (argumnetsBuilder_ == null) { + argumnetsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.Arguments, zingg.spark.connect.proto.Arguments.Builder, zingg.spark.connect.proto.ArgumentsOrBuilder>( + getArgumnets(), + getParentForChildren(), + isClean()); + argumnets_ = null; + } + return argumnetsBuilder_; + } + + private zingg.spark.connect.proto.ClientOptions cliOptions_; + private com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.ClientOptions, zingg.spark.connect.proto.ClientOptions.Builder, zingg.spark.connect.proto.ClientOptionsOrBuilder> cliOptionsBuilder_; + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return Whether the cliOptions field is set. + */ + public boolean hasCliOptions() { + return ((bitField0_ & 0x00000002) != 0); + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return The cliOptions. + */ + public zingg.spark.connect.proto.ClientOptions getCliOptions() { + if (cliOptionsBuilder_ == null) { + return cliOptions_ == null ? zingg.spark.connect.proto.ClientOptions.getDefaultInstance() : cliOptions_; + } else { + return cliOptionsBuilder_.getMessage(); + } + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public Builder setCliOptions(zingg.spark.connect.proto.ClientOptions value) { + if (cliOptionsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + cliOptions_ = value; + } else { + cliOptionsBuilder_.setMessage(value); + } + bitField0_ |= 0x00000002; + onChanged(); + return this; + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public Builder setCliOptions( + zingg.spark.connect.proto.ClientOptions.Builder builderForValue) { + if (cliOptionsBuilder_ == null) { + cliOptions_ = builderForValue.build(); + } else { + cliOptionsBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000002; + onChanged(); + return this; + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public Builder mergeCliOptions(zingg.spark.connect.proto.ClientOptions value) { + if (cliOptionsBuilder_ == null) { + if (((bitField0_ & 0x00000002) != 0) && + cliOptions_ != null && + cliOptions_ != zingg.spark.connect.proto.ClientOptions.getDefaultInstance()) { + getCliOptionsBuilder().mergeFrom(value); + } else { + cliOptions_ = value; + } + } else { + cliOptionsBuilder_.mergeFrom(value); + } + if (cliOptions_ != null) { + bitField0_ |= 0x00000002; + onChanged(); + } + return this; + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public Builder clearCliOptions() { + bitField0_ = (bitField0_ & ~0x00000002); + cliOptions_ = null; + if (cliOptionsBuilder_ != null) { + cliOptionsBuilder_.dispose(); + cliOptionsBuilder_ = null; + } + onChanged(); + return this; + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public zingg.spark.connect.proto.ClientOptions.Builder getCliOptionsBuilder() { + bitField0_ |= 0x00000002; + onChanged(); + return getCliOptionsFieldBuilder().getBuilder(); + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + public zingg.spark.connect.proto.ClientOptionsOrBuilder getCliOptionsOrBuilder() { + if (cliOptionsBuilder_ != null) { + return cliOptionsBuilder_.getMessageOrBuilder(); + } else { + return cliOptions_ == null ? + zingg.spark.connect.proto.ClientOptions.getDefaultInstance() : cliOptions_; + } + } + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + private com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.ClientOptions, zingg.spark.connect.proto.ClientOptions.Builder, zingg.spark.connect.proto.ClientOptionsOrBuilder> + getCliOptionsFieldBuilder() { + if (cliOptionsBuilder_ == null) { + cliOptionsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + zingg.spark.connect.proto.ClientOptions, zingg.spark.connect.proto.ClientOptions.Builder, zingg.spark.connect.proto.ClientOptionsOrBuilder>( + getCliOptions(), + getParentForChildren(), + isClean()); + cliOptions_ = null; + } + return cliOptionsBuilder_; + } + + private com.google.protobuf.ByteString inMemoryDate_ = com.google.protobuf.ByteString.EMPTY; + /** + *
+     * The next message is a serialized LogicalPlan
+     * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return Whether the inMemoryDate field is set. + */ + @java.lang.Override + public boolean hasInMemoryDate() { + return ((bitField0_ & 0x00000004) != 0); + } + /** + *
+     * The next message is a serialized LogicalPlan
+     * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return The inMemoryDate. + */ + @java.lang.Override + public com.google.protobuf.ByteString getInMemoryDate() { + return inMemoryDate_; + } + /** + *
+     * The next message is a serialized LogicalPlan
+     * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @param value The inMemoryDate to set. + * @return This builder for chaining. + */ + public Builder setInMemoryDate(com.google.protobuf.ByteString value) { + if (value == null) { throw new NullPointerException(); } + inMemoryDate_ = value; + bitField0_ |= 0x00000004; + onChanged(); + return this; + } + /** + *
+     * The next message is a serialized LogicalPlan
+     * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return This builder for chaining. + */ + public Builder clearInMemoryDate() { + bitField0_ = (bitField0_ & ~0x00000004); + inMemoryDate_ = getDefaultInstance().getInMemoryDate(); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:SubmitZinggJob) + } + + // @@protoc_insertion_point(class_scope:SubmitZinggJob) + private static final zingg.spark.connect.proto.SubmitZinggJob DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new zingg.spark.connect.proto.SubmitZinggJob(); + } + + public static zingg.spark.connect.proto.SubmitZinggJob getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public SubmitZinggJob parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public zingg.spark.connect.proto.SubmitZinggJob getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + +} + diff --git a/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJobOrBuilder.java b/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJobOrBuilder.java new file mode 100644 index 000000000..18777abf4 --- /dev/null +++ b/spark/client/src/main/java/zingg/spark/connect/proto/SubmitZinggJobOrBuilder.java @@ -0,0 +1,59 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: connect_plugins.proto + +// Protobuf Java Version: 3.25.3 +package zingg.spark.connect.proto; + +public interface SubmitZinggJobOrBuilder extends + // @@protoc_insertion_point(interface_extends:SubmitZinggJob) + com.google.protobuf.MessageOrBuilder { + + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return Whether the argumnets field is set. + */ + boolean hasArgumnets(); + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + * @return The argumnets. + */ + zingg.spark.connect.proto.Arguments getArgumnets(); + /** + * .Arguments argumnets = 1 [json_name = "argumnets"]; + */ + zingg.spark.connect.proto.ArgumentsOrBuilder getArgumnetsOrBuilder(); + + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return Whether the cliOptions field is set. + */ + boolean hasCliOptions(); + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + * @return The cliOptions. + */ + zingg.spark.connect.proto.ClientOptions getCliOptions(); + /** + * .ClientOptions cli_options = 2 [json_name = "cliOptions"]; + */ + zingg.spark.connect.proto.ClientOptionsOrBuilder getCliOptionsOrBuilder(); + + /** + *
+   * The next message is a serialized LogicalPlan
+   * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return Whether the inMemoryDate field is set. + */ + boolean hasInMemoryDate(); + /** + *
+   * The next message is a serialized LogicalPlan
+   * 
+ * + * optional bytes in_memory_date = 3 [json_name = "inMemoryDate"]; + * @return The inMemoryDate. + */ + com.google.protobuf.ByteString getInMemoryDate(); +} diff --git a/spark/core/src/main/scala/reifier/scala/DFUtil.scala b/spark/client/src/main/scala/reifier/scala/DFUtil.scala similarity index 100% rename from spark/core/src/main/scala/reifier/scala/DFUtil.scala rename to spark/client/src/main/scala/reifier/scala/DFUtil.scala diff --git a/spark/core/src/main/scala/reifier/scala/MyPolyExpansion.scala b/spark/client/src/main/scala/reifier/scala/MyPolyExpansion.scala similarity index 100% rename from spark/core/src/main/scala/reifier/scala/MyPolyExpansion.scala rename to spark/client/src/main/scala/reifier/scala/MyPolyExpansion.scala diff --git a/spark/core/src/main/scala/reifier/scala/TypeTags.scala b/spark/client/src/main/scala/reifier/scala/TypeTags.scala similarity index 100% rename from spark/core/src/main/scala/reifier/scala/TypeTags.scala rename to spark/client/src/main/scala/reifier/scala/TypeTags.scala diff --git a/spark/client/src/test/java/zingg/client/TestSparkFrame.java b/spark/client/src/test/java/zingg/client/TestSparkFrame.java deleted file mode 100644 index 3cfb3adce..000000000 --- a/spark/client/src/test/java/zingg/client/TestSparkFrame.java +++ /dev/null @@ -1,324 +0,0 @@ -package zingg.client; - -import static org.apache.spark.sql.functions.col; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.Arrays; -import java.util.Date; -import java.util.List; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.functions; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.junit.jupiter.api.Test; - -import scala.collection.JavaConverters; -import zingg.common.client.ZFrame; -import zingg.common.client.ZinggClientException; -import zingg.common.client.util.ColName; -import zingg.spark.client.SparkFrame; -import static org.junit.jupiter.api.Assertions.assertEquals; -public class TestSparkFrame extends TestSparkFrameBase { - public static final Log LOG = LogFactory.getLog(TestSparkFrame.class); - - public static final String NEW_COLUMN = "newColumn"; - - @Test - public void testCreateSparkDataFrameAndGetDF() { - SparkFrame sf = new SparkFrame(createSampleDataset()); - Dataset df = sf.df(); - assertTrue(df.except(createSampleDataset()).isEmpty(), "Two datasets are not equal"); - } - - @Test - public void testColumnsNamesandCount() { - SparkFrame sf = new SparkFrame(createSampleDataset()); - assertTrue(Arrays.equals(sf.columns(), createSampleDataset().columns()), - "Columns of SparkFrame and the dataset are not equal"); - } - - @Test - public void testAliasOfSparkFrame() { - SparkFrame sf = new SparkFrame(createSampleDataset()); - String aliasName = "AnotherName"; - sf.as(aliasName); - assertTrueCheckingExceptOutput(sf.as(aliasName), sf, "Dataframe and its alias are not same"); - } - - @Test - public void testSelectWithSingleColumnName() { - Dataset df = createSampleDataset(); - ZFrame, Row, Column> sf = new SparkFrame(df); - String colName = "recid"; - ZFrame, Row, Column> sf2 = sf.select(colName); - SparkFrame sf3 = new SparkFrame(df.select(colName)); - assertTrueCheckingExceptOutput(sf2, sf3, "SparkFrame.select(colName) does not have expected value"); - } - - @Test - public void testSelectWithColumnList() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - List columnList = Arrays.asList(col("recid"), col("surname"), col("postcode")); - ZFrame, Row, Column> sf2 = sf.select(columnList); - SparkFrame sf3 = new SparkFrame( - df.select(JavaConverters.asScalaIteratorConverter(columnList.iterator()).asScala().toSeq())); - assertTrueCheckingExceptOutput(sf2, sf3, "SparkFrame.select(columnList) does not have expected value"); - } - - @Test - public void testSelectWithColumnArray() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - Column[] columnArray = new Column[] {col("recid"), col("surname"), col("postcode")}; - ZFrame, Row, Column> sf2 = sf.select(columnArray); - SparkFrame sf3 = new SparkFrame(df.select(columnArray)); - assertTrueCheckingExceptOutput(sf2, sf3, "SparkFrame.select(columnArray) value does not match with standard select output"); - } - - @Test - public void testSelectWithMultipleColumnNamesAsString() { - Dataset df = createSampleDataset(); - ZFrame, Row, Column> sf = new SparkFrame(df); - ZFrame, Row, Column> sf2 = sf.select("recid", "surname", "postcode"); - SparkFrame sf3 = new SparkFrame(df.select("recid", "surname", "postcode")); - assertTrueCheckingExceptOutput(sf2, sf3, "SparkFrame.select(str1, str2, ...) value does not match with standard select output"); - } - - @Test - public void testSelectExprByPassingColumnStringsAsInSQLStatement() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - ZFrame, Row, Column> sf2 = sf.selectExpr("recid as RecordId", "surname as FamilyName", "postcode as Pin"); - SparkFrame sf3 = new SparkFrame(df.selectExpr("recid", "surname", "postcode")); - assertTrueCheckingExceptOutput(sf2, sf3, "SparkFrame.selectExpr(str1, str2, ...) value does not match with standard selectExpr output"); - } - - @Test - public void testDistinct() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - SparkFrame sf2 = new SparkFrame(df.distinct()); - assertTrueCheckingExceptOutput(sf.distinct(), sf2, "SparkFrame.distict() does not match with standard distict() output"); - } - - @Test - public void testDropSingleColumn() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - ZFrame, Row, Column> sf2 = new SparkFrame(df.drop("recid")); - assertTrueCheckingExceptOutput(sf2, sf.drop("recid"), "SparkFrame.drop(str) does not match with standard drop() output"); - } - - @Test - public void testDropColumnsAsStringArray() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - ZFrame, Row, Column> sf2 = new SparkFrame(df.drop("recid", "surname", "postcode")); - assertTrueCheckingExceptOutput(sf2, sf.drop("recid", "surname", "postcode"), "SparkFrame.drop(str...) does not match with standard drop(str...) output"); - } - - @Test - public void testLimit() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - int len = 5; - ZFrame, Row, Column> sf2 = sf.limit(len); - assertTrue(sf2.count() == len); - assertTrueCheckingExceptOutput(sf2, sf.limit(len), "SparkFrame.limit(len) does not match with standard limit(len) output"); - } - - @Test - public void testDropDuplicatesConsideringGivenColumnsAsStringArray() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - String[] columnArray = new String[] {"surname", "postcode"}; - ZFrame, Row, Column> sf2 = new SparkFrame(df.dropDuplicates(columnArray)); - assertTrueCheckingExceptOutput(sf2, sf.dropDuplicates(columnArray), "SparkFrame.dropDuplicates(str[]) does not match with standard dropDuplicates(str[]) output"); - } - - @Test - public void testDropDuplicatesConsideringGivenIndividualColumnsAsString() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - ZFrame, Row, Column> sf2 = new SparkFrame(df.dropDuplicates("surname", "postcode")); - assertTrueCheckingExceptOutput(sf2, sf.dropDuplicates("surname"), "SparkFrame.dropDuplicates(col1, col2) does not match with standard dropDuplicates(col1, col2) output"); - } - - @Test - public void testHead() { - Dataset df = createSampleDataset(); - SparkFrame sf = new SparkFrame(df); - Row row = sf.head(); - assertTrue(row.equals(df.head()), "Top Row is not the expected one"); - } - - @Test - public void testIsEmpty() { - if (spark==null) { - setUpSpark(); - } - Dataset df = spark.emptyDataFrame(); - SparkFrame sf = new SparkFrame(df); - assertTrue(sf.isEmpty(), "DataFrame is not empty"); - } - - @Test - public void testGetAsInt() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - Row row = sf.head(); - LOG.debug("Value: " + row.getAs("recid")); - assertTrue(sf.getAsInt(row, "recid") == (int) row.getAs("recid"), "row.getAsInt(col) hasn't returned correct int value"); - } - @Test - public void testGetAsString() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - Row row = sf.head(); - LOG.debug("Value: " + row.getAs("surname")); - assertTrue(sf.getAsString(row, "surname").equals(row.getAs("surname")), "row.getAsString(col) hasn't returned correct string value"); - } - @Test - public void testGetAsDouble() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - Row row = sf.head(); - LOG.debug("Value: " + row.getAs("cost")); - assertTrue(sf.getAsDouble(row, "cost") == (double) row.getAs("cost"), "row.getAsDouble(col) hasn't returned correct double value"); - } - @Test - public void testSortDescending() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String col = STR_RECID; - ZFrame,Row,Column> sf2 = sf.sortDescending(col); - assertTrueCheckingExceptOutput(sf2, df.sort(functions.desc(col)), "SparkFrame.sortDescending() output is not as expected"); - } - - @Test - public void testSortAscending() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String col = STR_RECID; - ZFrame,Row,Column> sf2 = sf.sortAscending(col); - assertTrueCheckingExceptOutput(sf2, df.sort(functions.asc(col)), "SparkFrame.sortAscending() output is not as expected"); - } - - @Test - public void testWithColumnforIntegerValue() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String newCol = NEW_COLUMN; - int newColVal = 36; - ZFrame,Row,Column> sf2 = sf.withColumn(newCol, newColVal); - assertTrueCheckingExceptOutput(sf2, df.withColumn(newCol, functions.lit(newColVal)), "SparkFrame.withColumn(c, int) output is not as expected"); - } - - @Test - public void testWithColumnforDoubleValue() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String newCol = NEW_COLUMN; - double newColVal = 3.14; - ZFrame,Row,Column> sf2 = sf.withColumn(newCol, newColVal); - assertTrueCheckingExceptOutput(sf2, df.withColumn(newCol, functions.lit(newColVal)), "SparkFrame.withColumn(c, double) output is not as expected"); - } - - @Test - public void testWithColumnforStringValue() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String newCol = NEW_COLUMN; - String newColVal = "zingg"; - ZFrame,Row,Column> sf2 = sf.withColumn(newCol, newColVal); - assertTrueCheckingExceptOutput(sf2, df.withColumn(newCol, functions.lit(newColVal)), "SparkFrame.withColumn(c, String) output is not as expected"); - } - - @Test - public void testWithColumnforAnotherColumn() { - Dataset df = createSampleDatasetHavingMixedDataTypes(); - SparkFrame sf = new SparkFrame(df); - String oldCol = STR_RECID; - String newCol = NEW_COLUMN; - ZFrame,Row,Column> sf2 = sf.withColumn(newCol, col(oldCol)); - assertTrueCheckingExceptOutput(sf2, df.withColumn(newCol, col(oldCol)), "SparkFrame.withColumn(c, Column) output is not as expected"); - } - - @Test - public void testGetMaxVal(){ - SparkFrame zScoreDF = getZScoreDF(); - assertEquals(400,zScoreDF.getMaxVal(ColName.CLUSTER_COLUMN)); - } - - @Test - public void testGroupByMinMax(){ - SparkFrame zScoreDF = getZScoreDF(); - ZFrame, Row, Column> groupByDF = zScoreDF.groupByMinMaxScore(zScoreDF.col(ColName.ID_COL)); - - Dataset assertionDF = groupByDF.df(); - List assertionRows = assertionDF.collectAsList(); - for (Row row : assertionRows) { - if(row.getInt(0)==1) { - assertEquals(1001,row.getInt(1)); - assertEquals(2002,row.getInt(2)); - } - } - } - - @Test - public void testGroupByMinMax2(){ - SparkFrame zScoreDF = getZScoreDF(); - ZFrame, Row, Column> groupByDF = zScoreDF.groupByMinMaxScore(zScoreDF.col(ColName.CLUSTER_COLUMN)); - - Dataset assertionDF = groupByDF.df(); - List assertionRows = assertionDF.collectAsList(); - for (Row row : assertionRows) { - if(row.getInt(0)==100) { - assertEquals(900,row.getInt(1)); - assertEquals(9002,row.getInt(2)); - } - } - } - - @Test - public void testRightJoinMultiCol(){ - ZFrame, Row, Column> inpData = getInputData(); - ZFrame, Row, Column> clusterData = getClusterData(); - ZFrame, Row, Column> joinedData = clusterData.join(inpData,ColName.ID_COL,ColName.SOURCE_COL,ZFrame.RIGHT_JOIN); - assertEquals(10,joinedData.count()); - } - - @Test - public void testFilterInCond(){ - SparkFrame inpData = getInputData(); - SparkFrame clusterData = getClusterDataWithNull(); - ZFrame, Row, Column> filteredData = inpData.filterInCond(ColName.ID_COL, clusterData, ColName.COL_PREFIX+ ColName.ID_COL); - assertEquals(5,filteredData.count()); - } - - @Test - public void testFilterNotNullCond(){ - SparkFrame clusterData = getClusterDataWithNull(); - ZFrame, Row, Column> filteredData = clusterData.filterNotNullCond(ColName.SOURCE_COL); - assertEquals(3,filteredData.count()); - } - - @Test - public void testFilterNullCond(){ - SparkFrame clusterData = getClusterDataWithNull(); - ZFrame, Row, Column> filteredData = clusterData.filterNullCond(ColName.SOURCE_COL); - assertEquals(2,filteredData.count()); - } - - - -} \ No newline at end of file diff --git a/spark/client/src/test/java/zingg/client/TestSparkFrameBase.java b/spark/client/src/test/java/zingg/client/TestSparkFrameBase.java deleted file mode 100644 index dcc75bd95..000000000 --- a/spark/client/src/test/java/zingg/client/TestSparkFrameBase.java +++ /dev/null @@ -1,215 +0,0 @@ -package zingg.client; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.Arrays; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; - -import zingg.common.client.Arguments; -import zingg.common.client.IArguments; -import zingg.common.client.ZFrame; -import zingg.common.client.util.ColName; -import zingg.spark.client.SparkFrame; - -public class TestSparkFrameBase { - - public static IArguments args; - public static JavaSparkContext ctx; - public static SparkSession spark; - - public static final Log LOG = LogFactory.getLog(TestSparkFrameBase.class); - - public static final String STR_RECID = "recid"; - public static final String STR_GIVENNAME = "givenname"; - public static final String STR_SURNAME = "surname"; - public static final String STR_COST = "cost"; - public static final String STR_POSTCODE = "postcode"; - public static final String STR_SUBURB = "suburb"; - - @BeforeAll - public static void setup() { - setUpSpark(); - } - - protected static void setUpSpark() { - try { - spark = SparkSession - .builder() - .master("local[*]") - .appName("Zingg" + "Junit") - .getOrCreate(); - ctx = new JavaSparkContext(spark.sparkContext()); - JavaSparkContext.jarOfClass(TestSparkFrameBase.class); - args = new Arguments(); - } catch (Throwable e) { - if (LOG.isDebugEnabled()) - e.printStackTrace(); - LOG.info("Problem in spark env setup"); - } - } - - @AfterAll - public static void teardown() { - if (ctx != null) { - ctx.stop(); - ctx = null; - } - if (spark != null) { - spark.stop(); - spark = null; - } - } - - public Dataset createSampleDataset() { - - if (spark==null) { - setUpSpark(); - } - - StructType schemaOfSample = new StructType(new StructField[] { - new StructField("recid", DataTypes.StringType, false, Metadata.empty()), - new StructField("givenname", DataTypes.StringType, false, Metadata.empty()), - new StructField("surname", DataTypes.StringType, false, Metadata.empty()), - new StructField("suburb", DataTypes.StringType, false, Metadata.empty()), - new StructField("postcode", DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset sample = spark.createDataFrame(Arrays.asList( - RowFactory.create("07317257", "erjc", "henson", "hendersonville", "2873g"), - RowFactory.create("03102490", "jhon", "kozak", "henders0nville", "28792"), - RowFactory.create("02890805", "david", "pisczek", "durham", "27717"), - RowFactory.create("04437063", "e5in", "bbrown", "greenville", "27858"), - RowFactory.create("03211564", "susan", "jones", "greenjboro", "274o7"), - RowFactory.create("04155808", "jerome", "wilkins", "battleborn", "2780g"), - RowFactory.create("05723231", "clarinw", "pastoreus", "elizabeth city", "27909"), - RowFactory.create("06087743", "william", "craven", "greenshoro", "27405"), - RowFactory.create("00538491", "marh", "jackdon", "greensboro", "27406"), - RowFactory.create("01306702", "vonnell", "palmer", "siler sity", "273q4")), schemaOfSample); - - return sample; - } - - public Dataset createSampleDatasetHavingMixedDataTypes() { - if (spark==null) { - setUpSpark(); - } - - StructType schemaOfSample = new StructType(new StructField[] { - new StructField(STR_RECID, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(STR_GIVENNAME, DataTypes.StringType, false, Metadata.empty()), - new StructField(STR_SURNAME, DataTypes.StringType, false, Metadata.empty()), - new StructField(STR_COST, DataTypes.DoubleType, false, Metadata.empty()), - new StructField(STR_POSTCODE, DataTypes.IntegerType, false, Metadata.empty()) - }); - - Dataset sample = spark.createDataFrame(Arrays.asList( - RowFactory.create(7317, "erjc", "henson", 0.54, 2873), - RowFactory.create(3102, "jhon", "kozak", 99.009, 28792), - RowFactory.create(2890, "david", "pisczek", 58.456, 27717), - RowFactory.create(4437, "e5in", "bbrown", 128.45, 27858) - ), schemaOfSample); - - return sample; - } - - protected SparkFrame getZScoreDF() { - Row[] rows = { - RowFactory.create( 0,100,900), - RowFactory.create( 1,100,1001), - RowFactory.create( 1,100,1002), - RowFactory.create( 1,100,2001), - RowFactory.create( 1,100,2002), - RowFactory.create( 11,100,9002), - RowFactory.create( 3,300,3001), - RowFactory.create( 3,300,3002), - RowFactory.create( 3,400,4001), - RowFactory.create( 4,400,4002) - }; - StructType schema = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.CLUSTER_COLUMN, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.SCORE_COL, DataTypes.IntegerType, false, Metadata.empty())}); - SparkFrame df = new SparkFrame(spark.createDataFrame(Arrays.asList(rows), schema)); - return df; - } - - protected SparkFrame getInputData() { - Row[] rows = { - RowFactory.create( 1,"fname1","b"), - RowFactory.create( 2,"fname","a"), - RowFactory.create( 3,"fna","b"), - RowFactory.create( 4,"x","c"), - RowFactory.create( 5,"y","c"), - RowFactory.create( 11,"new1","b"), - RowFactory.create( 22,"new12","a"), - RowFactory.create( 33,"new13","b"), - RowFactory.create( 44,"new14","c"), - RowFactory.create( 55,"new15","c") - }; - StructType schema = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField("fname", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty())}); - SparkFrame df = new SparkFrame(spark.createDataFrame(Arrays.asList(rows), schema)); - return df; - } - - - protected SparkFrame getClusterData() { - Row[] rows = { - RowFactory.create( 1,100,1001,"b"), - RowFactory.create( 2,100,1002,"a"), - RowFactory.create( 3,100,2001,"b"), - RowFactory.create( 4,900,2002,"c"), - RowFactory.create( 5,111,9002,"c") - }; - StructType schema = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.CLUSTER_COLUMN, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.SCORE_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty())}); - SparkFrame df = new SparkFrame(spark.createDataFrame(Arrays.asList(rows), schema)); - return df; - } - - protected SparkFrame getClusterDataWithNull() { - Row[] rows = { - RowFactory.create( 1,100,1001,"b"), - RowFactory.create( 2,100,1002,"a"), - RowFactory.create( 3,100,2001,null), - RowFactory.create( 4,900,2002,"c"), - RowFactory.create( 5,111,9002,null) - }; - StructType schema = new StructType(new StructField[] { - new StructField(ColName.COL_PREFIX+ ColName.ID_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.CLUSTER_COLUMN, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.SCORE_COL, DataTypes.IntegerType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, true, Metadata.empty())}); - SparkFrame df = new SparkFrame(spark.createDataFrame(Arrays.asList(rows), schema)); - return df; - } - - protected void assertTrueCheckingExceptOutput(ZFrame, Row, Column> sf1, ZFrame, Row, Column> sf2, String message) { - assertTrue(sf1.except(sf2).isEmpty(), message); - } - - - protected void assertTrueCheckingExceptOutput(ZFrame, Row, Column> sf1, Dataset df2, String message) { - SparkFrame sf2 = new SparkFrame(df2); - assertTrue(sf1.except(sf2).isEmpty(), message); - } -} \ No newline at end of file diff --git a/spark/core/pom.xml b/spark/core/pom.xml index 3034dd70b..82aa0f55d 100644 --- a/spark/core/pom.xml +++ b/spark/core/pom.xml @@ -24,36 +24,49 @@ zingg-common-client ${zingg.version} + + zingg + zingg-common-core + tests + test-jar + ${zingg.version} + test + + + org.junit.jupiter + junit-jupiter-engine + 5.8.1 + test + + + org.junit.jupiter + junit-jupiter-api + 5.8.1 + test + + + org.junit.jupiter + junit-jupiter-params + 5.8.1 + test + - - net.alchim31.maven - scala-maven-plugin - 3.2.2 + + org.apache.maven.plugins + maven-jar-plugin + 2.3.2 - scala-compile-first - process-resources - add-source - compile - - - - scala-test-compile - process-test-resources - - testCompile + test-jar - - ${scala.version} - - + diff --git a/spark/core/src/main/java/zingg/spark/core/block/SparkBlock.java b/spark/core/src/main/java/zingg/spark/core/block/SparkBlock.java index 6b8568c42..3cff3a304 100644 --- a/spark/core/src/main/java/zingg/spark/core/block/SparkBlock.java +++ b/spark/core/src/main/java/zingg/spark/core/block/SparkBlock.java @@ -8,7 +8,9 @@ import zingg.common.client.ZFrame; import zingg.common.client.util.ListMap; import zingg.common.core.block.Block; +import zingg.common.core.feature.FeatureFactory; import zingg.common.core.hash.HashFunction; +import zingg.spark.core.feature.SparkFeatureFactory; public class SparkBlock extends Block, Row, Column, DataType> { @@ -22,11 +24,10 @@ public SparkBlock(ZFrame, Row, Column> training, ZFrame, Row, Column, DataType>> functionsMap, long maxSize) { super(training, dupes, functionsMap, maxSize); } - - + @Override - public DataType getDataTypeFromString(String t) { - return DataType.fromDDL(t); - } + public FeatureFactory getFeatureFactory() { + return new SparkFeatureFactory(); + } } diff --git a/spark/core/src/main/java/zingg/spark/core/block/SparkBlockFunction.java b/spark/core/src/main/java/zingg/spark/core/block/SparkBlockFunction.java index 0e0dce39e..c80a0348c 100644 --- a/spark/core/src/main/java/zingg/spark/core/block/SparkBlockFunction.java +++ b/spark/core/src/main/java/zingg/spark/core/block/SparkBlockFunction.java @@ -7,7 +7,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import scala.collection.JavaConversions; +import scala.jdk.CollectionConverters; import scala.collection.Seq; import zingg.common.core.block.BlockFunction; import zingg.common.core.block.Canopy; @@ -23,7 +23,7 @@ public SparkBlockFunction(Tree> tree) { @Override public List getListFromRow(Row r) { Seq sObj = r.toSeq(); - List seqList = JavaConversions.seqAsJavaList(sObj); + List seqList = CollectionConverters.SeqHasAsJava(sObj).asJava(); //the abstract list returned here does not support adding a new element, //so an ugly way is to create a new list altogether (!!) //see in perf - maybe just iterate over all the row elements and add the last one? diff --git a/spark/core/src/main/java/zingg/spark/core/context/ZinggSparkContext.java b/spark/core/src/main/java/zingg/spark/core/context/ZinggSparkContext.java new file mode 100644 index 000000000..c9fcaac34 --- /dev/null +++ b/spark/core/src/main/java/zingg/spark/core/context/ZinggSparkContext.java @@ -0,0 +1,96 @@ +package zingg.spark.core.context; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; + +import zingg.common.client.IZingg; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.DSUtil; +import zingg.common.client.util.PipeUtilBase; +// +import zingg.common.core.context.Context; +import zingg.common.core.util.BlockingTreeUtil; +import zingg.common.core.util.GraphUtil; +import zingg.common.core.util.HashUtil; +import zingg.common.core.util.ModelUtil; +import zingg.spark.client.util.SparkDSUtil; +import zingg.spark.client.util.SparkPipeUtil; +import zingg.spark.core.util.SparkBlockingTreeUtil; +import zingg.spark.core.util.SparkGraphUtil; +import zingg.spark.core.util.SparkHashUtil; +import zingg.spark.core.util.SparkModelUtil; + + +public class ZinggSparkContext extends Context, Row,Column,DataType>{ + + + private static final long serialVersionUID = 1L; + protected JavaSparkContext ctx; + public static final Log LOG = LogFactory.getLog(ZinggSparkContext.class); + + + + @Override + public void init(SparkSession session) + throws ZinggClientException { + try{ +// if (session==null) { +// session = SparkSession +// .builder() +// .appName("Zingg") +// .getOrCreate(); +// +// //session = new SparkSession(spark, license); +// } + this.session = session; + if (ctx==null) { + ctx = JavaSparkContext.fromSparkContext(session.sparkContext()); + JavaSparkContext.jarOfClass(IZingg.class); + LOG.debug("Context " + ctx.toString()); + //initHashFns(); + ctx.setCheckpointDir("/tmp/checkpoint"); + setUtils(); + } + } + catch(Throwable e) { + if (LOG.isDebugEnabled()) e.printStackTrace(); + throw new ZinggClientException(e.getMessage()); + } + } + + @Override + public void cleanup() { + try { + if (ctx != null) { + ctx.stop(); + } + if (session!=null) { + session.stop(); + } + ctx = null; + session = null; + } catch (Exception e) { + // ignore any exception in cleanup + e.printStackTrace(); + } + } + + @Override + public void setUtils() { + LOG.debug("Session passed to utils is " + session); + setPipeUtil(new SparkPipeUtil(session)); + setDSUtil(new SparkDSUtil(session)); + setHashUtil(new SparkHashUtil(session)); + setGraphUtil(new SparkGraphUtil()); + setModelUtil(new SparkModelUtil(session)); + setBlockingTreeUtil(new SparkBlockingTreeUtil(session, getPipeUtil())); + } + + + } \ No newline at end of file diff --git a/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataColDocumenter.java b/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataColDocumenter.java index 146c51e08..758f5e35a 100644 --- a/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataColDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataColDocumenter.java @@ -7,20 +7,20 @@ import freemarker.template.Version; import zingg.common.client.IArguments; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.documenter.DataColDocumenter; import zingg.common.core.documenter.RowWrapper; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; /** * Spark specific implementation of DataColDocumenter * */ -public class SparkDataColDocumenter extends DataColDocumenter, Row, Column,DataType> { +public class SparkDataColDocumenter extends DataColDocumenter, Row, Column,DataType> { private static final long serialVersionUID = 1L; - public SparkDataColDocumenter(Context, Row, Column,DataType> context, IArguments args) { + public SparkDataColDocumenter(Context, Row, Column,DataType> context, IArguments args) { super(context, args); } diff --git a/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataDocumenter.java b/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataDocumenter.java index c591b99e4..add2e882e 100644 --- a/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/documenter/SparkDataDocumenter.java @@ -6,21 +6,22 @@ import org.apache.spark.sql.types.DataType; import freemarker.template.Version; +import zingg.common.client.Arguments; +import zingg.common.core.context.Context; import zingg.common.client.IArguments; -import zingg.common.core.Context; import zingg.common.core.documenter.DataDocumenter; import zingg.common.core.documenter.RowWrapper; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; /** * Spark specific implementation of DataDocumenter * */ -public class SparkDataDocumenter extends DataDocumenter, Row, Column,DataType> { +public class SparkDataDocumenter extends DataDocumenter, Row, Column,DataType> { private static final long serialVersionUID = 1L; - public SparkDataDocumenter(Context, Row, Column,DataType> context, IArguments args) { + public SparkDataDocumenter(Context, Row, Column,DataType> context, IArguments args) { super(context, args); } diff --git a/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelColDocumenter.java b/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelColDocumenter.java index 53b4b1829..2990b9b12 100644 --- a/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelColDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelColDocumenter.java @@ -6,22 +6,23 @@ import org.apache.spark.sql.types.DataType; import freemarker.template.Version; +import zingg.common.client.Arguments; +import zingg.common.core.context.Context; import zingg.common.client.IArguments; -import zingg.common.core.Context; import zingg.common.core.documenter.ModelColDocumenter; import zingg.common.core.documenter.RowWrapper; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; /** * Spark specific implementation of ModelColDocumenter * * */ -public class SparkModelColDocumenter extends ModelColDocumenter, Row, Column,DataType> { +public class SparkModelColDocumenter extends ModelColDocumenter, Row, Column,DataType> { private static final long serialVersionUID = 1L; - public SparkModelColDocumenter(Context, Row, Column,DataType> context, IArguments args) { + public SparkModelColDocumenter(Context, Row, Column,DataType> context, IArguments args) { super(context, args); } diff --git a/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelDocumenter.java b/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelDocumenter.java index 70a4f07aa..0124e117a 100644 --- a/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/documenter/SparkModelDocumenter.java @@ -7,20 +7,20 @@ import freemarker.template.Version; import zingg.common.client.IArguments; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.documenter.ModelDocumenter; import zingg.common.core.documenter.RowWrapper; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; /** * Spark specific implementation of ModelDocumenter * */ -public class SparkModelDocumenter extends ModelDocumenter, Row, Column,DataType> { +public class SparkModelDocumenter extends ModelDocumenter, Row, Column,DataType> { private static final long serialVersionUID = 1L; - public SparkModelDocumenter(Context, Row, Column,DataType> context, IArguments args) { + public SparkModelDocumenter(Context, Row, Column,DataType> context, IArguments args) { super(context, args); super.modelColDoc = new SparkModelColDocumenter(context,args); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java index 2a45904f6..98e452c90 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkDocumenter.java @@ -9,41 +9,46 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.core.documenter.DataDocumenter; import zingg.common.core.documenter.ModelDocumenter; import zingg.common.core.executor.Documenter; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.documenter.SparkDataDocumenter; import zingg.spark.core.documenter.SparkModelDocumenter; +import zingg.spark.core.context.ZinggSparkContext; -public class SparkDocumenter extends Documenter, Row, Column,DataType> { +public class SparkDocumenter extends Documenter, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkDocumenter"; public static final Log LOG = LogFactory.getLog(SparkDocumenter.class); public SparkDocumenter() { - setZinggOptions(ZinggOptions.GENERATE_DOCS); - setContext(new ZinggSparkContext()); + this(new ZinggSparkContext()); } + public SparkDocumenter(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.GENERATE_DOCS); + setContext(sparkContext); + } + @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - protected ModelDocumenter, Row, Column, DataType> getModelDocumenter() { + public ModelDocumenter, Row, Column, DataType> getModelDocumenter() { return new SparkModelDocumenter(getContext(),getArgs()); } @Override - protected DataDocumenter, Row, Column, DataType> getDataDocumenter() { + public DataDocumenter, Row, Column, DataType> getDataDocumenter() { return new SparkDataDocumenter(getContext(),getArgs()); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java index 3b395f85d..0c0aeb550 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkFindAndLabeller.java @@ -7,32 +7,37 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.SparkSession; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.core.executor.FindAndLabeller; -import zingg.spark.client.ZSparkSession; +import zingg.spark.core.context.ZinggSparkContext; + -public class SparkFindAndLabeller extends FindAndLabeller, Row, Column,DataType> { +public class SparkFindAndLabeller extends FindAndLabeller, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkFindAndLabeller"; public static final Log LOG = LogFactory.getLog(SparkFindAndLabeller.class); public SparkFindAndLabeller() { - setZinggOptions(ZinggOptions.FIND_AND_LABEL); - ZinggSparkContext sparkContext = new ZinggSparkContext(); + this(new ZinggSparkContext()); + } + + public SparkFindAndLabeller(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.FIND_AND_LABEL); setContext(sparkContext); finder = new SparkTrainingDataFinder(sparkContext); labeller = new SparkLabeller(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java index bba1779c6..33dcbd706 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabelUpdater.java @@ -10,11 +10,12 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.client.pipe.Pipe; import zingg.common.core.executor.LabelUpdater; -import zingg.spark.client.ZSparkSession; +import zingg.spark.core.context.ZinggSparkContext; +import org.apache.spark.sql.SparkSession; /** @@ -22,25 +23,28 @@ * * */ -public class SparkLabelUpdater extends LabelUpdater, Row, Column,DataType> { +public class SparkLabelUpdater extends LabelUpdater, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkLabelUpdater"; public static final Log LOG = LogFactory.getLog(SparkLabelUpdater.class); public SparkLabelUpdater() { - setZinggOptions(ZinggOptions.UPDATE_LABEL); - setContext(new ZinggSparkContext()); + this(new ZinggSparkContext()); } + public SparkLabelUpdater(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.UPDATE_LABEL); + setContext(sparkContext); + } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } - protected Pipe setSaveModeOnPipe(Pipe p) { + public Pipe setSaveModeOnPipe(Pipe p) { p.setMode(SaveMode.Overwrite.toString()); return p; } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java index 90c6d1585..e8aa8f6ec 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLabeller.java @@ -6,20 +6,22 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.SparkSession; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + +import zingg.spark.core.context.ZinggSparkContext; import zingg.common.core.executor.Labeller; -import zingg.spark.client.ZSparkSession; + /** * Spark specific implementation of Labeller * * */ -public class SparkLabeller extends Labeller, Row, Column,DataType> { +public class SparkLabeller extends Labeller, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkLabeller"; @@ -30,14 +32,14 @@ public SparkLabeller() { } public SparkLabeller(ZinggSparkContext sparkContext) { - setZinggOptions(ZinggOptions.LABEL); + setZinggOption(ZinggOptions.LABEL); setContext(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java index 3033f0813..85f442314 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkLinker.java @@ -5,45 +5,49 @@ import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataType; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; import zingg.common.core.executor.Linker; import zingg.common.core.model.Model; import zingg.common.core.preprocess.StopWordsRemover; -import zingg.spark.client.ZSparkSession; +import zingg.spark.core.context.ZinggSparkContext; import zingg.spark.core.preprocess.SparkStopWordsRemover; -public class SparkLinker extends Linker, Row, Column,DataType> { +public class SparkLinker extends Linker, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkLinker"; public static final Log LOG = LogFactory.getLog(SparkLinker.class); public SparkLinker() { - setZinggOptions(ZinggOptions.LINK); - setContext(new ZinggSparkContext()); + this(new ZinggSparkContext()); } + public SparkLinker(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.LINK); + setContext(sparkContext); + } + @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - protected Model getModel() throws ZinggClientException { + public Model getModel() throws ZinggClientException { Model model = getModelUtil().loadModel(false, args); - model.register(getContext().getSession()); + model.register(); return model; } @Override - protected StopWordsRemover, Row, Column, DataType> getStopWords() { + public StopWordsRemover, Row, Column, DataType> getStopWords() { return new SparkStopWordsRemover(getContext(),getArgs()); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java index 5bbbe0401..6cb0bc1cd 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java @@ -10,12 +10,12 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; +import zingg.spark.core.context.ZinggSparkContext; import zingg.common.core.executor.Matcher; import zingg.common.core.model.Model; import zingg.common.core.preprocess.StopWordsRemover; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.preprocess.SparkStopWordsRemover; /** @@ -23,7 +23,7 @@ * * */ -public class SparkMatcher extends Matcher,Row,Column,DataType>{ +public class SparkMatcher extends Matcher,Row,Column,DataType>{ private static final long serialVersionUID = 1L; @@ -35,26 +35,26 @@ public SparkMatcher() { } public SparkMatcher(ZinggSparkContext sparkContext) { - setZinggOptions(ZinggOptions.MATCH); + setZinggOption(ZinggOptions.MATCH); setContext(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - protected Model getModel() throws ZinggClientException { + public Model getModel() throws ZinggClientException { Model model = getModelUtil().loadModel(false, args); - model.register(getContext().getSession()); + model.register(); return model; } @Override - protected StopWordsRemover, Row, Column, DataType> getStopWords() { + public StopWordsRemover, Row, Column, DataType> getStopWords() { return new SparkStopWordsRemover(getContext(),getArgs()); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java index 9cf793ad9..115390b85 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkPeekModel.java @@ -10,35 +10,37 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.SparkSession; import zingg.common.client.ClientOptions; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.core.executor.ZinggBase; -import zingg.spark.client.ZSparkSession; +import zingg.spark.core.context.ZinggSparkContext; + -public class SparkPeekModel extends ZinggBase, Row, Column, DataType>{ +public class SparkPeekModel extends ZinggBase, Row, Column, DataType>{ private static final long serialVersionUID = 1L; protected static String name = "zingg.spark.core.executor.SparkPeekModel"; public static final Log LOG = LogFactory.getLog(SparkPeekModel.class); public SparkPeekModel() { - setZinggOptions(ZinggOptions.PEEK_MODEL); + setZinggOption(ZinggOptions.PEEK_MODEL); setContext(new ZinggSparkContext()); } @Override - public void init(IArguments args, IZinggLicense license) + public void init(IArguments args, SparkSession s) throws ZinggClientException { - super.init(args, license); + super.init(args,s); getContext().setUtils(); //we wil not init here as we wnt py to drive //the spark session etc - //getContext().init(license); + getContext().init(s); } @Override diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java index a34676143..cf608a6e9 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkRecommender.java @@ -9,11 +9,12 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.core.executor.Recommender; import zingg.common.core.recommender.StopWordsRecommender; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; +import zingg.spark.core.context.ZinggSparkContext; import zingg.spark.core.recommender.SparkStopWordsRecommender; @@ -21,26 +22,30 @@ * Spark specific implementation of Recommender * */ -public class SparkRecommender extends Recommender, Row, Column,DataType> { +public class SparkRecommender extends Recommender, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkRecommender"; public static final Log LOG = LogFactory.getLog(SparkRecommender.class); public SparkRecommender() { - setZinggOptions(ZinggOptions.RECOMMEND); - setContext(new ZinggSparkContext()); + this(new ZinggSparkContext()); } + public SparkRecommender(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.RECOMMEND); + setContext(sparkContext); + } + @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - public StopWordsRecommender, Row, Column, DataType> getStopWordsRecommender() { - StopWordsRecommender, Row, Column, DataType> stopWordsRecommender = new SparkStopWordsRecommender(getContext(),args); + public StopWordsRecommender, Row, Column, DataType> getStopWordsRecommender() { + StopWordsRecommender, Row, Column, DataType> stopWordsRecommender = new SparkStopWordsRecommender(getContext(),args); return stopWordsRecommender; } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java index b24db07cb..699af83bf 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainMatcher.java @@ -9,29 +9,34 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; + import zingg.common.core.executor.TrainMatcher; -import zingg.spark.client.ZSparkSession; +import zingg.spark.core.context.ZinggSparkContext; +import org.apache.spark.sql.SparkSession; -public class SparkTrainMatcher extends TrainMatcher, Row, Column,DataType> { +public class SparkTrainMatcher extends TrainMatcher, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkTrainMatcher"; public static final Log LOG = LogFactory.getLog(SparkTrainMatcher.class); public SparkTrainMatcher() { - setZinggOptions(ZinggOptions.TRAIN_MATCH); - ZinggSparkContext sparkContext = new ZinggSparkContext(); + this(new ZinggSparkContext()); + } + + + public SparkTrainMatcher(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.TRAIN_MATCH); setContext(sparkContext); trainer = new SparkTrainer(sparkContext); matcher = new SparkMatcher(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java index bd2124b4d..e23c5b043 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainer.java @@ -2,22 +2,24 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.SparkSession; import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; +import zingg.spark.core.context.ZinggSparkContext; import zingg.common.core.executor.Trainer; import zingg.common.core.preprocess.StopWordsRemover; -import zingg.spark.client.ZSparkSession; + import zingg.spark.core.preprocess.SparkStopWordsRemover; -public class SparkTrainer extends Trainer, Row, Column,DataType> { +public class SparkTrainer extends Trainer, Row, Column,DataType> { public static String name = "zingg.spark.core.executor.SparkTrainer"; private static final long serialVersionUID = 1L; @@ -28,18 +30,18 @@ public SparkTrainer() { } public SparkTrainer(ZinggSparkContext sparkContext) { - setZinggOptions(ZinggOptions.TRAIN); + setZinggOption(ZinggOptions.TRAIN); setContext(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - protected StopWordsRemover, Row, Column, DataType> getStopWords() { + public StopWordsRemover, Row, Column, DataType> getStopWords() { return new SparkStopWordsRemover(getContext(),getArgs()); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java index 9c0816128..012effdab 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkTrainingDataFinder.java @@ -9,14 +9,14 @@ import zingg.common.client.IArguments; import zingg.common.client.ZinggClientException; -import zingg.common.client.ZinggOptions; -import zingg.common.client.license.IZinggLicense; +import zingg.common.client.options.ZinggOptions; +import zingg.spark.core.context.ZinggSparkContext; import zingg.common.core.executor.TrainingDataFinder; import zingg.common.core.preprocess.StopWordsRemover; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.preprocess.SparkStopWordsRemover; -public class SparkTrainingDataFinder extends TrainingDataFinder, Row, Column,DataType> { +public class SparkTrainingDataFinder extends TrainingDataFinder, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.core.executor.SparkTrainingDataFinder"; @@ -27,18 +27,18 @@ public SparkTrainingDataFinder() { } public SparkTrainingDataFinder(ZinggSparkContext sparkContext) { - setZinggOptions(ZinggOptions.FIND_TRAINING_DATA); + super(); setContext(sparkContext); } @Override - public void init(IArguments args, IZinggLicense license) throws ZinggClientException { - super.init(args, license); - getContext().init(license); + public void init(IArguments args, SparkSession s) throws ZinggClientException { + super.init(args,s); + getContext().init(s); } @Override - protected StopWordsRemover, Row, Column, DataType> getStopWords() { + public StopWordsRemover, Row, Column, DataType> getStopWords() { return new SparkStopWordsRemover(getContext(),getArgs()); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/SparkZFactory.java b/spark/core/src/main/java/zingg/spark/core/executor/SparkZFactory.java index a64570f45..5e9079796 100644 --- a/spark/core/src/main/java/zingg/spark/core/executor/SparkZFactory.java +++ b/spark/core/src/main/java/zingg/spark/core/executor/SparkZFactory.java @@ -4,7 +4,8 @@ import zingg.common.client.IZingg; import zingg.common.client.IZinggFactory; -import zingg.common.client.ZinggOptions; +import zingg.common.client.options.ZinggOption; +import zingg.common.client.options.ZinggOptions; import zingg.spark.core.executor.SparkDocumenter; import zingg.spark.core.executor.SparkFindAndLabeller; import zingg.spark.core.executor.SparkLabelUpdater; @@ -20,7 +21,7 @@ public class SparkZFactory implements IZinggFactory{ public SparkZFactory() {} - public static HashMap zinggers = new HashMap(); + public static HashMap zinggers = new HashMap(); static { zinggers.put(ZinggOptions.TRAIN, SparkTrainer.name); @@ -36,7 +37,7 @@ public SparkZFactory() {} zinggers.put(ZinggOptions.PEEK_MODEL, SparkPeekModel.name); } - public IZingg get(ZinggOptions z) throws InstantiationException, IllegalAccessException, ClassNotFoundException { + public IZingg get(ZinggOption z) throws InstantiationException, IllegalAccessException, ClassNotFoundException { return (IZingg) Class.forName(zinggers.get(z)).newInstance(); } diff --git a/spark/core/src/main/java/zingg/spark/core/executor/ZinggSparkContext.java b/spark/core/src/main/java/zingg/spark/core/executor/ZinggSparkContext.java deleted file mode 100644 index bf28e5fb3..000000000 --- a/spark/core/src/main/java/zingg/spark/core/executor/ZinggSparkContext.java +++ /dev/null @@ -1,192 +0,0 @@ -package zingg.spark.core.executor; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.DataType; - -import zingg.common.client.IZingg; -import zingg.common.client.ZinggClientException; -import zingg.common.client.license.IZinggLicense; -import zingg.common.core.Context; -import zingg.common.core.util.BlockingTreeUtil; -import zingg.common.core.util.DSUtil; -import zingg.common.core.util.GraphUtil; -import zingg.common.core.util.HashUtil; -import zingg.common.core.util.ModelUtil; -import zingg.common.core.util.PipeUtilBase; -import zingg.spark.client.ZSparkSession; -import zingg.spark.core.util.SparkBlockingTreeUtil; -import zingg.spark.core.util.SparkDSUtil; -import zingg.spark.core.util.SparkGraphUtil; -import zingg.spark.core.util.SparkHashUtil; -import zingg.spark.core.util.SparkModelUtil; -import zingg.spark.core.util.SparkPipeUtil; - - -public class ZinggSparkContext implements Context, Row,Column,DataType>{ - - - private static final long serialVersionUID = 1L; - protected JavaSparkContext ctx; - protected ZSparkSession zSession; - protected PipeUtilBase, Row, Column> pipeUtil; - protected HashUtil, Row, Column, DataType> hashUtil; - protected DSUtil, Row, Column> dsUtil; - protected GraphUtil, Row, Column> graphUtil; - protected ModelUtil, Row, Column> modelUtil; - protected BlockingTreeUtil, Row, Column, DataType> blockingTreeUtil; - - public static final String hashFunctionFile = "hashFunctions.json"; - - - public static final Log LOG = LogFactory.getLog(ZinggSparkContext.class); - - - public ZSparkSession getSession() { - return zSession; - } - - public void setSession(ZSparkSession spark) { - LOG.debug("Session passed to context is " + spark); - this.zSession = spark; - } - - - - @Override - public void init(IZinggLicense license) - throws ZinggClientException { - try{ - if (zSession==null || zSession.getSession() == null) { - SparkSession spark = SparkSession - .builder() - .appName("Zingg") - .getOrCreate(); - - zSession = new ZSparkSession(spark, license); - } - if (ctx==null) { - ctx = JavaSparkContext.fromSparkContext(zSession.getSession().sparkContext()); - JavaSparkContext.jarOfClass(IZingg.class); - LOG.debug("Context " + ctx.toString()); - //initHashFns(); - ctx.setCheckpointDir("/tmp/checkpoint"); - setUtils(); - } - } - catch(Throwable e) { - if (LOG.isDebugEnabled()) e.printStackTrace(); - throw new ZinggClientException(e.getMessage()); - } - } - - @Override - public void cleanup() { - try { - if (ctx != null) { - ctx.stop(); - } - if (zSession!=null && zSession.getSession() != null) { - zSession.getSession().stop(); - } - ctx = null; - zSession = null; - } catch (Exception e) { - // ignore any exception in cleanup - e.printStackTrace(); - } - } - - @Override - public void setUtils() { - LOG.debug("Session passed to utils is " + zSession.getSession()); - setPipeUtil(new SparkPipeUtil(zSession)); - setDSUtil(new SparkDSUtil(zSession)); - setHashUtil(new SparkHashUtil(zSession)); - setGraphUtil(new SparkGraphUtil()); - setModelUtil(new SparkModelUtil(zSession)); - setBlockingTreeUtil(new SparkBlockingTreeUtil(zSession, getPipeUtil())); - } - - /** - public void initHashFns() throws ZinggClientException { - try { - //functions = Util.getFunctionList(this.functionFile); - hashFunctions = getHashUtil().getHashFunctionList(hashFunctionFile, getContext()); - } catch (Exception e) { - if (LOG.isDebugEnabled()) e.printStackTrace(); - throw new ZinggClientException("Unable to initialize base functions"); - } - } - */ - - - - public void setHashUtil(HashUtil, Row, Column, DataType> t) { - this.hashUtil = t; - } - - public void setGraphUtil(GraphUtil, Row, Column> t) { - this.graphUtil = t; - } - - - - public void setPipeUtil(PipeUtilBase, Row, Column> pipeUtil) { - this.pipeUtil = pipeUtil; - } - - - public void setDSUtil(DSUtil, Row, Column> pipeUtil) { - this.dsUtil = pipeUtil; - } - - public void setBlockingTreeUtil(BlockingTreeUtil, Row, Column, DataType> d) { - this.blockingTreeUtil = d; - } - - public void setModelUtil(ModelUtil, Row, Column> t) { - this.modelUtil = t; - } - - public ModelUtil, Row, Column> getModelUtil() { - return modelUtil; - } - - /* @Override - public void setSession(SparkSession session) { - this.spark = session; - } - */ - - @Override - public HashUtil, Row, Column, DataType> getHashUtil() { - return hashUtil; - } - - @Override - public GraphUtil, Row, Column> getGraphUtil() { - return graphUtil; - } - - @Override - public DSUtil, Row, Column> getDSUtil() { - return dsUtil; - } - - @Override - public PipeUtilBase, Row, Column> getPipeUtil() { - return pipeUtil; - } - - @Override - public BlockingTreeUtil, Row, Column, DataType> getBlockingTreeUtil() { - return blockingTreeUtil; - } - - } \ No newline at end of file diff --git a/spark/core/src/main/java/zingg/spark/core/model/SparkLabelModel.java b/spark/core/src/main/java/zingg/spark/core/model/SparkLabelModel.java index 18f563050..d7c5f32ef 100644 --- a/spark/core/src/main/java/zingg/spark/core/model/SparkLabelModel.java +++ b/spark/core/src/main/java/zingg/spark/core/model/SparkLabelModel.java @@ -2,6 +2,7 @@ import java.util.Map; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataType; import zingg.common.client.FieldDefinition; @@ -11,8 +12,8 @@ public class SparkLabelModel extends SparkModel{ private static final long serialVersionUID = 1L; - public SparkLabelModel(Map> f) { - super(f); + public SparkLabelModel(SparkSession s, Map> f) { + super(s,f); } } diff --git a/spark/core/src/main/java/zingg/spark/core/model/SparkModel.java b/spark/core/src/main/java/zingg/spark/core/model/SparkModel.java index 4a453a8ac..b00d23a22 100644 --- a/spark/core/src/main/java/zingg/spark/core/model/SparkModel.java +++ b/spark/core/src/main/java/zingg/spark/core/model/SparkModel.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.SparkSession; import zingg.common.client.FieldDefinition; import zingg.common.client.ZFrame; @@ -30,11 +31,11 @@ import zingg.common.core.model.Model; import zingg.common.core.similarity.function.SimFunction; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.similarity.SparkSimFunction; import zingg.spark.core.similarity.SparkTransformer; -public class SparkModel extends Model, Row, Column>{ +public class SparkModel extends Model, Row, Column>{ public static final Log LOG = LogFactory.getLog(SparkModel.class); //private Map featurers; @@ -43,19 +44,20 @@ public class SparkModel extends Model, Row LogisticRegression lr; Transformer transformer; BinaryClassificationEvaluator binaryClassificationEvaluator; - List columnsAdded; + VectorValueExtractor vve; - public SparkModel(Map> f) { + public SparkModel(SparkSession s, Map> f) { + super(s); featureCreators = new ArrayList(); pipelineStage = new ArrayList (); - columnsAdded = new ArrayList (); int count = 0; for (FieldDefinition fd : f.keySet()) { Feature fea = f.get(fd); List sfList = fea.getSimFunctions(); for (SimFunction sf : sfList) { - String outputCol = ColName.SIM_COL + count; + + String outputCol = getColumnName(fd.fieldName, sf.getName(), count); columnsAdded.add(outputCol); SparkTransformer st = new SparkTransformer(fd.fieldName, new SparkSimFunction(sf), outputCol); count++; @@ -92,9 +94,11 @@ public SparkModel(Map> f) { columnsAdded.add(ColName.RAW_PREDICTION); } - - public void fit(ZFrame,Row,Column> pos, ZFrame,Row,Column> neg) { + fitCore(pos, neg); + } + + protected ZFrame,Row,Column> fitCore(ZFrame,Row,Column> pos, ZFrame,Row,Column> neg) { //transform ZFrame,Row,Column> input = transform(pos.union(neg)).coalesce(1).cache(); //if (LOG.isDebugEnabled()) input.write().csv("/tmp/input/" + System.currentTimeMillis()); @@ -119,6 +123,7 @@ public void fit(ZFrame,Row,Column> pos, ZFrame,Row,Col CrossValidatorModel cvModel = cv.fit(input.df()); transformer = cvModel; LOG.debug("threshold after fitting is " + lr.getThreshold()); + return input; } @@ -126,13 +131,20 @@ public void load(String path) { transformer = CrossValidatorModel.load(path); } - public ZFrame,Row,Column> predict(ZFrame,Row,Column> data) { return predict(data, true); } @Override public ZFrame,Row,Column> predict(ZFrame,Row,Column> data, boolean isDrop) { + return dropFeatureCols(predictCore(data), isDrop); + } + + + + + @Override + protected ZFrame,Row,Column> predictCore(ZFrame,Row,Column> data) { //create features LOG.info("threshold while predicting is " + lr.getThreshold()); //lr.setThreshold(0.95); @@ -142,11 +154,7 @@ public ZFrame,Row,Column> predict(ZFrame,Row,Column> d //LOG.debug(predictWithFeatures.schema()); predictWithFeatures = vve.transform(predictWithFeatures); //LOG.debug("Original schema is " + predictWithFeatures.schema()); - if (isDrop) { - Dataset returnDS = predictWithFeatures.drop(columnsAdded.toArray(new String[columnsAdded.size()])); - //LOG.debug("Return schema after dropping additional columns is " + returnDS.schema()); - return new SparkFrame(returnDS); - } + LOG.debug("Return schema is " + predictWithFeatures.schema()); return new SparkFrame(predictWithFeatures); @@ -170,13 +178,13 @@ public ZFrame,Row,Column> transform(ZFrame,Row,Column> @Override - public void register(ZSparkSession spark) { + public void register() { if (featureCreators != null) { for (SparkTransformer bsf: featureCreators) { - bsf.register(spark); + bsf.register(session); } } - vve.register(spark); + vve.register(session); } diff --git a/spark/core/src/main/java/zingg/spark/core/model/VectorValueExtractor.java b/spark/core/src/main/java/zingg/spark/core/model/VectorValueExtractor.java index f6ab7486a..e842386c5 100644 --- a/spark/core/src/main/java/zingg/spark/core/model/VectorValueExtractor.java +++ b/spark/core/src/main/java/zingg/spark/core/model/VectorValueExtractor.java @@ -7,8 +7,9 @@ import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.types.DataTypes; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.similarity.SparkBaseTransformer; +import zingg.spark.core.util.SparkFnRegistrar; public class VectorValueExtractor extends SparkBaseTransformer implements UDF1{ @@ -25,8 +26,9 @@ public Double call(Vector v) { } @Override - public void register(ZSparkSession spark) { - spark.getSession().udf().register(uid, (UDF1) this, DataTypes.DoubleType); + public void register(SparkSession spark) { + + SparkFnRegistrar.registerUDF1(spark, uid, this, DataTypes.DoubleType); } /*@Override diff --git a/spark/core/src/main/java/zingg/spark/core/preprocess/SparkStopWordsRemover.java b/spark/core/src/main/java/zingg/spark/core/preprocess/SparkStopWordsRemover.java index d20c3fa38..860e66b7e 100644 --- a/spark/core/src/main/java/zingg/spark/core/preprocess/SparkStopWordsRemover.java +++ b/spark/core/src/main/java/zingg/spark/core/preprocess/SparkStopWordsRemover.java @@ -15,12 +15,13 @@ import zingg.common.client.IArguments; import zingg.common.client.ZFrame; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.preprocess.StopWordsRemover; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; +import zingg.spark.core.util.SparkFnRegistrar; -public class SparkStopWordsRemover extends StopWordsRemover,Row,Column,DataType> implements Serializable { +public class SparkStopWordsRemover extends StopWordsRemover,Row,Column,DataType> implements Serializable { private static final long serialVersionUID = 1L; protected static String name = "zingg.spark.preprocess.SparkStopWordsRemover"; @@ -28,7 +29,7 @@ public class SparkStopWordsRemover extends StopWordsRemover, Row, Column,DataType> context,IArguments args) { + public SparkStopWordsRemover(Context, Row, Column,DataType> context, IArguments args) { super(context,args); this.udfName = registerUDF(); } @@ -45,8 +46,9 @@ protected String registerUDF() { // Each field will have different pattern String udfName = removeStopWordsUDF.getName(); // register the UDF - ZSparkSession zSession = getContext().getSession(); - zSession.getSession().udf().register(udfName, removeStopWordsUDF, DataTypes.StringType); + SparkSession zSession = getContext().getSession(); + + SparkFnRegistrar.registerUDF2(zSession, udfName, removeStopWordsUDF, DataTypes.StringType); return udfName; } diff --git a/spark/core/src/main/java/zingg/spark/core/recommender/SparkStopWordsRecommender.java b/spark/core/src/main/java/zingg/spark/core/recommender/SparkStopWordsRecommender.java index 32873f035..c6842a465 100644 --- a/spark/core/src/main/java/zingg/spark/core/recommender/SparkStopWordsRecommender.java +++ b/spark/core/src/main/java/zingg/spark/core/recommender/SparkStopWordsRecommender.java @@ -7,10 +7,11 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import zingg.common.client.Arguments; import zingg.common.client.IArguments; -import zingg.common.core.Context; +import zingg.common.core.context.Context; import zingg.common.core.recommender.StopWordsRecommender; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; /** @@ -18,13 +19,13 @@ * * */ -public class SparkStopWordsRecommender extends StopWordsRecommender, Row, Column,DataType> { +public class SparkStopWordsRecommender extends StopWordsRecommender, Row, Column,DataType> { private static final long serialVersionUID = 1L; public static String name = "zingg.spark.SparkStopWordsRecommender"; public static final Log LOG = LogFactory.getLog(SparkStopWordsRecommender.class); - public SparkStopWordsRecommender(Context, Row, Column,DataType> context,IArguments args) { + public SparkStopWordsRecommender(Context, Row, Column,DataType> context, IArguments args) { super(context,args); } diff --git a/spark/core/src/main/java/zingg/spark/core/similarity/SparkBaseTransformer.java b/spark/core/src/main/java/zingg/spark/core/similarity/SparkBaseTransformer.java index 5773bad9b..b7dd56a1e 100644 --- a/spark/core/src/main/java/zingg/spark/core/similarity/SparkBaseTransformer.java +++ b/spark/core/src/main/java/zingg/spark/core/similarity/SparkBaseTransformer.java @@ -14,7 +14,7 @@ import org.apache.spark.sql.types.StructType; import zingg.common.client.util.ColName; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; public abstract class SparkBaseTransformer extends Transformer implements HasInputCol, HasOutputCol { @@ -113,6 +113,6 @@ public Param outputCol() { - public abstract void register(ZSparkSession spark); + public abstract void register(SparkSession spark); } diff --git a/spark/core/src/main/java/zingg/spark/core/similarity/SparkTransformer.java b/spark/core/src/main/java/zingg/spark/core/similarity/SparkTransformer.java index f24533306..f477067d6 100644 --- a/spark/core/src/main/java/zingg/spark/core/similarity/SparkTransformer.java +++ b/spark/core/src/main/java/zingg/spark/core/similarity/SparkTransformer.java @@ -2,10 +2,10 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.sql.types.DataTypes; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; +import zingg.spark.core.util.SparkFnRegistrar; public class SparkTransformer extends SparkBaseTransformer { @@ -25,8 +25,9 @@ public SparkTransformer(String inputCol, SparkSimFunction function, String outpu - public void register(ZSparkSession spark) { - spark.getSession().udf().register(getUid(), (UDF2) function, DataTypes.DoubleType); + public void register(SparkSession spark) { + + SparkFnRegistrar.registerUDF2(spark, getUid(), function, DataTypes.DoubleType); } diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java b/spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java index 76017a827..984e07b83 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java +++ b/spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java @@ -20,23 +20,23 @@ import zingg.common.client.ZFrame; import zingg.common.client.util.ColName; import zingg.common.client.util.ListMap; +import zingg.common.client.util.PipeUtilBase; import zingg.common.core.block.Block; import zingg.common.core.block.Canopy; import zingg.common.core.block.Tree; import zingg.common.core.hash.HashFunction; import zingg.common.core.util.BlockingTreeUtil; -import zingg.common.core.util.PipeUtilBase; import zingg.spark.client.SparkFrame; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.block.SparkBlock; import zingg.spark.core.block.SparkBlockFunction; -public class SparkBlockingTreeUtil extends BlockingTreeUtil, Row, Column, DataType>{ +public class SparkBlockingTreeUtil extends BlockingTreeUtil, Row, Column, DataType>{ public static final Log LOG = LogFactory.getLog(SparkBlockingTreeUtil.class); - protected ZSparkSession spark; + protected SparkSession spark; - public SparkBlockingTreeUtil(ZSparkSession s, PipeUtilBase pipeUtil) { + public SparkBlockingTreeUtil(SparkSession s, PipeUtilBase pipeUtil) { this.spark = s; setPipeUtil(pipeUtil); } @@ -63,7 +63,7 @@ public ZFrame, Row, Column> getTreeDF(byte[] blockingTree){ StructType schema = DataTypes.createStructType(new StructField[] { DataTypes.createStructField("BlockingTree", DataTypes.BinaryType, false) }); List objList = new ArrayList(); objList.add(RowFactory.create(blockingTree)); - Dataset df = spark.getSession().sqlContext().createDataFrame(objList, schema).toDF().coalesce(1); + Dataset df = spark.sqlContext().createDataFrame(objList, schema).toDF().coalesce(1); return new SparkFrame(df); } diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkFnRegistrar.java b/spark/core/src/main/java/zingg/spark/core/util/SparkFnRegistrar.java new file mode 100644 index 000000000..792d13b06 --- /dev/null +++ b/spark/core/src/main/java/zingg/spark/core/util/SparkFnRegistrar.java @@ -0,0 +1,25 @@ +package zingg.spark.core.util; + +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.sql.types.DataType; + +public class SparkFnRegistrar { + + public static void registerUDF1(SparkSession sparkSession, String functionName, UDF1 udf1, DataType dataType) { + + //only register udf1 if it is not registered already + if (!sparkSession.catalog().functionExists(functionName)) { + sparkSession.udf().register(functionName, udf1, dataType); + } + } + + public static void registerUDF2(SparkSession sparkSession, String functionName, UDF2 udf2, DataType dataType) { + + //only register udf2 if it is not registered already + if (!sparkSession.catalog().functionExists(functionName)) { + sparkSession.udf().register(functionName, udf2, dataType); + } + } +} diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkGraphUtil.java b/spark/core/src/main/java/zingg/spark/core/util/SparkGraphUtil.java index 44a8ac240..8a885c751 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkGraphUtil.java +++ b/spark/core/src/main/java/zingg/spark/core/util/SparkGraphUtil.java @@ -20,6 +20,7 @@ public ZFrame, Row, Column> buildGraph(ZFrame, Row, Co // we need to transform the input here by using stop words //rename id field which is a common field in data to another field as it //clashes with graphframes :-( + vOrig = vOrig.cache(); Dataset vertices = vOrig.df(); Dataset edges = ed.df(); vertices = vertices.withColumnRenamed(ColName.ID_EXTERNAL_ORIG_COL, ColName.ID_EXTERNAL_COL); diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkHashUtil.java b/spark/core/src/main/java/zingg/spark/core/util/SparkHashUtil.java index c128dc888..6096f9ecc 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkHashUtil.java +++ b/spark/core/src/main/java/zingg/spark/core/util/SparkHashUtil.java @@ -9,19 +9,20 @@ import zingg.common.core.hash.HashFnFromConf; import zingg.common.core.hash.HashFunction; import zingg.common.core.util.BaseHashUtil; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.hash.SparkHashFunctionRegistry; -public class SparkHashUtil extends BaseHashUtil, Row, Column,DataType>{ +public class SparkHashUtil extends BaseHashUtil, Row, Column,DataType>{ - public SparkHashUtil(ZSparkSession spark) { + public SparkHashUtil(SparkSession spark) { super(spark); } public HashFunction, Row, Column,DataType> registerHashFunction(HashFnFromConf scriptArg) { HashFunction, Row, Column,DataType> fn = new SparkHashFunctionRegistry().getFunction(scriptArg.getName()); - getSessionObj().getSession().udf().register(fn.getName(), (UDF1) fn, fn.getReturnType()); + + SparkFnRegistrar.registerUDF1(getSessionObj(), fn.getName(), (UDF1) fn, fn.getReturnType()); return fn; } diff --git a/spark/core/src/main/java/zingg/spark/core/util/SparkModelUtil.java b/spark/core/src/main/java/zingg/spark/core/util/SparkModelUtil.java index 2ad472636..ad6dcd886 100644 --- a/spark/core/src/main/java/zingg/spark/core/util/SparkModelUtil.java +++ b/spark/core/src/main/java/zingg/spark/core/util/SparkModelUtil.java @@ -11,36 +11,36 @@ import zingg.common.core.feature.FeatureFactory; import zingg.common.core.model.Model; import zingg.common.core.util.ModelUtil; -import zingg.spark.client.ZSparkSession; +import org.apache.spark.sql.SparkSession; import zingg.spark.core.feature.SparkFeatureFactory; import zingg.spark.core.model.SparkLabelModel; import zingg.spark.core.model.SparkModel; -public class SparkModelUtil extends ModelUtil, Row, Column> { +public class SparkModelUtil extends ModelUtil, Row, Column> { public static final Log LOG = LogFactory.getLog(SparkModelUtil.class); - public SparkModelUtil(ZSparkSession s) { - this.session = s; + public SparkModelUtil(SparkSession s) { + super(s); } - public Model, Row, Column> getModel(boolean isLabel, IArguments args) throws ZinggClientException{ - Model, Row, Column> model = null; + public Model, Row, Column> getModel(boolean isLabel, IArguments args) throws ZinggClientException{ + Model, Row, Column> model = null; if (isLabel) { - model = new SparkLabelModel(getFeaturers(args)); + model = new SparkLabelModel(session, getFeaturers(args)); } else { - model = new SparkModel(getFeaturers(args)); + model = new SparkModel(session, getFeaturers(args)); } return model; } @Override - public Model, Row, Column> loadModel(boolean isLabel, + public Model, Row, Column> loadModel(boolean isLabel, IArguments args) throws ZinggClientException { - Model, Row, Column> model = getModel(isLabel, args); + Model, Row, Column> model = getModel(isLabel, args); model.load(args.getModel()); return model; } diff --git a/spark/core/src/test/java/zingg/TestFebrlDataset.java b/spark/core/src/test/java/zingg/TestFebrlDataset.java index bf28e8537..a7ef49128 100644 --- a/spark/core/src/test/java/zingg/TestFebrlDataset.java +++ b/spark/core/src/test/java/zingg/TestFebrlDataset.java @@ -50,7 +50,7 @@ public void setUp() throws Exception, ZinggClientException{ public void testModelAccuracy(){ TrainMatcher tm = new SparkTrainMatcher(); try { - tm.init(args, null); + tm.init(args,spark); // tm.setSpark(spark); // tm.setCtx(ctx); tm.setArgs(args); diff --git a/spark/core/src/test/java/zingg/TestImageType.java b/spark/core/src/test/java/zingg/TestImageType.java index d96dd7313..bb005a7b2 100644 --- a/spark/core/src/test/java/zingg/TestImageType.java +++ b/spark/core/src/test/java/zingg/TestImageType.java @@ -19,6 +19,7 @@ import zingg.common.core.similarity.function.ArrayDoubleSimilarityFunction; import zingg.spark.core.executor.ZinggSparkTester; +import zingg.spark.core.util.SparkFnRegistrar; public class TestImageType extends ZinggSparkTester{ @@ -90,7 +91,7 @@ public void testUDFArray() { df.printSchema(); // register ArrayDoubleSimilarityFunction as a UDF TestUDFDoubleArr testUDFDoubleArr = new TestUDFDoubleArr(); - spark.udf().register("testUDFDoubleArr", testUDFDoubleArr, DataTypes.DoubleType); + SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleArr", testUDFDoubleArr, DataTypes.DoubleType); // call the UDF from select clause of DF df = df.withColumn("cosine", callUDF("testUDFDoubleArr", df.col("image_embedding"), df.col("image_embedding"))); @@ -116,8 +117,8 @@ public void testUDFList() { // register ArrayDoubleSimilarityFunction as a UDF TestUDFDoubleList testUDFDoubleList = new TestUDFDoubleList(); - spark.udf().register("testUDFDoubleList", testUDFDoubleList, DataTypes.DoubleType); - + SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleList", testUDFDoubleList, DataTypes.DoubleType); + // call the UDF from select clause of DF df = df.withColumn("cosine", callUDF("testUDFDoubleList",df.col("image_embedding"),df.col("image_embedding"))); // see if error is reproduced @@ -142,8 +143,8 @@ public void testUDFSeq() { // register ArrayDoubleSimilarityFunction as a UDF TestUDFDoubleSeq testUDFDoubleSeq = new TestUDFDoubleSeq(); - spark.udf().register("testUDFDoubleSeq", testUDFDoubleSeq, DataTypes.DoubleType); - + SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleSeq", testUDFDoubleSeq, DataTypes.DoubleType); + // call the UDF from select clause of DF df = df.withColumn("cosine", callUDF("testUDFDoubleSeq",df.col("image_embedding"),df.col("image_embedding"))); // see if error is reproduced @@ -168,8 +169,8 @@ public void testUDFWrappedArr() { // register ArrayDoubleSimilarityFunction as a UDF TestUDFDoubleWrappedArr testUDFDoubleWrappedArr = new TestUDFDoubleWrappedArr(); - spark.udf().register("testUDFDoubleWrappedArr", testUDFDoubleWrappedArr, DataTypes.DoubleType); - + SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleWrappedArr", testUDFDoubleWrappedArr, DataTypes.DoubleType); + // call the UDF from select clause of DF df = df.withColumn("cosine", callUDF("testUDFDoubleWrappedArr",df.col("image_embedding"),df.col("image_embedding"))); // see if error is reproduced @@ -197,8 +198,8 @@ public void testUDFObj() { // register ArrayDoubleSimilarityFunction as a UDF TestUDFDoubleObj testUDFDoubleObj = new TestUDFDoubleObj(); - spark.udf().register("testUDFDoubleObj", testUDFDoubleObj, DataTypes.DoubleType); - + SparkFnRegistrar.registerUDF2(spark, "testUDFDoubleObj", testUDFDoubleObj, DataTypes.DoubleType); + // call the UDF from select clause of DF df = df.withColumn("cosine", callUDF("testUDFDoubleObj",df.col("image_embedding"),df.col("image_embedding"))); // see if error is reproduced diff --git a/spark/core/src/test/java/zingg/TestSparkBase.java b/spark/core/src/test/java/zingg/TestSparkBase.java new file mode 100644 index 000000000..e04782700 --- /dev/null +++ b/spark/core/src/test/java/zingg/TestSparkBase.java @@ -0,0 +1,46 @@ +package zingg; + +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import zingg.spark.core.executor.ZinggSparkTester; + +public class TestSparkBase extends ZinggSparkTester implements BeforeAllCallback, AfterAllCallback, ParameterResolver{ + + public SparkSession sparkSession; + + static boolean isSetUp; + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return parameterContext.getParameter().getType() + .equals(SparkSession.class); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return sparkSession; + } + + @Override + public void afterAll(ExtensionContext context) { + + } + + @Override + public void beforeAll(ExtensionContext context) { + if (!isSetUp || sparkSession == null) { + super.setup(); + sparkSession = ZinggSparkTester.spark; + } + isSetUp = true; + } + + +} diff --git a/spark/core/src/test/java/zingg/TestUDFDoubleWrappedArr.java b/spark/core/src/test/java/zingg/TestUDFDoubleWrappedArr.java index 345362fb1..cd19368ee 100644 --- a/spark/core/src/test/java/zingg/TestUDFDoubleWrappedArr.java +++ b/spark/core/src/test/java/zingg/TestUDFDoubleWrappedArr.java @@ -2,15 +2,15 @@ import org.apache.spark.sql.api.java.UDF2; -import scala.collection.mutable.WrappedArray; +import scala.collection.mutable.ArraySeq; import zingg.common.core.similarity.function.ArrayDoubleSimilarityFunction; -public class TestUDFDoubleWrappedArr implements UDF2,WrappedArray, Double>{ +public class TestUDFDoubleWrappedArr implements UDF2,ArraySeq, Double>{ private static final long serialVersionUID = 1L; @Override - public Double call(WrappedArray t1, WrappedArray t2) throws Exception { + public Double call(ArraySeq t1, ArraySeq t2) throws Exception { System.out.println("TestUDFDoubleWrappedArr class" +t1.getClass()); Double[] t1Arr = new Double[t1.length()]; diff --git a/spark/core/src/test/java/zingg/block/TestBlock.java b/spark/core/src/test/java/zingg/block/TestBlock.java deleted file mode 100644 index cdcd09ba9..000000000 --- a/spark/core/src/test/java/zingg/block/TestBlock.java +++ /dev/null @@ -1,255 +0,0 @@ -package zingg.block; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.junit.jupiter.api.Test; - -import zingg.common.client.ArgumentsUtil; -import zingg.common.client.FieldDefinition; -import zingg.common.client.IArguments; -import zingg.common.client.MatchType; -import zingg.common.client.ZFrame; -import zingg.common.client.ZinggClientException; -import zingg.common.core.block.Canopy; -import zingg.common.core.block.Tree; -import zingg.spark.client.SparkFrame; -import zingg.spark.core.executor.ZinggSparkTester; -import zingg.spark.core.util.SparkBlockingTreeUtil; -import zingg.spark.core.util.SparkHashUtil; - -public class TestBlock extends ZinggSparkTester { - - @Test - public void testTree() throws Throwable { - - ZFrame, Row, Column> testData = getTestData(); - - ZFrame, Row, Column> posDf = getPosData(); - - IArguments args = getArguments(); - - // form tree - SparkBlockingTreeUtil blockingTreeUtil = new SparkBlockingTreeUtil(zSession, zsCTX.getPipeUtil()); - SparkHashUtil hashUtil = new SparkHashUtil(zSession); - - Tree> blockingTree = blockingTreeUtil.createBlockingTreeFromSample(testData, posDf, 0.5, -1, - args, hashUtil.getHashFunctionList()); - - // primary deciding is unique year so identityInteger should have been picked - Canopy head = blockingTree.getHead(); - assertEquals("identityInteger", head.getFunction().getName()); - - } - - StructType testDataSchema = new StructType(new StructField[] { - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("year", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("event", DataTypes.StringType, false, Metadata.empty()), - new StructField("comment", DataTypes.StringType, false, Metadata.empty())} - ); - - StructType schemaPos = new StructType(new StructField[] { - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("year", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("event", DataTypes.StringType, false, Metadata.empty()), - new StructField("comment", DataTypes.StringType, false, Metadata.empty()), - new StructField("z_year", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("z_event", DataTypes.StringType, false, Metadata.empty()), - new StructField("z_comment", DataTypes.StringType, false, Metadata.empty()), - new StructField("z_zid", DataTypes.StringType, false, Metadata.empty())} - ); - - - - - private IArguments getArguments() throws ZinggClientException { - String configFilePath = getClass().getResource("../../testFebrl/config.json").getFile(); - - IArguments args = argsUtil.createArgumentsFromJSON(configFilePath, "trainMatch"); - - List fdList = getFieldDefList(); - - args.setFieldDefinition(fdList); - return args; - } - - private List getFieldDefList() { - List fdList = new ArrayList(4); - - FieldDefinition idFD = new FieldDefinition(); - idFD.setDataType("integer"); - idFD.setFieldName("id"); - ArrayList matchTypelistId = new ArrayList(); - matchTypelistId.add(MatchType.DONT_USE); - idFD.setMatchType(matchTypelistId); - fdList.add(idFD); - - ArrayList matchTypelistFuzzy = new ArrayList(); - matchTypelistFuzzy.add(MatchType.FUZZY); - - - FieldDefinition yearFD = new FieldDefinition(); - yearFD.setDataType("integer"); - yearFD.setFieldName("year"); - yearFD.setMatchType(matchTypelistFuzzy); - fdList.add(yearFD); - - FieldDefinition eventFD = new FieldDefinition(); - eventFD.setDataType("string"); - eventFD.setFieldName("event"); - eventFD.setMatchType(matchTypelistFuzzy); - fdList.add(eventFD); - - FieldDefinition commentFD = new FieldDefinition(); - commentFD.setDataType("string"); - commentFD.setFieldName("comment"); - commentFD.setMatchType(matchTypelistFuzzy); - fdList.add(commentFD); - return fdList; - } - - public SparkFrame getTestData() { - int row_id = 1; - // Create a DataFrame containing test data - Row[] data = { - RowFactory.create(row_id++, new Integer(1942), "quit India", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disob", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit India", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidience", "India"), - RowFactory.create(row_id++, new Integer(1942), "Quit Bharat", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidence", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit Hindustan", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JW", "Amritsar"), - RowFactory.create(row_id++, new Integer(1930), "Civil Dis", "India") , - RowFactory.create(row_id++, new Integer(1942), "quit Nation", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disob", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit India", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidience", "India"), - RowFactory.create(row_id++, new Integer(1942), "Quit Bharat", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidence", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit Hindustan", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JW", "Amritsar"), - RowFactory.create(row_id++, new Integer(1930), "Civil Dis", "India") , - RowFactory.create(row_id++, new Integer(1942), "quit Nation", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disob", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit India", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidience", "India"), - RowFactory.create(row_id++, new Integer(1942), "Quit Bharat", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidence", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit Hindustan", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JW", "Amritsar"), - RowFactory.create(row_id++, new Integer(1930), "Civil Dis", "India") , - RowFactory.create(row_id++, new Integer(1942), "quit Nation", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disob", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit India", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidience", "India"), - RowFactory.create(row_id++, new Integer(1942), "Quit Bharat", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JallianWala", "Punjab"), - RowFactory.create(row_id++, new Integer(1930), "Civil Disobidence", "India"), - RowFactory.create(row_id++, new Integer(1942), "quit Hindustan", "Mahatma Gandhi"), - RowFactory.create(row_id++, new Integer(1919), "JW", "Amritsar"), - RowFactory.create(row_id++, new Integer(1930), "Civil Dis", "India") , - RowFactory.create(row_id++, new Integer(1942), "quit Nation", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma") - }; - - return new SparkFrame( - spark.createDataFrame(Arrays.asList(data), - testDataSchema)); - - } - - private SparkFrame getPosData() { - int row_id = 1000; - // Create positive matching data - Row[] posData = { - RowFactory.create(row_id++, new Integer(1942), "quit Nation", "Mahatma",new Integer(1942), "quit Nation", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit N", "Mahatma",new Integer(1942), "quit N", "Mahatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2"), - RowFactory.create(row_id++, new Integer(1942), "quit ", "Mahatm",new Integer(1942), "quit ", "Mahatm", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Ntn", "Mahama",new Integer(1942), "quit Ntn", "Mahama", "1"), - RowFactory.create(row_id++, new Integer(1942), "quit Natin", "Mahaatma",new Integer(1942), "quit Natin", "Mahaatma", "1"), - RowFactory.create(row_id++, new Integer(1919), "JallianWal", "Punjb",new Integer(1919), "JallianWal", "Punjb", "2") - }; - return new SparkFrame(spark.createDataFrame(Arrays.asList(posData), schemaPos)); - } - - -} diff --git a/spark/core/src/test/java/zingg/common/core/block/TestSparkBlock.java b/spark/core/src/test/java/zingg/common/core/block/TestSparkBlock.java new file mode 100644 index 000000000..0dcd25502 --- /dev/null +++ b/spark/core/src/test/java/zingg/common/core/block/TestSparkBlock.java @@ -0,0 +1,29 @@ +package zingg.common.core.block; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.junit.jupiter.api.extension.ExtendWith; +import zingg.TestSparkBase; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.IWithSession; +import zingg.common.client.util.WithSession; +import zingg.spark.client.util.SparkDFObjectUtil; +import zingg.spark.core.context.ZinggSparkContext; +import zingg.spark.core.util.SparkBlockingTreeUtil; +import zingg.spark.core.util.SparkHashUtil; + +@ExtendWith(TestSparkBase.class) +public class TestSparkBlock extends TestBlockBase, Row, Column, DataType> { + + public static ZinggSparkContext zsCTX = new ZinggSparkContext(); + public static IWithSession iWithSession = new WithSession(); + + public TestSparkBlock(SparkSession sparkSession) throws ZinggClientException { + super(new SparkDFObjectUtil(iWithSession), new SparkHashUtil(sparkSession), new SparkBlockingTreeUtil(sparkSession, zsCTX.getPipeUtil())); + iWithSession.setSession(sparkSession); + zsCTX.init(sparkSession); + } +} diff --git a/spark/core/src/test/java/zingg/common/core/preprocess/TestSparkStopWords.java b/spark/core/src/test/java/zingg/common/core/preprocess/TestSparkStopWords.java new file mode 100644 index 000000000..4887e3c09 --- /dev/null +++ b/spark/core/src/test/java/zingg/common/core/preprocess/TestSparkStopWords.java @@ -0,0 +1,33 @@ +package zingg.common.core.preprocess; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.extension.ExtendWith; +import zingg.TestSparkBase; +import zingg.common.client.ZinggClientException; +import zingg.common.client.util.IWithSession; +import zingg.common.client.util.WithSession; +import zingg.common.core.util.SparkStopWordRemoverUtility; +import zingg.spark.client.util.SparkDFObjectUtil; +import zingg.spark.core.context.ZinggSparkContext; + +@ExtendWith(TestSparkBase.class) +public class TestSparkStopWords extends TestStopWordsBase, Row, Column, DataType> { + + public static IWithSession iWithSession = new WithSession(); + public static ZinggSparkContext zsCTX = new ZinggSparkContext(); + + public TestSparkStopWords(SparkSession sparkSession) throws ZinggClientException { + super(new SparkDFObjectUtil(iWithSession), new SparkStopWordRemoverUtility(zsCTX), zsCTX); + iWithSession.setSession(sparkSession); + zsCTX.init(sparkSession); + } +} diff --git a/spark/core/src/test/java/zingg/common/core/preprocess/TestStopWords.java b/spark/core/src/test/java/zingg/common/core/preprocess/TestStopWords.java deleted file mode 100644 index c9ace5f3f..000000000 --- a/spark/core/src/test/java/zingg/common/core/preprocess/TestStopWords.java +++ /dev/null @@ -1,283 +0,0 @@ -package zingg.common.core.preprocess; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -import zingg.common.client.Arguments; -import zingg.common.client.FieldDefinition; -import zingg.common.client.IArguments; -import zingg.common.client.MatchType; -import zingg.common.client.ZinggClientException; -import zingg.common.client.util.ColName; -import zingg.spark.client.SparkFrame; -import zingg.spark.core.executor.ZinggSparkTester; -import zingg.spark.core.preprocess.SparkStopWordsRemover; - -public class TestStopWords extends ZinggSparkTester{ - - public static final Log LOG = LogFactory.getLog(TestStopWords.class); - - @DisplayName ("Test Stop Words removal from Single column dataset") - @Test - public void testStopWordsSingleColumn() throws ZinggClientException { - - StructType schema = new StructType(new StructField[] { - new StructField("statement", DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset datasetOriginal = spark.createDataFrame( - Arrays.asList( - RowFactory.create("The zingg is a Spark application"), - RowFactory.create("It is very popular in data Science"), - RowFactory.create("It is written in Java and Scala"), - RowFactory.create("Best of luck to zingg")), - schema); - - String stopWords = "\\b(a|an|the|is|It|of|yes|no|I|has|have|you)\\b\\s?".toLowerCase(); - - Dataset datasetExpected = spark.createDataFrame( - Arrays.asList( - RowFactory.create("zingg spark application"), - RowFactory.create("very popular in data science"), - RowFactory.create("written in java and scala"), - RowFactory.create("best luck to zingg")), - schema); - - List fdList = new ArrayList(4); - - ArrayList matchTypelistFuzzy = new ArrayList(); - matchTypelistFuzzy.add(MatchType.FUZZY); - - FieldDefinition eventFD = new FieldDefinition(); - eventFD.setDataType("string"); - eventFD.setFieldName("statement"); - eventFD.setMatchType(matchTypelistFuzzy); - fdList.add(eventFD); - - IArguments stmtArgs = new Arguments(); - stmtArgs.setFieldDefinition(fdList); - - StopWordsRemover stopWordsObj = new SparkStopWordsRemover(zsCTX,stmtArgs); - - stopWordsObj.preprocessForStopWords(new SparkFrame(datasetOriginal)); - System.out.println("datasetOriginal.show() : "); - datasetOriginal.show(); - SparkFrame datasetWithoutStopWords = (SparkFrame)stopWordsObj.removeStopWordsFromDF(new SparkFrame(datasetOriginal),"statement",stopWords); - System.out.println("datasetWithoutStopWords.show() : "); - datasetWithoutStopWords.show(); - - assertTrue(datasetExpected.except(datasetWithoutStopWords.df()).isEmpty()); - assertTrue(datasetWithoutStopWords.df().except(datasetExpected).isEmpty()); - } - - @Test - public void testRemoveStopWordsFromDataset() throws ZinggClientException { - StructType schemaOriginal = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset original = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "The zingg is a spark application", "two", - "Yes. a good application", "test"), - RowFactory.create("20", "It is very popular in Data Science", "Three", "true indeed", - "test"), - RowFactory.create("30", "It is written in java and scala", "four", "", "test"), - RowFactory.create("40", "Best of luck to zingg Mobile/T-Mobile", "Five", "thank you", "test")), - schemaOriginal); - - Dataset datasetExpected = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "zingg spark application", "two", "Yes. a good application", "test"), - RowFactory.create("20", "very popular data science", "Three", "true indeed", "test"), - RowFactory.create("30", "written java scala", "four", "", "test"), - RowFactory.create("40", "best luck to zingg ", "Five", "thank you", "test")), - schemaOriginal); - String stopWordsFileName = getClass().getResource("../../../../preProcess/stopWords.csv").getFile(); - FieldDefinition fd = new FieldDefinition(); - fd.setStopWords(stopWordsFileName); - fd.setFieldName("field1"); - - List fieldDefinitionList = Arrays.asList(fd); - args.setFieldDefinition(fieldDefinitionList); - - SparkStopWordsRemover stopWordsObj = new SparkStopWordsRemover(zsCTX,args); - - Dataset newDataSet = ((SparkFrame)(stopWordsObj.preprocessForStopWords(new SparkFrame(original)))).df(); - assertTrue(datasetExpected.except(newDataSet).isEmpty()); - assertTrue(newDataSet.except(datasetExpected).isEmpty()); - } - - @Test - public void testStopWordColumnMissingFromStopWordFile() throws ZinggClientException { - StructType schemaOriginal = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset original = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "The zingg is a spark application", "two", - "Yes. a good application", "test"), - RowFactory.create("20", "It is very popular in Data Science", "Three", "true indeed", - "test"), - RowFactory.create("30", "It is written in java and scala", "four", "", "test"), - RowFactory.create("40", "Best of luck to zingg Mobile/T-Mobile", "Five", "thank you", "test")), - schemaOriginal); - - Dataset datasetExpected = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "zingg spark application", "two", "Yes. a good application", "test"), - RowFactory.create("20", "very popular data science", "Three", "true indeed", "test"), - RowFactory.create("30", "written java scala", "four", "", "test"), - RowFactory.create("40", "best luck to zingg ", "Five", "thank you", "test")), - schemaOriginal); - String stopWordsFileName = getClass().getResource("../../../../preProcess/stopWordsWithoutHeader.csv").getFile(); - FieldDefinition fd = new FieldDefinition(); - fd.setStopWords(stopWordsFileName); - fd.setFieldName("field1"); - - List fieldDefinitionList = Arrays.asList(fd); - args.setFieldDefinition(fieldDefinitionList); - - SparkStopWordsRemover stopWordsObj = new SparkStopWordsRemover(zsCTX,args); - - System.out.println("testStopWordColumnMissingFromStopWordFile : orginal "); - original.show(200); - Dataset newDataSet = ((SparkFrame)(stopWordsObj.preprocessForStopWords(new SparkFrame(original)))).df(); - System.out.println("testStopWordColumnMissingFromStopWordFile : newDataSet "); - newDataSet.show(200); - System.out.println("testStopWordColumnMissingFromStopWordFile : datasetExpected "); - datasetExpected.show(200); - assertTrue(datasetExpected.except(newDataSet).isEmpty()); - assertTrue(newDataSet.except(datasetExpected).isEmpty()); - } - - - @Test - public void testForOriginalDataAfterPostprocess() { - StructType schemaActual = new StructType(new StructField[] { - new StructField(ColName.CLUSTER_COLUMN, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.PREDICTION_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SCORE_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.MATCH_FLAG_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - StructType schemaOriginal = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset original = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "The zingg is a spark application", "two", - "Yes. a good application", "test"), - RowFactory.create("20", "It is very popular in data science", "Three", "true indeed", - "test"), - RowFactory.create("30", "It is written in java and scala", "four", "", "test"), - RowFactory.create("40", "Best of luck to zingg", "Five", "thank you", "test")), - schemaOriginal); - - Dataset actual = spark.createDataFrame( - Arrays.asList( - RowFactory.create("1648811730857:10", "10", "1.0", "0.555555", "-1", - "The zingg spark application", "two", "Yes. good application", "test"), - RowFactory.create("1648811730857:20", "20", "1.0", "1.0", "-1", - "It very popular data science", "Three", "true indeed", "test"), - RowFactory.create("1648811730857:30", "30", "1.0", "0.999995", "-1", - "It written java scala", "four", "", "test"), - RowFactory.create("1648811730857:40", "40", "1.0", "1.0", "-1", "Best luck zingg", "Five", - "thank", "test")), - schemaActual); - - Dataset newDataset = ((SparkFrame)(zsCTX.getDSUtil().postprocess(new SparkFrame(actual), new SparkFrame(original)))).df(); - assertTrue(newDataset.select(ColName.ID_COL, "field1", "field2", "field3", ColName.SOURCE_COL).except(original).isEmpty()); - assertTrue(original.except(newDataset.select(ColName.ID_COL, "field1", "field2", "field3", ColName.SOURCE_COL)).isEmpty()); - } - - @Test - public void testOriginalDataAfterPostprocessLinked() { - StructType schemaActual = new StructType(new StructField[] { - new StructField(ColName.CLUSTER_COLUMN, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.PREDICTION_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SCORE_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.MATCH_FLAG_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - StructType schemaOriginal = new StructType(new StructField[] { - new StructField(ColName.ID_COL, DataTypes.StringType, false, Metadata.empty()), - new StructField("field1", DataTypes.StringType, false, Metadata.empty()), - new StructField("field2", DataTypes.StringType, false, Metadata.empty()), - new StructField("field3", DataTypes.StringType, false, Metadata.empty()), - new StructField(ColName.SOURCE_COL, DataTypes.StringType, false, Metadata.empty()) - }); - - Dataset original = spark.createDataFrame( - Arrays.asList( - RowFactory.create("10", "The zingg is a spark application", "two", - "Yes. a good application", "test"), - RowFactory.create("20", "It is very popular in data science", "Three", "true indeed", - "test"), - RowFactory.create("30", "It is written in java and scala", "four", "", "test"), - RowFactory.create("40", "Best of luck to zingg", "Five", "thank you", "test")), - schemaOriginal); - - Dataset actual = spark.createDataFrame( - Arrays.asList( - RowFactory.create("1648811730857:10", "10", "1.0", "0.555555", "-1", - "The zingg spark application", "two", "Yes. good application", "test"), - RowFactory.create("1648811730857:20", "20", "1.0", "1.0", "-1", - "It very popular data science", "Three", "true indeed", "test"), - RowFactory.create("1648811730857:30", "30", "1.0", "0.999995", "-1", - "It written java scala", "four", "", "test"), - RowFactory.create("1648811730857:40", "40", "1.0", "1.0", "-1", "Best luck zingg", "Five", - "thank", "test")), - schemaActual); - - System.out.println("testOriginalDataAfterPostprocessLinked original :"); - original.show(200); - - Dataset newDataset = ((SparkFrame)(zsCTX.getDSUtil().postprocessLinked(new SparkFrame(actual), new SparkFrame(original)))).df(); - - System.out.println("testOriginalDataAfterPostprocessLinked newDataset :"); - newDataset.show(200); - - assertTrue(newDataset.select("field1", "field2", "field3").except(original.select("field1", "field2", "field3")).isEmpty()); - assertTrue(original.select("field1", "field2", "field3").except(newDataset.select("field1", "field2", "field3")).isEmpty()); - } -} \ No newline at end of file diff --git a/spark/core/src/test/java/zingg/common/core/sparkFrame/TestSparkFrame.java b/spark/core/src/test/java/zingg/common/core/sparkFrame/TestSparkFrame.java new file mode 100644 index 000000000..4bddaa44a --- /dev/null +++ b/spark/core/src/test/java/zingg/common/core/sparkFrame/TestSparkFrame.java @@ -0,0 +1,22 @@ +package zingg.common.core.sparkFrame; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.extension.ExtendWith; +import zingg.TestSparkBase; +import zingg.common.client.util.IWithSession; +import zingg.common.client.util.WithSession; +import zingg.common.core.zFrame.TestZFrameBase; +import zingg.spark.client.util.SparkDFObjectUtil; + +@ExtendWith(TestSparkBase.class) +public class TestSparkFrame extends TestZFrameBase, Row, Column> { + public static IWithSession iWithSession = new WithSession(); + + public TestSparkFrame(SparkSession sparkSession) { + super(new SparkDFObjectUtil(iWithSession)); + iWithSession.setSession(sparkSession); + } +} \ No newline at end of file diff --git a/spark/core/src/test/java/zingg/common/core/util/SparkStopWordRemoverUtility.java b/spark/core/src/test/java/zingg/common/core/util/SparkStopWordRemoverUtility.java new file mode 100644 index 000000000..32af2bbbd --- /dev/null +++ b/spark/core/src/test/java/zingg/common/core/util/SparkStopWordRemoverUtility.java @@ -0,0 +1,26 @@ +package zingg.common.core.util; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import zingg.common.client.IArguments; +import zingg.common.client.ZinggClientException; +import zingg.common.core.context.Context; +import zingg.spark.core.preprocess.SparkStopWordsRemover; + +public class SparkStopWordRemoverUtility extends StopWordRemoverUtility, Row, Column, DataType> { + + private final Context, Row, Column, DataType> context; + + public SparkStopWordRemoverUtility(Context, Row, Column, DataType> context) throws ZinggClientException { + super(); + this.context = context; + } + + @Override + public void addStopWordRemover(IArguments iArguments) { + super.stopWordsRemovers.add(new SparkStopWordsRemover(context, iArguments)); + } +} diff --git a/spark/core/src/test/java/zingg/spark/core/executor/JunitSparkLabeller.java b/spark/core/src/test/java/zingg/spark/core/executor/JunitSparkLabeller.java new file mode 100644 index 000000000..ba1ed9372 --- /dev/null +++ b/spark/core/src/test/java/zingg/spark/core/executor/JunitSparkLabeller.java @@ -0,0 +1,44 @@ +package zingg.spark.core.executor; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; + +import zingg.common.client.IArguments; +import zingg.common.client.ZFrame; +import zingg.common.client.ZinggClientException; +import zingg.common.client.options.ZinggOptions; +import zingg.common.core.executor.JunitLabeller; +import zingg.spark.core.context.ZinggSparkContext; + +public class JunitSparkLabeller extends SparkLabeller { + + private static final long serialVersionUID = 1L; + + JunitLabeller,Row,Column,DataType> junitLabeller; + + public JunitSparkLabeller() { + this(new ZinggSparkContext()); + } + + public JunitSparkLabeller(ZinggSparkContext sparkContext) { + setZinggOption(ZinggOptions.LABEL); + setContext(sparkContext); + junitLabeller = new JunitLabeller,Row,Column,DataType>(sparkContext); + } + + @Override + public void setArgs(IArguments args) { + super.setArgs(args); + junitLabeller.setArgs(args); + } + + @Override + public ZFrame,Row,Column> processRecordsCli(ZFrame,Row,Column> lines) + throws ZinggClientException { + return junitLabeller.processRecordsCli(lines); + } +} + diff --git a/spark/core/src/test/java/zingg/spark/core/executor/SparkTrainerTester.java b/spark/core/src/test/java/zingg/spark/core/executor/SparkTrainerTester.java new file mode 100644 index 000000000..db1e45f09 --- /dev/null +++ b/spark/core/src/test/java/zingg/spark/core/executor/SparkTrainerTester.java @@ -0,0 +1,37 @@ +package zingg.spark.core.executor; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; + +import zingg.common.client.IArguments; +import zingg.common.client.ZinggClientException; +import zingg.common.core.executor.Trainer; +import zingg.common.core.executor.TrainerTester; + +public class SparkTrainerTester extends TrainerTester,Row,Column,DataType> { + + public static final Log LOG = LogFactory.getLog(SparkTrainerTester.class); + + public SparkTrainerTester(Trainer,Row,Column,DataType> executor,IArguments args) { + super(executor,args); + } + + @Override + public void validateResults() throws ZinggClientException { + // check that model is created + LOG.info("Zingg Model Dir : "+args.getZinggModelDir()); + + File modelDir = new File(args.getZinggModelDir()); + assertTrue(modelDir.exists(),"check if model has been created"); + } + +} diff --git a/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutors.java b/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutors.java new file mode 100644 index 000000000..e948393d9 --- /dev/null +++ b/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutors.java @@ -0,0 +1,94 @@ +package zingg.spark.core.executor; + +import java.io.File; +import java.io.IOException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.junit.jupiter.api.AfterEach; + +import zingg.common.client.ZinggClientException; +import zingg.common.core.executor.Labeller; +import zingg.common.core.executor.TestExecutorsGeneric; +import zingg.common.core.executor.Trainer; +import zingg.common.core.executor.TrainerTester; +import zingg.spark.core.context.ZinggSparkContext; + +public class TestSparkExecutors extends TestExecutorsGeneric,Row,Column,DataType> { + protected static final String CONFIG_FILE = "zingg/spark/core/executor/configSparkIntTest.json"; + + protected static final String TEST_DATA_FILE = "zingg/spark/core/executor/test.csv"; + + public static final Log LOG = LogFactory.getLog(TestSparkExecutors.class); + + protected ZinggSparkContext ctx; + + + public TestSparkExecutors() throws IOException, ZinggClientException { + SparkSession spark = SparkSession + .builder() + .master("local[*]") + .appName("Zingg" + "Junit") + .getOrCreate(); + this.ctx = new ZinggSparkContext(); + this.ctx.setSession(spark); + this.ctx.setUtils(); + init(spark); + } + + @Override + public String getConfigFile() { + return CONFIG_FILE; + } + + @Override + protected SparkTrainingDataFinder getTrainingDataFinder() throws ZinggClientException { + SparkTrainingDataFinder stdf = new SparkTrainingDataFinder(ctx); + return stdf; + } + @Override + protected Labeller,Row,Column,DataType> getLabeller() throws ZinggClientException { + JunitSparkLabeller jlbl = new JunitSparkLabeller(ctx); + return jlbl; + } + @Override + protected SparkTrainer getTrainer() throws ZinggClientException { + SparkTrainer st = new SparkTrainer(ctx); + return st; + } + @Override + protected SparkMatcher getMatcher() throws ZinggClientException { + SparkMatcher sm = new SparkMatcher(ctx); + return sm; + } + + + @Override + public String setupArgs() throws ZinggClientException, IOException { + String configFile = super.setupArgs(); + String testFile = getClass().getClassLoader().getResource(TEST_DATA_FILE).getFile(); + // correct the location of test data + args.getData()[0].setProp("location", testFile); + return configFile; + } + + @Override + protected SparkTrainerTester getTrainerTester(Trainer,Row,Column,DataType> trainer) { + return new SparkTrainerTester(trainer,args); + } + + @Override + @AfterEach + public void tearDown() { + // just rename, would be removed automatically as it's in /tmp + File dir = new File(args.getZinggDir()); + File newDir = new File(dir.getParent() + "/zingg_junit_" + System.currentTimeMillis()); + dir.renameTo(newDir); + } + +} diff --git a/spark/core/src/test/java/zingg/spark/core/executor/ZinggSparkTester.java b/spark/core/src/test/java/zingg/spark/core/executor/ZinggSparkTester.java index adf26d3ac..447ac780d 100644 --- a/spark/core/src/test/java/zingg/spark/core/executor/ZinggSparkTester.java +++ b/spark/core/src/test/java/zingg/spark/core/executor/ZinggSparkTester.java @@ -20,13 +20,8 @@ import zingg.common.client.ArgumentsUtil; import zingg.common.client.IArguments; import zingg.common.client.IZingg; -import zingg.spark.client.ZSparkSession; -import zingg.spark.core.util.SparkBlockingTreeUtil; -import zingg.spark.core.util.SparkDSUtil; -import zingg.spark.core.util.SparkGraphUtil; -import zingg.spark.core.util.SparkHashUtil; -import zingg.spark.core.util.SparkModelUtil; -import zingg.spark.core.util.SparkPipeUtil; + +import zingg.spark.core.context.ZinggSparkContext; public class ZinggSparkTester { @@ -34,7 +29,6 @@ public class ZinggSparkTester { public static JavaSparkContext ctx; public static SparkSession spark; public static ZinggSparkContext zsCTX; - public static ZSparkSession zSession; public ArgumentsUtil argsUtil = new ArgumentsUtil(); public static final Log LOG = LogFactory.getLog(ZinggSparkTester.class); @@ -47,24 +41,14 @@ public static void setup() { spark = SparkSession .builder() .master("local[*]") - .appName("Zingg" + "Junit") + .appName("ZinggJunit") + .config("spark.debug.maxToStringFields", 100) .getOrCreate(); ctx = new JavaSparkContext(spark.sparkContext()); JavaSparkContext.jarOfClass(IZingg.class); args = new Arguments(); zsCTX = new ZinggSparkContext(); - zsCTX.ctx = ctx; - zSession = new ZSparkSession(spark, null); - zsCTX.zSession = zSession; - - ctx.setCheckpointDir("/tmp/checkpoint"); - zsCTX.setPipeUtil(new SparkPipeUtil(zSession)); - zsCTX.setDSUtil(new SparkDSUtil(zSession)); - zsCTX.setHashUtil(new SparkHashUtil(zSession)); - zsCTX.setGraphUtil(new SparkGraphUtil()); - zsCTX.setModelUtil(new SparkModelUtil(zSession)); - zsCTX.setBlockingTreeUtil(new SparkBlockingTreeUtil(zSession, zsCTX.getPipeUtil())); - + zsCTX.init(spark); } catch (Throwable e) { if (LOG.isDebugEnabled()) e.printStackTrace(); @@ -100,10 +84,5 @@ public Dataset createDFWithDoubles(int numRows, int numCols) { return spark.createDataFrame(nums, structType); - - - - - } } diff --git a/spark/core/src/test/resources/zingg/spark/core/executor/configSparkIntTest.json b/spark/core/src/test/resources/zingg/spark/core/executor/configSparkIntTest.json new file mode 100644 index 000000000..0ef68c004 --- /dev/null +++ b/spark/core/src/test/resources/zingg/spark/core/executor/configSparkIntTest.json @@ -0,0 +1,95 @@ +{ + "fieldDefinition":[ + { + "fieldName" : "id", + "matchType" : "dont_use", + "fields" : "id", + "dataType": "string" + }, + { + "fieldName" : "fname", + "matchType" : "fuzzy", + "fields" : "fname", + "dataType": "string" + }, + { + "fieldName" : "lname", + "matchType" : "fuzzy", + "fields" : "lname", + "dataType": "string" + }, + { + "fieldName" : "stNo", + "matchType": "fuzzy", + "fields" : "stNo", + "dataType": "string" + }, + { + "fieldName" : "add1", + "matchType": "fuzzy", + "fields" : "add1", + "dataType": "string" + }, + { + "fieldName" : "add2", + "matchType": "fuzzy", + "fields" : "add2", + "dataType": "string" + }, + { + "fieldName" : "city", + "matchType": "fuzzy", + "fields" : "city", + "dataType": "string" + }, + { + "fieldName" : "areacode", + "matchType": "fuzzy", + "fields" : "areacode", + "dataType": "string" + }, + { + "fieldName" : "state", + "matchType": "fuzzy", + "fields" : "state", + "dataType": "string" + }, + { + "fieldName" : "dob", + "matchType": "fuzzy", + "fields" : "dob", + "dataType": "string" + }, + { + "fieldName" : "ssn", + "matchType": "fuzzy", + "fields" : "ssn", + "dataType": "string" + } + ], + "output" : [{ + "name":"output", + "format":"csv", + "props": { + "location": "/tmp/junit_integration_spark/zinggOutput", + "delimiter": ",", + "header":true + } + }], + "data" : [{ + "name":"test", + "format":"csv", + "props": { + "location": "test.csv", + "delimiter": ",", + "header":false + }, + "schema": "id string, fname string, lname string, stNo string, add1 string, add2 string, city string, state string, areacode string, dob string, ssn string" + } + ], + "labelDataSampleSize" : 0.5, + "numPartitions":4, + "modelId": "junit_integration_spark", + "zinggDir": "/tmp/junit_integration_spark" + +} diff --git a/spark/core/src/test/resources/zingg/spark/core/executor/test.csv b/spark/core/src/test/resources/zingg/spark/core/executor/test.csv new file mode 100644 index 000000000..4473ef4c2 --- /dev/null +++ b/spark/core/src/test/resources/zingg/spark/core/executor/test.csv @@ -0,0 +1,63 @@ +rec-1020-org, blake, ryan,4, starling place, berkeley vlge, marsden,5412, nsw,19271027,2402765 +rec-1021-dup-0, thomas, georgze,1, mcmanus place, , north turarmurra,3130, sa,19630225,5460534 +rec-1021-org, thomas, george,1, mcmanus place, stoney creek, north turramurra,3130, sa,19630225,5460534 +rec-1022-dup-1, Érik, Guay,840, mountview, fowles treet, burlei gh heads,2803, sa,19830807,2932837 +rec-1022-dup-2, Érik, Guay,840, fowles street, moun tvjiew, burleigh heads,2830, ss, ,2932837 +rec-1022-dup-3, jackson, christo,840, fowles street, mou ntveiw, burleig heads,2830, sa,19830807,2932837 +rec-1022-dup-4, jackson, eglinton,840, fowles street, mountv iew, burleigh heads,2830, sa,19830807,2932837 +rec-1022-org, jackson, eglinton,840, fowles street, mountview, burleigh heads,2830, sa,19830807,2932837 +rec-1023-org, gianni, matson,701, willis street, boonooloo, clifton,3101, vic,19410111,2540080 +rec-1024-org, takeisha, freeborn,6, suttor street, the groves street, wentworth falls,4615, vic,19620206,8111362 +rec-1025-org, emiily, britten,8, kitchener street, hilltop hostel rowethorpe, lake heights,2463, qld,19491021,9588775 +rec-1026-dup-0, xani, green, , phill ip avenue, , armidale,5108, nsw,19390410,9201057 +rec-1026-dup-1, xani, green,2, phillip avenue, abbey green, armidale,5108, nsw,19390410,9201857 +rec-1026-org, xani, green,2, phillip avenue, abbey green, armidale,5108, nsw,19390410,9201057 +rec-1027-org, nathan, smallacombe,20, guthridge crescent, red cross units, sandy bay,6056, sa,19241223,7522263 +rec-1028-dup-0, , ,24, , woorinyan, riverwood,3749, qld,19180205,9341716 +rec-1028-dup-1, , eglinton,24, curriecrescent, woorinyan, riverwood,3749, qld,19180205,1909717 +rec-1028-org, , eglinton,24, currie crescent, woorinyan, riverwood,3749, qld,19180205,9341716 +rec-1029-dup-0, kylee, stepehndon,81, rose scott circuit, cordobak anor, ashfield,4226, vic,19461101,4783085 +rec-1029-dup-1, sachin, stephenson,81, rose scott circuit, cordoba manor, ashfi eld,4226, vic,19461101,4783085 +rec-1029-dup-2, annalise, stephenson,81, rose scott circuit, cordoba manor, ashfoeld,4226, vic,19461101,4783085 +rec-1029-dup-3, kykee, turale,81, rose scott circuit, , ashfield,4226, vic,19461101,4783085 +rec-1029-dup-4, kylee, stephenson,81, cordoba manor, rose scott circuit, ashfield,4226, vic,19461101,4783085 +rec-1029-org, kylee, stephenson,81, rose scott circuit, cordoba manor, ashfield,4226, vic,19461101,4783085 +rec-103-dup-0, benjamin, koerbin,15, wybel anah, violet grover place, mill park,2446, nsw,19210210,3808808 +rec-103-org, briony, koerbin,146, violet grover place, wybelanah, mill park,2446, nsw,19210210,3808808 +rec-1030-org, emma, crossman,53, mcdowall place, kellhaven, tara,5608, vic,19391027,3561186 +rec-1031-org, samantha, sabieray,68, quandong street, wattle brae, gorokan,4019, wa,19590807,2863290 +rec-1032-dup-0, brooklyn, naar-cafentas,210, duffy street, tourist psrk, berwick,2481, nsw, ,3624304 +rec-1032-org, brooklyn, naar-cafentas,210, duffy street, tourist park, berwick,2481, nsw,19840802,3624304 +rec-1033-dup-0, keziah, painter,18, ainsli e avenue, sec 1, torquay,3205, vic,19191031,7801066 +rec-1033-org, keziah, painter,18, ainslie avenue, sec 1, torquay,3205, vic,19191031,7801066 +rec-1034-dup-0, erin, maynard,24, , wariala, little river,2777, vic,19970430,7429462 +rec-1034-dup-1, erin, maynard,51, wilshire street, warialda, little irver,2777, vic,19970430,1815999 +rec-1034-dup-2, hayley, maynard,14, wilshire street, , little river,2777, vic,19970430,7429462 +rec-1034-org, erin, maynard,14, wilshire street, warialda, little river,2777, vic,19970430,7429462 +rec-1035-dup-0, jaiden, rollins,48, tulgeywood, rossarden street, balwyn north,2224, nt,19280722,7626396 +rec-1035-dup-1, jaiden, rollins,95, rossarden street, tulgewyood, balwyn north,2224, nt,19280722,7626396 +rec-1035-dup-2, jaiden, rolilns,48, swinden street, tulgeywood, balwyn north,2224, nt,19280722,7626396 +rec-1035-dup-3, jaiden, rolli ns,48, tulgeywomod, rossarden street, balwyn north,2224, nf,19280722,7626396 +rec-1035-org, jaiden, rollins,48, rossarden street, tulgeywood, balwyn north,2224, nt,19280722,7626396 +rec-1036-dup-0, , held,24, lampard circuit, emerald garden, golden bay,2447, vic,19510806,3710651 +rec-1036-dup-1, sarsha, held,42, lampard circuit, , golden bay,2447, vic,19510806,3710651 +rec-1036-org, amber, held,24, lampard circuit, emerald garden, golden bay,2447, vic,19510806,3710651 +rec-1037-org, connor, beckwith,10, heard street, , mill park,5031, nsw,19081103,2209091 +rec-1038-org, danny, campbell,95, totterdell street, moama, shellharbour,2209, vic,19951105,9554924 +rec-1039-dup-0, angus, roas,62, gormansto crescent, mlc centre, kiruwah,3350, sa,19250817,2655081 +rec-1039-org, angus, rosa,62, gormanston crescent, mlc centre, kirwan,3350, sa,19250817,2655081 +rec-104-dup-0, benjaminl, carbone,18, arthella, wattle s treet, orange,3550, vic,19050820,3677127 +rec-104-org, benjamin, carbone,18, wattle street, arthella, orange,3550, vic,19050820,3677127 +rec-1040-dup-0, matilda, mestrov, , housecicuit, retirement village, taringa,3820, qld,19801119,2536135 +rec-1040-dup-1, matilda, mestrv,5, house circuit, retirement village, taringa,3802, qld,19801119,2563135 +rec-1040-dup-2, matilda, mestrov,5, house circuit, retiremen tvillage, taringa,3820, ,19801119,2563135 +rec-1040-org, matilda, mestrov,5, house circuit, retirement village, taringa,3820, qld,19801119,2563135 +rec-1041-dup-0, tyler, frojd, , burramurra avenue, kmart p plaza, san rmeo,3670, sa,19800916,7812219 +rec-1041-org, tyler, froud,8, burramurra avenue, kmart p plaza, san remo,3670, sa,19800916,7812219 +rec-1042-dup-0, kiandra, ,2, gatliff place, rustenburg sth, girgarre,3995, qld,19801125,3328205 +rec-1042-dup-1, kiandra, cowle,2, gatliff place, rustenubr g sth, girgarre,3995, qld,19801125,3328205 +rec-1042-org, kiandra, cowle,2, gatliff place, rustenburg sth, girgarre,3995, qld,19801125,3328205 +rec-1043-org, giorgia, frahn,62, handasyde street, ramano estate locn 1, tallebudgera,4506, vic,19670206,9724789 +rec-1044-dup-0, nicole, shadbolt,46, schlich s treet, simpson army barracks, toowoomba,3000, wa,19030926,8190756 +rec-1044-dup-2, nicole, carbone,46, schlich street, simpson arm ybarracks, toowong,3000, wa,19030926,8190756 +rec-1044-org, nicole, carbone,46, schlich street, simpson army barracks, toowoomba,3000, wa,19030926,8190756 diff --git a/spark/pom.xml b/spark/pom.xml index 2ea784073..5e9a89a7f 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -13,6 +13,18 @@ client + + org.apache.spark + spark-connect_${scala.binary.version} + ${spark.version} + + + com.google.guava + guava + + + provided + org.apache.spark spark-mllib_${scala.binary.version} @@ -42,8 +54,18 @@ graphframes graphframes ${graphframes.version} + + + org.slf4j + slf4j-api + + + + com.google.protobuf + protobuf-java + 3.25.1 + compile + - - diff --git a/test/note.txt b/test/note.txt index 4503641df..686bb0cdc 100644 --- a/test/note.txt +++ b/test/note.txt @@ -5,7 +5,7 @@ To run it: 1. cd test/ 2. pyspark < testInfraOwnGateway.py (or) -2. /opt/spark-3.2.4-bin-hadoop3.2/bin/spark-submit --jars ../common/client/target/zingg-common-client-0.4.0-SNAPSHOT.jar testInfraOwnGateway.py +2. /opt/spark-3.2.4-bin-hadoop3.2/bin/spark-submit --jars ../common/client/target/zingg-common-client-0.4.1-SNAPSHOT.jar testInfraOwnGateway.py If faced version mismatch issue: diff --git a/test/run_tests.sh b/test/run_tests.sh index 93ffcec82..0bd396f5b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -1,8 +1,8 @@ #!/bin/bash # Set the paths to your JAR files and Spark binaries -SPARK_HOME="/opt/spark-3.2.4-bin-hadoop3.2" -PY4J_JAR="../common/client/target/zingg-common-client-0.4.0.jar" +SPARK_HOME="/opt/spark-3.5.0-bin-hadoop3" +PY4J_JAR="../common/client/target/zingg-common-client-0.4.1-SNAPSHOT.jar" # Run Spark with the required JAR files and your test script $SPARK_HOME/bin/spark-submit --jars $PY4J_JAR testInfra.py