aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-28 16:32:51 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-28 16:32:51 -0700
commit7859ab659eecbcf2d8b9a274a4e9e4f5186a528c (patch)
treedc968849ed71a2aeb02a8e3d0b969785ef4607d3 /mllib
parent3e312a5ed0154527c66eeeee0d2cc3bfce0a820e (diff)
downloadspark-7859ab659eecbcf2d8b9a274a4e9e4f5186a528c.tar.gz
spark-7859ab659eecbcf2d8b9a274a4e9e4f5186a528c.tar.bz2
spark-7859ab659eecbcf2d8b9a274a4e9e4f5186a528c.zip
[SPARK-7198] [MLLIB] VectorAssembler should output ML attributes
`VectorAssembler` should carry over ML attributes. For unknown attributes, we assume numeric values. This PR handles the following cases: 1. DoubleType with ML attribute: carry over 2. DoubleType without ML attribute: numeric value 3. Scalar type: numeric value 4. VectorType with all ML attributes: carry over and update names 5. VectorType with number of ML attributes: assume all numeric 6. VectorType without ML attributes: check the first row and get the number of attributes jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6452 from mengxr/SPARK-7198 and squashes the following commits: a9d2469 [Xiangrui Meng] add space facdb1f [Xiangrui Meng] VectorAssembler should output ML attributes
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala37
2 files changed, 83 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 514ffb03c0..229ee27ec5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@@ -37,7 +38,7 @@ import org.apache.spark.sql.types._
class VectorAssembler(override val uid: String)
extends Transformer with HasInputCols with HasOutputCol {
- def this() = this(Identifiable.randomUID("va"))
+ def this() = this(Identifiable.randomUID("vecAssembler"))
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
@@ -46,19 +47,59 @@ class VectorAssembler(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
+ // Schema transformation.
+ val schema = dataset.schema
+ lazy val first = dataset.first()
+ val attrs = $(inputCols).flatMap { c =>
+ val field = schema(c)
+ val index = schema.fieldIndex(c)
+ field.dataType match {
+ case DoubleType =>
+ val attr = Attribute.fromStructField(field)
+ // If the input column doesn't have ML attribute, assume numeric.
+ if (attr == UnresolvedAttribute) {
+ Some(NumericAttribute.defaultAttr.withName(c))
+ } else {
+ Some(attr.withName(c))
+ }
+ case _: NumericType | BooleanType =>
+ // If the input column type is a compatible scalar type, assume numeric.
+ Some(NumericAttribute.defaultAttr.withName(c))
+ case _: VectorUDT =>
+ val group = AttributeGroup.fromStructField(field)
+ if (group.attributes.isDefined) {
+ // If attributes are defined, copy them with updated names.
+ group.attributes.get.map { attr =>
+ if (attr.name.isDefined) {
+ // TODO: Define a rigorous naming scheme.
+ attr.withName(c + "_" + attr.name.get)
+ } else {
+ attr
+ }
+ }
+ } else {
+ // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
+ // from metadata, check the first row.
+ val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
+ Array.fill(numAttrs)(NumericAttribute.defaultAttr)
+ }
+ }
+ }
+ val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
+
+ // Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
- val schema = dataset.schema
- val inputColNames = $(inputCols)
- val args = inputColNames.map { c =>
+ val args = $(inputCols).map { c =>
schema(c).dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
- dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol)))
+
+ dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata))
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index d0cd62c5e4..43534e8992 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.SparkException
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
@@ -61,4 +63,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
}
}
+
+ test("ML attributes") {
+ val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
+ val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
+ val user = new AttributeGroup("user", Array(
+ NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
+ NumericAttribute.defaultAttr.withName("salary")))
+ val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
+ val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
+ .select(
+ col("browser").as("browser", browser.toMetadata()),
+ col("hour").as("hour", hour.toMetadata()),
+ col("count"), // "count" is an integer column without ML attribute
+ col("user").as("user", user.toMetadata()),
+ col("ad")) // "ad" is a vector column without ML attribute
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("browser", "hour", "count", "user", "ad"))
+ .setOutputCol("features")
+ val output = assembler.transform(df)
+ val schema = output.schema
+ val features = AttributeGroup.fromStructField(schema("features"))
+ assert(features.size === 7)
+ val browserOut = features.getAttr(0)
+ assert(browserOut === browser.withIndex(0).withName("browser"))
+ val hourOut = features.getAttr(1)
+ assert(hourOut === hour.withIndex(1).withName("hour"))
+ val countOut = features.getAttr(2)
+ assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
+ val userGenderOut = features.getAttr(3)
+ assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
+ val userSalaryOut = features.getAttr(4)
+ assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
+ assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
+ assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+ }
}