aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMike Dusenberry <dusenberrymw@gmail.com>2015-06-21 18:25:36 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-21 18:25:36 -0700
commit47c1d5629373566df9d12fdc4ceb22f38b869482 (patch)
tree34885825c9706d48e4a943c30ac1ee52aa529b94 /mllib
parenta1894422ad6b3335c84c73ba9466da6677d893cb (diff)
downloadspark-47c1d5629373566df9d12fdc4ceb22f38b869482.tar.gz
spark-47c1d5629373566df9d12fdc4ceb22f38b869482.tar.bz2
spark-47c1d5629373566df9d12fdc4ceb22f38b869482.zip
[SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow any NumericType.
Updated `Attribute.fromStructField` to allow any `NumericType`, rather than just `DoubleType`, and added unit tests for a few of the other NumericTypes. Author: Mike Dusenberry <dusenberrymw@gmail.com> Closes #6540 from dusenberrymw/SPARK-7426_AttributeFactory.fromStructField_Should_Allow_NumericTypes and squashes the following commits: 87fecb3 [Mike Dusenberry] Updated Attribute.fromStructField to allow any NumericType, rather than just DoubleType, and added unit tests for a few of the other NumericTypes.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala5
2 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index ce43a450da..e479f16902 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
import scala.annotation.varargs
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
+import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
/**
* :: DeveloperApi ::
@@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
* Creates an [[Attribute]] from a [[StructField]] instance.
*/
def fromStructField(field: StructField): Attribute = {
- require(field.dataType == DoubleType)
+ require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 72b575d022..c5fd2f9d5a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
+ // Attribute.fromStructField should accept any NumericType, not just DoubleType
+ val longFldWithMeta = new StructField("x", LongType, false, metadata)
+ assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
+ val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
+ assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}