aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-08-05 17:07:55 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-05 17:07:55 -0700
commita018b85716fd510ae95a3c66d676bbdb90f8d4e7 (patch)
tree9e03081c9fb8947e1798cc7a285d68e8caf14ce9 /mllib/src/test
parent9c878923db6634effed98c99bf24dd263bb7c6ad (diff)
downloadspark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.tar.gz
spark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.tar.bz2
spark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.zip
[SPARK-5895] [ML] Add VectorSlicer - updated
Add VectorSlicer transformer to spark.ml, with features specified as either indices or names. Transfers feature attributes for selected features. Updated version of [https://github.com/apache/spark/pull/5731] CC: yinxusen This updates your PR. You'll still be the primary author of this PR. CC: mengxr Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits: b16e86e [Joseph K. Bradley] fixed scala style 71c65d2 [Joseph K. Bradley] fix import order 86e9739 [Joseph K. Bradley] cleanups per code review 9d8d6f1 [Joseph K. Bradley] style fix 83bc2e9 [Joseph K. Bradley] Updated VectorSlicer 98c6939 [Xusen Yin] fix style error ecbf2d3 [Xusen Yin] change interfaces and params f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895 e4781f2 [Xusen Yin] fix commit error fd154d7 [Xusen Yin] add test suite of vector slicer 17171f8 [Xusen Yin] fix slicer 9ab9747 [Xusen Yin] add vector slicer aa5a0bf [Xusen Yin] add vector slicer
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala109
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala7
2 files changed, 116 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
new file mode 100644
index 0000000000..a6c2fba836
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ val slicer = new VectorSlicer
+ ParamsSuite.checkParams(slicer)
+ assert(slicer.getIndices.length === 0)
+ assert(slicer.getNames.length === 0)
+ withClue("VectorSlicer should not have any features selected by default") {
+ intercept[IllegalArgumentException] {
+ slicer.validateParams()
+ }
+ }
+ }
+
+ test("feature validity checks") {
+ import VectorSlicer._
+ assert(validIndices(Array(0, 1, 8, 2)))
+ assert(validIndices(Array.empty[Int]))
+ assert(!validIndices(Array(-1)))
+ assert(!validIndices(Array(1, 2, 1)))
+
+ assert(validNames(Array("a", "b")))
+ assert(validNames(Array.empty[String]))
+ assert(!validNames(Array("", "b")))
+ assert(!validNames(Array("a", "b", "a")))
+ }
+
+ test("Test vector slicer") {
+ val sqlContext = new SQLContext(sc)
+
+ val data = Array(
+ Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
+ Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
+ Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
+ Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
+ Vectors.sparse(5, Seq())
+ )
+
+ // Expected after selecting indices 1, 4
+ val expected = Array(
+ Vectors.sparse(2, Seq((0, 2.3))),
+ Vectors.dense(2.3, 1.0),
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(-1.1, 3.3),
+ Vectors.sparse(2, Seq())
+ )
+
+ val defaultAttr = NumericAttribute.defaultAttr
+ val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
+ val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
+
+ val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
+ val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
+
+ val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
+ val df = sqlContext.createDataFrame(rdd,
+ StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
+
+ val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
+
+ def validateResults(df: DataFrame): Unit = {
+ df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 === vec2)
+ }
+ val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
+ val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
+ assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
+ resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
+ assert(a === b)
+ }
+ }
+
+ vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
+ validateResults(vectorSlicer.transform(df))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 1c37ea5123..6508ddeba4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
assert(sv1 === sv1c)
}
+
+ test("SparseVector.slice") {
+ val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
+ assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
+ assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
+ assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
+ }
}