From d91967e159f416924bbd7f0db25156588d4bd7b1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 23 Sep 2015 22:49:08 -0700 Subject: [SPARK-10763] [ML] [JAVA] [TEST] Update Java MLLIB/ML tests to use simplified dataframe construction As introduced in https://issues.apache.org/jira/browse/SPARK-10630 we now have an easier way to create dataframes from local Java lists. Lets update the tests to use those. Author: Holden Karau Closes #8886 from holdenk/SPARK-10763-update-java-mllib-ml-tests-to-use-simplified-dataframe-construction. --- .../spark/ml/classification/JavaNaiveBayesSuite.java | 8 ++++---- .../org/apache/spark/ml/feature/JavaBucketizerSuite.java | 14 +++++++------- .../java/org/apache/spark/ml/feature/JavaDCTSuite.java | 11 +++++------ .../org/apache/spark/ml/feature/JavaHashingTFSuite.java | 7 ++++--- .../spark/ml/feature/JavaPolynomialExpansionSuite.java | 5 +++-- .../apache/spark/ml/feature/JavaStopWordsRemoverSuite.java | 7 ++++--- .../apache/spark/ml/feature/JavaStringIndexerSuite.java | 7 ++++--- .../apache/spark/ml/feature/JavaVectorAssemblerSuite.java | 3 +-- .../org/apache/spark/ml/feature/JavaVectorSlicerSuite.java | 7 ++++--- .../org/apache/spark/ml/feature/JavaWord2VecSuite.java | 12 ++++++------ 10 files changed, 42 insertions(+), 39 deletions(-) (limited to 'mllib') diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 075a62c493..f5f690eabd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; @@ -75,21 +76,20 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void testNaiveBayes() { - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), - RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) - )); + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(jrdd, schema); + DataFrame dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 47d68de599..8a1e5ef015 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -55,16 +55,16 @@ public class JavaBucketizerSuite { public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) - )); StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + DataFrame dataset = jsql.createDataFrame( + Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2)), + schema); Bucketizer bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 0f6ec64d97..39da47381b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -57,12 +57,11 @@ public class JavaDCTSuite { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(input)) - )); - DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ - new StructField("vec", (new VectorUDT()), false, Metadata.empty()) - })); + DataFrame dataset = jsql.createDataFrame( + Arrays.asList(RowFactory.create(Vectors.dense(input))), + new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 03dd5369bd..d12332c2a0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -55,17 +56,17 @@ public class JavaHashingTFSuite { @Test public void hashingTF() { - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0.0, "Hi I heard about Spark"), RowFactory.create(0.0, "I wish Java could use case classes"), RowFactory.create(1.0, "Logistic regression models are neat") - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + DataFrame sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 834fedbb59..bf8eefd719 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -60,7 +61,7 @@ public class JavaPolynomialExpansionSuite { .setOutputCol("polyFeatures") .setDegree(3); - JavaRDD data = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create( Vectors.dense(-2.0, 2.3), Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) @@ -70,7 +71,7 @@ public class JavaPolynomialExpansionSuite { Vectors.dense(0.6, -1.1), Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331) ) - )); + ); StructType schema = new StructType(new StructField[] { new StructField("features", new VectorUDT(), false, Metadata.empty()), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 76cdd0fae8..848d9f8aa9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; @@ -58,14 +59,14 @@ public class JavaStopWordsRemoverSuite { .setInputCol("raw") .setOutputCol("filtered"); - JavaRDD rdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) - )); + ); StructType schema = new StructType(new StructField[] { new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(rdd, schema); + DataFrame dataset = jsql.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 35b18c5308..6b2c48ef1c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -56,9 +57,9 @@ public class JavaStringIndexerSuite { createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); - JavaRDD rdd = jsc.parallelize( - Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + List data = Arrays.asList( + c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index b7c564caad..e283777570 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -65,8 +65,7 @@ public class JavaVectorAssemblerSuite { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index f953361427..00174e6a68 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -63,12 +64,12 @@ public class JavaVectorSlicerSuite { }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) - )); + ); - DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 70f5ad9432..0c0c1c4d12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -51,15 +51,15 @@ public class JavaWord2VecSuite { @Test public void testJavaWord2Vec() { - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), - RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), - RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) - )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + DataFrame documentDF = sqlContext.createDataFrame( + Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))), + schema); Word2Vec word2Vec = new Word2Vec() .setInputCol("text") -- cgit v1.2.3