aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2017-01-23 17:24:53 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-01-23 17:24:53 -0800
commit49f5b0ae4c31e4b7369104a14e562e1546aa7736 (patch)
treeb1495ad11208a317fb38be9071ffc93c37df25e1 /mllib
parente4974721f33e64604501f673f74052e11920d438 (diff)
downloadspark-49f5b0ae4c31e4b7369104a14e562e1546aa7736.tar.gz
spark-49f5b0ae4c31e4b7369104a14e562e1546aa7736.tar.bz2
spark-49f5b0ae4c31e4b7369104a14e562e1546aa7736.zip
[SPARK-17747][ML] WeightCol support non-double numeric datatypes
## What changes were proposed in this pull request? 1, add test for `WeightCol` in `MLTestingUtils.checkNumericTypes` 2, move datatype cast to `Predict.fit`, and supply algos' `train()` with casted dataframe ## How was this patch tested? local tests in spark-shell and unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15314 from zhengruifeng/weightCol_support_int.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala52
9 files changed, 95 insertions, 38 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 4b43a3aa5b..215f9d86f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -40,7 +40,7 @@ private[ml] trait PredictorParams extends Params
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * E.g., [[VectorUDT]] for vector features.
* @return output schema
*/
protected def validateAndTransformSchema(
@@ -51,6 +51,14 @@ private[ml] trait PredictorParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
SchemaUtils.checkNumericType(schema, $(labelCol))
+
+ this match {
+ case p: HasWeightCol =>
+ if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+ SchemaUtils.checkNumericType(schema, $(p.weightCol))
+ }
+ case _ =>
+ }
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
@@ -59,10 +67,12 @@ private[ml] trait PredictorParams extends Params
/**
* :: DeveloperApi ::
* Abstraction for prediction problems (regression and classification). It accepts all NumericType
- * labels and will automatically cast it to DoubleType in `fit()`.
+ * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports
+ * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType
+ * in `fit()`.
*
* @tparam FeaturesType Type of features.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * E.g., [[VectorUDT]] for vector features.
* @tparam Learner Specialization of this class. If you subclass this type, use this type
* parameter to specify the concrete type.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
@@ -91,7 +101,19 @@ abstract class Predictor[
// Cast LabelCol to DoubleType and keep the metadata.
val labelMeta = dataset.schema($(labelCol)).metadata
- val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+ val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+ // Cast WeightCol to DoubleType and keep the metadata.
+ val casted = this match {
+ case p: HasWeightCol =>
+ if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+ val weightMeta = dataset.schema($(p.weightCol)).metadata
+ labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
+ } else {
+ labelCasted
+ }
+ case _ => labelCasted
+ }
copyValues(train(casted).setParent(this))
}
@@ -138,7 +160,7 @@ abstract class Predictor[
* Abstraction for a model for prediction tasks (regression and classification).
*
* @tparam FeaturesType Type of features.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * E.g., [[VectorUDT]] for vector features.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
* parameter to specify the concrete type for the corresponding model.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 1ed9d3c809..90e77bc76e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -86,11 +86,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
col($(featuresCol))
}
- val w = if (hasWeightCol) {
- col($(weightCol))
- } else {
- lit(1.0)
- }
+ val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0)
+
dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
@@ -109,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
if (fitting) {
SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
- SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(weightCol))
} else {
logInfo("The weight column is not defined. Treat all instance weights as 1.0.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
index 03e0c536a9..ec45e32d41 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@@ -30,24 +31,28 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
import PredictorSuite._
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
val df = spark.createDataFrame(Seq(
- (0, Vectors.dense(0, 2, 3)),
- (1, Vectors.dense(0, 3, 9)),
- (0, Vectors.dense(0, 2, 6))
- )).toDF("label", "features")
+ (0, 1, Vectors.dense(0, 2, 3)),
+ (1, 2, Vectors.dense(0, 3, 9)),
+ (0, 3, Vectors.dense(0, 2, 6))
+ )).toDF("label", "weight", "features")
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
- val predictor = new MockPredictor()
+ val predictor = new MockPredictor().setWeightCol("weight")
types.foreach { t =>
- predictor.fit(df.select(col("label").cast(t), col("features")))
+ predictor.fit(df.select(col("label").cast(t), col("weight").cast(t), col("features")))
}
intercept[IllegalArgumentException] {
- predictor.fit(df.select(col("label").cast(StringType), col("features")))
+ predictor.fit(df.select(col("label").cast(StringType), col("weight"), col("features")))
+ }
+
+ intercept[IllegalArgumentException] {
+ predictor.fit(df.select(col("label"), col("weight").cast(StringType), col("features")))
}
}
}
@@ -55,12 +60,15 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
object PredictorSuite {
class MockPredictor(override val uid: String)
- extends Predictor[Vector, MockPredictor, MockPredictionModel] {
+ extends Predictor[Vector, MockPredictor, MockPredictionModel] with HasWeightCol {
def this() = this(Identifiable.randomUID("mockpredictor"))
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override def train(dataset: Dataset[_]): MockPredictionModel = {
require(dataset.schema("label").dataType == DoubleType)
+ require(dataset.schema("weight").dataType == DoubleType)
new MockPredictionModel(uid)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index c14dcbd552..43547a4aaf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2066,7 +2066,7 @@ class LogisticRegressionSuite
checkModelData)
}
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, spark) { (expected, actual) =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 2a69ef1c3e..37d7991fe8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -283,7 +283,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
nb, spark) { (expected, actual) =>
@@ -324,8 +324,8 @@ object NaiveBayesSuite {
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
- val _pi = pi.map(math.pow(math.E, _))
- val _theta = theta.map(row => row.map(math.pow(math.E, _)))
+ val _pi = pi.map(math.exp)
+ val _theta = theta.map(row => row.map(math.exp))
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index e3c278777c..828b95e544 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1086,7 +1086,7 @@ class GeneralizedLinearRegressionSuite
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
val glr = new GeneralizedLinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[
GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index c2c79476e8..8cbb2acad2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -181,7 +181,7 @@ class IsotonicRegressionSuite
checkModelData)
}
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
val ir = new IsotonicRegression()
MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
ir, spark, isClassification = false) { (expected, actual) =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index e05d0c9411..584a1b272f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -988,7 +988,7 @@ class LinearRegressionSuite
checkModelData)
}
- test("should support all NumericType labels and not support other types") {
+ test("should support all NumericType labels and weights, and not support other types") {
for (solver <- Seq("auto", "l-bfgs", "normal")) {
val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d219c42818..f1ed568d5e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -47,18 +47,44 @@ object MLTestingUtils extends SparkFunSuite {
} else {
genRegressionDFWithNumericLabelCol(spark)
}
- val expected = estimator.fit(dfs(DoubleType))
- val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
+
+ val finalEstimator = estimator match {
+ case weighted: Estimator[M] with HasWeightCol =>
+ weighted.set(weighted.weightCol, "weight")
+ weighted
+ case _ => estimator
+ }
+
+ val expected = finalEstimator.fit(dfs(DoubleType))
+
+ val actuals = dfs.keys.filter(_ != DoubleType).map { t =>
+ finalEstimator.fit(dfs(t))
+ }
+
actuals.foreach(actual => check(expected, actual))
val dfWithStringLabels = spark.createDataFrame(Seq(
- ("0", Vectors.dense(0, 2, 3), 0.0)
- )).toDF("label", "features", "censor")
+ ("0", 1, Vectors.dense(0, 2, 3), 0.0)
+ )).toDF("label", "weight", "features", "censor")
val thrown = intercept[IllegalArgumentException] {
estimator.fit(dfWithStringLabels)
}
assert(thrown.getMessage.contains(
"Column label must be of type NumericType but was actually of type StringType"))
+
+ estimator match {
+ case weighted: Estimator[M] with HasWeightCol =>
+ val dfWithStringWeights = spark.createDataFrame(Seq(
+ (0, "1", Vectors.dense(0, 2, 3), 0.0)
+ )).toDF("label", "weight", "features", "censor")
+ weighted.set(weighted.weightCol, "weight")
+ val thrown = intercept[IllegalArgumentException] {
+ weighted.fit(dfWithStringWeights)
+ }
+ assert(thrown.getMessage.contains(
+ "Column weight must be of type NumericType but was actually of type StringType"))
+ case _ =>
+ }
}
def checkNumericTypesALS(
@@ -75,7 +101,7 @@ object MLTestingUtils extends SparkFunSuite {
actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
val baseDF = dfs(baseType)
- val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+ val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
val cols = Seq(col(column).cast(StringType)) ++ others
val strDF = baseDF.select(cols: _*)
val thrown = intercept[IllegalArgumentException] {
@@ -104,7 +130,8 @@ object MLTestingUtils extends SparkFunSuite {
def genClassifDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",
- featuresColName: String = "features"): Map[NumericType, DataFrame] = {
+ featuresColName: String = "features",
+ weightColName: String = "weight"): Map[NumericType, DataFrame] = {
val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0, 2, 3)),
(1, Vectors.dense(0, 3, 1)),
@@ -118,12 +145,14 @@ object MLTestingUtils extends SparkFunSuite {
types.map { t =>
val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
+ .withColumn(weightColName, round(rand(seed = 42)).cast(t))
}.toMap
}
def genRegressionDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",
+ weightColName: String = "weight",
featuresColName: String = "features",
censorColName: String = "censor"): Map[NumericType, DataFrame] = {
val df = spark.createDataFrame(Seq(
@@ -137,10 +166,11 @@ object MLTestingUtils extends SparkFunSuite {
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types.map { t =>
- val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
- t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
- .withColumn(censorColName, lit(0.0))
- }.toMap
+ val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+ t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
+ .withColumn(censorColName, lit(0.0))
+ .withColumn(weightColName, round(rand(seed = 42)).cast(t))
+ }.toMap
}
def genRatingsDFWithNumericCols(
@@ -154,7 +184,7 @@ object MLTestingUtils extends SparkFunSuite {
(4, 50, 5.0)
)).toDF("user", "item", "rating")
- val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+ val others = df.columns.toSeq.diff(Seq(column)).map(col)
val types: Seq[NumericType] =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types.map { t =>