aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-08-05 17:07:55 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-05 17:07:55 -0700
commita018b85716fd510ae95a3c66d676bbdb90f8d4e7 (patch)
tree9e03081c9fb8947e1798cc7a285d68e8caf14ce9 /mllib/src
parent9c878923db6634effed98c99bf24dd263bb7c6ad (diff)
downloadspark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.tar.gz
spark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.tar.bz2
spark-a018b85716fd510ae95a3c66d676bbdb90f8d4e7.zip
[SPARK-5895] [ML] Add VectorSlicer - updated
Add VectorSlicer transformer to spark.ml, with features specified as either indices or names. Transfers feature attributes for selected features. Updated version of [https://github.com/apache/spark/pull/5731] CC: yinxusen This updates your PR. You'll still be the primary author of this PR. CC: mengxr Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits: b16e86e [Joseph K. Bradley] fixed scala style 71c65d2 [Joseph K. Bradley] fix import order 86e9739 [Joseph K. Bradley] cleanups per code review 9d8d6f1 [Joseph K. Bradley] style fix 83bc2e9 [Joseph K. Bradley] Updated VectorSlicer 98c6939 [Xusen Yin] fix style error ecbf2d3 [Xusen Yin] change interfaces and params f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895 e4781f2 [Xusen Yin] fix commit error fd154d7 [Xusen Yin] add test suite of vector slicer 17171f8 [Xusen Yin] fix slicer 9ab9747 [Xusen Yin] add vector slicer aa5a0bf [Xusen Yin] add vector slicer
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala170
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala109
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala7
5 files changed, 327 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
new file mode 100644
index 0000000000..772bebeff2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * This class takes a feature vector and outputs a new feature vector with a subarray of the
+ * original features.
+ *
+ * The subset of features can be specified with either indices ([[setIndices()]])
+ * or names ([[setNames()]]). At least one feature must be selected. Duplicate features
+ * are not allowed, so there can be no overlap between selected indices and names.
+ *
+ * The output vector will order features with the selected indices first (in the order given),
+ * followed by the selected names (in the order given).
+ */
+@Experimental
+final class VectorSlicer(override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("vectorSlicer"))
+
+ /**
+ * An array of indices to select features from a vector column.
+ * There can be no overlap with [[names]].
+ * @group param
+ */
+ val indices = new IntArrayParam(this, "indices",
+ "An array of indices to select features from a vector column." +
+ " There can be no overlap with names.", VectorSlicer.validIndices)
+
+ setDefault(indices -> Array.empty[Int])
+
+ /** @group getParam */
+ def getIndices: Array[Int] = $(indices)
+
+ /** @group setParam */
+ def setIndices(value: Array[Int]): this.type = set(indices, value)
+
+ /**
+ * An array of feature names to select features from a vector column.
+ * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
+ * There can be no overlap with [[indices]].
+ * @group param
+ */
+ val names = new StringArrayParam(this, "names",
+ "An array of feature names to select features from a vector column." +
+ " There can be no overlap with indices.", VectorSlicer.validNames)
+
+ setDefault(names -> Array.empty[String])
+
+ /** @group getParam */
+ def getNames: Array[String] = $(names)
+
+ /** @group setParam */
+ def setNames(value: Array[String]): this.type = set(names, value)
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def validateParams(): Unit = {
+ require($(indices).length > 0 || $(names).length > 0,
+ s"VectorSlicer requires that at least one feature be selected.")
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ // Validity checks
+ transformSchema(dataset.schema)
+ val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
+ inputAttr.numAttributes.foreach { numFeatures =>
+ val maxIndex = $(indices).max
+ require(maxIndex < numFeatures,
+ s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
+ }
+
+ // Prepare output attributes
+ val inds = getSelectedFeatureIndices(dataset.schema)
+ val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
+ inds.map(index => attrs(index))
+ }
+ val outputAttr = selectedAttrs match {
+ case Some(attrs) => new AttributeGroup($(outputCol), attrs)
+ case None => new AttributeGroup($(outputCol), inds.length)
+ }
+
+ // Select features
+ val slicer = udf { vec: Vector =>
+ vec match {
+ case features: DenseVector => Vectors.dense(inds.map(features.apply))
+ case features: SparseVector => features.slice(inds)
+ }
+ }
+ dataset.withColumn($(outputCol),
+ slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+ }
+
+ /** Get the feature indices in order: indices, names */
+ private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
+ val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
+ val indFeatures = $(indices)
+ val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
+ lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
+ s" sets of features, but they overlap." +
+ s" indices: ${indFeatures.mkString("[", ",", "]")}." +
+ s" names: " +
+ nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
+ require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
+ indFeatures ++ nameFeatures
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+
+ if (schema.fieldNames.contains($(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
+ }
+ val numFeaturesSelected = $(indices).length + $(names).length
+ val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
+ val outputFields = schema.fields :+ outputAttr.toStructField()
+ StructType(outputFields)
+ }
+
+ override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
+}
+
+private[feature] object VectorSlicer {
+
+ /** Return true if given feature indices are valid */
+ def validIndices(indices: Array[Int]): Boolean = {
+ if (indices.isEmpty) {
+ true
+ } else {
+ indices.length == indices.distinct.length && indices.forall(_ >= 0)
+ }
+ }
+
+ /** Return true if given feature names are valid */
+ def validNames(names: Array[String]): Boolean = {
+ names.forall(_.nonEmpty) && names.length == names.distinct.length
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index 2a1db90f2c..fcb517b5f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap
import org.apache.spark.ml.attribute._
+import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.StructField
@@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
}
}
+ /**
+ * Takes a Vector column and a list of feature names, and returns the corresponding list of
+ * feature indices in the column, in order.
+ * @param col Vector column which must have feature names specified via attributes
+ * @param names List of feature names
+ */
+ def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
+ require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
+ + s" to be Vector type, but it was type ${col.dataType} instead.")
+ val inputAttr = AttributeGroup.fromStructField(col)
+ names.map { name =>
+ require(inputAttr.hasAttr(name),
+ s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
+ inputAttr.getAttr(name).index.get
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 96d1f48ba2..86c461fa91 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -766,6 +766,30 @@ class SparseVector(
maxIdx
}
}
+
+ /**
+ * Create a slice of this vector based on the given indices.
+ * @param selectedIndices Unsorted list of indices into the vector.
+ * This does NOT do bound checking.
+ * @return New SparseVector with values in the order specified by the given indices.
+ *
+ * NOTE: The API needs to be discussed before making this public.
+ * Also, if we have a version assuming indices are sorted, we should optimize it.
+ */
+ private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
+ var currentIdx = 0
+ val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
+ val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
+ val i_v = if (iIdx >= 0) {
+ Iterator((currentIdx, this.values(iIdx)))
+ } else {
+ Iterator()
+ }
+ currentIdx += 1
+ i_v
+ }.unzip
+ new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
+ }
}
object SparseVector {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
new file mode 100644
index 0000000000..a6c2fba836
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ val slicer = new VectorSlicer
+ ParamsSuite.checkParams(slicer)
+ assert(slicer.getIndices.length === 0)
+ assert(slicer.getNames.length === 0)
+ withClue("VectorSlicer should not have any features selected by default") {
+ intercept[IllegalArgumentException] {
+ slicer.validateParams()
+ }
+ }
+ }
+
+ test("feature validity checks") {
+ import VectorSlicer._
+ assert(validIndices(Array(0, 1, 8, 2)))
+ assert(validIndices(Array.empty[Int]))
+ assert(!validIndices(Array(-1)))
+ assert(!validIndices(Array(1, 2, 1)))
+
+ assert(validNames(Array("a", "b")))
+ assert(validNames(Array.empty[String]))
+ assert(!validNames(Array("", "b")))
+ assert(!validNames(Array("a", "b", "a")))
+ }
+
+ test("Test vector slicer") {
+ val sqlContext = new SQLContext(sc)
+
+ val data = Array(
+ Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
+ Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
+ Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
+ Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
+ Vectors.sparse(5, Seq())
+ )
+
+ // Expected after selecting indices 1, 4
+ val expected = Array(
+ Vectors.sparse(2, Seq((0, 2.3))),
+ Vectors.dense(2.3, 1.0),
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(-1.1, 3.3),
+ Vectors.sparse(2, Seq())
+ )
+
+ val defaultAttr = NumericAttribute.defaultAttr
+ val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
+ val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
+
+ val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
+ val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
+
+ val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
+ val df = sqlContext.createDataFrame(rdd,
+ StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
+
+ val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
+
+ def validateResults(df: DataFrame): Unit = {
+ df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 === vec2)
+ }
+ val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
+ val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
+ assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
+ resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
+ assert(a === b)
+ }
+ }
+
+ vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
+ validateResults(vectorSlicer.transform(df))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 1c37ea5123..6508ddeba4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
assert(sv1 === sv1c)
}
+
+ test("SparseVector.slice") {
+ val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
+ assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
+ assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
+ assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
+ }
}