From 971b95b0c9002bd541bcbe0da54a9967ba22588f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 13 Apr 2015 21:18:05 -0700 Subject: [SPARK-5957][ML] better handling of parameters The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley 1. Use codegen for shared params. 1. Move shared params to package `ml.param.shared`. 1. Set default values in `Params` instead of in `Param`. 1. Add a few methods to `Params` and `ParamMap`. 1. Move schema handling to `SchemaUtils` from `Params`. - [x] check visibility of the methods added Author: Xiangrui Meng Closes #5431 from mengxr/SPARK-5957 and squashes the following commits: d19236d [Xiangrui Meng] fix test 26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected 38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain() 409e2d5 [Xiangrui Meng] address comments 2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 eec2264 [Xiangrui Meng] make get* public in Params 4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 4fee9e7 [Xiangrui Meng] re-gen shared params 2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen e938f81 [Xiangrui Meng] update code to set default parameter values 28ed322 [Xiangrui Meng] merge master 55be1f3 [Xiangrui Meng] merge master d63b5cc [Xiangrui Meng] fix examples 29b004c [Xiangrui Meng] update ParamsSuite 94fd98e [Xiangrui Meng] fix explain params 48d0e84 [Xiangrui Meng] add remove and update explainParams 4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params 0d9594e [Xiangrui Meng] add getOrElse to ParamMap eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues 0d3fc5b [Xiangrui Meng] setDefault after param a9dbf59 [Xiangrui Meng] minor updates d9302b8 [Xiangrui Meng] generate default values 1c72579 [Xiangrui Meng] pass test compile abb7a3b [Xiangrui Meng] update default values handling dcab97a [Xiangrui Meng] add codegen for shared params --- .../main/scala/org/apache/spark/ml/Estimator.scala | 2 +- .../main/scala/org/apache/spark/ml/Pipeline.scala | 10 +- .../scala/org/apache/spark/ml/Transformer.scala | 5 +- .../spark/ml/classification/Classifier.scala | 17 +- .../ml/classification/LogisticRegression.scala | 18 +- .../classification/ProbabilisticClassifier.scala | 11 +- .../evaluation/BinaryClassificationEvaluator.scala | 15 +- .../org/apache/spark/ml/feature/HashingTF.scala | 6 +- .../org/apache/spark/ml/feature/Normalizer.scala | 7 +- .../apache/spark/ml/feature/StandardScaler.scala | 9 +- .../apache/spark/ml/feature/StringIndexer.scala | 10 +- .../org/apache/spark/ml/feature/Tokenizer.scala | 16 +- .../apache/spark/ml/feature/VectorAssembler.scala | 7 +- .../apache/spark/ml/feature/VectorIndexer.scala | 25 +- .../apache/spark/ml/impl/estimator/Predictor.scala | 16 +- .../scala/org/apache/spark/ml/param/params.scala | 236 ++++++++++++------- .../ml/param/shared/SharedParamsCodeGen.scala | 169 ++++++++++++++ .../spark/ml/param/shared/sharedParams.scala | 259 +++++++++++++++++++++ .../org/apache/spark/ml/param/sharedParams.scala | 173 -------------- .../org/apache/spark/ml/recommendation/ALS.scala | 49 ++-- .../spark/ml/regression/LinearRegression.scala | 8 +- .../apache/spark/ml/tuning/CrossValidator.scala | 18 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 61 +++++ .../org/apache/spark/ml/param/ParamsSuite.scala | 47 +++- .../org/apache/spark/ml/param/TestParams.scala | 12 +- 25 files changed, 815 insertions(+), 391 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index eff7ef925d..d6b3503ebd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { */ @varargs def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + val map = ParamMap(paramPairs: _*) fit(dataset, map) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a455341a1f..8eddf79cdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -177,14 +177,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9a5848684b..7fb87fe452 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,6 +22,7 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index c5fc89f935..29339c98f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -67,8 +69,7 @@ private[spark] abstract class Classifier[ with ClassifierParams { /** @group setParam */ - def setRawPredictionCol(value: String): E = - set(rawPredictionCol, value).asInstanceOf[E] + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) } @@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 34625745dd..cc8b0721cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -31,8 +31,10 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold { + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) +} /** * :: AlphaComponent :: @@ -45,10 +47,6 @@ class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) - /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -100,8 +98,6 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - setThreshold(0.5) - /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -123,7 +119,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -184,7 +180,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > getThreshold) 1 else 0 } override protected def predictProbabilities(features: Vector): Vector = { @@ -199,7 +195,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(this.extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index bd8caac855..10404548cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * Params for probabilistic classification. */ @@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2360f4479f..c865eb9fe0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType - /** * :: AlphaComponent :: * @@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params * @group param */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = get(metricName) + def getMetricName: String = getOrDefault(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema - checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index fc4e12773c..b20f2fc49a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { * number of features * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ - def getNumFeatures: Int = get(numFeatures) + def getNumFeatures: Int = getOrDefault(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + setDefault(numFeatures -> (1 << 18)) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 05f91dc910..decaeb0da6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -35,14 +35,16 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { * Normalization in L^p^ space, p = 2 by default. * @group param */ - val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + val p = new DoubleParam(this, "p", "the p norm value") /** @group getParam */ - def getP: Double = get(p) + def getP: Double = getOrDefault(p) /** @group setParam */ def setP(value: Double): this.type = set(p, value) + setDefault(p -> 2.0) + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform @@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { override protected def outputDataType: DataType = new VectorUDT() } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1142aa4f8e..1b102619b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -47,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -56,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -86,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 61e6742e88..4d960df357 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -22,6 +22,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructType} @@ -34,8 +36,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - checkInputColumn(schema, map(inputCol), StringType) + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -64,7 +66,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase // TODO: handle unseen labels override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) @@ -105,7 +107,7 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 68401e3695..376a004858 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -56,39 +56,39 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param for minimum token length, default is one to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = get(minTokenLength) + def getMinTokenLength: Int = getOrDefault(minTokenLength) /** * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam( - this, "gaps", "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = get(gaps) + def getGaps: Boolean = getOrDefault(gaps) /** * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[String] = new Param( - this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern) + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => val re = paramMap(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index d1b8f7e6e9..e567e069e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -22,7 +22,8 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -44,7 +45,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } @@ -61,7 +62,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputColNames = map(inputCols) val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 8760960e19..452faa06e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, Attribute, AttributeGroup} -import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions.callUDF @@ -40,11 +42,12 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ val maxCategories = new IntParam(this, "maxCategories", "Threshold for the number of values a categorical feature can take." + - " If a feature is found to have > maxCategories values, then it is declared continuous.", - Some(20)) + " If a feature is found to have > maxCategories values, then it is declared continuous.") /** @group getParam */ - def getMaxCategories: Int = get(maxCategories) + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) } /** @@ -101,7 +104,7 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val firstRow = dataset.select(map(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size @@ -120,12 +123,12 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) - addOutputColumn(schema, map(outputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) } } @@ -320,7 +323,7 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val newField = prepOutputField(dataset.schema, map) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) // For now, just check the first row of inputCol for vector length. @@ -334,13 +337,13 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexerModel requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index dfb89cc8d4..195333a5cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -53,14 +55,14 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - checkInputColumn(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) } - addOutputColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) } } @@ -98,7 +100,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -141,7 +143,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -201,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 7d5178d0ab..849c60433c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable -import org.apache.spark.sql.types.{DataType, StructField, StructType} - /** * :: AlphaComponent :: @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -55,58 +49,55 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { - s"$name: $doc" + "(undefined)" } + s"$name: $doc $valueStr" } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) - extends Param[Double](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) - extends Param[Int](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) - extends Param[Float](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) - extends Param[Long](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) - extends Param[Boolean](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -124,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -153,25 +147,29 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - protected def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) + protected final def set[T](param: Param[T], value: T): this.type = { + shouldOwn(param) paramMap.put(param.asInstanceOf[Param[Any]], value) this } @@ -179,52 +177,102 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** - * Gets the value of a parameter in the embedded param map. + * Optionally returns the user-supplied value of a param. + */ + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) + } + + /** + * Clears the user-supplied value for the input param. + */ + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } + + /** + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. + */ + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get + } + + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value */ - protected def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + protected final def setDefault[T](param: Param[T], value: T): this.type = { + shouldOwn(param) + defaultParamMap.put(param, value) + this } /** - * Internal param map. + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. */ - protected val paramMap: ParamMap = ParamMap.empty + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } /** - * Check whether the given schema contains an input column. - * @param colName Input column name - * @param dataType Input column DataType + * Gets the default value of a parameter. */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) } /** - * Add an output column to the given schema. - * This fails if the given output column already exists. - * @param schema Initial schema (not modified) - * @param colName Output column name. If this column name is an empy String "", this method - * returns the initial schema, unchanged. This allows users to disable output - * columns. - * @param dataType Output column DataType - */ - protected def addOutputColumn( - schema: StructType, - colName: String, - dataType: DataType): StructType = { - if (colName.length == 0) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Output column $colName already exists.") - val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) - StructType(outputFields) + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + */ + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap + } + + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") } } @@ -261,12 +309,13 @@ private[spark] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -288,12 +337,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -301,10 +355,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -316,6 +367,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -325,7 +383,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) @@ -337,7 +395,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. @@ -363,7 +421,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Number of param pairs in this set. + * Number of param pairs in this map. */ def size: Int = map.size } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 0000000000..95d7e64790 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name$defaultValueDoc. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + |$setDefault + | /** @group getParam */ + | final def get$Name: $T = getOrDefault($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 0000000000..72b08bf276 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = getOrDefault(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol (default: "features"). + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = getOrDefault(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol (default: "label"). + */ +@DeveloperApi +trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = getOrDefault(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol (default: "prediction"). + */ +@DeveloperApi +trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = getOrDefault(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol (default: "probability"). + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = getOrDefault(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + + /** @group getParam */ + final def getThreshold: Double = getOrDefault(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = getOrDefault(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCols. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = getOrDefault(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) +} + +/** + * :: DeveloperApi :: + * Trait for shared param fitIntercept (default: true). + */ +@DeveloperApi +trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = getOrDefault(fitIntercept) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index 07e6eb4177..0000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -/* NOTE TO DEVELOPERS: - * If you mix these parameter traits into your algorithm, please add a setter method as well - * so that users may use a builder pattern: - * val myLearner = new MyLearner().setParam1(x).setParam2(y)... - */ - -private[ml] trait HasRegParam extends Params { - /** - * param for regularization parameter - * @group param - */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - - /** @group getParam */ - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** - * param for max number of iterations - * @group param - */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - - /** @group getParam */ - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** - * param for features column name - * @group param - */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - - /** @group getParam */ - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** - * param for label column name - * @group param - */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - - /** @group getParam */ - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** - * param for prediction column name - * @group param - */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - - /** @group getParam */ - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasRawPredictionCol extends Params { - /** - * param for raw prediction column name - * @group param - */ - val rawPredictionCol: Param[String] = - new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", - Some("rawPrediction")) - - /** @group getParam */ - def getRawPredictionCol: String = get(rawPredictionCol) -} - -private[ml] trait HasProbabilityCol extends Params { - /** - * param for predicted class conditional probabilities column name - * @group param - */ - val probabilityCol: Param[String] = - new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", - Some("probability")) - - /** @group getParam */ - def getProbabilityCol: String = get(probabilityCol) -} - -private[ml] trait HasFitIntercept extends Params { - /** - * param for fitting the intercept term, defaults to true - * @group param - */ - val fitIntercept: BooleanParam = - new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) - - /** @group getParam */ - def getFitIntercept: Boolean = get(fitIntercept) -} - -private[ml] trait HasThreshold extends Params { - /** - * param for threshold in (binary) prediction - * @group param - */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - - /** @group getParam */ - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** - * param for input column name - * @group param - */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - - /** @group getParam */ - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasInputCols extends Params { - /** - * Param for input column names. - */ - val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names") - - /** @group getParam */ - def getInputCols: Array[String] = get(inputCols) -} - -private[ml] trait HasOutputCol extends Params { - /** - * param for output column name - * @group param - */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - - /** @group getParam */ - def getOutputCol: String = get(outputCol) -} - -private[ml] trait HasCheckpointInterval extends Params { - /** - * param for checkpoint interval - * @group param - */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") - - /** @group getParam */ - def getCheckpointInterval: Int = get(checkpointInterval) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 52c9e95d60..bd793beba3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -54,86 +55,88 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for rank of the matrix factorization. * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ - def getRank: Int = get(rank) + def getRank: Int = getOrDefault(rank) /** * Param for number of user blocks. * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ - def getNumUserBlocks: Int = get(numUserBlocks) + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** * Param for number of item blocks. * @group param */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ - def getNumItemBlocks: Int = get(numItemBlocks) + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. * @group param */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = get(implicitPrefs) + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation. * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ - def getAlpha: Double = get(alpha) + def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = get(userCol) + def getUserCol: String = getOrDefault(userCol) /** * Param for the column name for item ids. * @group param */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) + val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = get(itemCol) + def getItemCol: String = getOrDefault(itemCol) /** * Param for the column name for ratings. * @group param */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = get(ratingCol) + def getRatingCol: String = getOrDefault(ratingCol) /** * Param for whether to apply nonnegativity constraints. * @group param */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - val getNonnegative: Boolean = get(nonnegative) + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -142,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -171,7 +174,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -283,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 65f6627a0c..26ca7459c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql.DataFrame @@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setRegParam(0.1) - setMaxIter(100) + setDefault(regParam -> 0.1, maxIter -> 100) /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2eb1dac56f..4bb4ed813c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { + /** * param for the estimator to be cross-validated * @group param @@ -38,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = get(estimator) + def getEstimator: Estimator[_] = getOrDefault(estimator) /** * param for estimator param maps @@ -48,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) /** * param for the evaluator for selection @@ -57,17 +58,18 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = get(evaluator) + def getEvaluator: Evaluator = getOrDefault(evaluator) /** * param for number of folds for cross validation * @group param */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ - def getNumFolds: Int = get(numFolds) + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -92,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -130,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 0000000000..0383bf0b38 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce2987612..88ea679eea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,25 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) @@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { @@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index ce52f2f230..8f9ab687c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -20,17 +20,21 @@ package org.apache.spark.ml.param /** A subclass of Params for testing. */ class TestParams extends Params { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) + def getInputCol: String = getOrDefault(inputCol) + + setDefault(maxIter -> 10) override def validate(paramMap: ParamMap): Unit = { - val m = this.paramMap ++ paramMap + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } -- cgit v1.2.3