aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-08-21 16:30:12 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-21 16:30:19 -0700
commitfbf9a6ed55e99c986a1283fb20dd719a397231f7 (patch)
tree726e5f74bce4365cec2284da45628680737d584e /mllib/src
parentcb61c7b4e3d06efdbb61795603c71b3a989d6df1 (diff)
downloadspark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.tar.gz
spark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.tar.bz2
spark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.zip
[SPARK-9893] User guide with Java test suite for VectorSlicer
Add user guide for `VectorSlicer`, with Java test suite and Python version VectorSlicer. Note that Python version does not support selecting by names now. Author: Xusen Yin <yinxusen@gmail.com> Closes #8267 from yinxusen/SPARK-9893. (cherry picked from commit 630a994e6a9785d1704f8e7fb604f32f5dea24f8) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java85
1 files changed, 85 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
new file mode 100644
index 0000000000..56988b9fb2
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.attribute.Attribute;
+import org.apache.spark.ml.attribute.AttributeGroup;
+import org.apache.spark.ml.attribute.NumericAttribute;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructType;
+
+
+public class JavaVectorSlicerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void vectorSlice() {
+ Attribute[] attrs = new Attribute[]{
+ NumericAttribute.defaultAttr().withName("f1"),
+ NumericAttribute.defaultAttr().withName("f2"),
+ NumericAttribute.defaultAttr().withName("f3")
+ };
+ AttributeGroup group = new AttributeGroup("userFeatures", attrs);
+
+ JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
+ RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
+ ));
+
+ DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
+
+ VectorSlicer vectorSlicer = new VectorSlicer()
+ .setInputCol("userFeatures").setOutputCol("features");
+
+ vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
+
+ DataFrame output = vectorSlicer.transform(dataset);
+
+ for (Row r : output.select("userFeatures", "features").take(2)) {
+ Vector features = r.getAs(1);
+ Assert.assertEquals(features.size(), 2);
+ }
+ }
+}