From 49f5b0ae4c31e4b7369104a14e562e1546aa7736 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 23 Jan 2017 17:24:53 -0800 Subject: [SPARK-17747][ML] WeightCol support non-double numeric datatypes ## What changes were proposed in this pull request? 1, add test for `WeightCol` in `MLTestingUtils.checkNumericTypes` 2, move datatype cast to `Predict.fit`, and supply algos' `train()` with casted dataframe ## How was this patch tested? local tests in spark-shell and unit tests Author: Zheng RuiFeng Closes #15314 from zhengruifeng/weightCol_support_int. --- .../main/scala/org/apache/spark/ml/Predictor.scala | 32 ++++++++++--- .../spark/ml/regression/IsotonicRegression.scala | 9 ++-- .../scala/org/apache/spark/ml/PredictorSuite.scala | 26 +++++++---- .../classification/LogisticRegressionSuite.scala | 2 +- .../spark/ml/classification/NaiveBayesSuite.scala | 6 +-- .../GeneralizedLinearRegressionSuite.scala | 2 +- .../ml/regression/IsotonicRegressionSuite.scala | 2 +- .../ml/regression/LinearRegressionSuite.scala | 2 +- .../org/apache/spark/ml/util/MLTestingUtils.scala | 52 +++++++++++++++++----- 9 files changed, 95 insertions(+), 38 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 4b43a3aa5b..215f9d86f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -40,7 +40,7 @@ private[ml] trait PredictorParams extends Params * @param schema input schema * @param fitting whether this is in fitting * @param featuresDataType SQL DataType for FeaturesType. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * E.g., [[VectorUDT]] for vector features. * @return output schema */ protected def validateAndTransformSchema( @@ -51,6 +51,14 @@ private[ml] trait PredictorParams extends Params SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { SchemaUtils.checkNumericType(schema, $(labelCol)) + + this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + SchemaUtils.checkNumericType(schema, $(p.weightCol)) + } + case _ => + } } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } @@ -59,10 +67,12 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: * Abstraction for prediction problems (regression and classification). It accepts all NumericType - * labels and will automatically cast it to DoubleType in `fit()`. + * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports + * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType + * in `fit()`. * * @tparam FeaturesType Type of features. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * E.g., [[VectorUDT]] for vector features. * @tparam Learner Specialization of this class. If you subclass this type, use this type * parameter to specify the concrete type. * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type @@ -91,7 +101,19 @@ abstract class Predictor[ // Cast LabelCol to DoubleType and keep the metadata. val labelMeta = dataset.schema($(labelCol)).metadata - val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + // Cast WeightCol to DoubleType and keep the metadata. + val casted = this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + val weightMeta = dataset.schema($(p.weightCol)).metadata + labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta) + } else { + labelCasted + } + case _ => labelCasted + } copyValues(train(casted).setParent(this)) } @@ -138,7 +160,7 @@ abstract class Predictor[ * Abstraction for a model for prediction tasks (regression and classification). * * @tparam FeaturesType Type of features. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * E.g., [[VectorUDT]] for vector features. * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type * parameter to specify the concrete type for the corresponding model. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 1ed9d3c809..90e77bc76e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -86,11 +86,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { col($(featuresCol)) } - val w = if (hasWeightCol) { - col($(weightCol)) - } else { - lit(1.0) - } + val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) + dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) @@ -109,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures if (fitting) { SchemaUtils.checkNumericType(schema, $(labelCol)) if (hasWeightCol) { - SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(weightCol)) } else { logInfo("The weight column is not defined. Treat all instance weights as 1.0.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index 03e0c536a9..ec45e32d41 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -30,24 +31,28 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { import PredictorSuite._ - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { val df = spark.createDataFrame(Seq( - (0, Vectors.dense(0, 2, 3)), - (1, Vectors.dense(0, 3, 9)), - (0, Vectors.dense(0, 2, 6)) - )).toDF("label", "features") + (0, 1, Vectors.dense(0, 2, 3)), + (1, 2, Vectors.dense(0, 3, 9)), + (0, 3, Vectors.dense(0, 2, 6)) + )).toDF("label", "weight", "features") val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - val predictor = new MockPredictor() + val predictor = new MockPredictor().setWeightCol("weight") types.foreach { t => - predictor.fit(df.select(col("label").cast(t), col("features"))) + predictor.fit(df.select(col("label").cast(t), col("weight").cast(t), col("features"))) } intercept[IllegalArgumentException] { - predictor.fit(df.select(col("label").cast(StringType), col("features"))) + predictor.fit(df.select(col("label").cast(StringType), col("weight"), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label"), col("weight").cast(StringType), col("features"))) } } } @@ -55,12 +60,15 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { object PredictorSuite { class MockPredictor(override val uid: String) - extends Predictor[Vector, MockPredictor, MockPredictionModel] { + extends Predictor[Vector, MockPredictor, MockPredictionModel] with HasWeightCol { def this() = this(Identifiable.randomUID("mockpredictor")) + def setWeightCol(value: String): this.type = set(weightCol, value) + override def train(dataset: Dataset[_]): MockPredictionModel = { require(dataset.schema("label").dataType == DoubleType) + require(dataset.schema("weight").dataType == DoubleType) new MockPredictionModel(uid) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index c14dcbd552..43547a4aaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2066,7 +2066,7 @@ class LogisticRegressionSuite checkModelData) } - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { val lr = new LogisticRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( lr, spark) { (expected, actual) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 2a69ef1c3e..37d7991fe8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -283,7 +283,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) } - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { val nb = new NaiveBayes() MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( nb, spark) { (expected, actual) => @@ -324,8 +324,8 @@ object NaiveBayesSuite { sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) - val _theta = theta.map(row => row.map(math.pow(math.E, _))) + val _pi = pi.map(math.exp) + val _theta = theta.map(row => row.map(math.exp)) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index e3c278777c..828b95e544 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1086,7 +1086,7 @@ class GeneralizedLinearRegressionSuite GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index c2c79476e8..8cbb2acad2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -181,7 +181,7 @@ class IsotonicRegressionSuite checkModelData) } - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( ir, spark, isClassification = false) { (expected, actual) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index e05d0c9411..584a1b272f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -988,7 +988,7 @@ class LinearRegressionSuite checkModelData) } - test("should support all NumericType labels and not support other types") { + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d219c42818..f1ed568d5e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -47,18 +47,44 @@ object MLTestingUtils extends SparkFunSuite { } else { genRegressionDFWithNumericLabelCol(spark) } - val expected = estimator.fit(dfs(DoubleType)) - val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) + + val finalEstimator = estimator match { + case weighted: Estimator[M] with HasWeightCol => + weighted.set(weighted.weightCol, "weight") + weighted + case _ => estimator + } + + val expected = finalEstimator.fit(dfs(DoubleType)) + + val actuals = dfs.keys.filter(_ != DoubleType).map { t => + finalEstimator.fit(dfs(t)) + } + actuals.foreach(actual => check(expected, actual)) val dfWithStringLabels = spark.createDataFrame(Seq( - ("0", Vectors.dense(0, 2, 3), 0.0) - )).toDF("label", "features", "censor") + ("0", 1, Vectors.dense(0, 2, 3), 0.0) + )).toDF("label", "weight", "features", "censor") val thrown = intercept[IllegalArgumentException] { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( "Column label must be of type NumericType but was actually of type StringType")) + + estimator match { + case weighted: Estimator[M] with HasWeightCol => + val dfWithStringWeights = spark.createDataFrame(Seq( + (0, "1", Vectors.dense(0, 2, 3), 0.0) + )).toDF("label", "weight", "features", "censor") + weighted.set(weighted.weightCol, "weight") + val thrown = intercept[IllegalArgumentException] { + weighted.fit(dfWithStringWeights) + } + assert(thrown.getMessage.contains( + "Column weight must be of type NumericType but was actually of type StringType")) + case _ => + } } def checkNumericTypesALS( @@ -75,7 +101,7 @@ object MLTestingUtils extends SparkFunSuite { actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_)) + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) val cols = Seq(col(column).cast(StringType)) ++ others val strDF = baseDF.select(cols: _*) val thrown = intercept[IllegalArgumentException] { @@ -104,7 +130,8 @@ object MLTestingUtils extends SparkFunSuite { def genClassifDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", - featuresColName: String = "features"): Map[NumericType, DataFrame] = { + featuresColName: String = "features", + weightColName: String = "weight"): Map[NumericType, DataFrame] = { val df = spark.createDataFrame(Seq( (0, Vectors.dense(0, 2, 3)), (1, Vectors.dense(0, 3, 1)), @@ -118,12 +145,14 @@ object MLTestingUtils extends SparkFunSuite { types.map { t => val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName) + .withColumn(weightColName, round(rand(seed = 42)).cast(t)) }.toMap } def genRegressionDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", + weightColName: String = "weight", featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { val df = spark.createDataFrame(Seq( @@ -137,10 +166,11 @@ object MLTestingUtils extends SparkFunSuite { val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types.map { t => - val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) - t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) - .withColumn(censorColName, lit(0.0)) - }.toMap + val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) + t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) + .withColumn(censorColName, lit(0.0)) + .withColumn(weightColName, round(rand(seed = 42)).cast(t)) + }.toMap } def genRatingsDFWithNumericCols( @@ -154,7 +184,7 @@ object MLTestingUtils extends SparkFunSuite { (4, 50, 5.0) )).toDF("user", "item", "rating") - val others = df.columns.toSeq.diff(Seq(column)).map(col(_)) + val others = df.columns.toSeq.diff(Seq(column)).map(col) val types: Seq[NumericType] = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types.map { t => -- cgit v1.2.3