aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-23 22:49:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-23 22:49:08 -0700
commitd91967e159f416924bbd7f0db25156588d4bd7b1 (patch)
treee94afae9b844bd0a6411786eef830fa4f18ab118 /mllib
parent758c9d25e92417f8c06328c3af7ea2ef0212c79f (diff)
downloadspark-d91967e159f416924bbd7f0db25156588d4bd7b1.tar.gz
spark-d91967e159f416924bbd7f0db25156588d4bd7b1.tar.bz2
spark-d91967e159f416924bbd7f0db25156588d4bd7b1.zip
[SPARK-10763] [ML] [JAVA] [TEST] Update Java MLLIB/ML tests to use simplified dataframe construction
As introduced in https://issues.apache.org/jira/browse/SPARK-10630 we now have an easier way to create dataframes from local Java lists. Lets update the tests to use those. Author: Holden Karau <holden@pigscanfly.ca> Closes #8886 from holdenk/SPARK-10763-update-java-mllib-ml-tests-to-use-simplified-dataframe-construction.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java14
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java11
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java12
10 files changed, 42 insertions, 39 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 075a62c493..f5f690eabd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -75,21 +76,20 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testNaiveBayes() {
- JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+ List<Row> data = Arrays.asList(
RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
- RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
- ));
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(jrdd, schema);
+ DataFrame dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 47d68de599..8a1e5ef015 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -55,16 +55,16 @@ public class JavaBucketizerSuite {
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};
- JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
- RowFactory.create(-0.5),
- RowFactory.create(-0.3),
- RowFactory.create(0.0),
- RowFactory.create(0.2)
- ));
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ DataFrame dataset = jsql.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(-0.5),
+ RowFactory.create(-0.3),
+ RowFactory.create(0.0),
+ RowFactory.create(0.2)),
+ schema);
Bucketizer bucketizer = new Bucketizer()
.setInputCol("feature")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 0f6ec64d97..39da47381b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -57,12 +57,11 @@ public class JavaDCTSuite {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
- JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
- RowFactory.create(Vectors.dense(input))
- ));
- DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
- new StructField("vec", (new VectorUDT()), false, Metadata.empty())
- }));
+ DataFrame dataset = jsql.createDataFrame(
+ Arrays.asList(RowFactory.create(Vectors.dense(input))),
+ new StructType(new StructField[]{
+ new StructField("vec", (new VectorUDT()), false, Metadata.empty())
+ }));
double[] expectedResult = input.clone();
(new DoubleDCT_1D(input.length)).forward(expectedResult, true);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 03dd5369bd..d12332c2a0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -55,17 +56,17 @@ public class JavaHashingTFSuite {
@Test
public void hashingTF() {
- JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+ List<Row> data = Arrays.asList(
RowFactory.create(0.0, "Hi I heard about Spark"),
RowFactory.create(0.0, "I wish Java could use case classes"),
RowFactory.create(1.0, "Logistic regression models are neat")
- ));
+ );
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
+ DataFrame sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index 834fedbb59..bf8eefd719 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -60,7 +61,7 @@ public class JavaPolynomialExpansionSuite {
.setOutputCol("polyFeatures")
.setDegree(3);
- JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
+ List<Row> data = Arrays.asList(
RowFactory.create(
Vectors.dense(-2.0, 2.3),
Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
@@ -70,7 +71,7 @@ public class JavaPolynomialExpansionSuite {
Vectors.dense(0.6, -1.1),
Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
)
- ));
+ );
StructType schema = new StructType(new StructField[] {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 76cdd0fae8..848d9f8aa9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -58,14 +59,14 @@ public class JavaStopWordsRemoverSuite {
.setInputCol("raw")
.setOutputCol("filtered");
- JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
+ List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
- ));
+ );
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(rdd, schema);
+ DataFrame dataset = jsql.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 35b18c5308..6b2c48ef1c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -56,9 +57,9 @@ public class JavaStringIndexerSuite {
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
- JavaRDD<Row> rdd = jsc.parallelize(
- Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
- DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ List<Row> data = Arrays.asList(
+ c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"));
+ DataFrame dataset = sqlContext.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index b7c564caad..e283777570 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -65,8 +65,7 @@ public class JavaVectorAssemblerSuite {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
- DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index f953361427..00174e6a68 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -63,12 +64,12 @@ public class JavaVectorSlicerSuite {
};
AttributeGroup group = new AttributeGroup("userFeatures", attrs);
- JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+ List<Row> data = Arrays.asList(
RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
- ));
+ );
- DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
+ DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index 70f5ad9432..0c0c1c4d12 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -51,15 +51,15 @@ public class JavaWord2VecSuite {
@Test
public void testJavaWord2Vec() {
- JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
- RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
- RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
- RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
- ));
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
+ DataFrame documentDF = sqlContext.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
+ RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
+ RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))),
+ schema);
Word2Vec word2Vec = new Word2Vec()
.setInputCol("text")