From 31f1aebbeb77b4eb1080f22c9bece7fafd8022f8 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 13 May 2016 09:08:04 +0200 Subject: [SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label ## What changes were proposed in this pull request? Made ChiSqSelector and RFormula accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet Closes #12467 from BenFradet/SPARK-13961. --- mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'mllib/src/main/scala') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index cfecae7e0b..29f55a7f71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { + val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String) override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5219680be2..a2f3d44132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -256,8 +256,8 @@ class RFormulaModel private[feature]( val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, - "Label column already exists and is not of type DoubleType.") + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") -- cgit v1.2.3