aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-19 12:46:27 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-19 12:46:27 -0800
commita5fed34355b403188ad50b567ab62ee54597b493 (patch)
treebc8805bc728ebed7e9b010e285ad5c82e42bb83b /mllib
parent8ca3418e1b3e2687e75a08c185d17045a97279fb (diff)
downloadspark-a5fed34355b403188ad50b567ab62ee54597b493.tar.gz
spark-a5fed34355b403188ad50b567ab62ee54597b493.tar.bz2
spark-a5fed34355b403188ad50b567ab62ee54597b493.zip
[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml
For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml. This would be nice to include in Spark 1.3 CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #4682 from jkbradley/SPARK-5902 and squashes the following commits: 6f02357 [Joseph K. Bradley] Made transformSchema public 0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala4
5 files changed, 20 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5607ed21af..5bbcd2e080 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
+ * :: DeveloperAPI ::
+ *
* Derives the output schema from the input schema and parameters.
+ * The schema describes the columns and types of the data.
+ *
+ * @param schema Input schema to this stage
+ * @param paramMap Parameters passed to this stage
+ * @return Output schema from this stage
*/
- private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+ @DeveloperApi
+ def transformSchema(schema: StructType, paramMap: ParamMap): StructType
/**
* Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
new PipelineModel(this, map, transformers.toArray)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ddbd648d64..1142aa4f8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
model
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 7daeff980f..dfb89cc8d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
}
@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 8d70e4347c..c2ec716f08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
model
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b07a68269c..2eb1dac56f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
cvModel
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transformSchema(schema, paramMap)
}
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
bestModel.transform(dataset, paramMap)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transformSchema(schema, paramMap)
}
}