aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
commit971b95b0c9002bd541bcbe0da54a9967ba22588f (patch)
treeb2a79cf00c1d2290e7e4024df27c0ee9b203c09a /mllib
parent0ba3fdd5992cf09bd38303ebff34d2ed19e5e09b (diff)
downloadspark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.gz
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.bz2
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.zip
[SPARK-5957][ML] better handling of parameters
The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley 1. Use codegen for shared params. 1. Move shared params to package `ml.param.shared`. 1. Set default values in `Params` instead of in `Param`. 1. Add a few methods to `Params` and `ParamMap`. 1. Move schema handling to `SchemaUtils` from `Params`. - [x] check visibility of the methods added Author: Xiangrui Meng <meng@databricks.com> Closes #5431 from mengxr/SPARK-5957 and squashes the following commits: d19236d [Xiangrui Meng] fix test 26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected 38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain() 409e2d5 [Xiangrui Meng] address comments 2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 eec2264 [Xiangrui Meng] make get* public in Params 4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 4fee9e7 [Xiangrui Meng] re-gen shared params 2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen e938f81 [Xiangrui Meng] update code to set default parameter values 28ed322 [Xiangrui Meng] merge master 55be1f3 [Xiangrui Meng] merge master d63b5cc [Xiangrui Meng] fix examples 29b004c [Xiangrui Meng] update ParamsSuite 94fd98e [Xiangrui Meng] fix explain params 48d0e84 [Xiangrui Meng] add remove and update explainParams 4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params 0d9594e [Xiangrui Meng] add getOrElse to ParamMap eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues 0d3fc5b [Xiangrui Meng] setDefault after param a9dbf59 [Xiangrui Meng] minor updates d9302b8 [Xiangrui Meng] generate default values 1c72579 [Xiangrui Meng] pass test compile abb7a3b [Xiangrui Meng] update default values handling dcab97a [Xiangrui Meng] add codegen for shared params
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala236
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala169
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala259
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala173
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala61
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala47
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala12
25 files changed, 815 insertions, 391 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index eff7ef925d..d6b3503ebd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
*/
@varargs
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
- val map = new ParamMap().put(paramPairs: _*)
+ val map = ParamMap(paramPairs: _*)
fit(dataset, map)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a455341a1f..8eddf79cdf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] {
/** param for pipeline stages */
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
- def getStages: Array[PipelineStage] = get(stages)
+ def getStages: Array[PipelineStage] = getOrDefault(stages)
/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
@@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val theStages = map(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
@@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] {
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
"Cannot have duplicate components in a pipeline.")
@@ -177,14 +177,14 @@ class PipelineModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ val map = fittingParamMap ++ extractParamMap(paramMap)
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ val map = fittingParamMap ++ extractParamMap(paramMap)
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 9a5848684b..7fb87fe452 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -22,6 +22,7 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
protected def validateInputType(inputType: DataType): Unit = {}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val inputType = schema(map(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains(map(outputCol))) {
@@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
dataset.withColumn(map(outputCol),
callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index c5fc89f935..29339c98f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,12 +17,14 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.ml.param.shared.HasRawPredictionCol
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = this.paramMap ++ paramMap
- addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+ val map = extractParamMap(paramMap)
+ SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
}
}
@@ -67,8 +69,7 @@ private[spark] abstract class Classifier[
with ClassifierParams {
/** @group setParam */
- def setRawPredictionCol(value: String): E =
- set(rawPredictionCol, value).asInstanceOf[E]
+ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
}
@@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
// Prepare model
val tmpModel = if (paramMap.size != 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 34625745dd..cc8b0721cf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
@@ -31,8 +31,10 @@ import org.apache.spark.storage.StorageLevel
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
- with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
+ with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {
+ setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
+}
/**
* :: AlphaComponent ::
@@ -45,10 +47,6 @@ class LogisticRegression
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {
- setRegParam(0.1)
- setMaxIter(100)
- setThreshold(0.5)
-
/** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
@@ -100,8 +98,6 @@ class LogisticRegressionModel private[ml] (
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {
- setThreshold(0.5)
-
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
@@ -123,7 +119,7 @@ class LogisticRegressionModel private[ml] (
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
@@ -184,7 +180,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[threshold]].
*/
override protected def predict(features: Vector): Double = {
- if (score(features) > paramMap(threshold)) 1 else 0
+ if (score(features) > getThreshold) 1 else 0
}
override protected def predictProbabilities(features: Vector): Vector = {
@@ -199,7 +195,7 @@ class LogisticRegressionModel private[ml] (
override protected def copy(): LogisticRegressionModel = {
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
- Params.inheritValues(this.paramMap, this, m)
+ Params.inheritValues(this.extractParamMap(), this, m)
m
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index bd8caac855..10404548cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,13 +18,14 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
-
/**
* Params for probabilistic classification.
*/
@@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = this.paramMap ++ paramMap
- addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
+ val map = extractParamMap(paramMap)
+ SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT)
}
}
@@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
// Prepare model
val tmpModel = if (paramMap.size != 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 2360f4479f..c865eb9fe0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
-
/**
* :: AlphaComponent ::
*
@@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params
* @group param
*/
val metricName: Param[String] = new Param(this, "metricName",
- "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+ "metric name in evaluation (areaUnderROC|areaUnderPR)")
/** @group getParam */
- def getMetricName: String = get(metricName)
+ def getMetricName: String = getOrDefault(metricName)
/** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
@@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
+ setDefault(metricName -> "areaUnderROC")
+
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val schema = dataset.schema
- checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
- checkInputColumn(schema, map(labelCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index fc4e12773c..b20f2fc49a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
* number of features
* @group param
*/
- val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
+ val numFeatures = new IntParam(this, "numFeatures", "number of features")
/** @group getParam */
- def getNumFeatures: Int = get(numFeatures)
+ def getNumFeatures: Int = getOrDefault(numFeatures)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
+ setDefault(numFeatures -> (1 << 18))
+
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 05f91dc910..decaeb0da6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -35,14 +35,16 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
* Normalization in L^p^ space, p = 2 by default.
* @group param
*/
- val p = new DoubleParam(this, "p", "the p norm value", Some(2))
+ val p = new DoubleParam(this, "p", "the p norm value")
/** @group getParam */
- def getP: Double = get(p)
+ def getP: Double = getOrDefault(p)
/** @group setParam */
def setP(value: Double): this.type = set(p, value)
+ setDefault(p -> 2.0)
+
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
val normalizer = new feature.Normalizer(paramMap(p))
normalizer.transform
@@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
override protected def outputDataType: DataType = new VectorUDT()
}
-
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 1142aa4f8e..1b102619b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
@@ -47,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
@@ -56,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
@@ -86,13 +87,13 @@ class StandardScalerModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 61e6742e88..4d960df357 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -22,6 +22,8 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
@@ -34,8 +36,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
- checkInputColumn(schema, map(inputCol), StringType)
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), StringType)
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
@@ -64,7 +66,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
// TODO: handle unseen labels
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels)
@@ -105,7 +107,7 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 68401e3695..376a004858 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -56,39 +56,39 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
- val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))
+ val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")
/** @group setParam */
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
/** @group getParam */
- def getMinTokenLength: Int = get(minTokenLength)
+ def getMinTokenLength: Int = getOrDefault(minTokenLength)
/**
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
- val gaps: BooleanParam = new BooleanParam(
- this, "gaps", "Set regex to match gaps or tokens", Some(false))
+ val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
/** @group setParam */
def setGaps(value: Boolean): this.type = set(gaps, value)
/** @group getParam */
- def getGaps: Boolean = get(gaps)
+ def getGaps: Boolean = getOrDefault(gaps)
/**
* param sets regex pattern used by tokenizer
* @group param
*/
- val pattern: Param[String] = new Param(
- this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))
+ val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")
/** @group setParam */
def setPattern(value: String): this.type = set(pattern, value)
/** @group getParam */
- def getPattern: String = get(pattern)
+ def getPattern: String = getOrDefault(pattern)
+
+ setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+")
override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>
val re = paramMap(pattern).r
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index d1b8f7e6e9..e567e069e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -22,7 +22,8 @@ import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -44,7 +45,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
@@ -61,7 +62,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val inputColNames = map(inputCols)
val outputColName = map(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 8760960e19..452faa06e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -18,10 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
Attribute, AttributeGroup}
-import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions.callUDF
@@ -40,11 +42,12 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
*/
val maxCategories = new IntParam(this, "maxCategories",
"Threshold for the number of values a categorical feature can take." +
- " If a feature is found to have > maxCategories values, then it is declared continuous.",
- Some(20))
+ " If a feature is found to have > maxCategories values, then it is declared continuous.")
/** @group getParam */
- def getMaxCategories: Int = get(maxCategories)
+ def getMaxCategories: Int = getOrDefault(maxCategories)
+
+ setDefault(maxCategories -> 20)
}
/**
@@ -101,7 +104,7 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val firstRow = dataset.select(map(inputCol)).take(1)
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
val numFeatures = firstRow(0).getAs[Vector](0).size
@@ -120,12 +123,12 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val dataType = new VectorUDT
require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol")
require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol")
- checkInputColumn(schema, map(inputCol), dataType)
- addOutputColumn(schema, map(outputCol), dataType)
+ SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
+ SchemaUtils.appendColumn(schema, map(outputCol), dataType)
}
}
@@ -320,7 +323,7 @@ class VectorIndexerModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
// For now, just check the first row of inputCol for vector length.
@@ -334,13 +337,13 @@ class VectorIndexerModel private[ml] (
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val dataType = new VectorUDT
require(map.contains(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
require(map.contains(outputCol),
s"VectorIndexerModel requires output column parameter: $outputCol")
- checkInputColumn(schema, map(inputCol), dataType)
+ SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index dfb89cc8d4..195333a5cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -18,8 +18,10 @@
package org.apache.spark.ml.impl.estimator
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -53,14 +55,14 @@ private[spark] trait PredictorParams extends Params
paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
- checkInputColumn(schema, map(featuresCol), featuresDataType)
+ SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType)
if (fitting) {
// TODO: Allow other numeric types
- checkInputColumn(schema, map(labelCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
}
- addOutputColumn(schema, map(predictionCol), DoubleType)
+ SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType)
}
}
@@ -98,7 +100,7 @@ private[spark] abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val model = train(dataset, map)
Params.inheritValues(map, this, model) // copy params to model
model
@@ -141,7 +143,7 @@ private[spark] abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
@@ -201,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
// Prepare model
val tmpModel = if (paramMap.size != 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 7d5178d0ab..849c60433c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.param
+import java.lang.reflect.Modifier
+import java.util.NoSuchElementException
+
import scala.annotation.varargs
import scala.collection.mutable
-import java.lang.reflect.Modifier
-
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.Identifiable
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
-
/**
* :: AlphaComponent ::
@@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
* @tparam T param value type
*/
@AlphaComponent
-class Param[T] (
- val parent: Params,
- val name: String,
- val doc: String,
- val defaultValue: Option[T] = None)
- extends Serializable {
+class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable {
/**
* Creates a param pair with the given value (for Java).
@@ -55,58 +49,55 @@ class Param[T] (
*/
def ->(value: T): ParamPair[T] = ParamPair(this, value)
+ /**
+ * Converts this param's name, doc, and optionally its default value and the user-supplied
+ * value in its parent to string.
+ */
override def toString: String = {
- if (defaultValue.isDefined) {
- s"$name: $doc (default: ${defaultValue.get})"
+ val valueStr = if (parent.isDefined(this)) {
+ val defaultValueStr = parent.getDefault(this).map("default: " + _)
+ val currentValueStr = parent.get(this).map("current: " + _)
+ (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
} else {
- s"$name: $doc"
+ "(undefined)"
}
+ s"$name: $doc $valueStr"
}
}
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
- extends Param[Double](parent, name, doc, defaultValue) {
-
- def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+class DoubleParam(parent: Params, name: String, doc: String)
+ extends Param[Double](parent, name, doc) {
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
- extends Param[Int](parent, name, doc, defaultValue) {
-
- def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+class IntParam(parent: Params, name: String, doc: String)
+ extends Param[Int](parent, name, doc) {
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
- extends Param[Float](parent, name, doc, defaultValue) {
-
- def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+class FloatParam(parent: Params, name: String, doc: String)
+ extends Param[Float](parent, name, doc) {
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
- extends Param[Long](parent, name, doc, defaultValue) {
-
- def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+class LongParam(parent: Params, name: String, doc: String)
+ extends Param[Long](parent, name, doc) {
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
- extends Param[Boolean](parent, name, doc, defaultValue) {
-
- def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+class BooleanParam(parent: Params, name: String, doc: String)
+ extends Param[Boolean](parent, name, doc) {
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
@@ -124,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T)
@AlphaComponent
trait Params extends Identifiable with Serializable {
- /** Returns all params. */
- def params: Array[Param[_]] = {
+ /**
+ * Returns all params sorted by their names. The default implementation uses Java reflection to
+ * list all public methods that have no arguments and return [[Param]].
+ */
+ lazy val params: Array[Param[_]] = {
val methods = this.getClass.getMethods
methods.filter { m =>
Modifier.isPublic(m.getModifiers) &&
@@ -153,25 +147,29 @@ trait Params extends Identifiable with Serializable {
def explainParams(): String = params.mkString("\n")
/** Checks whether a param is explicitly set. */
- def isSet(param: Param[_]): Boolean = {
- require(param.parent.eq(this))
+ final def isSet(param: Param[_]): Boolean = {
+ shouldOwn(param)
paramMap.contains(param)
}
+ /** Checks whether a param is explicitly set or has a default value. */
+ final def isDefined(param: Param[_]): Boolean = {
+ shouldOwn(param)
+ defaultParamMap.contains(param) || paramMap.contains(param)
+ }
+
/** Gets a param by its name. */
- private[ml] def getParam(paramName: String): Param[Any] = {
- val m = this.getClass.getMethod(paramName)
- assert(Modifier.isPublic(m.getModifiers) &&
- classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
- m.getParameterTypes.isEmpty)
- m.invoke(this).asInstanceOf[Param[Any]]
+ def getParam(paramName: String): Param[Any] = {
+ params.find(_.name == paramName).getOrElse {
+ throw new NoSuchElementException(s"Param $paramName does not exist.")
+ }.asInstanceOf[Param[Any]]
}
/**
* Sets a parameter in the embedded param map.
*/
- protected def set[T](param: Param[T], value: T): this.type = {
- require(param.parent.eq(this))
+ protected final def set[T](param: Param[T], value: T): this.type = {
+ shouldOwn(param)
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
}
@@ -179,52 +177,102 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter (by name) in the embedded param map.
*/
- private[ml] def set(param: String, value: Any): this.type = {
+ protected final def set(param: String, value: Any): this.type = {
set(getParam(param), value)
}
/**
- * Gets the value of a parameter in the embedded param map.
+ * Optionally returns the user-supplied value of a param.
+ */
+ final def get[T](param: Param[T]): Option[T] = {
+ shouldOwn(param)
+ paramMap.get(param)
+ }
+
+ /**
+ * Clears the user-supplied value for the input param.
+ */
+ protected final def clear(param: Param[_]): this.type = {
+ shouldOwn(param)
+ paramMap.remove(param)
+ this
+ }
+
+ /**
+ * Gets the value of a param in the embedded param map or its default value. Throws an exception
+ * if neither is set.
+ */
+ final def getOrDefault[T](param: Param[T]): T = {
+ shouldOwn(param)
+ get(param).orElse(getDefault(param)).get
+ }
+
+ /**
+ * Sets a default value for a param.
+ * @param param param to set the default value. Make sure that this param is initialized before
+ * this method gets called.
+ * @param value the default value
*/
- protected def get[T](param: Param[T]): T = {
- require(param.parent.eq(this))
- paramMap(param)
+ protected final def setDefault[T](param: Param[T], value: T): this.type = {
+ shouldOwn(param)
+ defaultParamMap.put(param, value)
+ this
}
/**
- * Internal param map.
+ * Sets default values for a list of params.
+ * @param paramPairs a list of param pairs that specify params and their default values to set
+ * respectively. Make sure that the params are initialized before this method
+ * gets called.
*/
- protected val paramMap: ParamMap = ParamMap.empty
+ protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ setDefault(p.param.asInstanceOf[Param[Any]], p.value)
+ }
+ this
+ }
/**
- * Check whether the given schema contains an input column.
- * @param colName Input column name
- * @param dataType Input column DataType
+ * Gets the default value of a parameter.
*/
- protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
- val actualDataType = schema(colName).dataType
- require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" +
- s" but was actually $actualDataType. Column param description: ${getParam(colName)}")
+ final def getDefault[T](param: Param[T]): Option[T] = {
+ shouldOwn(param)
+ defaultParamMap.get(param)
}
/**
- * Add an output column to the given schema.
- * This fails if the given output column already exists.
- * @param schema Initial schema (not modified)
- * @param colName Output column name. If this column name is an empy String "", this method
- * returns the initial schema, unchanged. This allows users to disable output
- * columns.
- * @param dataType Output column DataType
- */
- protected def addOutputColumn(
- schema: StructType,
- colName: String,
- dataType: DataType): StructType = {
- if (colName.length == 0) return schema
- val fieldNames = schema.fieldNames
- require(!fieldNames.contains(colName), s"Output column $colName already exists.")
- val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false))
- StructType(outputFields)
+ * Tests whether the input param has a default value set.
+ */
+ final def hasDefault[T](param: Param[T]): Boolean = {
+ shouldOwn(param)
+ defaultParamMap.contains(param)
+ }
+
+ /**
+ * Extracts the embedded default param values and user-supplied values, and then merges them with
+ * extra values from input into a flat param map, where the latter value is used if there exist
+ * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
+ */
+ protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
+ defaultParamMap ++ paramMap ++ extraParamMap
+ }
+
+ /**
+ * [[extractParamMap]] with no extra values.
+ */
+ protected final def extractParamMap(): ParamMap = {
+ extractParamMap(ParamMap.empty)
+ }
+
+ /** Internal param map for user-supplied values. */
+ private val paramMap: ParamMap = ParamMap.empty
+
+ /** Internal param map for default values. */
+ private val defaultParamMap: ParamMap = ParamMap.empty
+
+ /** Validates that the input param belongs to this instance. */
+ private def shouldOwn(param: Param[_]): Unit = {
+ require(param.parent.eq(this), s"Param $param does not belong to $this.")
}
}
@@ -261,12 +309,13 @@ private[spark] object Params {
* A param to value map.
*/
@AlphaComponent
-class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable {
+final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
+ extends Serializable {
/**
* Creates an empty param map.
*/
- def this() = this(mutable.Map.empty[Param[Any], Any])
+ def this() = this(mutable.Map.empty)
/**
* Puts a (param, value) pair (overwrites if the input param exists).
@@ -288,12 +337,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
}
/**
- * Optionally returns the value associated with a param or its default.
+ * Optionally returns the value associated with a param.
*/
def get[T](param: Param[T]): Option[T] = {
- map.get(param.asInstanceOf[Param[Any]])
- .orElse(param.defaultValue)
- .asInstanceOf[Option[T]]
+ map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
+ }
+
+ /**
+ * Returns the value associated with a param or a default value.
+ */
+ def getOrElse[T](param: Param[T], default: T): T = {
+ get(param).getOrElse(default)
}
/**
@@ -301,10 +355,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Raises a NoSuchElementException if there is no value associated with the input param.
*/
def apply[T](param: Param[T]): T = {
- val value = get(param)
- if (value.isDefined) {
- value.get
- } else {
+ get(param).getOrElse {
throw new NoSuchElementException(s"Cannot find param ${param.name}.")
}
}
@@ -317,6 +368,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
}
/**
+ * Removes a key from this map and returns its value associated previously as an option.
+ */
+ def remove[T](param: Param[T]): Option[T] = {
+ map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
+ }
+
+ /**
* Filters this param map for the given parent.
*/
def filter(parent: Params): ParamMap = {
@@ -325,7 +383,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
}
/**
- * Make a copy of this param map.
+ * Creates a copy of this param map.
*/
def copy: ParamMap = new ParamMap(map.clone())
@@ -337,7 +395,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
/**
* Returns a new param map that contains parameters in this map and the given map,
- * where the latter overwrites this if there exists conflicts.
+ * where the latter overwrites this if there exist conflicts.
*/
def ++(other: ParamMap): ParamMap = {
// TODO: Provide a better method name for Java users.
@@ -363,7 +421,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
}
/**
- * Number of param pairs in this set.
+ * Number of param pairs in this map.
*/
def size: Int = map.size
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
new file mode 100644
index 0000000000..95d7e64790
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import java.io.PrintWriter
+
+import scala.reflect.ClassTag
+
+/**
+ * Code generator for shared params (sharedParams.scala). Run under the Spark folder with
+ * {{{
+ * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen"
+ * }}}
+ */
+private[shared] object SharedParamsCodeGen {
+
+ def main(args: Array[String]): Unit = {
+ val params = Seq(
+ ParamDesc[Double]("regParam", "regularization parameter"),
+ ParamDesc[Int]("maxIter", "max number of iterations"),
+ ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
+ ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
+ ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
+ ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
+ Some("\"rawPrediction\"")),
+ ParamDesc[String]("probabilityCol",
+ "column name for predicted class conditional probabilities", Some("\"probability\"")),
+ ParamDesc[Double]("threshold", "threshold in binary classification prediction"),
+ ParamDesc[String]("inputCol", "input column name"),
+ ParamDesc[Array[String]]("inputCols", "input column names"),
+ ParamDesc[String]("outputCol", "output column name"),
+ ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
+ ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")))
+
+ val code = genSharedParams(params)
+ val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
+ val writer = new PrintWriter(file)
+ writer.write(code)
+ writer.close()
+ }
+
+ /** Description of a param. */
+ private case class ParamDesc[T: ClassTag](
+ name: String,
+ doc: String,
+ defaultValueStr: Option[String] = None) {
+
+ require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
+ require(doc.nonEmpty) // TODO: more rigorous on doc
+
+ def paramTypeName: String = {
+ val c = implicitly[ClassTag[T]].runtimeClass
+ c match {
+ case _ if c == classOf[Int] => "IntParam"
+ case _ if c == classOf[Long] => "LongParam"
+ case _ if c == classOf[Float] => "FloatParam"
+ case _ if c == classOf[Double] => "DoubleParam"
+ case _ if c == classOf[Boolean] => "BooleanParam"
+ case _ => s"Param[${getTypeString(c)}]"
+ }
+ }
+
+ def valueTypeName: String = {
+ val c = implicitly[ClassTag[T]].runtimeClass
+ getTypeString(c)
+ }
+
+ private def getTypeString(c: Class[_]): String = {
+ c match {
+ case _ if c == classOf[Int] => "Int"
+ case _ if c == classOf[Long] => "Long"
+ case _ if c == classOf[Float] => "Float"
+ case _ if c == classOf[Double] => "Double"
+ case _ if c == classOf[Boolean] => "Boolean"
+ case _ if c == classOf[String] => "String"
+ case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]"
+ }
+ }
+ }
+
+ /** Generates the HasParam trait code for the input param. */
+ private def genHasParamTrait(param: ParamDesc[_]): String = {
+ val name = param.name
+ val Name = name(0).toUpper +: name.substring(1)
+ val Param = param.paramTypeName
+ val T = param.valueTypeName
+ val doc = param.doc
+ val defaultValue = param.defaultValueStr
+ val defaultValueDoc = defaultValue.map { v =>
+ s" (default: $v)"
+ }.getOrElse("")
+ val setDefault = defaultValue.map { v =>
+ s"""
+ | setDefault($name, $v)
+ |""".stripMargin
+ }.getOrElse("")
+
+ s"""
+ |/**
+ | * :: DeveloperApi ::
+ | * Trait for shared param $name$defaultValueDoc.
+ | */
+ |@DeveloperApi
+ |trait Has$Name extends Params {
+ |
+ | /**
+ | * Param for $doc.
+ | * @group param
+ | */
+ | final val $name: $Param = new $Param(this, "$name", "$doc")
+ |$setDefault
+ | /** @group getParam */
+ | final def get$Name: $T = getOrDefault($name)
+ |}
+ |""".stripMargin
+ }
+
+ /** Generates Scala source code for the input params with header. */
+ private def genSharedParams(params: Seq[ParamDesc[_]]): String = {
+ val header =
+ """/*
+ | * Licensed to the Apache Software Foundation (ASF) under one or more
+ | * contributor license agreements. See the NOTICE file distributed with
+ | * this work for additional information regarding copyright ownership.
+ | * The ASF licenses this file to You under the Apache License, Version 2.0
+ | * (the "License"); you may not use this file except in compliance with
+ | * the License. You may obtain a copy of the License at
+ | *
+ | * http://www.apache.org/licenses/LICENSE-2.0
+ | *
+ | * Unless required by applicable law or agreed to in writing, software
+ | * distributed under the License is distributed on an "AS IS" BASIS,
+ | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ | * See the License for the specific language governing permissions and
+ | * limitations under the License.
+ | */
+ |
+ |package org.apache.spark.ml.param.shared
+ |
+ |import org.apache.spark.annotation.DeveloperApi
+ |import org.apache.spark.ml.param._
+ |
+ |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
+ |
+ |// scalastyle:off
+ |""".stripMargin
+
+ val footer = "// scalastyle:on\n"
+
+ val traits = params.map(genHasParamTrait).mkString
+
+ header + traits + footer
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
new file mode 100644
index 0000000000..72b08bf276
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param._
+
+// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
+
+// scalastyle:off
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param regParam.
+ */
+@DeveloperApi
+trait HasRegParam extends Params {
+
+ /**
+ * Param for regularization parameter.
+ * @group param
+ */
+ final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+
+ /** @group getParam */
+ final def getRegParam: Double = getOrDefault(regParam)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param maxIter.
+ */
+@DeveloperApi
+trait HasMaxIter extends Params {
+
+ /**
+ * Param for max number of iterations.
+ * @group param
+ */
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+
+ /** @group getParam */
+ final def getMaxIter: Int = getOrDefault(maxIter)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param featuresCol (default: "features").
+ */
+@DeveloperApi
+trait HasFeaturesCol extends Params {
+
+ /**
+ * Param for features column name.
+ * @group param
+ */
+ final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name")
+
+ setDefault(featuresCol, "features")
+
+ /** @group getParam */
+ final def getFeaturesCol: String = getOrDefault(featuresCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param labelCol (default: "label").
+ */
+@DeveloperApi
+trait HasLabelCol extends Params {
+
+ /**
+ * Param for label column name.
+ * @group param
+ */
+ final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name")
+
+ setDefault(labelCol, "label")
+
+ /** @group getParam */
+ final def getLabelCol: String = getOrDefault(labelCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param predictionCol (default: "prediction").
+ */
+@DeveloperApi
+trait HasPredictionCol extends Params {
+
+ /**
+ * Param for prediction column name.
+ * @group param
+ */
+ final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name")
+
+ setDefault(predictionCol, "prediction")
+
+ /** @group getParam */
+ final def getPredictionCol: String = getOrDefault(predictionCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param rawPredictionCol (default: "rawPrediction").
+ */
+@DeveloperApi
+trait HasRawPredictionCol extends Params {
+
+ /**
+ * Param for raw prediction (a.k.a. confidence) column name.
+ * @group param
+ */
+ final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")
+
+ setDefault(rawPredictionCol, "rawPrediction")
+
+ /** @group getParam */
+ final def getRawPredictionCol: String = getOrDefault(rawPredictionCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param probabilityCol (default: "probability").
+ */
+@DeveloperApi
+trait HasProbabilityCol extends Params {
+
+ /**
+ * Param for column name for predicted class conditional probabilities.
+ * @group param
+ */
+ final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities")
+
+ setDefault(probabilityCol, "probability")
+
+ /** @group getParam */
+ final def getProbabilityCol: String = getOrDefault(probabilityCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param threshold.
+ */
+@DeveloperApi
+trait HasThreshold extends Params {
+
+ /**
+ * Param for threshold in binary classification prediction.
+ * @group param
+ */
+ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction")
+
+ /** @group getParam */
+ final def getThreshold: Double = getOrDefault(threshold)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param inputCol.
+ */
+@DeveloperApi
+trait HasInputCol extends Params {
+
+ /**
+ * Param for input column name.
+ * @group param
+ */
+ final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")
+
+ /** @group getParam */
+ final def getInputCol: String = getOrDefault(inputCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param inputCols.
+ */
+@DeveloperApi
+trait HasInputCols extends Params {
+
+ /**
+ * Param for input column names.
+ * @group param
+ */
+ final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
+
+ /** @group getParam */
+ final def getInputCols: Array[String] = getOrDefault(inputCols)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param outputCol.
+ */
+@DeveloperApi
+trait HasOutputCol extends Params {
+
+ /**
+ * Param for output column name.
+ * @group param
+ */
+ final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
+
+ /** @group getParam */
+ final def getOutputCol: String = getOrDefault(outputCol)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param checkpointInterval.
+ */
+@DeveloperApi
+trait HasCheckpointInterval extends Params {
+
+ /**
+ * Param for checkpoint interval.
+ * @group param
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+
+ /** @group getParam */
+ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param fitIntercept (default: true).
+ */
+@DeveloperApi
+trait HasFitIntercept extends Params {
+
+ /**
+ * Param for whether to fit an intercept term.
+ * @group param
+ */
+ final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term")
+
+ setDefault(fitIntercept, true)
+
+ /** @group getParam */
+ final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
+}
+// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
deleted file mode 100644
index 07e6eb4177..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.param
-
-/* NOTE TO DEVELOPERS:
- * If you mix these parameter traits into your algorithm, please add a setter method as well
- * so that users may use a builder pattern:
- * val myLearner = new MyLearner().setParam1(x).setParam2(y)...
- */
-
-private[ml] trait HasRegParam extends Params {
- /**
- * param for regularization parameter
- * @group param
- */
- val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
-
- /** @group getParam */
- def getRegParam: Double = get(regParam)
-}
-
-private[ml] trait HasMaxIter extends Params {
- /**
- * param for max number of iterations
- * @group param
- */
- val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
-
- /** @group getParam */
- def getMaxIter: Int = get(maxIter)
-}
-
-private[ml] trait HasFeaturesCol extends Params {
- /**
- * param for features column name
- * @group param
- */
- val featuresCol: Param[String] =
- new Param(this, "featuresCol", "features column name", Some("features"))
-
- /** @group getParam */
- def getFeaturesCol: String = get(featuresCol)
-}
-
-private[ml] trait HasLabelCol extends Params {
- /**
- * param for label column name
- * @group param
- */
- val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
-
- /** @group getParam */
- def getLabelCol: String = get(labelCol)
-}
-
-private[ml] trait HasPredictionCol extends Params {
- /**
- * param for prediction column name
- * @group param
- */
- val predictionCol: Param[String] =
- new Param(this, "predictionCol", "prediction column name", Some("prediction"))
-
- /** @group getParam */
- def getPredictionCol: String = get(predictionCol)
-}
-
-private[ml] trait HasRawPredictionCol extends Params {
- /**
- * param for raw prediction column name
- * @group param
- */
- val rawPredictionCol: Param[String] =
- new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
- Some("rawPrediction"))
-
- /** @group getParam */
- def getRawPredictionCol: String = get(rawPredictionCol)
-}
-
-private[ml] trait HasProbabilityCol extends Params {
- /**
- * param for predicted class conditional probabilities column name
- * @group param
- */
- val probabilityCol: Param[String] =
- new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
- Some("probability"))
-
- /** @group getParam */
- def getProbabilityCol: String = get(probabilityCol)
-}
-
-private[ml] trait HasFitIntercept extends Params {
- /**
- * param for fitting the intercept term, defaults to true
- * @group param
- */
- val fitIntercept: BooleanParam =
- new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true))
-
- /** @group getParam */
- def getFitIntercept: Boolean = get(fitIntercept)
-}
-
-private[ml] trait HasThreshold extends Params {
- /**
- * param for threshold in (binary) prediction
- * @group param
- */
- val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
-
- /** @group getParam */
- def getThreshold: Double = get(threshold)
-}
-
-private[ml] trait HasInputCol extends Params {
- /**
- * param for input column name
- * @group param
- */
- val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
-
- /** @group getParam */
- def getInputCol: String = get(inputCol)
-}
-
-private[ml] trait HasInputCols extends Params {
- /**
- * Param for input column names.
- */
- val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names")
-
- /** @group getParam */
- def getInputCols: Array[String] = get(inputCols)
-}
-
-private[ml] trait HasOutputCol extends Params {
- /**
- * param for output column name
- * @group param
- */
- val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
-
- /** @group getParam */
- def getOutputCol: String = get(outputCol)
-}
-
-private[ml] trait HasCheckpointInterval extends Params {
- /**
- * param for checkpoint interval
- * @group param
- */
- val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
-
- /** @group getParam */
- def getCheckpointInterval: Int = get(checkpointInterval)
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 52c9e95d60..bd793beba3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
@@ -54,86 +55,88 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* Param for rank of the matrix factorization.
* @group param
*/
- val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+ val rank = new IntParam(this, "rank", "rank of the factorization")
/** @group getParam */
- def getRank: Int = get(rank)
+ def getRank: Int = getOrDefault(rank)
/**
* Param for number of user blocks.
* @group param
*/
- val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks")
/** @group getParam */
- def getNumUserBlocks: Int = get(numUserBlocks)
+ def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
/**
* Param for number of item blocks.
* @group param
*/
val numItemBlocks =
- new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+ new IntParam(this, "numItemBlocks", "number of item blocks")
/** @group getParam */
- def getNumItemBlocks: Int = get(numItemBlocks)
+ def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
/**
* Param to decide whether to use implicit preference.
* @group param
*/
- val implicitPrefs =
- new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+ val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference")
/** @group getParam */
- def getImplicitPrefs: Boolean = get(implicitPrefs)
+ def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs)
/**
* Param for the alpha parameter in the implicit preference formulation.
* @group param
*/
- val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference")
/** @group getParam */
- def getAlpha: Double = get(alpha)
+ def getAlpha: Double = getOrDefault(alpha)
/**
* Param for the column name for user ids.
* @group param
*/
- val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+ val userCol = new Param[String](this, "userCol", "column name for user ids")
/** @group getParam */
- def getUserCol: String = get(userCol)
+ def getUserCol: String = getOrDefault(userCol)
/**
* Param for the column name for item ids.
* @group param
*/
- val itemCol =
- new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids")
/** @group getParam */
- def getItemCol: String = get(itemCol)
+ def getItemCol: String = getOrDefault(itemCol)
/**
* Param for the column name for ratings.
* @group param
*/
- val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
/** @group getParam */
- def getRatingCol: String = get(ratingCol)
+ def getRatingCol: String = getOrDefault(ratingCol)
/**
* Param for whether to apply nonnegativity constraints.
* @group param
*/
val nonnegative = new BooleanParam(
- this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false))
+ this, "nonnegative", "whether to use nonnegative constraint for least squares")
/** @group getParam */
- val getNonnegative: Boolean = get(nonnegative)
+ def getNonnegative: Boolean = getOrDefault(nonnegative)
+
+ setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
+ implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
+ ratingCol -> "rating", nonnegative -> false)
/**
* Validates and transforms the input schema.
@@ -142,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
assert(schema(map(userCol)).dataType == IntegerType)
assert(schema(map(itemCol)).dataType== IntegerType)
val ratingType = schema(map(ratingCol)).dataType
@@ -171,7 +174,7 @@ class ALSModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.implicits._
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val users = userFactors.toDF("id", "features")
val items = itemFactors.toDF("id", "features")
@@ -283,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setCheckpointInterval(10)
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val ratings = dataset
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 65f6627a0c..26ca7459c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -18,7 +18,8 @@
package org.apache.spark.ml.regression
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.sql.DataFrame
@@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams {
- setRegParam(0.1)
- setMaxIter(100)
+ setDefault(regParam -> 0.1, maxIter -> 100)
/** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
@@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] (
override protected def copy(): LinearRegressionModel = {
val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
- Params.inheritValues(this.paramMap, this, m)
+ Params.inheritValues(extractParamMap(), this, m)
m
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2eb1dac56f..4bb4ed813c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
private[ml] trait CrossValidatorParams extends Params {
+
/**
* param for the estimator to be cross-validated
* @group param
@@ -38,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params {
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
/** @group getParam */
- def getEstimator: Estimator[_] = get(estimator)
+ def getEstimator: Estimator[_] = getOrDefault(estimator)
/**
* param for estimator param maps
@@ -48,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params {
new Param(this, "estimatorParamMaps", "param maps for the estimator")
/** @group getParam */
- def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
+ def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps)
/**
* param for the evaluator for selection
@@ -57,17 +58,18 @@ private[ml] trait CrossValidatorParams extends Params {
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
/** @group getParam */
- def getEvaluator: Evaluator = get(evaluator)
+ def getEvaluator: Evaluator = getOrDefault(evaluator)
/**
* param for number of folds for cross validation
* @group param
*/
- val numFolds: IntParam =
- new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+ val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation")
/** @group getParam */
- def getNumFolds: Int = get(numFolds)
+ def getNumFolds: Int = getOrDefault(numFolds)
+
+ setDefault(numFolds -> 3)
}
/**
@@ -92,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
def setNumFolds(value: Int): this.type = set(numFolds, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val schema = dataset.schema
transformSchema(dataset.schema, paramMap, logging = true)
val sqlCtx = dataset.sqlContext
@@ -130,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
map(estimator).transformSchema(schema, paramMap)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
new file mode 100644
index 0000000000..0383bf0b38
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
+/**
+ * :: DeveloperApi ::
+ * Utils for handling schemas.
+ */
+@DeveloperApi
+object SchemaUtils {
+
+ // TODO: Move the utility methods to SQL.
+
+ /**
+ * Check whether the given schema contains a column of the required data type.
+ * @param colName column name
+ * @param dataType required column data type
+ */
+ def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
+ val actualDataType = schema(colName).dataType
+ require(actualDataType.equals(dataType),
+ s"Column $colName must be of type $dataType but was actually $actualDataType.")
+ }
+
+ /**
+ * Appends a new column to the input schema. This fails if the given output column already exists.
+ * @param schema input schema
+ * @param colName new column name. If this column name is an empty string "", this method returns
+ * the input schema unchanged. This allows users to disable output columns.
+ * @param dataType new column data type
+ * @return new schema with the input column appended
+ */
+ def appendColumn(
+ schema: StructType,
+ colName: String,
+ dataType: DataType): StructType = {
+ if (colName.isEmpty) return schema
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(colName), s"Column $colName already exists.")
+ val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
+ StructType(outputFields)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 1ce2987612..88ea679eea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,19 +21,25 @@ import org.scalatest.FunSuite
class ParamsSuite extends FunSuite {
- val solver = new TestParams()
- import solver.{inputCol, maxIter}
-
test("param") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
assert(maxIter.name === "maxIter")
assert(maxIter.doc === "max number of iterations")
- assert(maxIter.defaultValue.get === 100)
assert(maxIter.parent.eq(solver))
- assert(maxIter.toString === "maxIter: max number of iterations (default: 100)")
- assert(inputCol.defaultValue === None)
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 10)")
+
+ solver.setMaxIter(5)
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)")
+
+ assert(inputCol.toString === "inputCol: input column name (undefined)")
}
test("param pair") {
+ val solver = new TestParams()
+ import solver.maxIter
+
val pair0 = maxIter -> 5
val pair1 = maxIter.w(5)
val pair2 = ParamPair(maxIter, 5)
@@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite {
}
test("param map") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
val map0 = ParamMap.empty
assert(!map0.contains(maxIter))
- assert(map0(maxIter) === maxIter.defaultValue.get)
map0.put(maxIter, 10)
assert(map0.contains(maxIter))
assert(map0(maxIter) === 10)
@@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite {
}
test("params") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
val params = solver.params
- assert(params.size === 2)
+ assert(params.length === 2)
assert(params(0).eq(inputCol), "params must be ordered by name")
assert(params(1).eq(maxIter))
+
+ assert(!solver.isSet(maxIter))
+ assert(solver.isDefined(maxIter))
+ assert(solver.getMaxIter === 10)
+ solver.setMaxIter(100)
+ assert(solver.isSet(maxIter))
+ assert(solver.getMaxIter === 100)
+ assert(!solver.isSet(inputCol))
+ assert(!solver.isDefined(inputCol))
+ intercept[NoSuchElementException](solver.getInputCol)
+
assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+
assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter))
- intercept[NoSuchMethodException] {
+ intercept[NoSuchElementException] {
solver.getParam("abc")
}
- assert(!solver.isSet(inputCol))
+
intercept[IllegalArgumentException] {
solver.validate()
}
solver.validate(ParamMap(inputCol -> "input"))
solver.setInputCol("input")
assert(solver.isSet(inputCol))
+ assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input")
solver.validate()
intercept[IllegalArgumentException] {
@@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite {
intercept[IllegalArgumentException] {
solver.validate()
}
+
+ solver.clearMaxIter()
+ assert(!solver.isSet(maxIter))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index ce52f2f230..8f9ab687c0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -20,17 +20,21 @@ package org.apache.spark.ml.param
/** A subclass of Params for testing. */
class TestParams extends Params {
- val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100))
+ val maxIter = new IntParam(this, "maxIter", "max number of iterations")
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
- def getMaxIter: Int = get(maxIter)
+ def getMaxIter: Int = getOrDefault(maxIter)
val inputCol = new Param[String](this, "inputCol", "input column name")
def setInputCol(value: String): this.type = { set(inputCol, value); this }
- def getInputCol: String = get(inputCol)
+ def getInputCol: String = getOrDefault(inputCol)
+
+ setDefault(maxIter -> 10)
override def validate(paramMap: ParamMap): Unit = {
- val m = this.paramMap ++ paramMap
+ val m = extractParamMap(paramMap)
require(m(maxIter) >= 0)
require(m.contains(inputCol))
}
+
+ def clearMaxIter(): this.type = clear(maxIter)
}