diff options
author | BenFradet <benjamin.fradet@gmail.com> | 2016-05-13 09:08:04 +0200 |
---|---|---|
committer | Nick Pentreath <nick.pentreath@gmail.com> | 2016-05-13 09:08:04 +0200 |
commit | 31f1aebbeb77b4eb1080f22c9bece7fafd8022f8 (patch) | |
tree | 95723ad90476594c4eeb1458fe61cc6e0007eaa3 /mllib/src/main/scala | |
parent | 5b849766ab080c91864ed06ebbfd82ad978d5e4c (diff) | |
download | spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.gz spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.bz2 spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.zip |
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label
## What changes were proposed in this pull request?
Made ChiSqSelector and RFormula accept all numeric types for label
## How was this patch tested?
Unit tests
Author: BenFradet <benjamin.fradet@gmail.com>
Closes #12467 from BenFradet/SPARK-13961.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala | 4 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index cfecae7e0b..29f55a7f71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { + val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String) override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5219680be2..a2f3d44132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -256,8 +256,8 @@ class RFormulaModel private[feature]( val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, - "Label column already exists and is not of type DoubleType.") + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") |