aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-09-17 14:09:06 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-17 14:09:06 -0700
commit4fbf3328692e876f39ea78494510f9d9c5a53f15 (patch)
tree7e9c0c83edd393a905453e3fbae8e8c87d8b41f3 /mllib
parentf1c911552cf5d0d60831c79c1881016293aec66c (diff)
downloadspark-4fbf3328692e876f39ea78494510f9d9c5a53f15.tar.gz
spark-4fbf3328692e876f39ea78494510f9d9c5a53f15.tar.bz2
spark-4fbf3328692e876f39ea78494510f9d9c5a53f15.zip
[SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions
This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang <ekl@databricks.com> Closes #7987 from ericl/interaction.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala278
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala165
2 files changed, 443 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
new file mode 100644
index 0000000000..9194763fb3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.Transformer
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Implements the feature interaction transform. This transformer takes in Double and Vector type
+ * columns and outputs a flattened vector of their feature interactions. To handle interaction,
+ * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is
+ * produced.
+ *
+ * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be
+ * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal
+ * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
+ */
+@Experimental
+class Interaction(override val uid: String) extends Transformer
+ with HasInputCols with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("interaction"))
+
+ /** @group setParam */
+ def setInputCols(values: Array[String]): this.type = set(inputCols, values)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ // optimistic schema; does not contain any ML attributes
+ override def transformSchema(schema: StructType): StructType = {
+ validateParams()
+ StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ validateParams()
+ val inputFeatures = $(inputCols).map(c => dataset.schema(c))
+ val featureEncoders = getFeatureEncoders(inputFeatures)
+ val featureAttrs = getFeatureAttrs(inputFeatures)
+
+ def interactFunc = udf { row: Row =>
+ var indices = ArrayBuilder.make[Int]
+ var values = ArrayBuilder.make[Double]
+ var size = 1
+ indices += 0
+ values += 1.0
+ var featureIndex = row.length - 1
+ while (featureIndex >= 0) {
+ val prevIndices = indices.result()
+ val prevValues = values.result()
+ val prevSize = size
+ val currentEncoder = featureEncoders(featureIndex)
+ indices = ArrayBuilder.make[Int]
+ values = ArrayBuilder.make[Double]
+ size *= currentEncoder.outputSize
+ currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => {
+ var j = 0
+ while (j < prevIndices.length) {
+ indices += prevIndices(j) + i * prevSize
+ values += prevValues(j) * a
+ j += 1
+ }
+ })
+ featureIndex -= 1
+ }
+ Vectors.sparse(size, indices.result(), values.result()).compressed
+ }
+
+ val featureCols = inputFeatures.map { f =>
+ f.dataType match {
+ case DoubleType => dataset(f.name)
+ case _: VectorUDT => dataset(f.name)
+ case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType)
+ }
+ }
+ dataset.select(
+ col("*"),
+ interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata()))
+ }
+
+ /**
+ * Creates a feature encoder for each input column, which supports efficient iteration over
+ * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]].
+ *
+ * @param features The input feature columns to create encoders for.
+ */
+ private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = {
+ def getNumFeatures(attr: Attribute): Int = {
+ attr match {
+ case nominal: NominalAttribute =>
+ math.max(1, nominal.getNumValues.getOrElse(
+ throw new SparkException("Nominal features must have attr numValues defined.")))
+ case _ =>
+ 1 // numeric feature
+ }
+ }
+ features.map { f =>
+ val numFeatures = f.dataType match {
+ case _: NumericType | BooleanType =>
+ Array(getNumFeatures(Attribute.fromStructField(f)))
+ case _: VectorUDT =>
+ val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse(
+ throw new SparkException("Vector attributes must be defined for interaction."))
+ attrs.map(getNumFeatures).toArray
+ }
+ new FeatureEncoder(numFeatures)
+ }.toArray
+ }
+
+ /**
+ * Generates ML attributes for the output vector of all feature interactions. We make a best
+ * effort to generate reasonable names for output features, based on the concatenation of the
+ * interacting feature names and values delimited with `_`. When no feature name is specified,
+ * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction
+ * between the numeric `foo` feature and a nominal third feature from column `bar`.
+ *
+ * @param features The input feature columns to the Interaction transformer.
+ */
+ private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = {
+ var featureAttrs: Seq[Attribute] = Nil
+ features.reverse.foreach { f =>
+ val encodedAttrs = f.dataType match {
+ case _: NumericType | BooleanType =>
+ val attr = Attribute.fromStructField(f)
+ encodedFeatureAttrs(Seq(attr), None)
+ case _: VectorUDT =>
+ val group = AttributeGroup.fromStructField(f)
+ encodedFeatureAttrs(group.attributes.get, Some(group.name))
+ }
+ if (featureAttrs.isEmpty) {
+ featureAttrs = encodedAttrs
+ } else {
+ featureAttrs = encodedAttrs.flatMap { head =>
+ featureAttrs.map { tail =>
+ NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get)
+ }
+ }
+ }
+ }
+ new AttributeGroup($(outputCol), featureAttrs.toArray)
+ }
+
+ /**
+ * Generates the output ML attributes for a single input feature. Each output feature name has
+ * up to three parts: the group name, feature name, and category name (for nominal features),
+ * each separated by an underscore.
+ *
+ * @param inputAttrs The attributes of the input feature.
+ * @param groupName Optional name of the input feature group (for Vector type features).
+ */
+ private def encodedFeatureAttrs(
+ inputAttrs: Seq[Attribute],
+ groupName: Option[String]): Seq[Attribute] = {
+
+ def format(
+ index: Int,
+ attrName: Option[String],
+ categoryName: Option[String]): String = {
+ val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName)
+ parts.flatten.mkString("_")
+ }
+
+ inputAttrs.zipWithIndex.flatMap {
+ case (nominal: NominalAttribute, i) =>
+ if (nominal.values.isDefined) {
+ nominal.values.get.map(
+ v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v))))
+ } else {
+ Array.tabulate(nominal.getNumValues.get)(
+ j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString))))
+ }
+ case (a: Attribute, i) =>
+ Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None)))
+ }
+ }
+
+ override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
+
+ override def validateParams(): Unit = {
+ require(get(inputCols).isDefined, "Input cols must be defined first.")
+ require(get(outputCol).isDefined, "Output col must be defined first.")
+ require($(inputCols).length > 0, "Input cols must have non-zero length.")
+ require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
+ }
+}
+
+/**
+ * This class performs on-the-fly one-hot encoding of features as you iterate over them. To
+ * indicate which input features should be one-hot encoded, an array of the feature counts
+ * must be passed in ahead of time.
+ *
+ * @param numFeatures Array of feature counts for each input feature. For nominal features this
+ * count is equal to the number of categories. For numeric features the count
+ * should be set to 1.
+ */
+private[ml] class FeatureEncoder(numFeatures: Array[Int]) {
+ assert(numFeatures.forall(_ > 0), "Features counts must all be positive.")
+
+ /** The size of the output vector. */
+ val outputSize = numFeatures.sum
+
+ /** Precomputed offsets for the location of each output feature. */
+ private val outputOffsets = {
+ val arr = new Array[Int](numFeatures.length)
+ var i = 1
+ while (i < arr.length) {
+ arr(i) = arr(i - 1) + numFeatures(i - 1)
+ i += 1
+ }
+ arr
+ }
+
+ /**
+ * Given an input row of features, invokes the specific function for every non-zero output.
+ *
+ * @param value The row value to encode, either a Double or Vector.
+ * @param f The callback to invoke on each non-zero (index, value) output pair.
+ */
+ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match {
+ case d: Double =>
+ assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.")
+ val numOutputCols = numFeatures.head
+ if (numOutputCols > 1) {
+ assert(
+ d >= 0.0 && d == d.toInt && d < numOutputCols,
+ s"Values from column must be indices, but got $d.")
+ f(d.toInt, 1.0)
+ } else {
+ f(0, d)
+ }
+ case vec: Vector =>
+ assert(numFeatures.length == vec.size,
+ s"Vector column size was ${vec.size}, expected ${numFeatures.length}")
+ vec.foreachActive { (i, v) =>
+ val numOutputCols = numFeatures(i)
+ if (numOutputCols > 1) {
+ assert(
+ v >= 0.0 && v == v.toInt && v < numOutputCols,
+ s"Values from column must be indices, but got $v.")
+ f(outputOffsets(i) + v.toInt, 1.0)
+ } else {
+ f(outputOffsets(i), v)
+ }
+ }
+ case null =>
+ throw new SparkException("Values to interact cannot be null.")
+ case o =>
+ throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
new file mode 100644
index 0000000000..2beb62ca08
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.functions.col
+
+class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Interaction())
+ }
+
+ test("feature encoder") {
+ def encode(cardinalities: Array[Int], value: Any): Vector = {
+ var indices = ArrayBuilder.make[Int]
+ var values = ArrayBuilder.make[Double]
+ val encoder = new FeatureEncoder(cardinalities)
+ encoder.foreachNonzeroOutput(value, (i, v) => {
+ indices += i
+ values += v
+ })
+ Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed
+ }
+ assert(encode(Array(1), 2.2) === Vectors.dense(2.2))
+ assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0))
+ assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2))
+ assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2))
+ assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2))
+ assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0))
+ intercept[SparkException] { encode(Array(1), "foo") }
+ intercept[SparkException] { encode(Array(1), null) }
+ intercept[AssertionError] { encode(Array(2), 2.2) }
+ intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) }
+ intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) }
+ intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) }
+ intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) }
+ }
+
+ test("numeric interaction") {
+ val data = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(3.0, 4.0)),
+ (1, Vectors.dense(1.0, 5.0)))
+ ).toDF("a", "b")
+ val groupAttr = new AttributeGroup(
+ "b",
+ Array[Attribute](
+ NumericAttribute.defaultAttr.withName("foo"),
+ NumericAttribute.defaultAttr.withName("bar")))
+ val df = data.select(
+ col("a").as("a", NumericAttribute.defaultAttr.toMetadata()),
+ col("b").as("b", groupAttr.toMetadata()))
+ val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
+ val res = trans.transform(df)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
+ (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
+ ).toDF("a", "b", "features")
+ assert(res.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(res.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a:b_foo"), Some(1)),
+ new NumericAttribute(Some("a:b_bar"), Some(2))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("nominal interaction") {
+ val data = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(3.0, 4.0)),
+ (1, Vectors.dense(1.0, 5.0)))
+ ).toDF("a", "b")
+ val groupAttr = new AttributeGroup(
+ "b",
+ Array[Attribute](
+ NumericAttribute.defaultAttr.withName("foo"),
+ NumericAttribute.defaultAttr.withName("bar")))
+ val df = data.select(
+ col("a").as(
+ "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()),
+ col("b").as("b", groupAttr.toMetadata()))
+ val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
+ val res = trans.transform(df)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
+ (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
+ ).toDF("a", "b", "features")
+ assert(res.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(res.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_up:b_foo"), Some(1)),
+ new NumericAttribute(Some("a_up:b_bar"), Some(2)),
+ new NumericAttribute(Some("a_down:b_foo"), Some(3)),
+ new NumericAttribute(Some("a_down:b_bar"), Some(4)),
+ new NumericAttribute(Some("a_left:b_foo"), Some(5)),
+ new NumericAttribute(Some("a_left:b_bar"), Some(6))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("default attr names") {
+ val data = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(0.0, 4.0), 1.0),
+ (1, Vectors.dense(1.0, 5.0), 10.0))
+ ).toDF("a", "b", "c")
+ val groupAttr = new AttributeGroup(
+ "b",
+ Array[Attribute](
+ NominalAttribute.defaultAttr.withNumValues(2),
+ NumericAttribute.defaultAttr))
+ val df = data.select(
+ col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()),
+ col("b").as("b", groupAttr.toMetadata()),
+ col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
+ val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
+ val res = trans.transform(df)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
+ (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))
+ ).toDF("a", "b", "c", "features")
+ assert(res.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(res.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)),
+ new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)),
+ new NumericAttribute(Some("a_0:b_1:c"), Some(3)),
+ new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)),
+ new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)),
+ new NumericAttribute(Some("a_1:b_1:c"), Some(6)),
+ new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)),
+ new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)),
+ new NumericAttribute(Some("a_2:b_1:c"), Some(9))))
+ assert(attrs === expectedAttrs)
+ }
+}