aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-12 16:34:56 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-12 16:34:56 -0700
commita4b27162f2d7cb501f71d818581c8a2471bb7cf6 (patch)
tree84821bd07967770a50e921dd3fef73c596b0b407
parentfb4787c9531be5dd9e512e79ff4ff45d24eb370d (diff)
downloadspark-a4b27162f2d7cb501f71d818581c8a2471bb7cf6.tar.gz
spark-a4b27162f2d7cb501f71d818581c8a2471bb7cf6.tar.bz2
spark-a4b27162f2d7cb501f71d818581c8a2471bb7cf6.zip
[SPARK-4588] ML Attributes
This continues the work in #4460 from srowen . The design doc is published on the JIRA page with some minor changes. Short description of ML attributes: https://github.com/apache/spark/pull/4925/files?diff=unified#diff-95e7f5060429f189460b44a3f8731a35R24 More details can be found in the design doc. srowen Could you help review this PR? There are many lines but most of them are boilerplate code. Author: Xiangrui Meng <meng@databricks.com> Author: Sean Owen <sowen@cloudera.com> Closes #4925 from mengxr/SPARK-4588-new and squashes the following commits: 71d1bd0 [Xiangrui Meng] add JavaDoc for package ml.attribute 617be40 [Xiangrui Meng] remove final; rename cardinality to numValues 393ffdc [Xiangrui Meng] forgot to include Java attribute group tests b1aceef [Xiangrui Meng] more tests e7ab467 [Xiangrui Meng] update ML attribute impl 7c944da [Sean Owen] Add FeatureType hierarchy and categorical cardinality 2a21d6d [Sean Owen] Initial draft of FeatureAttributes class
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala234
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala61
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala512
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java41
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala44
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java45
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala65
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala212
-rw-r--r--project/SparkBuild.scala3
11 files changed, 1308 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
new file mode 100644
index 0000000000..970e6ad551
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
+
+/**
+ * Attributes that describe a vector ML column.
+ *
+ * @param name name of the attribute group (the ML column name)
+ * @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs`
+ * can be defined.
+ * @param attrs optional array of attributes. Attribute will be copied with their corresponding
+ * indices in the array.
+ */
+class AttributeGroup private (
+ val name: String,
+ val numAttributes: Option[Int],
+ attrs: Option[Array[Attribute]]) extends Serializable {
+
+ require(name.nonEmpty, "Cannot have an empty string for name.")
+ require(!(numAttributes.isDefined && attrs.isDefined),
+ "Cannot have both numAttributes and attrs defined.")
+
+ /**
+ * Creates an attribute group without attribute info.
+ * @param name name of the attribute group
+ */
+ def this(name: String) = this(name, None, None)
+
+ /**
+ * Creates an attribute group knowing only the number of attributes.
+ * @param name name of the attribute group
+ * @param numAttributes number of attributes
+ */
+ def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None)
+
+ /**
+ * Creates an attribute group with attributes.
+ * @param name name of the attribute group
+ * @param attrs array of attributes. Attributes will be copied with their corresponding indices in
+ * the array.
+ */
+ def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs))
+
+ /**
+ * Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined.
+ */
+ val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) =>
+ attr.withIndex(i)
+ }.toArray)
+
+ private lazy val nameToIndex: Map[String, Int] = {
+ attributes.map(_.view.flatMap { attr =>
+ attr.name.map(_ -> attr.index.get)
+ }.toMap).getOrElse(Map.empty)
+ }
+
+ /** Size of the attribute group. Returns -1 if the size is unknown. */
+ def size: Int = {
+ if (numAttributes.isDefined) {
+ numAttributes.get
+ } else if (attributes.isDefined) {
+ attributes.get.length
+ } else {
+ -1
+ }
+ }
+
+ /** Test whether this attribute group contains a specific attribute. */
+ def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName)
+
+ /** Index of an attribute specified by name. */
+ def indexOf(attrName: String): Int = nameToIndex(attrName)
+
+ /** Gets an attribute by its name. */
+ def apply(attrName: String): Attribute = {
+ attributes.get(indexOf(attrName))
+ }
+
+ /** Gets an attribute by its name. */
+ def getAttr(attrName: String): Attribute = this(attrName)
+
+ /** Gets an attribute by its index. */
+ def apply(attrIndex: Int): Attribute = attributes.get(attrIndex)
+
+ /** Gets an attribute by its index. */
+ def getAttr(attrIndex: Int): Attribute = this(attrIndex)
+
+ /** Converts to metadata without name. */
+ private[attribute] def toMetadata: Metadata = {
+ import AttributeKeys._
+ val bldr = new MetadataBuilder()
+ if (attributes.isDefined) {
+ val numericMetadata = ArrayBuffer.empty[Metadata]
+ val nominalMetadata = ArrayBuffer.empty[Metadata]
+ val binaryMetadata = ArrayBuffer.empty[Metadata]
+ attributes.get.foreach {
+ case numeric: NumericAttribute =>
+ // Skip default numeric attributes.
+ if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
+ numericMetadata += numeric.toMetadata(withType = false)
+ }
+ case nominal: NominalAttribute =>
+ nominalMetadata += nominal.toMetadata(withType = false)
+ case binary: BinaryAttribute =>
+ binaryMetadata += binary.toMetadata(withType = false)
+ }
+ val attrBldr = new MetadataBuilder
+ if (numericMetadata.nonEmpty) {
+ attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray)
+ }
+ if (nominalMetadata.nonEmpty) {
+ attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray)
+ }
+ if (binaryMetadata.nonEmpty) {
+ attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray)
+ }
+ bldr.putMetadata(ATTRIBUTES, attrBldr.build())
+ bldr.putLong(NUM_ATTRIBUTES, attributes.get.length)
+ } else if (numAttributes.isDefined) {
+ bldr.putLong(NUM_ATTRIBUTES, numAttributes.get)
+ }
+ bldr.build()
+ }
+
+ /** Converts to a StructField with some existing metadata. */
+ def toStructField(existingMetadata: Metadata): StructField = {
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadata)
+ .build()
+ StructField(name, new VectorUDT, nullable = false, newMetadata)
+ }
+
+ /** Converts to a StructField. */
+ def toStructField(): StructField = toStructField(Metadata.empty)
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: AttributeGroup =>
+ (name == o.name) &&
+ (numAttributes == o.numAttributes) &&
+ (attributes.map(_.toSeq) == o.attributes.map(_.toSeq))
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode: Int = {
+ var sum = 17
+ sum = 37 * sum + name.hashCode
+ sum = 37 * sum + numAttributes.hashCode
+ sum = 37 * sum + attributes.map(_.toSeq).hashCode
+ sum
+ }
+}
+
+/** Factory methods to create attribute groups. */
+object AttributeGroup {
+
+ import AttributeKeys._
+
+ /** Creates an attribute group from a [[Metadata]] instance with name. */
+ private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = {
+ import org.apache.spark.ml.attribute.AttributeType._
+ if (metadata.contains(ATTRIBUTES)) {
+ val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt
+ val attributes = new Array[Attribute](numAttrs)
+ val attrMetadata = metadata.getMetadata(ATTRIBUTES)
+ if (attrMetadata.contains(Numeric.name)) {
+ attrMetadata.getMetadataArray(Numeric.name)
+ .map(NumericAttribute.fromMetadata)
+ .foreach { attr =>
+ attributes(attr.index.get) = attr
+ }
+ }
+ if (attrMetadata.contains(Nominal.name)) {
+ attrMetadata.getMetadataArray(Nominal.name)
+ .map(NominalAttribute.fromMetadata)
+ .foreach { attr =>
+ attributes(attr.index.get) = attr
+ }
+ }
+ if (attrMetadata.contains(Binary.name)) {
+ attrMetadata.getMetadataArray(Binary.name)
+ .map(BinaryAttribute.fromMetadata)
+ .foreach { attr =>
+ attributes(attr.index.get) = attr
+ }
+ }
+ var i = 0
+ while (i < numAttrs) {
+ if (attributes(i) == null) {
+ attributes(i) = NumericAttribute.defaultAttr
+ }
+ i += 1
+ }
+ new AttributeGroup(name, attributes)
+ } else if (metadata.contains(NUM_ATTRIBUTES)) {
+ new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES).toInt)
+ } else {
+ new AttributeGroup(name)
+ }
+ }
+
+ /** Creates an attribute group from a [[StructField]] instance. */
+ def fromStructField(field: StructField): AttributeGroup = {
+ require(field.dataType == new VectorUDT)
+ if (field.metadata.contains(ML_ATTR)) {
+ fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name)
+ } else {
+ new AttributeGroup(field.name)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala
new file mode 100644
index 0000000000..f714f7becc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+/**
+ * Keys used to store attributes.
+ */
+private[attribute] object AttributeKeys {
+ val ML_ATTR: String = "ml_attr"
+ val TYPE: String = "type"
+ val NAME: String = "name"
+ val INDEX: String = "idx"
+ val MIN: String = "min"
+ val MAX: String = "max"
+ val STD: String = "std"
+ val SPARSITY: String = "sparsity"
+ val ORDINAL: String = "ord"
+ val VALUES: String = "vals"
+ val NUM_VALUES: String = "num_vals"
+ val ATTRIBUTES: String = "attrs"
+ val NUM_ATTRIBUTES: String = "num_attrs"
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
new file mode 100644
index 0000000000..65e7e43d5a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+/**
+ * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]],
+ * and [[AttributeType$#Binary]].
+ */
+sealed abstract class AttributeType(val name: String)
+
+object AttributeType {
+
+ /** Numeric type. */
+ val Numeric: AttributeType = {
+ case object Numeric extends AttributeType("numeric")
+ Numeric
+ }
+
+ /** Nominal type. */
+ val Nominal: AttributeType = {
+ case object Nominal extends AttributeType("nominal")
+ Nominal
+ }
+
+ /** Binary type. */
+ val Binary: AttributeType = {
+ case object Binary extends AttributeType("binary")
+ Binary
+ }
+
+ /**
+ * Gets the [[AttributeType]] object from its name.
+ * @param name attribute type name: "numeric", "nominal", or "binary"
+ */
+ def fromName(name: String): AttributeType = {
+ if (name == Numeric.name) {
+ Numeric
+ } else if (name == Nominal.name) {
+ Nominal
+ } else if (name == Binary.name) {
+ Binary
+ } else {
+ throw new IllegalArgumentException(s"Cannot recognize type $name.")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
new file mode 100644
index 0000000000..00b7566aab
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -0,0 +1,512 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+import scala.annotation.varargs
+
+import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
+
+/**
+ * Abstract class for ML attributes.
+ */
+sealed abstract class Attribute extends Serializable {
+
+ name.foreach { n =>
+ require(n.nonEmpty, "Cannot have an empty string for name.")
+ }
+ index.foreach { i =>
+ require(i >= 0, s"Index cannot be negative but got $i")
+ }
+
+ /** Attribute type. */
+ def attrType: AttributeType
+
+ /** Name of the attribute. None if it is not set. */
+ def name: Option[String]
+
+ /** Copy with a new name. */
+ def withName(name: String): Attribute
+
+ /** Copy without the name. */
+ def withoutName: Attribute
+
+ /** Index of the attribute. None if it is not set. */
+ def index: Option[Int]
+
+ /** Copy with a new index. */
+ def withIndex(index: Int): Attribute
+
+ /** Copy without the index. */
+ def withoutIndex: Attribute
+
+ /**
+ * Tests whether this attribute is numeric, true for [[NumericAttribute]] and [[BinaryAttribute]].
+ */
+ def isNumeric: Boolean
+
+ /**
+ * Tests whether this attribute is nominal, true for [[NominalAttribute]] and [[BinaryAttribute]].
+ */
+ def isNominal: Boolean
+
+ /**
+ * Converts this attribute to [[Metadata]].
+ * @param withType whether to include the type info
+ */
+ private[attribute] def toMetadata(withType: Boolean): Metadata
+
+ /**
+ * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
+ * save space, because numeric type is the default attribute type. For nominal and binary
+ * attributes, the type info is included.
+ */
+ private[attribute] def toMetadata(): Metadata = {
+ if (attrType == AttributeType.Numeric) {
+ toMetadata(withType = false)
+ } else {
+ toMetadata(withType = true)
+ }
+ }
+
+ /**
+ * Converts to a [[StructField]] with some existing metadata.
+ * @param existingMetadata existing metadata to carry over
+ */
+ def toStructField(existingMetadata: Metadata): StructField = {
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+ .build()
+ StructField(name.get, DoubleType, nullable = false, newMetadata)
+ }
+
+ /** Converts to a [[StructField]]. */
+ def toStructField(): StructField = toStructField(Metadata.empty)
+
+ override def toString: String = toMetadata(withType = true).toString
+}
+
+/** Trait for ML attribute factories. */
+private[attribute] trait AttributeFactory {
+
+ /**
+ * Creates an [[Attribute]] from a [[Metadata]] instance.
+ */
+ private[attribute] def fromMetadata(metadata: Metadata): Attribute
+
+ /**
+ * Creates an [[Attribute]] from a [[StructField]] instance.
+ */
+ def fromStructField(field: StructField): Attribute = {
+ require(field.dataType == DoubleType)
+ fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name)
+ }
+}
+
+object Attribute extends AttributeFactory {
+
+ private[attribute] override def fromMetadata(metadata: Metadata): Attribute = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val attrType = if (metadata.contains(TYPE)) {
+ metadata.getString(TYPE)
+ } else {
+ AttributeType.Numeric.name
+ }
+ getFactory(attrType).fromMetadata(metadata)
+ }
+
+ /** Gets the attribute factory given the attribute type name. */
+ private def getFactory(attrType: String): AttributeFactory = {
+ if (attrType == AttributeType.Numeric.name) {
+ NumericAttribute
+ } else if (attrType == AttributeType.Nominal.name) {
+ NominalAttribute
+ } else if (attrType == AttributeType.Binary.name) {
+ BinaryAttribute
+ } else {
+ throw new IllegalArgumentException(s"Cannot recognize type $attrType.")
+ }
+ }
+}
+
+
+/**
+ * A numeric attribute with optional summary statistics.
+ * @param name optional name
+ * @param index optional index
+ * @param min optional min value
+ * @param max optional max value
+ * @param std optional standard deviation
+ * @param sparsity optional sparsity (ratio of zeros)
+ */
+class NumericAttribute private[ml] (
+ override val name: Option[String] = None,
+ override val index: Option[Int] = None,
+ val min: Option[Double] = None,
+ val max: Option[Double] = None,
+ val std: Option[Double] = None,
+ val sparsity: Option[Double] = None) extends Attribute {
+
+ std.foreach { s =>
+ require(s >= 0.0, s"Standard deviation cannot be negative but got $s.")
+ }
+ sparsity.foreach { s =>
+ require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.")
+ }
+
+ override def attrType: AttributeType = AttributeType.Numeric
+
+ override def withName(name: String): NumericAttribute = copy(name = Some(name))
+ override def withoutName: NumericAttribute = copy(name = None)
+
+ override def withIndex(index: Int): NumericAttribute = copy(index = Some(index))
+ override def withoutIndex: NumericAttribute = copy(index = None)
+
+ /** Copy with a new min value. */
+ def withMin(min: Double): NumericAttribute = copy(min = Some(min))
+
+ /** Copy without the min value. */
+ def withoutMin: NumericAttribute = copy(min = None)
+
+
+ /** Copy with a new max value. */
+ def withMax(max: Double): NumericAttribute = copy(max = Some(max))
+
+ /** Copy without the max value. */
+ def withoutMax: NumericAttribute = copy(max = None)
+
+ /** Copy with a new standard deviation. */
+ def withStd(std: Double): NumericAttribute = copy(std = Some(std))
+
+ /** Copy without the standard deviation. */
+ def withoutStd: NumericAttribute = copy(std = None)
+
+ /** Copy with a new sparsity. */
+ def withSparsity(sparsity: Double): NumericAttribute = copy(sparsity = Some(sparsity))
+
+ /** Copy without the sparsity. */
+ def withoutSparsity: NumericAttribute = copy(sparsity = None)
+
+ /** Copy without summary statistics. */
+ def withoutSummary: NumericAttribute = copy(min = None, max = None, std = None, sparsity = None)
+
+ override def isNumeric: Boolean = true
+
+ override def isNominal: Boolean = false
+
+ /** Convert this attribute to metadata. */
+ private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val bldr = new MetadataBuilder()
+ if (withType) bldr.putString(TYPE, attrType.name)
+ name.foreach(bldr.putString(NAME, _))
+ index.foreach(bldr.putLong(INDEX, _))
+ min.foreach(bldr.putDouble(MIN, _))
+ max.foreach(bldr.putDouble(MAX, _))
+ std.foreach(bldr.putDouble(STD, _))
+ sparsity.foreach(bldr.putDouble(SPARSITY, _))
+ bldr.build()
+ }
+
+ /** Creates a copy of this attribute with optional changes. */
+ private def copy(
+ name: Option[String] = name,
+ index: Option[Int] = index,
+ min: Option[Double] = min,
+ max: Option[Double] = max,
+ std: Option[Double] = std,
+ sparsity: Option[Double] = sparsity): NumericAttribute = {
+ new NumericAttribute(name, index, min, max, std, sparsity)
+ }
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: NumericAttribute =>
+ (name == o.name) &&
+ (index == o.index) &&
+ (min == o.min) &&
+ (max == o.max) &&
+ (std == o.std) &&
+ (sparsity == o.sparsity)
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode: Int = {
+ var sum = 17
+ sum = 37 * sum + name.hashCode
+ sum = 37 * sum + index.hashCode
+ sum = 37 * sum + min.hashCode
+ sum = 37 * sum + max.hashCode
+ sum = 37 * sum + std.hashCode
+ sum = 37 * sum + sparsity.hashCode
+ sum
+ }
+}
+
+/**
+ * Factory methods for numeric attributes.
+ */
+object NumericAttribute extends AttributeFactory {
+
+ /** The default numeric attribute. */
+ val defaultAttr: NumericAttribute = new NumericAttribute
+
+ private[attribute] override def fromMetadata(metadata: Metadata): NumericAttribute = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
+ val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
+ val min = if (metadata.contains(MIN)) Some(metadata.getDouble(MIN)) else None
+ val max = if (metadata.contains(MAX)) Some(metadata.getDouble(MAX)) else None
+ val std = if (metadata.contains(STD)) Some(metadata.getDouble(STD)) else None
+ val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None
+ new NumericAttribute(name, index, min, max, std, sparsity)
+ }
+}
+
+/**
+ * A nominal attribute.
+ * @param name optional name
+ * @param index optional index
+ * @param isOrdinal whether this attribute is ordinal (optional)
+ * @param numValues optional number of values. At most one of `numValues` and `values` can be
+ * defined.
+ * @param values optional values. At most one of `numValues` and `values` can be defined.
+ */
+class NominalAttribute private[ml] (
+ override val name: Option[String] = None,
+ override val index: Option[Int] = None,
+ val isOrdinal: Option[Boolean] = None,
+ val numValues: Option[Int] = None,
+ val values: Option[Array[String]] = None) extends Attribute {
+
+ numValues.foreach { n =>
+ require(n >= 0, s"numValues cannot be negative but got $n.")
+ }
+ require(!(numValues.isDefined && values.isDefined),
+ "Cannot have both numValues and values defined.")
+
+ override def attrType: AttributeType = AttributeType.Nominal
+
+ override def isNumeric: Boolean = false
+
+ override def isNominal: Boolean = true
+
+ private lazy val valueToIndex: Map[String, Int] = {
+ values.map(_.zipWithIndex.toMap).getOrElse(Map.empty)
+ }
+
+ /** Index of a specific value. */
+ def indexOf(value: String): Int = {
+ valueToIndex(value)
+ }
+
+ /** Tests whether this attribute contains a specific value. */
+ def hasValue(value: String): Boolean = valueToIndex.contains(value)
+
+ /** Gets a value given its index. */
+ def getValue(index: Int): String = values.get(index)
+
+ override def withName(name: String): NominalAttribute = copy(name = Some(name))
+ override def withoutName: NominalAttribute = copy(name = None)
+
+ override def withIndex(index: Int): NominalAttribute = copy(index = Some(index))
+ override def withoutIndex: NominalAttribute = copy(index = None)
+
+ /** Copy with new values and empty `numValues`. */
+ def withValues(values: Array[String]): NominalAttribute = {
+ copy(numValues = None, values = Some(values))
+ }
+
+ /** Copy with new values and empty `numValues`. */
+ @varargs
+ def withValues(first: String, others: String*): NominalAttribute = {
+ copy(numValues = None, values = Some((first +: others).toArray))
+ }
+
+ /** Copy without the values. */
+ def withoutValues: NominalAttribute = {
+ copy(values = None)
+ }
+
+ /** Copy with a new `numValues` and empty `values`. */
+ def withNumValues(numValues: Int): NominalAttribute = {
+ copy(numValues = Some(numValues), values = None)
+ }
+
+ /** Copy without the `numValues`. */
+ def withoutNumValues: NominalAttribute = copy(numValues = None)
+
+ /** Creates a copy of this attribute with optional changes. */
+ private def copy(
+ name: Option[String] = name,
+ index: Option[Int] = index,
+ isOrdinal: Option[Boolean] = isOrdinal,
+ numValues: Option[Int] = numValues,
+ values: Option[Array[String]] = values): NominalAttribute = {
+ new NominalAttribute(name, index, isOrdinal, numValues, values)
+ }
+
+ private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val bldr = new MetadataBuilder()
+ if (withType) bldr.putString(TYPE, attrType.name)
+ name.foreach(bldr.putString(NAME, _))
+ index.foreach(bldr.putLong(INDEX, _))
+ isOrdinal.foreach(bldr.putBoolean(ORDINAL, _))
+ numValues.foreach(bldr.putLong(NUM_VALUES, _))
+ values.foreach(v => bldr.putStringArray(VALUES, v))
+ bldr.build()
+ }
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: NominalAttribute =>
+ (name == o.name) &&
+ (index == o.index) &&
+ (isOrdinal == o.isOrdinal) &&
+ (numValues == o.numValues) &&
+ (values.map(_.toSeq) == o.values.map(_.toSeq))
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode: Int = {
+ var sum = 17
+ sum = 37 * sum + name.hashCode
+ sum = 37 * sum + index.hashCode
+ sum = 37 * sum + isOrdinal.hashCode
+ sum = 37 * sum + numValues.hashCode
+ sum = 37 * sum + values.map(_.toSeq).hashCode
+ sum
+ }
+}
+
+/** Factory methods for nominal attributes. */
+object NominalAttribute extends AttributeFactory {
+
+ /** The default nominal attribute. */
+ final val defaultAttr: NominalAttribute = new NominalAttribute
+
+ private[attribute] override def fromMetadata(metadata: Metadata): NominalAttribute = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
+ val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
+ val isOrdinal = if (metadata.contains(ORDINAL)) Some(metadata.getBoolean(ORDINAL)) else None
+ val numValues =
+ if (metadata.contains(NUM_VALUES)) Some(metadata.getLong(NUM_VALUES).toInt) else None
+ val values =
+ if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None
+ new NominalAttribute(name, index, isOrdinal, numValues, values)
+ }
+}
+
+/**
+ * A binary attribute.
+ * @param name optional name
+ * @param index optional index
+ * @param values optionla values. If set, its size must be 2.
+ */
+class BinaryAttribute private[ml] (
+ override val name: Option[String] = None,
+ override val index: Option[Int] = None,
+ val values: Option[Array[String]] = None)
+ extends Attribute {
+
+ values.foreach { v =>
+ require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.")
+ }
+
+ override def attrType: AttributeType = AttributeType.Binary
+
+ override def isNumeric: Boolean = true
+
+ override def isNominal: Boolean = true
+
+ override def withName(name: String): BinaryAttribute = copy(name = Some(name))
+ override def withoutName: BinaryAttribute = copy(name = None)
+
+ override def withIndex(index: Int): BinaryAttribute = copy(index = Some(index))
+ override def withoutIndex: BinaryAttribute = copy(index = None)
+
+ /**
+ * Copy with new values.
+ * @param negative name for negative
+ * @param positive name for positive
+ */
+ def withValues(negative: String, positive: String): BinaryAttribute =
+ copy(values = Some(Array(negative, positive)))
+
+ /** Copy without the values. */
+ def withoutValues: BinaryAttribute = copy(values = None)
+
+ /** Creates a copy of this attribute with optional changes. */
+ private def copy(
+ name: Option[String] = name,
+ index: Option[Int] = index,
+ values: Option[Array[String]] = values): BinaryAttribute = {
+ new BinaryAttribute(name, index, values)
+ }
+
+ private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val bldr = new MetadataBuilder
+ if (withType) bldr.putString(TYPE, attrType.name)
+ name.foreach(bldr.putString(NAME, _))
+ index.foreach(bldr.putLong(INDEX, _))
+ values.foreach(v => bldr.putStringArray(VALUES, v))
+ bldr.build()
+ }
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: BinaryAttribute =>
+ (name == o.name) &&
+ (index == o.index) &&
+ (values.map(_.toSeq) == o.values.map(_.toSeq))
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode: Int = {
+ var sum = 17
+ sum = 37 * sum + name.hashCode
+ sum = 37 * sum + index.hashCode
+ sum = 37 * sum + values.map(_.toSeq).hashCode
+ sum
+ }
+}
+
+/** Factory methods for binary attributes. */
+object BinaryAttribute extends AttributeFactory {
+
+ /** The default binary attribute. */
+ final val defaultAttr: BinaryAttribute = new BinaryAttribute
+
+ private[attribute] override def fromMetadata(metadata: Metadata): BinaryAttribute = {
+ import org.apache.spark.ml.attribute.AttributeKeys._
+ val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None
+ val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None
+ val values =
+ if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None
+ new BinaryAttribute(name, index, values)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java
new file mode 100644
index 0000000000..e3474f3c1d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// The content here should be in sync with `package.scala`.
+
+/**
+ * <h2>ML attributes</h2>
+ *
+ * The ML pipeline API uses {@link org.apache.spark.sql.DataFrame}s as ML datasets.
+ * Each dataset consists of typed columns, e.g., string, double, vector, etc.
+ * However, knowing only the column type may not be sufficient to handle the data properly.
+ * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices,
+ * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may
+ * want to know the names and types of features stored in a vector column.
+ * ML attributes are used to provide additional information to describe columns in a dataset.
+ *
+ * <h3>ML columns</h3>
+ *
+ * A column with ML attributes attached is called an ML column.
+ * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column
+ * of double values or a vector column.
+ * Columns of other types must be encoded into ML columns using transformers.
+ * We use {@link org.apache.spark.ml.attribute.Attribute} to describe a scalar ML column, and
+ * {@link org.apache.spark.ml.attribute.AttributeGroup} to describe a vector ML column.
+ * ML attributes are stored in the metadata field of the column schema.
+ */
+package org.apache.spark.ml.attribute;
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala
new file mode 100644
index 0000000000..7ac21d7d56
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
+
+/**
+ * ==ML attributes==
+ *
+ * The ML pipeline API uses [[DataFrame]]s as ML datasets.
+ * Each dataset consists of typed columns, e.g., string, double, vector, etc.
+ * However, knowing only the column type may not be sufficient to handle the data properly.
+ * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices,
+ * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may
+ * want to know the names and types of features stored in a vector column.
+ * ML attributes are used to provide additional information to describe columns in a dataset.
+ *
+ * ===ML columns===
+ *
+ * A column with ML attributes attached is called an ML column.
+ * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column
+ * of double values or a vector column.
+ * Columns of other types must be encoded into ML columns using transformers.
+ * We use [[Attribute]] to describe a scalar ML column, and [[AttributeGroup]] to describe a vector
+ * ML column.
+ * ML attributes are stored in the metadata field of the column schema.
+ */
+package object attribute
diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java
new file mode 100644
index 0000000000..38eb58673a
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class JavaAttributeGroupSuite {
+
+ @Test
+ public void testAttributeGroup() {
+ Attribute[] attrs = new Attribute[]{
+ NumericAttribute.defaultAttr(),
+ NominalAttribute.defaultAttr(),
+ BinaryAttribute.defaultAttr().withIndex(0),
+ NumericAttribute.defaultAttr().withName("age").withSparsity(0.8),
+ NominalAttribute.defaultAttr().withName("size").withValues("small", "medium", "large"),
+ BinaryAttribute.defaultAttr().withName("clicked").withValues("no", "yes"),
+ NumericAttribute.defaultAttr(),
+ NumericAttribute.defaultAttr()
+ };
+ AttributeGroup group = new AttributeGroup("user", attrs);
+ Assert.assertEquals(8, group.size());
+ Assert.assertEquals("user", group.name());
+ Assert.assertEquals(NumericAttribute.defaultAttr().withIndex(0), group.getAttr(0));
+ Assert.assertEquals(3, group.indexOf("age"));
+ Assert.assertFalse(group.hasAttr("abc"));
+ Assert.assertEquals(group, AttributeGroup.fromStructField(group.toStructField()));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
new file mode 100644
index 0000000000..b74bbed231
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute;
+
+import org.junit.Test;
+import org.junit.Assert;
+
+public class JavaAttributeSuite {
+
+ @Test
+ public void testAttributeType() {
+ AttributeType numericType = AttributeType.Numeric();
+ AttributeType nominalType = AttributeType.Nominal();
+ AttributeType binaryType = AttributeType.Binary();
+ Assert.assertEquals(numericType, NumericAttribute.defaultAttr().attrType());
+ Assert.assertEquals(nominalType, NominalAttribute.defaultAttr().attrType());
+ Assert.assertEquals(binaryType, BinaryAttribute.defaultAttr().attrType());
+ }
+
+ @Test
+ public void testNumericAttribute() {
+ NumericAttribute attr = NumericAttribute.defaultAttr()
+ .withName("age").withIndex(0).withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.4);
+ Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField()));
+ }
+
+ @Test
+ public void testNominalAttribute() {
+ NominalAttribute attr = NominalAttribute.defaultAttr()
+ .withName("size").withIndex(1).withValues("small", "medium", "large");
+ Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField()));
+ }
+
+ @Test
+ public void testBinaryAttribute() {
+ BinaryAttribute attr = BinaryAttribute.defaultAttr()
+ .withName("clicked").withIndex(2).withValues("no", "yes");
+ Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField()));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
new file mode 100644
index 0000000000..3fb6e2ec46
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+import org.scalatest.FunSuite
+
+class AttributeGroupSuite extends FunSuite {
+
+ test("attribute group") {
+ val attrs = Array(
+ NumericAttribute.defaultAttr,
+ NominalAttribute.defaultAttr,
+ BinaryAttribute.defaultAttr.withIndex(0),
+ NumericAttribute.defaultAttr.withName("age").withSparsity(0.8),
+ NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"),
+ BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"),
+ NumericAttribute.defaultAttr,
+ NumericAttribute.defaultAttr)
+ val group = new AttributeGroup("user", attrs)
+ assert(group.size === 8)
+ assert(group.name === "user")
+ assert(group(0) === NumericAttribute.defaultAttr.withIndex(0))
+ assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2))
+ assert(group.indexOf("age") === 3)
+ assert(group.indexOf("size") === 4)
+ assert(group.indexOf("clicked") === 5)
+ assert(!group.hasAttr("abc"))
+ intercept[NoSuchElementException] {
+ group("abc")
+ }
+ assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name))
+ assert(group === AttributeGroup.fromStructField(group.toStructField()))
+ }
+
+ test("attribute group without attributes") {
+ val group0 = new AttributeGroup("user", 10)
+ assert(group0.name === "user")
+ assert(group0.numAttributes === Some(10))
+ assert(group0.size === 10)
+ assert(group0.attributes.isEmpty)
+ assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
+
+ val group1 = new AttributeGroup("item")
+ assert(group1.name === "item")
+ assert(group1.numAttributes.isEmpty)
+ assert(group1.attributes.isEmpty)
+ assert(group1.size === -1)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
new file mode 100644
index 0000000000..6ec35b0365
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.attribute
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata}
+
+class AttributeSuite extends FunSuite {
+
+ test("default numeric attribute") {
+ val attr: NumericAttribute = NumericAttribute.defaultAttr
+ val metadata = Metadata.fromJson("{}")
+ val metadataWithType = Metadata.fromJson("""{"type":"numeric"}""")
+ assert(attr.attrType === AttributeType.Numeric)
+ assert(attr.isNumeric)
+ assert(!attr.isNominal)
+ assert(attr.name.isEmpty)
+ assert(attr.index.isEmpty)
+ assert(attr.min.isEmpty)
+ assert(attr.max.isEmpty)
+ assert(attr.std.isEmpty)
+ assert(attr.sparsity.isEmpty)
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = false) === metadata)
+ assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === Attribute.fromMetadata(metadataWithType))
+ intercept[NoSuchElementException] {
+ attr.toStructField()
+ }
+ }
+
+ test("customized numeric attribute") {
+ val name = "age"
+ val index = 0
+ val metadata = Metadata.fromJson("""{"name":"age","idx":0}""")
+ val metadataWithType = Metadata.fromJson("""{"type":"numeric","name":"age","idx":0}""")
+ val attr: NumericAttribute = NumericAttribute.defaultAttr
+ .withName(name)
+ .withIndex(index)
+ assert(attr.attrType == AttributeType.Numeric)
+ assert(attr.isNumeric)
+ assert(!attr.isNominal)
+ assert(attr.name === Some(name))
+ assert(attr.index === Some(index))
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = false) === metadata)
+ assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === Attribute.fromMetadata(metadataWithType))
+ val field = attr.toStructField()
+ assert(field.dataType === DoubleType)
+ assert(!field.nullable)
+ assert(attr.withoutIndex === Attribute.fromStructField(field))
+ val existingMetadata = new MetadataBuilder()
+ .putString("name", "test")
+ .build()
+ assert(attr.toStructField(existingMetadata).metadata.getString("name") === "test")
+
+ val attr2 =
+ attr.withoutName.withoutIndex.withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.3)
+ assert(attr2.name.isEmpty)
+ assert(attr2.index.isEmpty)
+ assert(attr2.min === Some(0.0))
+ assert(attr2.max === Some(1.0))
+ assert(attr2.std === Some(0.5))
+ assert(attr2.sparsity === Some(0.3))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ }
+
+ test("bad numeric attributes") {
+ val attr = NumericAttribute.defaultAttr
+ intercept[IllegalArgumentException](attr.withName(""))
+ intercept[IllegalArgumentException](attr.withIndex(-1))
+ intercept[IllegalArgumentException](attr.withStd(-0.1))
+ intercept[IllegalArgumentException](attr.withSparsity(-0.5))
+ intercept[IllegalArgumentException](attr.withSparsity(1.5))
+ }
+
+ test("default nominal attribute") {
+ val attr: NominalAttribute = NominalAttribute.defaultAttr
+ val metadata = Metadata.fromJson("""{"type":"nominal"}""")
+ val metadataWithoutType = Metadata.fromJson("{}")
+ assert(attr.attrType === AttributeType.Nominal)
+ assert(!attr.isNumeric)
+ assert(attr.isNominal)
+ assert(attr.name.isEmpty)
+ assert(attr.index.isEmpty)
+ assert(attr.values.isEmpty)
+ assert(attr.numValues.isEmpty)
+ assert(attr.isOrdinal.isEmpty)
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = true) === metadata)
+ assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
+ intercept[NoSuchElementException] {
+ attr.toStructField()
+ }
+ }
+
+ test("customized nominal attribute") {
+ val name = "size"
+ val index = 1
+ val values = Array("small", "medium", "large")
+ val metadata = Metadata.fromJson(
+ """{"type":"nominal","name":"size","idx":1,"vals":["small","medium","large"]}""")
+ val metadataWithoutType = Metadata.fromJson(
+ """{"name":"size","idx":1,"vals":["small","medium","large"]}""")
+ val attr: NominalAttribute = NominalAttribute.defaultAttr
+ .withName(name)
+ .withIndex(index)
+ .withValues(values)
+ assert(attr.attrType === AttributeType.Nominal)
+ assert(!attr.isNumeric)
+ assert(attr.isNominal)
+ assert(attr.name === Some(name))
+ assert(attr.index === Some(index))
+ assert(attr.values === Some(values))
+ assert(attr.indexOf("medium") === 1)
+ assert(attr.getValue(1) === "medium")
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = true) === metadata)
+ assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
+ assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
+
+ val attr2 = attr.withoutName.withoutIndex.withValues(attr.values.get :+ "x-large")
+ assert(attr2.name.isEmpty)
+ assert(attr2.index.isEmpty)
+ assert(attr2.values.get === Array("small", "medium", "large", "x-large"))
+ assert(attr2.indexOf("x-large") === 3)
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
+ }
+
+ test("bad nominal attributes") {
+ val attr = NominalAttribute.defaultAttr
+ intercept[IllegalArgumentException](attr.withName(""))
+ intercept[IllegalArgumentException](attr.withIndex(-1))
+ intercept[IllegalArgumentException](attr.withNumValues(-1))
+ }
+
+ test("default binary attribute") {
+ val attr = BinaryAttribute.defaultAttr
+ val metadata = Metadata.fromJson("""{"type":"binary"}""")
+ val metadataWithoutType = Metadata.fromJson("{}")
+ assert(attr.attrType === AttributeType.Binary)
+ assert(attr.isNumeric)
+ assert(attr.isNominal)
+ assert(attr.name.isEmpty)
+ assert(attr.index.isEmpty)
+ assert(attr.values.isEmpty)
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = true) === metadata)
+ assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
+ intercept[NoSuchElementException] {
+ attr.toStructField()
+ }
+ }
+
+ test("customized binary attribute") {
+ val name = "clicked"
+ val index = 2
+ val values = Array("no", "yes")
+ val metadata = Metadata.fromJson(
+ """{"type":"binary","name":"clicked","idx":2,"vals":["no","yes"]}""")
+ val metadataWithoutType = Metadata.fromJson(
+ """{"name":"clicked","idx":2,"vals":["no","yes"]}""")
+ val attr = BinaryAttribute.defaultAttr
+ .withName(name)
+ .withIndex(index)
+ .withValues(values(0), values(1))
+ assert(attr.attrType === AttributeType.Binary)
+ assert(attr.isNumeric)
+ assert(attr.isNominal)
+ assert(attr.name === Some(name))
+ assert(attr.index === Some(index))
+ assert(attr.values.get === values)
+ assert(attr.toMetadata() === metadata)
+ assert(attr.toMetadata(withType = true) === metadata)
+ assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr === Attribute.fromMetadata(metadata))
+ assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
+ assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
+ }
+
+ test("bad binary attributes") {
+ val attr = BinaryAttribute.defaultAttr
+ intercept[IllegalArgumentException](attr.withName(""))
+ intercept[IllegalArgumentException](attr.withIndex(-1))
+ }
+}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 35e748f26b..4a06b9821b 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -408,7 +408,8 @@ object Unidoc {
"mllib.tree.impurity", "mllib.tree.model", "mllib.util",
"mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation",
"mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss",
- "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning"
+ "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param",
+ "ml.tuning"
),
"-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
"-noqualifier", "java.lang"