aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-07 08:19:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-07 08:19:17 -0700
commitd73bc08d98e803889ef6215eab81d7bb0049e941 (patch)
tree2334c2725d513099299761907749e68c6a4e2284 /mllib
parentdcbd85b70f026fbc0b7e77fcc364513581007c8d (diff)
downloadspark-d73bc08d98e803889ef6215eab81d7bb0049e941.tar.gz
spark-d73bc08d98e803889ef6215eab81d7bb0049e941.tar.bz2
spark-d73bc08d98e803889ef6215eab81d7bb0049e941.zip
[SPARK-8788] [ML] Add Java unit test for PCA transformer
Add Java unit test for PCA transformer Author: Yanbo Liang <ybliang8@gmail.com> Closes #7184 from yanboliang/spark-8788 and squashes the following commits: 9d1a2af [Yanbo Liang] address comments b34451f [Yanbo Liang] Add Java unit test for PCA transformer
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java114
1 files changed, 114 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
new file mode 100644
index 0000000000..5cf43fec6f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.distributed.RowMatrix;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaPCASuite implements Serializable {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaPCASuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ public static class VectorPair implements Serializable {
+ private Vector features = Vectors.dense(0.0);
+ private Vector expected = Vectors.dense(0.0);
+
+ public void setFeatures(Vector features) {
+ this.features = features;
+ }
+
+ public Vector getFeatures() {
+ return this.features;
+ }
+
+ public void setExpected(Vector expected) {
+ this.expected = expected;
+ }
+
+ public Vector getExpected() {
+ return this.expected;
+ }
+ }
+
+ @Test
+ public void testPCA() {
+ List<Vector> points = Lists.newArrayList(
+ Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ );
+ JavaRDD<Vector> dataRDD = jsc.parallelize(points, 2);
+
+ RowMatrix mat = new RowMatrix(dataRDD.rdd());
+ Matrix pc = mat.computePrincipalComponents(3);
+ JavaRDD<Vector> expected = mat.multiply(pc).rows().toJavaRDD();
+
+ JavaRDD<VectorPair> featuresExpected = dataRDD.zip(expected).map(
+ new Function<Tuple2<Vector, Vector>, VectorPair>() {
+ public VectorPair call(Tuple2<Vector, Vector> pair) {
+ VectorPair featuresExpected = new VectorPair();
+ featuresExpected.setFeatures(pair._1());
+ featuresExpected.setExpected(pair._2());
+ return featuresExpected;
+ }
+ }
+ );
+
+ DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ PCAModel pca = new PCA()
+ .setInputCol("features")
+ .setOutputCol("pca_features")
+ .setK(3)
+ .fit(df);
+ List<Row> result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect();
+ for (Row r : result) {
+ Assert.assertEquals(r.get(1), r.get(0));
+ }
+ }
+}