diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fb733f9a34..7a78ecbdf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w).rdd.map { + dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } @@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures schema: StructType, fitting: Boolean): StructType = { if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) if (hasWeightCol) { SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) } else { @@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } |