aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-04-01 18:25:43 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-01 18:25:43 -0700
commit36e8fb8005eccea67a9dea8cf68ec3105aa43351 (patch)
tree98cc84e42e9c5472752e6f73022ed3e2ec2b7778 /mllib/src/main
parentabc6c42c2d76d6aa4a8cac605cb4214e8f8f3328 (diff)
downloadspark-36e8fb8005eccea67a9dea8cf68ec3105aa43351.tar.gz
spark-36e8fb8005eccea67a9dea8cf68ec3105aa43351.tar.bz2
spark-36e8fb8005eccea67a9dea8cf68ec3105aa43351.zip
[SPARK-7425][ML] spark.ml Predictor should support other numeric types for label
Currently, the Predictor abstraction expects the input labelCol type to be DoubleType, but we should support other numeric types. This will involve updating the PredictorParams.validateAndTransformSchema method. Author: BenFradet <benjamin.fradet@gmail.com> Closes #10355 from BenFradet/SPARK-7425.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala24
8 files changed, 52 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ebe48700f8..d23ae6f794 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params
/**
* Validates and transforms the input schema with the provided param map.
+ *
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
@@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
- // TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
@@ -121,9 +121,8 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
- dataset.select($(labelCol), $(featuresCol)).rdd.map {
- case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3d1d5b6892..aeb94a6600 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -38,6 +38,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -265,7 +266,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
@@ -361,7 +362,7 @@ class LogisticRegression @Since("1.2.0") (
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
val vec = optInitialModel.get.coefficients
logWarning(
- s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
+ s"Initial coefficients provided $vec did not match the expected size $numFeatures")
}
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
@@ -522,7 +523,7 @@ class LogisticRegressionModel private[spark] (
(LogisticRegressionModel, String) = {
$(probabilityCol) match {
case "" =>
- val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+ val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 98b99a3485..263d54ce4d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -295,10 +295,12 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.4.0")
override def fit(dataset: DataFrame): OneVsRestModel = {
+ transformSchema(dataset.schema)
+
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
- val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708ab8d..3278974954 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -103,7 +103,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -184,10 +184,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* and put it in an RDD with strong types.
*/
protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
- dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
- case Row(features: Vector, label: Double, censor: Double) =>
- AFTPoint(features, label, censor)
- }
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+ .rdd.map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d8e1..a40d3731cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
+ *
* @group param
*/
@Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ *
* @group param
*/
@Since("2.0.0")
@@ -210,9 +212,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
if (familyObj == Gaussian && linkObj == Identity) {
@@ -698,7 +701,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9a34..bd0b631d89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
- dataset.select(col($(labelCol)), f, w).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 5ec02135cc..ba5ad4c072 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -40,6 +40,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
@@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -550,7 +551,7 @@ class LinearRegressionSummary private[regression] (
@transient private val metrics = new RegressionMetrics(
predictions
- .select(predictionCol, labelCol)
+ .select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!model.getFitIntercept)
@@ -653,7 +654,7 @@ class LinearRegressionSummary private[regression] (
col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
}
val sigma2 = rss / degreesOfFreedom
- diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
+ diagInvAtWA.map(_ * sigma2).map(math.sqrt)
}
}
@@ -826,7 +827,7 @@ private class LeastSquaresAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76021ad8f4..334410c962 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.util
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
/**
@@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
}
/**
- * Check whether the given schema contains a column of one of the require data types.
- * @param colName column name
- * @param dataTypes required column data types
- */
+ * Check whether the given schema contains a column of one of the require data types.
+ * @param colName column name
+ * @param dataTypes required column data types
+ */
def checkColumnTypes(
schema: StructType,
colName: String,
@@ -61,6 +61,20 @@ private[spark] object SchemaUtils {
}
/**
+ * Check whether the given schema contains a column of the numeric data type.
+ * @param colName column name
+ */
+ def checkNumericType(
+ schema: StructType,
+ colName: String,
+ msg: String = ""): Unit = {
+ val actualDataType = schema(colName).dataType
+ val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+ require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
+ s"NumericType but was actually of type $actualDataType.$message")
+ }
+
+ /**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param colName new column name. If this column name is an empty string "", this method returns