aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-07-29 04:40:20 -0700
committerSean Owen <sowen@cloudera.com>2016-07-29 04:40:20 -0700
commit0557a45452f6e73877e5ec972110825ce8f3fbc5 (patch)
tree28b18541ba9bfc1217041a08a2210c3d5835c757 /mllib/src/main
parentd1d5069aa3744d46abd3889abab5f15e9067382a (diff)
downloadspark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.gz
spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.bz2
spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.zip
[SPARK-16750][ML] Fix GaussianMixture training failed due to feature column type mistake
## What changes were proposed in this pull request? ML ```GaussianMixture``` training failed due to feature column type mistake. The feature column type should be ```ml.linalg.VectorUDT``` but got ```mllib.linalg.VectorUDT``` by mistake. See [SPARK-16750](https://issues.apache.org/jira/browse/SPARK-16750) for how to reproduce this bug. Why the unit tests did not complain this errors? Because some estimators/transformers missed calling ```transformSchema(dataset.schema)``` firstly during ```fit``` or ```transform```. I will also add this function to all estimators/transformers who missed in this PR. ## How was this patch tested? No new tests, should pass existing ones. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14378 from yanboliang/spark-16750.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala3
10 files changed, 19 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index afb1080b9b..a97bd0fb16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -99,6 +99,7 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -222,6 +223,7 @@ class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
+ transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 81749055c7..69f060ad77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -30,7 +30,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
- Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
+ Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
@@ -61,9 +61,9 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
- SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
+ SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
}
}
@@ -95,6 +95,7 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val predUDF = udf((vector: Vector) => predict(vector))
val probUDF = udf((vector: Vector) => predictProbability(vector))
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
@@ -317,6 +318,7 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
+ transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 9fb7d6a9a2..6c46be7196 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -120,6 +120,7 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -304,6 +305,7 @@ class KMeans @Since("1.5.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
+ transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 7b11f86279..96d0bdee9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -68,6 +68,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 9ed8d83324..068f11a2a5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -170,6 +170,7 @@ class MinMaxScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray
val minArray = originalMin.toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 9a636bd8a5..558a7bbf0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -97,7 +97,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(inputCol))
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
@@ -108,6 +108,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
+ transformSchema(dataset.schema, logging = true)
val splits = dataset.stat.approxQuantile($(inputCol),
(0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
splits(0) = Double.NegativeInfinity
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index c95dacfce8..2ee899bcca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -112,6 +112,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): RFormulaModel = {
+ transformSchema(dataset.schema, logging = true)
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 289037640f..259be2679c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -63,6 +63,7 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val tableName = Identifiable.randomUID(uid)
dataset.createOrReplaceTempView(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2b9912657f..d4ae59deff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
@Since("2.0.0")
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
- validateAndTransformSchema(dataset.schema, fitting = true)
+ transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
@@ -326,7 +326,7 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema)
+ transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
if (hasQuantilesCol) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 35396446ed..cd7b4f2a9c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -164,7 +164,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("2.0.0")
override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
- validateAndTransformSchema(dataset.schema, fitting = true)
+ transformSchema(dataset.schema, logging = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -234,6 +234,7 @@ class IsotonicRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }