aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-11-01 10:46:36 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-11-01 10:46:36 -0700
commit8ac09108fcf3fb62a812333a5b386b566a9d98ec (patch)
tree0865d49e98f34517e0e8a22ed19a5d766703c6ea /mllib/src/main
parent0cba535af3c65618f342fa2d7db9647f5e6f6f1b (diff)
downloadspark-8ac09108fcf3fb62a812333a5b386b566a9d98ec.tar.gz
spark-8ac09108fcf3fb62a812333a5b386b566a9d98ec.tar.bz2
spark-8ac09108fcf3fb62a812333a5b386b566a9d98ec.zip
[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit
## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15414 from zhengruifeng/move_cast.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
7 files changed, 16 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e29d7f48a1..aa92edde7a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
/**
* :: DeveloperApi ::
- * Abstraction for prediction problems (regression and classification).
+ * Abstraction for prediction problems (regression and classification). It accepts all NumericType
+ * labels and will automatically cast it to DoubleType in [[fit()]].
*
* @tparam FeaturesType Type of features.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@@ -87,7 +88,12 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
- copyValues(train(dataset).setParent(this))
+
+ // Cast LabelCol to DoubleType and keep the metadata.
+ val labelMeta = dataset.schema($(labelCol)).metadata
+ val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+ copyValues(train(casted).setParent(this))
}
override def copy(extra: ParamMap): Learner
@@ -121,7 +127,7 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d1b21b16f2..a3da3067e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -71,7 +71,7 @@ abstract class Classifier[
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
- * and features ([[Vector]]). Labels are cast to [[DoubleType]].
+ * and features ([[Vector]]).
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @throws SparkException if any label is not an integer >= 0
@@ -79,7 +79,7 @@ abstract class Classifier[
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 8bffe0cda0..f8f164e8c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val oldDataset: RDD[LabeledPoint] =
- dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fdaae04c4..c4651054fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 994ed993c9..b03a07a6bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
+ val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
seqOp = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 33cb25c8c7..8656ecf609 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 519f3bdec8..ae876b3839 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}