diff options
author | Xusen Yin <yinxusen@gmail.com> | 2015-08-21 16:30:12 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-08-21 16:30:19 -0700 |
commit | fbf9a6ed55e99c986a1283fb20dd719a397231f7 (patch) | |
tree | 726e5f74bce4365cec2284da45628680737d584e /mllib/src | |
parent | cb61c7b4e3d06efdbb61795603c71b3a989d6df1 (diff) | |
download | spark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.tar.gz spark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.tar.bz2 spark-fbf9a6ed55e99c986a1283fb20dd719a397231f7.zip |
[SPARK-9893] User guide with Java test suite for VectorSlicer
Add user guide for `VectorSlicer`, with Java test suite and Python version VectorSlicer.
Note that Python version does not support selecting by names now.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #8267 from yinxusen/SPARK-9893.
(cherry picked from commit 630a994e6a9785d1704f8e7fb604f32f5dea24f8)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java new file mode 100644 index 0000000000..56988b9fb2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; + + +public class JavaVectorSlicerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void vectorSlice() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + + DataFrame output = vectorSlicer.transform(dataset); + + for (Row r : output.select("userFeatures", "features").take(2)) { + Vector features = r.getAs(1); + Assert.assertEquals(features.size(), 2); + } + } +} |