aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala16
1 files changed, 8 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9a34..7a78ecbdf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
* Extracts (label, feature, weight) from input dataset.
*/
protected[ml] def extractWeightedLabeledPoints(
- dataset: DataFrame): RDD[(Double, Double, Double)] = {
+ dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
val idx = $(featureIndex)
val extract = udf { v: Vector => v(idx) }
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
- dataset.select(col($(labelCol)), f, w).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
@@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
@@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }