aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-05-13 09:08:04 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-05-13 09:08:04 +0200
commit31f1aebbeb77b4eb1080f22c9bece7fafd8022f8 (patch)
tree95723ad90476594c4eeb1458fe61cc6e0007eaa3 /mllib/src/main/scala
parent5b849766ab080c91864ed06ebbfd82ad978d5e4c (diff)
downloadspark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.gz
spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.bz2
spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.zip
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label
## What changes were proposed in this pull request? Made ChiSqSelector and RFormula accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet <benjamin.fradet@gmail.com> Closes #12467 from BenFradet/SPARK-13961.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
2 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index cfecae7e0b..29f55a7f71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
- val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
+ val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
@@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String)
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 5219680be2..a2f3d44132 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -256,8 +256,8 @@ class RFormulaModel private[feature](
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
- !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
- "Label column already exists and is not of type DoubleType.")
+ !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
+ "Label column already exists and is not of type NumericType.")
}
@Since("2.0.0")