aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala4
5 files changed, 20 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5607ed21af..5bbcd2e080 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
+ * :: DeveloperAPI ::
+ *
* Derives the output schema from the input schema and parameters.
+ * The schema describes the columns and types of the data.
+ *
+ * @param schema Input schema to this stage
+ * @param paramMap Parameters passed to this stage
+ * @return Output schema from this stage
*/
- private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+ @DeveloperApi
+ def transformSchema(schema: StructType, paramMap: ParamMap): StructType
/**
* Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
new PipelineModel(this, map, transformers.toArray)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ddbd648d64..1142aa4f8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
model
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 7daeff980f..dfb89cc8d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
}
@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 8d70e4347c..c2ec716f08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
model
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b07a68269c..2eb1dac56f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
cvModel
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transformSchema(schema, paramMap)
}
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
bestModel.transform(dataset, paramMap)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transformSchema(schema, paramMap)
}
}