From a4b27162f2d7cb501f71d818581c8a2471bb7cf6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 12 Mar 2015 16:34:56 -0700 Subject: [SPARK-4588] ML Attributes This continues the work in #4460 from srowen . The design doc is published on the JIRA page with some minor changes. Short description of ML attributes: https://github.com/apache/spark/pull/4925/files?diff=unified#diff-95e7f5060429f189460b44a3f8731a35R24 More details can be found in the design doc. srowen Could you help review this PR? There are many lines but most of them are boilerplate code. Author: Xiangrui Meng Author: Sean Owen Closes #4925 from mengxr/SPARK-4588-new and squashes the following commits: 71d1bd0 [Xiangrui Meng] add JavaDoc for package ml.attribute 617be40 [Xiangrui Meng] remove final; rename cardinality to numValues 393ffdc [Xiangrui Meng] forgot to include Java attribute group tests b1aceef [Xiangrui Meng] more tests e7ab467 [Xiangrui Meng] update ML attribute impl 7c944da [Sean Owen] Add FeatureType hierarchy and categorical cardinality 2a21d6d [Sean Owen] Initial draft of FeatureAttributes class --- .../apache/spark/ml/attribute/AttributeGroup.scala | 234 ++++++++++ .../apache/spark/ml/attribute/AttributeKeys.scala | 37 ++ .../apache/spark/ml/attribute/AttributeType.scala | 61 +++ .../org/apache/spark/ml/attribute/attributes.scala | 512 +++++++++++++++++++++ .../apache/spark/ml/attribute/package-info.java | 41 ++ .../org/apache/spark/ml/attribute/package.scala | 44 ++ .../ml/attribute/JavaAttributeGroupSuite.java | 45 ++ .../spark/ml/attribute/JavaAttributeSuite.java | 55 +++ .../spark/ml/attribute/AttributeGroupSuite.scala | 65 +++ .../apache/spark/ml/attribute/AttributeSuite.scala | 212 +++++++++ project/SparkBuild.scala | 3 +- 11 files changed, 1308 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java create mode 100644 mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala new file mode 100644 index 0000000000..970e6ad551 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} + +/** + * Attributes that describe a vector ML column. + * + * @param name name of the attribute group (the ML column name) + * @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs` + * can be defined. + * @param attrs optional array of attributes. Attribute will be copied with their corresponding + * indices in the array. + */ +class AttributeGroup private ( + val name: String, + val numAttributes: Option[Int], + attrs: Option[Array[Attribute]]) extends Serializable { + + require(name.nonEmpty, "Cannot have an empty string for name.") + require(!(numAttributes.isDefined && attrs.isDefined), + "Cannot have both numAttributes and attrs defined.") + + /** + * Creates an attribute group without attribute info. + * @param name name of the attribute group + */ + def this(name: String) = this(name, None, None) + + /** + * Creates an attribute group knowing only the number of attributes. + * @param name name of the attribute group + * @param numAttributes number of attributes + */ + def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None) + + /** + * Creates an attribute group with attributes. + * @param name name of the attribute group + * @param attrs array of attributes. Attributes will be copied with their corresponding indices in + * the array. + */ + def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs)) + + /** + * Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined. + */ + val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) => + attr.withIndex(i) + }.toArray) + + private lazy val nameToIndex: Map[String, Int] = { + attributes.map(_.view.flatMap { attr => + attr.name.map(_ -> attr.index.get) + }.toMap).getOrElse(Map.empty) + } + + /** Size of the attribute group. Returns -1 if the size is unknown. */ + def size: Int = { + if (numAttributes.isDefined) { + numAttributes.get + } else if (attributes.isDefined) { + attributes.get.length + } else { + -1 + } + } + + /** Test whether this attribute group contains a specific attribute. */ + def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName) + + /** Index of an attribute specified by name. */ + def indexOf(attrName: String): Int = nameToIndex(attrName) + + /** Gets an attribute by its name. */ + def apply(attrName: String): Attribute = { + attributes.get(indexOf(attrName)) + } + + /** Gets an attribute by its name. */ + def getAttr(attrName: String): Attribute = this(attrName) + + /** Gets an attribute by its index. */ + def apply(attrIndex: Int): Attribute = attributes.get(attrIndex) + + /** Gets an attribute by its index. */ + def getAttr(attrIndex: Int): Attribute = this(attrIndex) + + /** Converts to metadata without name. */ + private[attribute] def toMetadata: Metadata = { + import AttributeKeys._ + val bldr = new MetadataBuilder() + if (attributes.isDefined) { + val numericMetadata = ArrayBuffer.empty[Metadata] + val nominalMetadata = ArrayBuffer.empty[Metadata] + val binaryMetadata = ArrayBuffer.empty[Metadata] + attributes.get.foreach { + case numeric: NumericAttribute => + // Skip default numeric attributes. + if (numeric.withoutIndex != NumericAttribute.defaultAttr) { + numericMetadata += numeric.toMetadata(withType = false) + } + case nominal: NominalAttribute => + nominalMetadata += nominal.toMetadata(withType = false) + case binary: BinaryAttribute => + binaryMetadata += binary.toMetadata(withType = false) + } + val attrBldr = new MetadataBuilder + if (numericMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray) + } + if (nominalMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray) + } + if (binaryMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray) + } + bldr.putMetadata(ATTRIBUTES, attrBldr.build()) + bldr.putLong(NUM_ATTRIBUTES, attributes.get.length) + } else if (numAttributes.isDefined) { + bldr.putLong(NUM_ATTRIBUTES, numAttributes.get) + } + bldr.build() + } + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + val newMetadata = new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadata) + .build() + StructField(name, new VectorUDT, nullable = false, newMetadata) + } + + /** Converts to a StructField. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def equals(other: Any): Boolean = { + other match { + case o: AttributeGroup => + (name == o.name) && + (numAttributes == o.numAttributes) && + (attributes.map(_.toSeq) == o.attributes.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + numAttributes.hashCode + sum = 37 * sum + attributes.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods to create attribute groups. */ +object AttributeGroup { + + import AttributeKeys._ + + /** Creates an attribute group from a [[Metadata]] instance with name. */ + private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = { + import org.apache.spark.ml.attribute.AttributeType._ + if (metadata.contains(ATTRIBUTES)) { + val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt + val attributes = new Array[Attribute](numAttrs) + val attrMetadata = metadata.getMetadata(ATTRIBUTES) + if (attrMetadata.contains(Numeric.name)) { + attrMetadata.getMetadataArray(Numeric.name) + .map(NumericAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Nominal.name)) { + attrMetadata.getMetadataArray(Nominal.name) + .map(NominalAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Binary.name)) { + attrMetadata.getMetadataArray(Binary.name) + .map(BinaryAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + var i = 0 + while (i < numAttrs) { + if (attributes(i) == null) { + attributes(i) = NumericAttribute.defaultAttr + } + i += 1 + } + new AttributeGroup(name, attributes) + } else if (metadata.contains(NUM_ATTRIBUTES)) { + new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES).toInt) + } else { + new AttributeGroup(name) + } + } + + /** Creates an attribute group from a [[StructField]] instance. */ + def fromStructField(field: StructField): AttributeGroup = { + require(field.dataType == new VectorUDT) + if (field.metadata.contains(ML_ATTR)) { + fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name) + } else { + new AttributeGroup(field.name) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala new file mode 100644 index 0000000000..f714f7becc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +/** + * Keys used to store attributes. + */ +private[attribute] object AttributeKeys { + val ML_ATTR: String = "ml_attr" + val TYPE: String = "type" + val NAME: String = "name" + val INDEX: String = "idx" + val MIN: String = "min" + val MAX: String = "max" + val STD: String = "std" + val SPARSITY: String = "sparsity" + val ORDINAL: String = "ord" + val VALUES: String = "vals" + val NUM_VALUES: String = "num_vals" + val ATTRIBUTES: String = "attrs" + val NUM_ATTRIBUTES: String = "num_attrs" +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala new file mode 100644 index 0000000000..65e7e43d5a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +/** + * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], + * and [[AttributeType$#Binary]]. + */ +sealed abstract class AttributeType(val name: String) + +object AttributeType { + + /** Numeric type. */ + val Numeric: AttributeType = { + case object Numeric extends AttributeType("numeric") + Numeric + } + + /** Nominal type. */ + val Nominal: AttributeType = { + case object Nominal extends AttributeType("nominal") + Nominal + } + + /** Binary type. */ + val Binary: AttributeType = { + case object Binary extends AttributeType("binary") + Binary + } + + /** + * Gets the [[AttributeType]] object from its name. + * @param name attribute type name: "numeric", "nominal", or "binary" + */ + def fromName(name: String): AttributeType = { + if (name == Numeric.name) { + Numeric + } else if (name == Nominal.name) { + Nominal + } else if (name == Binary.name) { + Binary + } else { + throw new IllegalArgumentException(s"Cannot recognize type $name.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala new file mode 100644 index 0000000000..00b7566aab --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -0,0 +1,512 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.annotation.varargs + +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} + +/** + * Abstract class for ML attributes. + */ +sealed abstract class Attribute extends Serializable { + + name.foreach { n => + require(n.nonEmpty, "Cannot have an empty string for name.") + } + index.foreach { i => + require(i >= 0, s"Index cannot be negative but got $i") + } + + /** Attribute type. */ + def attrType: AttributeType + + /** Name of the attribute. None if it is not set. */ + def name: Option[String] + + /** Copy with a new name. */ + def withName(name: String): Attribute + + /** Copy without the name. */ + def withoutName: Attribute + + /** Index of the attribute. None if it is not set. */ + def index: Option[Int] + + /** Copy with a new index. */ + def withIndex(index: Int): Attribute + + /** Copy without the index. */ + def withoutIndex: Attribute + + /** + * Tests whether this attribute is numeric, true for [[NumericAttribute]] and [[BinaryAttribute]]. + */ + def isNumeric: Boolean + + /** + * Tests whether this attribute is nominal, true for [[NominalAttribute]] and [[BinaryAttribute]]. + */ + def isNominal: Boolean + + /** + * Converts this attribute to [[Metadata]]. + * @param withType whether to include the type info + */ + private[attribute] def toMetadata(withType: Boolean): Metadata + + /** + * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to + * save space, because numeric type is the default attribute type. For nominal and binary + * attributes, the type info is included. + */ + private[attribute] def toMetadata(): Metadata = { + if (attrType == AttributeType.Numeric) { + toMetadata(withType = false) + } else { + toMetadata(withType = true) + } + } + + /** + * Converts to a [[StructField]] with some existing metadata. + * @param existingMetadata existing metadata to carry over + */ + def toStructField(existingMetadata: Metadata): StructField = { + val newMetadata = new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata()) + .build() + StructField(name.get, DoubleType, nullable = false, newMetadata) + } + + /** Converts to a [[StructField]]. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def toString: String = toMetadata(withType = true).toString +} + +/** Trait for ML attribute factories. */ +private[attribute] trait AttributeFactory { + + /** + * Creates an [[Attribute]] from a [[Metadata]] instance. + */ + private[attribute] def fromMetadata(metadata: Metadata): Attribute + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = { + require(field.dataType == DoubleType) + fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name) + } +} + +object Attribute extends AttributeFactory { + + private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val attrType = if (metadata.contains(TYPE)) { + metadata.getString(TYPE) + } else { + AttributeType.Numeric.name + } + getFactory(attrType).fromMetadata(metadata) + } + + /** Gets the attribute factory given the attribute type name. */ + private def getFactory(attrType: String): AttributeFactory = { + if (attrType == AttributeType.Numeric.name) { + NumericAttribute + } else if (attrType == AttributeType.Nominal.name) { + NominalAttribute + } else if (attrType == AttributeType.Binary.name) { + BinaryAttribute + } else { + throw new IllegalArgumentException(s"Cannot recognize type $attrType.") + } + } +} + + +/** + * A numeric attribute with optional summary statistics. + * @param name optional name + * @param index optional index + * @param min optional min value + * @param max optional max value + * @param std optional standard deviation + * @param sparsity optional sparsity (ratio of zeros) + */ +class NumericAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val min: Option[Double] = None, + val max: Option[Double] = None, + val std: Option[Double] = None, + val sparsity: Option[Double] = None) extends Attribute { + + std.foreach { s => + require(s >= 0.0, s"Standard deviation cannot be negative but got $s.") + } + sparsity.foreach { s => + require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.") + } + + override def attrType: AttributeType = AttributeType.Numeric + + override def withName(name: String): NumericAttribute = copy(name = Some(name)) + override def withoutName: NumericAttribute = copy(name = None) + + override def withIndex(index: Int): NumericAttribute = copy(index = Some(index)) + override def withoutIndex: NumericAttribute = copy(index = None) + + /** Copy with a new min value. */ + def withMin(min: Double): NumericAttribute = copy(min = Some(min)) + + /** Copy without the min value. */ + def withoutMin: NumericAttribute = copy(min = None) + + + /** Copy with a new max value. */ + def withMax(max: Double): NumericAttribute = copy(max = Some(max)) + + /** Copy without the max value. */ + def withoutMax: NumericAttribute = copy(max = None) + + /** Copy with a new standard deviation. */ + def withStd(std: Double): NumericAttribute = copy(std = Some(std)) + + /** Copy without the standard deviation. */ + def withoutStd: NumericAttribute = copy(std = None) + + /** Copy with a new sparsity. */ + def withSparsity(sparsity: Double): NumericAttribute = copy(sparsity = Some(sparsity)) + + /** Copy without the sparsity. */ + def withoutSparsity: NumericAttribute = copy(sparsity = None) + + /** Copy without summary statistics. */ + def withoutSummary: NumericAttribute = copy(min = None, max = None, std = None, sparsity = None) + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = false + + /** Convert this attribute to metadata. */ + private[attribute] override def toMetadata(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + min.foreach(bldr.putDouble(MIN, _)) + max.foreach(bldr.putDouble(MAX, _)) + std.foreach(bldr.putDouble(STD, _)) + sparsity.foreach(bldr.putDouble(SPARSITY, _)) + bldr.build() + } + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + min: Option[Double] = min, + max: Option[Double] = max, + std: Option[Double] = std, + sparsity: Option[Double] = sparsity): NumericAttribute = { + new NumericAttribute(name, index, min, max, std, sparsity) + } + + override def equals(other: Any): Boolean = { + other match { + case o: NumericAttribute => + (name == o.name) && + (index == o.index) && + (min == o.min) && + (max == o.max) && + (std == o.std) && + (sparsity == o.sparsity) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + min.hashCode + sum = 37 * sum + max.hashCode + sum = 37 * sum + std.hashCode + sum = 37 * sum + sparsity.hashCode + sum + } +} + +/** + * Factory methods for numeric attributes. + */ +object NumericAttribute extends AttributeFactory { + + /** The default numeric attribute. */ + val defaultAttr: NumericAttribute = new NumericAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NumericAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val min = if (metadata.contains(MIN)) Some(metadata.getDouble(MIN)) else None + val max = if (metadata.contains(MAX)) Some(metadata.getDouble(MAX)) else None + val std = if (metadata.contains(STD)) Some(metadata.getDouble(STD)) else None + val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None + new NumericAttribute(name, index, min, max, std, sparsity) + } +} + +/** + * A nominal attribute. + * @param name optional name + * @param index optional index + * @param isOrdinal whether this attribute is ordinal (optional) + * @param numValues optional number of values. At most one of `numValues` and `values` can be + * defined. + * @param values optional values. At most one of `numValues` and `values` can be defined. + */ +class NominalAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val isOrdinal: Option[Boolean] = None, + val numValues: Option[Int] = None, + val values: Option[Array[String]] = None) extends Attribute { + + numValues.foreach { n => + require(n >= 0, s"numValues cannot be negative but got $n.") + } + require(!(numValues.isDefined && values.isDefined), + "Cannot have both numValues and values defined.") + + override def attrType: AttributeType = AttributeType.Nominal + + override def isNumeric: Boolean = false + + override def isNominal: Boolean = true + + private lazy val valueToIndex: Map[String, Int] = { + values.map(_.zipWithIndex.toMap).getOrElse(Map.empty) + } + + /** Index of a specific value. */ + def indexOf(value: String): Int = { + valueToIndex(value) + } + + /** Tests whether this attribute contains a specific value. */ + def hasValue(value: String): Boolean = valueToIndex.contains(value) + + /** Gets a value given its index. */ + def getValue(index: Int): String = values.get(index) + + override def withName(name: String): NominalAttribute = copy(name = Some(name)) + override def withoutName: NominalAttribute = copy(name = None) + + override def withIndex(index: Int): NominalAttribute = copy(index = Some(index)) + override def withoutIndex: NominalAttribute = copy(index = None) + + /** Copy with new values and empty `numValues`. */ + def withValues(values: Array[String]): NominalAttribute = { + copy(numValues = None, values = Some(values)) + } + + /** Copy with new values and empty `numValues`. */ + @varargs + def withValues(first: String, others: String*): NominalAttribute = { + copy(numValues = None, values = Some((first +: others).toArray)) + } + + /** Copy without the values. */ + def withoutValues: NominalAttribute = { + copy(values = None) + } + + /** Copy with a new `numValues` and empty `values`. */ + def withNumValues(numValues: Int): NominalAttribute = { + copy(numValues = Some(numValues), values = None) + } + + /** Copy without the `numValues`. */ + def withoutNumValues: NominalAttribute = copy(numValues = None) + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + isOrdinal: Option[Boolean] = isOrdinal, + numValues: Option[Int] = numValues, + values: Option[Array[String]] = values): NominalAttribute = { + new NominalAttribute(name, index, isOrdinal, numValues, values) + } + + private[attribute] override def toMetadata(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + isOrdinal.foreach(bldr.putBoolean(ORDINAL, _)) + numValues.foreach(bldr.putLong(NUM_VALUES, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: NominalAttribute => + (name == o.name) && + (index == o.index) && + (isOrdinal == o.isOrdinal) && + (numValues == o.numValues) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + isOrdinal.hashCode + sum = 37 * sum + numValues.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods for nominal attributes. */ +object NominalAttribute extends AttributeFactory { + + /** The default nominal attribute. */ + final val defaultAttr: NominalAttribute = new NominalAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NominalAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val isOrdinal = if (metadata.contains(ORDINAL)) Some(metadata.getBoolean(ORDINAL)) else None + val numValues = + if (metadata.contains(NUM_VALUES)) Some(metadata.getLong(NUM_VALUES).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new NominalAttribute(name, index, isOrdinal, numValues, values) + } +} + +/** + * A binary attribute. + * @param name optional name + * @param index optional index + * @param values optionla values. If set, its size must be 2. + */ +class BinaryAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val values: Option[Array[String]] = None) + extends Attribute { + + values.foreach { v => + require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.") + } + + override def attrType: AttributeType = AttributeType.Binary + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = true + + override def withName(name: String): BinaryAttribute = copy(name = Some(name)) + override def withoutName: BinaryAttribute = copy(name = None) + + override def withIndex(index: Int): BinaryAttribute = copy(index = Some(index)) + override def withoutIndex: BinaryAttribute = copy(index = None) + + /** + * Copy with new values. + * @param negative name for negative + * @param positive name for positive + */ + def withValues(negative: String, positive: String): BinaryAttribute = + copy(values = Some(Array(negative, positive))) + + /** Copy without the values. */ + def withoutValues: BinaryAttribute = copy(values = None) + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + values: Option[Array[String]] = values): BinaryAttribute = { + new BinaryAttribute(name, index, values) + } + + private[attribute] override def toMetadata(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: BinaryAttribute => + (name == o.name) && + (index == o.index) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods for binary attributes. */ +object BinaryAttribute extends AttributeFactory { + + /** The default binary attribute. */ + final val defaultAttr: BinaryAttribute = new BinaryAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): BinaryAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new BinaryAttribute(name, index, values) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java new file mode 100644 index 0000000000..e3474f3c1d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The content here should be in sync with `package.scala`. + +/** + *

ML attributes

+ * + * The ML pipeline API uses {@link org.apache.spark.sql.DataFrame}s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + *

ML columns

+ * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use {@link org.apache.spark.ml.attribute.Attribute} to describe a scalar ML column, and + * {@link org.apache.spark.ml.attribute.AttributeGroup} to describe a vector ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package org.apache.spark.ml.attribute; diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala new file mode 100644 index 0000000000..7ac21d7d56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} + +/** + * ==ML attributes== + * + * The ML pipeline API uses [[DataFrame]]s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + * ===ML columns=== + * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use [[Attribute]] to describe a scalar ML column, and [[AttributeGroup]] to describe a vector + * ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package object attribute diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java new file mode 100644 index 0000000000..38eb58673a --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaAttributeGroupSuite { + + @Test + public void testAttributeGroup() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr(), + NominalAttribute.defaultAttr(), + BinaryAttribute.defaultAttr().withIndex(0), + NumericAttribute.defaultAttr().withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr().withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr().withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr(), + NumericAttribute.defaultAttr() + }; + AttributeGroup group = new AttributeGroup("user", attrs); + Assert.assertEquals(8, group.size()); + Assert.assertEquals("user", group.name()); + Assert.assertEquals(NumericAttribute.defaultAttr().withIndex(0), group.getAttr(0)); + Assert.assertEquals(3, group.indexOf("age")); + Assert.assertFalse(group.hasAttr("abc")); + Assert.assertEquals(group, AttributeGroup.fromStructField(group.toStructField())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java new file mode 100644 index 0000000000..b74bbed231 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Test; +import org.junit.Assert; + +public class JavaAttributeSuite { + + @Test + public void testAttributeType() { + AttributeType numericType = AttributeType.Numeric(); + AttributeType nominalType = AttributeType.Nominal(); + AttributeType binaryType = AttributeType.Binary(); + Assert.assertEquals(numericType, NumericAttribute.defaultAttr().attrType()); + Assert.assertEquals(nominalType, NominalAttribute.defaultAttr().attrType()); + Assert.assertEquals(binaryType, BinaryAttribute.defaultAttr().attrType()); + } + + @Test + public void testNumericAttribute() { + NumericAttribute attr = NumericAttribute.defaultAttr() + .withName("age").withIndex(0).withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.4); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testNominalAttribute() { + NominalAttribute attr = NominalAttribute.defaultAttr() + .withName("size").withIndex(1).withValues("small", "medium", "large"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testBinaryAttribute() { + BinaryAttribute attr = BinaryAttribute.defaultAttr() + .withName("clicked").withIndex(2).withValues("no", "yes"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala new file mode 100644 index 0000000000..3fb6e2ec46 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.scalatest.FunSuite + +class AttributeGroupSuite extends FunSuite { + + test("attribute group") { + val attrs = Array( + NumericAttribute.defaultAttr, + NominalAttribute.defaultAttr, + BinaryAttribute.defaultAttr.withIndex(0), + NumericAttribute.defaultAttr.withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr) + val group = new AttributeGroup("user", attrs) + assert(group.size === 8) + assert(group.name === "user") + assert(group(0) === NumericAttribute.defaultAttr.withIndex(0)) + assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2)) + assert(group.indexOf("age") === 3) + assert(group.indexOf("size") === 4) + assert(group.indexOf("clicked") === 5) + assert(!group.hasAttr("abc")) + intercept[NoSuchElementException] { + group("abc") + } + assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) + } + + test("attribute group without attributes") { + val group0 = new AttributeGroup("user", 10) + assert(group0.name === "user") + assert(group0.numAttributes === Some(10)) + assert(group0.size === 10) + assert(group0.attributes.isEmpty) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + + val group1 = new AttributeGroup("item") + assert(group1.name === "item") + assert(group1.numAttributes.isEmpty) + assert(group1.attributes.isEmpty) + assert(group1.size === -1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala new file mode 100644 index 0000000000..6ec35b0365 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata} + +class AttributeSuite extends FunSuite { + + test("default numeric attribute") { + val attr: NumericAttribute = NumericAttribute.defaultAttr + val metadata = Metadata.fromJson("{}") + val metadataWithType = Metadata.fromJson("""{"type":"numeric"}""") + assert(attr.attrType === AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.min.isEmpty) + assert(attr.max.isEmpty) + assert(attr.std.isEmpty) + assert(attr.sparsity.isEmpty) + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = false) === metadata) + assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized numeric attribute") { + val name = "age" + val index = 0 + val metadata = Metadata.fromJson("""{"name":"age","idx":0}""") + val metadataWithType = Metadata.fromJson("""{"type":"numeric","name":"age","idx":0}""") + val attr: NumericAttribute = NumericAttribute.defaultAttr + .withName(name) + .withIndex(index) + assert(attr.attrType == AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = false) === metadata) + assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + val field = attr.toStructField() + assert(field.dataType === DoubleType) + assert(!field.nullable) + assert(attr.withoutIndex === Attribute.fromStructField(field)) + val existingMetadata = new MetadataBuilder() + .putString("name", "test") + .build() + assert(attr.toStructField(existingMetadata).metadata.getString("name") === "test") + + val attr2 = + attr.withoutName.withoutIndex.withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.3) + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.min === Some(0.0)) + assert(attr2.max === Some(1.0)) + assert(attr2.std === Some(0.5)) + assert(attr2.sparsity === Some(0.3)) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) + } + + test("bad numeric attributes") { + val attr = NumericAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withStd(-0.1)) + intercept[IllegalArgumentException](attr.withSparsity(-0.5)) + intercept[IllegalArgumentException](attr.withSparsity(1.5)) + } + + test("default nominal attribute") { + val attr: NominalAttribute = NominalAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"nominal"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.numValues.isEmpty) + assert(attr.isOrdinal.isEmpty) + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = true) === metadata) + assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized nominal attribute") { + val name = "size" + val index = 1 + val values = Array("small", "medium", "large") + val metadata = Metadata.fromJson( + """{"type":"nominal","name":"size","idx":1,"vals":["small","medium","large"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"size","idx":1,"vals":["small","medium","large"]}""") + val attr: NominalAttribute = NominalAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values) + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values === Some(values)) + assert(attr.indexOf("medium") === 1) + assert(attr.getValue(1) === "medium") + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = true) === metadata) + assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + + val attr2 = attr.withoutName.withoutIndex.withValues(attr.values.get :+ "x-large") + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.values.get === Array("small", "medium", "large", "x-large")) + assert(attr2.indexOf("x-large") === 3) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false))) + } + + test("bad nominal attributes") { + val attr = NominalAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withNumValues(-1)) + } + + test("default binary attribute") { + val attr = BinaryAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"binary"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = true) === metadata) + assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized binary attribute") { + val name = "clicked" + val index = 2 + val values = Array("no", "yes") + val metadata = Metadata.fromJson( + """{"type":"binary","name":"clicked","idx":2,"vals":["no","yes"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"clicked","idx":2,"vals":["no","yes"]}""") + val attr = BinaryAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values(0), values(1)) + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values.get === values) + assert(attr.toMetadata() === metadata) + assert(attr.toMetadata(withType = true) === metadata) + assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + } + + test("bad binary attributes") { + val attr = BinaryAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 35e748f26b..4a06b9821b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -408,7 +408,8 @@ object Unidoc { "mllib.tree.impurity", "mllib.tree.model", "mllib.util", "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning" + "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", + "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" -- cgit v1.2.3