aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-29 17:26:46 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-29 17:26:46 -0700
commit114bad606e7a17f980ea6c99e31c8ab0179fec2e (patch)
treef0a1a5f81f6f626412a9a526b72d0b6e5edf570c /mllib
parent1fdfdb47b44315ff8ccb0ef92e56d3f2a070f1f1 (diff)
downloadspark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.tar.gz
spark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.tar.bz2
spark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.zip
[SPARK-7176] [ML] Add validation functionality to Param
Main change: Added isValid field to Param. Modified all usages to use isValid when relevant. Added helper methods in ParamValidate. Also overrode Params.validate() in: * CrossValidator + model * Pipeline + model I made a few updates for the elastic net patch: * I changed "tol" to "convergenceTol" * I added some documentation This PR is Scala + Java only. Python will be in a follow-up PR. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5740 from jkbradley/enforce-validate and squashes the following commits: ad9c6c1 [Joseph K. Bradley] re-generated sharedParams after merging with current master 76415e8 [Joseph K. Bradley] reverted convergenceTol to tol af62f4b [Joseph K. Bradley] Removed changes to SparkBuild, python linalg. Fixed test failures. Renamed ParamValidate to ParamValidators. Removed explicit type from ParamValidators calls where possible. bb2665a [Joseph K. Bradley] merged with elastic net pr ecda302 [Joseph K. Bradley] fix rat tests, plus add a little doc 6895dfc [Joseph K. Bradley] small cleanups 069ac6d [Joseph K. Bradley] many cleanups 928fb84 [Joseph K. Bradley] Maybe done a910ac7 [Joseph K. Bradley] still workin 6d60e2e [Joseph K. Bradley] Still workin b987319 [Joseph K. Bradley] Partly done with adding checks, but blocking on adding checking functionality to Param dbc9fb2 [Joseph K. Bradley] merged with master. enforcing Params.validate
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala115
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala179
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala122
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java66
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java63
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala2
21 files changed, 585 insertions, 268 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 8eddf79cdf..6bfeecd764 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.{Params, Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -86,6 +86,14 @@ class Pipeline extends Estimator[PipelineModel] {
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
def getStages: Array[PipelineStage] = getOrDefault(stages)
+ override def validate(paramMap: ParamMap): Unit = {
+ val map = extractParamMap(paramMap)
+ getStages.foreach {
+ case pStage: Params => pStage.validate(map)
+ case _ =>
+ }
+ }
+
/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -140,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
val theStages = map(stages)
- require(theStages.toSet.size == theStages.size,
+ require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
}
@@ -157,6 +165,11 @@ class PipelineModel private[ml] (
private[ml] val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
+ override def validate(paramMap: ParamMap): Unit = {
+ val map = fittingParamMap ++ extractParamMap(paramMap)
+ stages.foreach(_.validate(map))
+ }
+
/**
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
* estimator does not exist in the pipeline.
@@ -168,7 +181,7 @@ class PipelineModel private[ml] (
}
if (matched.isEmpty) {
throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
- } else if (matched.size > 1) {
+ } else if (matched.length > 1) {
throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
} else {
matched.head.asInstanceOf[M]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index d2e052fbbb..3d849867d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -103,21 +103,16 @@ final class GBTClassifier
*/
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+ s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
setDefault(lossType -> "logistic")
/** @group setParam */
- def setLossType(value: String): this.type = {
- val lossStr = value.toLowerCase
- require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
- s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
- set(lossType, lossStr)
- this
- }
+ def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
- def getLossType: String = getOrDefault(lossType)
+ def getLossType: String = getOrDefault(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index b20f2fc49a..0b3128f9ee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.sql.types.DataType
@@ -32,10 +32,14 @@ import org.apache.spark.sql.types.DataType
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
/**
- * number of features
+ * Number of features. Should be > 0.
+ * (default = 2^18^)
* @group param
*/
- val numFeatures = new IntParam(this, "numFeatures", "number of features")
+ val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
+ ParamValidators.gt(0))
+
+ setDefault(numFeatures -> (1 << 18))
/** @group getParam */
def getNumFeatures: Int = getOrDefault(numFeatures)
@@ -43,8 +47,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- setDefault(numFeatures -> (1 << 18))
-
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index decaeb0da6..bd2b5f6067 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{DoubleParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.sql.types.DataType
@@ -32,10 +32,13 @@ import org.apache.spark.sql.types.DataType
class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
/**
- * Normalization in L^p^ space, p = 2 by default.
+ * Normalization in L^p^ space. Must be >= 1.
+ * (default: p = 2)
* @group param
*/
- val p = new DoubleParam(this, "p", "the p norm value")
+ val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1))
+
+ setDefault(p -> 2.0)
/** @group getParam */
def getP: Double = getOrDefault(p)
@@ -43,8 +46,6 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
/** @group setParam */
def setP(value: Double): this.type = set(p, value)
- setDefault(p -> 2.0)
-
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
val normalizer = new feature.Normalizer(paramMap(p))
normalizer.transform
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index d855f04799..1b7c939c2d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -37,10 +37,13 @@ import org.apache.spark.sql.types.DataType
class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
/**
- * The polynomial degree to expand, which should be larger than 1.
+ * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion.
+ * Default: 2
* @group param
*/
- val degree = new IntParam(this, "degree", "the polynomial degree to expand")
+ val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)",
+ ParamValidators.gt(1))
+
setDefault(degree -> 2)
/** @group getParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 447851ec03..a0e9ed32e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -31,17 +31,19 @@ import org.apache.spark.sql.types.{StructField, StructType}
* Params for [[StandardScaler]] and [[StandardScalerModel]].
*/
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
-
+
/**
- * False by default. Centers the data with mean before scaling.
+ * Centers the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
+ * Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
/**
- * True by default. Scales the data to unit standard deviation.
+ * Scales the data to unit standard deviation.
+ * Default: true
* @group param
*/
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
@@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
setDefault(withMean -> false, withStd -> true)
-
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 376a004858..01752ba482 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
+import org.apache.spark.ml.param._
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
/**
@@ -43,20 +43,20 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
- * or using it to split the text (set matching to false). Optional parameters also allow to fold
- * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
+ * or using it to split the text (set matching to false). Optional parameters also allow filtering
+ * tokens using a minimal length.
* It returns an array of strings that can be empty.
- * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
- * lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
/**
- * param for minimum token length, default is one to avoid returning empty strings
+ * Minimum token length, >= 0.
+ * Default: 1, to avoid returning empty strings
* @group param
*/
- val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")
+ val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
+ ParamValidators.gtEq(0))
/** @group setParam */
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
@@ -65,7 +65,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def getMinTokenLength: Int = getOrDefault(minTokenLength)
/**
- * param sets regex as splitting on gaps (true) or matching tokens (false)
+ * Indicates whether regex splits on gaps (true) or matching tokens (false).
+ * Default: false
* @group param
*/
val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
@@ -77,7 +78,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def getGaps: Boolean = getOrDefault(gaps)
/**
- * param sets regex pattern used by tokenizer
+ * Regex pattern used by tokenizer.
+ * Default: `"\\p{L}+|[^\\p{L}\\s]+"`
* @group param
*/
val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1e5ffd15af..ed833c63c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
Attribute, AttributeGroup}
-import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
import org.apache.spark.sql.{Row, DataFrame}
@@ -37,17 +37,19 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
/**
* Threshold for the number of values a categorical feature can take.
* If a feature is found to have > maxCategories values, then it is declared continuous.
+ * Must be >= 2.
*
* (default = 20)
*/
val maxCategories = new IntParam(this, "maxCategories",
- "Threshold for the number of values a categorical feature can take." +
- " If a feature is found to have > maxCategories values, then it is declared continuous.")
+ "Threshold for the number of values a categorical feature can take (>= 2)." +
+ " If a feature is found to have > maxCategories values, then it is declared continuous.",
+ ParamValidators.gtEq(2))
+
+ setDefault(maxCategories -> 20)
/** @group getParam */
def getMaxCategories: Int = getOrDefault(maxCategories)
-
- setDefault(maxCategories -> 20)
}
/**
@@ -90,11 +92,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
/** @group setParam */
- def setMaxCategories(value: Int): this.type = {
- require(value > 1,
- s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.")
- set(maxCategories, value)
- }
+ def setMaxCategories(value: Int): this.type = set(maxCategories, value)
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index ab6281b9b2..fb770622e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -38,14 +38,15 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
private[ml] trait DecisionTreeParams extends PredictorParams {
/**
- * Maximum depth of the tree.
+ * Maximum depth of the tree (>= 0).
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (default = 5)
* @group param
*/
final val maxDepth: IntParam =
- new IntParam(this, "maxDepth", "Maximum depth of the tree." +
- " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+ new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
+ ParamValidators.gtEq(0))
/**
* Maximum number of bins used for discretizing continuous features and for choosing how to split
@@ -56,7 +57,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
*/
final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
" discretizing continuous features. Must be >=2 and >= number of categories for any" +
- " categorical feature.")
+ " categorical feature.", ParamValidators.gtEq(2))
/**
* Minimum number of instances each child must have after split.
@@ -69,7 +70,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
" number of instances each child must have after split. If a split causes the left or right" +
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
- " Should be >= 1.")
+ " Should be >= 1.", ParamValidators.gtEq(1))
/**
* Minimum information gain for a split to be considered at a tree node.
@@ -85,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
* @group expertParam
*/
final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
- "Maximum memory in MB allocated to histogram aggregation.")
+ "Maximum memory in MB allocated to histogram aggregation.",
+ ParamValidators.gtEq(0))
/**
* If false, the algorithm will pass trees to executors to match instances with nodes.
@@ -111,34 +113,26 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
" how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
" checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
- " checkpoint directory is set in the SparkContext. Must be >= 1.")
+ " checkpoint directory is set in the SparkContext. Must be >= 1.",
+ ParamValidators.gtEq(1))
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
/** @group setParam */
- def setMaxDepth(value: Int): this.type = {
- require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
- set(maxDepth, value)
- }
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group getParam */
final def getMaxDepth: Int = getOrDefault(maxDepth)
/** @group setParam */
- def setMaxBins(value: Int): this.type = {
- require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
- set(maxBins, value)
- }
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group getParam */
final def getMaxBins: Int = getOrDefault(maxBins)
/** @group setParam */
- def setMinInstancesPerNode(value: Int): this.type = {
- require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
- set(minInstancesPerNode, value)
- }
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group getParam */
final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
@@ -150,10 +144,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
final def getMinInfoGain: Double = getOrDefault(minInfoGain)
/** @group expertSetParam */
- def setMaxMemoryInMB(value: Int): this.type = {
- require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
- set(maxMemoryInMB, value)
- }
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertGetParam */
final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
@@ -165,10 +156,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
/** @group expertSetParam */
- def setCheckpointInterval(value: Int): this.type = {
- require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
- set(checkpointInterval, value)
- }
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group expertGetParam */
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
@@ -209,21 +197,16 @@ private[ml] trait TreeClassifierParams extends Params {
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
- s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
+ (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
setDefault(impurity -> "gini")
/** @group setParam */
- def setImpurity(value: String): this.type = {
- val impurityStr = value.toLowerCase
- require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
- s"Tree-based classifier was given unrecognized impurity: $value." +
- s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
- set(impurity, impurityStr)
- }
+ def setImpurity(value: String): this.type = set(impurity, value)
/** @group getParam */
- final def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity).toLowerCase
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -256,21 +239,16 @@ private[ml] trait TreeRegressorParams extends Params {
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
- s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+ (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
setDefault(impurity -> "variance")
/** @group setParam */
- def setImpurity(value: String): this.type = {
- val impurityStr = value.toLowerCase
- require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
- s"Tree-based regressor was given unrecognized impurity: $value." +
- s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
- set(impurity, impurityStr)
- }
+ def setImpurity(value: String): this.type = set(impurity, value)
/** @group getParam */
- final def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity).toLowerCase
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -299,21 +277,18 @@ private[ml] object TreeRegressorParams {
private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
/**
- * Fraction of the training data used for learning each decision tree.
+ * Fraction of the training data used for learning each decision tree, in range (0, 1].
* (default = 1.0)
* @group param
*/
final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
- "Fraction of the training data used for learning each decision tree.")
+ "Fraction of the training data used for learning each decision tree, in range (0, 1].",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
setDefault(subsamplingRate -> 1.0)
/** @group setParam */
- def setSubsamplingRate(value: Double): this.type = {
- require(value > 0.0 && value <= 1.0,
- s"Subsampling rate must be in range (0,1]. Bad rate: $value")
- set(subsamplingRate, value)
- }
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
/** @group getParam */
final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
@@ -350,7 +325,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
* (default = 20)
* @group param
*/
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)")
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
/**
* The number of features to consider for splits at each tree node.
@@ -378,30 +354,23 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
- s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+ s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
+ (value: String) =>
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = {
- require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.")
- set(numTrees, value)
- }
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
final def getNumTrees: Int = getOrDefault(numTrees)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = {
- val strategyStr = value.toLowerCase
- require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr),
- s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" +
- s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
- set(featureSubsetStrategy, strategyStr)
- }
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
+ final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase
}
private[ml] object RandomForestParams {
@@ -426,7 +395,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
* @group param
*/
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
- " learning rate) in interval (0, 1] for shrinking the contribution of each estimator")
+ " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
/* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used.
@@ -442,17 +412,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
setDefault(maxIter -> 20, stepSize -> 0.1)
/** @group setParam */
- def setMaxIter(value: Int): this.type = {
- require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.")
- set(maxIter, value)
- }
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */
- def setStepSize(value: Double): this.type = {
- require(value > 0.0 && value <= 1.0,
- s"GBT given invalid step size ($value). Value should be in (0,1].")
- set(stepSize, value)
- }
+ def setStepSize(value: Double): this.type = set(stepSize, value)
/** @group getParam */
final def getStepSize: Double = getOrDefault(stepSize)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 014e124e44..df6360dce6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -34,10 +34,35 @@ import org.apache.spark.ml.util.Identifiable
* @param parent parent object
* @param name param name
* @param doc documentation
+ * @param isValid optional validation method which indicates if a value is valid.
+ * See [[ParamValidators]] for factory methods for common validation functions.
* @tparam T param value type
*/
@AlphaComponent
-class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable {
+class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean)
+ extends Serializable {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue[T])
+
+ /**
+ * Assert that the given value is valid for this parameter.
+ *
+ * Note: Parameter checks involving interactions between multiple parameters should be
+ * implemented in [[Params.validate()]]. Checks for input/output columns should be
+ * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
+ *
+ * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
+ * should be specified via [[ParamPair]].
+ *
+ * @throws IllegalArgumentException if the value is invalid
+ */
+ private[param] def validate(value: T): Unit = {
+ if (!isValid(value)) {
+ throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." +
+ s" Parameter description: $toString")
+ }
+ }
/**
* Creates a param pair with the given value (for Java).
@@ -65,38 +90,129 @@ class Param[T] (val parent: Params, val name: String, val doc: String) extends S
}
}
+/**
+ * Factory methods for common validation functions for [[Param.isValid]].
+ * The numerical methods only support Int, Long, Float, and Double.
+ */
+object ParamValidators {
+
+ /** (private[param]) Default validation always return true */
+ private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true
+
+ /**
+ * Private method for checking numerical types and converting to Double.
+ * This is mainly for the sake of compilation; type checks are really handled
+ * by [[Params]] setters and the [[ParamPair]] constructor.
+ */
+ private def getDouble[T](value: T): Double = value match {
+ case x: Int => x.toDouble
+ case x: Long => x.toDouble
+ case x: Float => x.toDouble
+ case x: Double => x.toDouble
+ case _ =>
+ // The type should be checked before this is ever called.
+ throw new IllegalArgumentException("Numerical Param validation failed because" +
+ s" of unexpected input type: ${value.getClass}")
+ }
+
+ /** Check if value > lowerBound */
+ def gt[T](lowerBound: Double): T => Boolean = { (value: T) =>
+ getDouble(value) > lowerBound
+ }
+
+ /** Check if value >= lowerBound */
+ def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) =>
+ getDouble(value) >= lowerBound
+ }
+
+ /** Check if value < upperBound */
+ def lt[T](upperBound: Double): T => Boolean = { (value: T) =>
+ getDouble(value) < upperBound
+ }
+
+ /** Check if value <= upperBound */
+ def ltEq[T](upperBound: Double): T => Boolean = { (value: T) =>
+ getDouble(value) <= upperBound
+ }
+
+ /**
+ * Check for value in range lowerBound to upperBound.
+ * @param lowerInclusive If true, check for value >= lowerBound.
+ * If false, check for value > lowerBound.
+ * @param upperInclusive If true, check for value <= upperBound.
+ * If false, check for value < upperBound.
+ */
+ def inRange[T](
+ lowerBound: Double,
+ upperBound: Double,
+ lowerInclusive: Boolean,
+ upperInclusive: Boolean): T => Boolean = { (value: T) =>
+ val x: Double = getDouble(value)
+ val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
+ val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
+ lowerValid && upperValid
+ }
+
+ /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */
+ def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = {
+ inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true)
+ }
+
+ /** Check for value in an allowed set of values. */
+ def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) =>
+ allowed.contains(value)
+ }
+
+ /** Check for value in an allowed set of values. */
+ def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
+ allowed.contains(value)
+ }
+}
+
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String)
- extends Param[Double](parent, name, doc) {
+class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean)
+ extends Param[Double](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String)
- extends Param[Int](parent, name, doc) {
+class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean)
+ extends Param[Int](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String)
- extends Param[Float](parent, name, doc) {
+class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean)
+ extends Param[Float](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String)
- extends Param[Long](parent, name, doc) {
+class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean)
+ extends Param[Long](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String)
+class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid
extends Param[Boolean](parent, name, doc) {
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
@@ -105,7 +221,11 @@ class BooleanParam(parent: Params, name: String, doc: String)
/**
* A param amd its value.
*/
-case class ParamPair[T](param: Param[T], value: T)
+case class ParamPair[T](param: Param[T], value: T) {
+ // This is *the* place Param.validate is called. Whenever a parameter is specified, we should
+ // always construct a ParamPair so that validate is called.
+ param.validate(value)
+}
/**
* :: AlphaComponent ::
@@ -132,12 +252,22 @@ trait Params extends Identifiable with Serializable {
/**
* Validates parameter values stored internally plus the input parameter map.
* Raises an exception if any parameter is invalid.
+ *
+ * This only needs to check for interactions between parameters.
+ * Parameter value checks which do not depend on other parameters are handled by
+ * [[Param.validate()]]. This method does not handle input/output column parameters;
+ * those are checked during schema validation.
*/
- def validate(paramMap: ParamMap): Unit = {}
+ def validate(paramMap: ParamMap): Unit = { }
/**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
+ *
+ * This only needs to check for interactions between parameters.
+ * Parameter value checks which do not depend on other parameters are handled by
+ * [[Param.validate()]]. This method does not handle input/output column parameters;
+ * those are checked during schema validation.
*/
def validate(): Unit = validate(ParamMap.empty)
@@ -221,6 +351,10 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
+ *
+ * Note: Java developers should use the single-parameter [[setDefault()]].
+ * Annotating this with varargs causes compilation failures.
+ *
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
@@ -306,6 +440,14 @@ private[spark] object Params {
}
/**
+ * Java-friendly wrapper for [[Params]].
+ * Java developers who need to extend [[Params]] should use this class instead.
+ * If you need to extend a abstract class which already extends [[Params]], then that abstract
+ * class should be Java-friendly as well.
+ */
+abstract class JavaParams extends Params
+
+/**
* :: AlphaComponent ::
* A param to value map.
*/
@@ -313,6 +455,12 @@ private[spark] object Params {
final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
extends Serializable {
+ /* DEVELOPERS: About validating parameter values
+ * This and ParamPair are the only two collections of parameters.
+ * This class should always create ParamPairs when
+ * specifying new parameter values. ParamPair will then call Param.validate().
+ */
+
/**
* Creates an empty param map.
*/
@@ -321,10 +469,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/**
* Puts a (param, value) pair (overwrites if the input param exists).
*/
- def put[T](param: Param[T], value: T): this.type = {
- map(param.asInstanceOf[Param[Any]]) = value
- this
- }
+ def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value))
/**
* Puts a list of param pairs (overwrites if the input params exists).
@@ -332,7 +477,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
@varargs
def put(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
- put(p.param.asInstanceOf[Param[Any]], p.value)
+ map(p.param.asInstanceOf[Param[Any]]) = p.value
}
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 654cd72d53..7da4bb4b4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -21,6 +21,8 @@ import java.io.PrintWriter
import scala.reflect.ClassTag
+import org.apache.spark.ml.param.ParamValidators
+
/**
* Code generator for shared params (sharedParams.scala). Run under the Spark folder with
* {{{
@@ -31,8 +33,10 @@ private[shared] object SharedParamsCodeGen {
def main(args: Array[String]): Unit = {
val params = Seq(
- ParamDesc[Double]("regParam", "regularization parameter"),
- ParamDesc[Int]("maxIter", "max number of iterations"),
+ ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
+ isValid = "ParamValidators.gtEq(0)"),
+ ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+ isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
@@ -40,14 +44,19 @@ private[shared] object SharedParamsCodeGen {
Some("\"rawPrediction\"")),
ParamDesc[String]("probabilityCol",
"column name for predicted class conditional probabilities", Some("\"probability\"")),
- ParamDesc[Double]("threshold", "threshold in binary classification prediction"),
+ ParamDesc[Double]("threshold",
+ "threshold in binary classification prediction, in range [0, 1]",
+ isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name"),
- ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
+ ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
+ isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
- ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
+ ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
+ " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
+ isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
@@ -62,7 +71,8 @@ private[shared] object SharedParamsCodeGen {
private case class ParamDesc[T: ClassTag](
name: String,
doc: String,
- defaultValueStr: Option[String] = None) {
+ defaultValueStr: Option[String] = None,
+ isValid: String = "") {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -113,20 +123,23 @@ private[shared] object SharedParamsCodeGen {
| setDefault($name, $v)
|""".stripMargin
}.getOrElse("")
+ val isValid = if (param.isValid != "") {
+ ", " + param.isValid
+ } else {
+ ""
+ }
s"""
|/**
- | * :: DeveloperApi ::
- | * Trait for shared param $name$defaultValueDoc.
+ | * (private[ml]) Trait for shared param $name$defaultValueDoc.
| */
- |@DeveloperApi
- |trait Has$Name extends Params {
+ |private[ml] trait Has$Name extends Params {
|
| /**
| * Param for $doc.
| * @group param
| */
- | final val $name: $Param = new $Param(this, "$name", "$doc")
+ | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
| /** @group getParam */
| final def get$Name: $T = getOrDefault($name)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 96d11ed76f..e1549f46a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -26,45 +26,39 @@ import org.apache.spark.util.Utils
// scalastyle:off
/**
- * :: DeveloperApi ::
- * Trait for shared param regParam.
+ * (private[ml]) Trait for shared param regParam.
*/
-@DeveloperApi
-trait HasRegParam extends Params {
+private[ml] trait HasRegParam extends Params {
/**
- * Param for regularization parameter.
+ * Param for regularization parameter (>= 0).
* @group param
*/
- final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+ final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getRegParam: Double = getOrDefault(regParam)
}
/**
- * :: DeveloperApi ::
- * Trait for shared param maxIter.
+ * (private[ml]) Trait for shared param maxIter.
*/
-@DeveloperApi
-trait HasMaxIter extends Params {
+private[ml] trait HasMaxIter extends Params {
/**
- * Param for max number of iterations.
+ * Param for max number of iterations (>= 0).
* @group param
*/
- final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getMaxIter: Int = getOrDefault(maxIter)
}
/**
- * :: DeveloperApi ::
- * Trait for shared param featuresCol (default: "features").
+ * (private[ml]) Trait for shared param featuresCol (default: "features").
*/
-@DeveloperApi
-trait HasFeaturesCol extends Params {
+private[ml] trait HasFeaturesCol extends Params {
/**
* Param for features column name.
@@ -79,11 +73,9 @@ trait HasFeaturesCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param labelCol (default: "label").
+ * (private[ml]) Trait for shared param labelCol (default: "label").
*/
-@DeveloperApi
-trait HasLabelCol extends Params {
+private[ml] trait HasLabelCol extends Params {
/**
* Param for label column name.
@@ -98,11 +90,9 @@ trait HasLabelCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param predictionCol (default: "prediction").
+ * (private[ml]) Trait for shared param predictionCol (default: "prediction").
*/
-@DeveloperApi
-trait HasPredictionCol extends Params {
+private[ml] trait HasPredictionCol extends Params {
/**
* Param for prediction column name.
@@ -117,11 +107,9 @@ trait HasPredictionCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param rawPredictionCol (default: "rawPrediction").
+ * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
*/
-@DeveloperApi
-trait HasRawPredictionCol extends Params {
+private[ml] trait HasRawPredictionCol extends Params {
/**
* Param for raw prediction (a.k.a. confidence) column name.
@@ -136,11 +124,9 @@ trait HasRawPredictionCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param probabilityCol (default: "probability").
+ * (private[ml]) Trait for shared param probabilityCol (default: "probability").
*/
-@DeveloperApi
-trait HasProbabilityCol extends Params {
+private[ml] trait HasProbabilityCol extends Params {
/**
* Param for column name for predicted class conditional probabilities.
@@ -155,28 +141,24 @@ trait HasProbabilityCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param threshold.
+ * (private[ml]) Trait for shared param threshold.
*/
-@DeveloperApi
-trait HasThreshold extends Params {
+private[ml] trait HasThreshold extends Params {
/**
- * Param for threshold in binary classification prediction.
+ * Param for threshold in binary classification prediction, in range [0, 1].
* @group param
*/
- final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction")
+ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
/** @group getParam */
final def getThreshold: Double = getOrDefault(threshold)
}
/**
- * :: DeveloperApi ::
- * Trait for shared param inputCol.
+ * (private[ml]) Trait for shared param inputCol.
*/
-@DeveloperApi
-trait HasInputCol extends Params {
+private[ml] trait HasInputCol extends Params {
/**
* Param for input column name.
@@ -189,11 +171,9 @@ trait HasInputCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param inputCols.
+ * (private[ml]) Trait for shared param inputCols.
*/
-@DeveloperApi
-trait HasInputCols extends Params {
+private[ml] trait HasInputCols extends Params {
/**
* Param for input column names.
@@ -206,11 +186,9 @@ trait HasInputCols extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param outputCol.
+ * (private[ml]) Trait for shared param outputCol.
*/
-@DeveloperApi
-trait HasOutputCol extends Params {
+private[ml] trait HasOutputCol extends Params {
/**
* Param for output column name.
@@ -223,28 +201,24 @@ trait HasOutputCol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param checkpointInterval.
+ * (private[ml]) Trait for shared param checkpointInterval.
*/
-@DeveloperApi
-trait HasCheckpointInterval extends Params {
+private[ml] trait HasCheckpointInterval extends Params {
/**
- * Param for checkpoint interval.
+ * Param for checkpoint interval (>= 1).
* @group param
*/
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
/** @group getParam */
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
}
/**
- * :: DeveloperApi ::
- * Trait for shared param fitIntercept (default: true).
+ * (private[ml]) Trait for shared param fitIntercept (default: true).
*/
-@DeveloperApi
-trait HasFitIntercept extends Params {
+private[ml] trait HasFitIntercept extends Params {
/**
* Param for whether to fit an intercept term.
@@ -259,11 +233,9 @@ trait HasFitIntercept extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param seed (default: Utils.random.nextLong()).
+ * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()).
*/
-@DeveloperApi
-trait HasSeed extends Params {
+private[ml] trait HasSeed extends Params {
/**
* Param for random seed.
@@ -278,28 +250,24 @@ trait HasSeed extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param elasticNetParam.
+ * (private[ml]) Trait for shared param elasticNetParam.
*/
-@DeveloperApi
-trait HasElasticNetParam extends Params {
+private[ml] trait HasElasticNetParam extends Params {
/**
- * Param for the ElasticNet mixing parameter.
+ * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
* @group param
*/
- final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
+ final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
/** @group getParam */
final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
}
/**
- * :: DeveloperApi ::
- * Trait for shared param tol.
+ * (private[ml]) Trait for shared param tol.
*/
-@DeveloperApi
-trait HasTol extends Params {
+private[ml] trait HasTol extends Params {
/**
* Param for the convergence tolerance for iterative algorithms.
@@ -312,11 +280,9 @@ trait HasTol extends Params {
}
/**
- * :: DeveloperApi ::
- * Trait for shared param stepSize.
+ * (private[ml]) Trait for shared param stepSize.
*/
-@DeveloperApi
-trait HasStepSize extends Params {
+private[ml] trait HasStepSize extends Params {
/**
* Param for Step size to be used for each iteration of optimization..
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index bd793beba3..f9f2b2764d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -52,35 +52,40 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
with HasPredictionCol with HasCheckpointInterval {
/**
- * Param for rank of the matrix factorization.
+ * Param for rank of the matrix factorization (>= 1).
+ * Default: 10
* @group param
*/
- val rank = new IntParam(this, "rank", "rank of the factorization")
+ val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1))
/** @group getParam */
def getRank: Int = getOrDefault(rank)
/**
- * Param for number of user blocks.
+ * Param for number of user blocks (>= 1).
+ * Default: 10
* @group param
*/
- val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks")
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
+ ParamValidators.gtEq(1))
/** @group getParam */
def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
/**
- * Param for number of item blocks.
+ * Param for number of item blocks (>= 1).
+ * Default: 10
* @group param
*/
- val numItemBlocks =
- new IntParam(this, "numItemBlocks", "number of item blocks")
+ val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
+ ParamValidators.gtEq(1))
/** @group getParam */
def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
/**
* Param to decide whether to use implicit preference.
+ * Default: false
* @group param
*/
val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference")
@@ -89,16 +94,19 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs)
/**
- * Param for the alpha parameter in the implicit preference formulation.
+ * Param for the alpha parameter in the implicit preference formulation (>= 0).
+ * Default: 1.0
* @group param
*/
- val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference")
+ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference",
+ ParamValidators.gtEq(0))
/** @group getParam */
def getAlpha: Double = getOrDefault(alpha)
/**
* Param for the column name for user ids.
+ * Default: "user"
* @group param
*/
val userCol = new Param[String](this, "userCol", "column name for user ids")
@@ -108,6 +116,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/**
* Param for the column name for item ids.
+ * Default: "item"
* @group param
*/
val itemCol = new Param[String](this, "itemCol", "column name for item ids")
@@ -117,6 +126,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/**
* Param for the column name for ratings.
+ * Default: "rating"
* @group param
*/
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
@@ -126,6 +136,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/**
* Param for whether to apply nonnegativity constraints.
+ * Default: false
* @group param
*/
val nonnegative = new BooleanParam(
@@ -136,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
- ratingCol -> "rating", nonnegative -> false)
+ ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
/**
* Validates and transforms the input schema.
@@ -281,10 +292,6 @@ class ALS extends Estimator[ALSModel] with ALSParams {
this
}
- setMaxIter(20)
- setRegParam(1.0)
- setCheckpointInterval(10)
-
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = extractParamMap(paramMap)
val ratings = dataset
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index c784cf39ed..76c9837693 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -102,21 +102,16 @@ final class GBTRegressor
*/
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+ s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
setDefault(lossType -> "squared")
/** @group setParam */
- def setLossType(value: String): this.type = {
- val lossStr = value.toLowerCase
- require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" +
- s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}")
- set(lossType, lossStr)
- this
- }
+ def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
- def getLossType: String = getOrDefault(lossType)
+ def getLossType: String = getOrDefault(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index cc9ad22cb8..11c6cea0f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,7 +25,8 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction}
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Params, ParamMap}
-import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter,
+ HasRegParam}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
@@ -46,6 +47,16 @@ private[regression] trait LinearRegressionParams extends RegressorParams
* :: AlphaComponent ::
*
* Linear regression.
+ *
+ * The learning objective is to minimize the squared error, with regularization.
+ * The specific squared error loss function used is:
+ * L = 1/2n ||A weights - y||^2^
+ *
+ * This support multiple types of regularization:
+ * - none (a.k.a. ordinary least squares)
+ * - L2 (ridge regression)
+ * - L1 (Lasso)
+ * - L2 + L1 (elastic net)
*/
@AlphaComponent
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
@@ -135,7 +146,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
} else {
- new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
+ new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam,
+ paramMap(tol))
}
val initialWeights = Vectors.zeros(numFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 4bb4ed813c..d1ad0893cd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -22,7 +22,7 @@ import com.github.fommil.netlib.F2jBLAS
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -61,10 +61,12 @@ private[ml] trait CrossValidatorParams extends Params {
def getEvaluator: Evaluator = getOrDefault(evaluator)
/**
- * param for number of folds for cross validation
+ * Param for number of folds for cross validation. Must be >= 2.
+ * Default: 3
* @group param
*/
- val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation")
+ val numFolds: IntParam = new IntParam(this, "numFolds",
+ "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2))
/** @group getParam */
def getNumFolds: Int = getOrDefault(numFolds)
@@ -93,6 +95,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
+ override def validate(paramMap: ParamMap): Unit = {
+ getEstimatorParamMaps.foreach { eMap =>
+ getEstimator.validate(eMap ++ paramMap)
+ }
+ }
+
override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
val map = extractParamMap(paramMap)
val schema = dataset.schema
@@ -101,8 +109,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val est = map(estimator)
val eval = map(evaluator)
val epm = map(estimatorParamMaps)
- val numModels = epm.size
- val metrics = new Array[Double](epm.size)
+ val numModels = epm.length
+ val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
@@ -148,6 +156,10 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
+ override def validate(paramMap: ParamMap): Unit = {
+ bestModel.validate(paramMap)
+ }
+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
bestModel.transform(dataset, paramMap)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index deac390130..1f779584dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -180,7 +180,9 @@ object GradientBoostedTrees extends Logging {
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
input.persist(StorageLevel.MEMORY_AND_DISK)
true
- } else false
+ } else {
+ false
+ }
timer.stop("init")
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
new file mode 100644
index 0000000000..e7df10dfa6
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * Test Param and related classes in Java
+ */
+public class JavaParamsSuite {
+
+ private transient JavaSparkContext jsc;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaParamsSuite");
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void testParams() {
+ JavaTestParams testParams = new JavaTestParams();
+ Assert.assertEquals(testParams.getMyIntParam(), 1);
+ testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
+ Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
+ Assert.assertEquals(testParams.getMyStringParam(), "a");
+ }
+
+ @Test
+ public void testParamValidate() {
+ ParamValidators.gt(1.0);
+ ParamValidators.gtEq(1.0);
+ ParamValidators.lt(1.0);
+ ParamValidators.ltEq(1.0);
+ ParamValidators.inRange(0, 1, true, false);
+ ParamValidators.inRange(0, 1);
+ ParamValidators.inArray(Lists.newArrayList(0, 1, 3));
+ ParamValidators.inArray(Lists.newArrayList("a", "b"));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
new file mode 100644
index 0000000000..8abe575610
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+/**
+ * A subclass of Params for testing.
+ */
+public class JavaTestParams extends JavaParams {
+
+ public IntParam myIntParam;
+
+ public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
+
+ public JavaTestParams setMyIntParam(int value) {
+ set(myIntParam, value); return this;
+ }
+
+ public DoubleParam myDoubleParam;
+
+ public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
+
+ public JavaTestParams setMyDoubleParam(double value) {
+ set(myDoubleParam, value); return this;
+ }
+
+ public Param<String> myStringParam;
+
+ public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
+
+ public JavaTestParams setMyStringParam(String value) {
+ set(myStringParam, value); return this;
+ }
+
+ public JavaTestParams() {
+ myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
+ myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
+ ParamValidators.inRange(0.0, 1.0));
+ List<String> validStrings = Lists.newArrayList("a", "b");
+ myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
+ ParamValidators.inArray(validStrings));
+ setDefault(myIntParam, 1);
+ setDefault(myDoubleParam, 0.5);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 88ea679eea..f8852606ab 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -26,14 +26,22 @@ class ParamsSuite extends FunSuite {
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
- assert(maxIter.doc === "max number of iterations")
+ assert(maxIter.doc === "max number of iterations (>= 0)")
assert(maxIter.parent.eq(solver))
- assert(maxIter.toString === "maxIter: max number of iterations (default: 10)")
+ assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)")
+ assert(!maxIter.isValid(-1))
+ assert(maxIter.isValid(0))
+ assert(maxIter.isValid(1))
solver.setMaxIter(5)
- assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)")
+ assert(maxIter.toString ===
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
assert(inputCol.toString === "inputCol: input column name (undefined)")
+
+ intercept[IllegalArgumentException] {
+ solver.setMaxIter(-1)
+ }
}
test("param pair") {
@@ -47,6 +55,9 @@ class ParamsSuite extends FunSuite {
assert(pair.param.eq(maxIter))
assert(pair.value === 5)
}
+ intercept[IllegalArgumentException] {
+ val pair = maxIter -> -1
+ }
}
test("param map") {
@@ -59,6 +70,9 @@ class ParamsSuite extends FunSuite {
map0.put(maxIter, 10)
assert(map0.contains(maxIter))
assert(map0(maxIter) === 10)
+ intercept[IllegalArgumentException] {
+ map0.put(maxIter, -1)
+ }
assert(!map0.contains(inputCol))
intercept[NoSuchElementException] {
@@ -122,14 +136,57 @@ class ParamsSuite extends FunSuite {
assert(solver.getInputCol === "input")
solver.validate()
intercept[IllegalArgumentException] {
- solver.validate(ParamMap(maxIter -> -10))
+ ParamMap(maxIter -> -10)
}
- solver.setMaxIter(-10)
intercept[IllegalArgumentException] {
- solver.validate()
+ solver.setMaxIter(-10)
}
solver.clearMaxIter()
assert(!solver.isSet(maxIter))
}
+
+ test("ParamValidate") {
+ val alwaysTrue = ParamValidators.alwaysTrue[Int]
+ assert(alwaysTrue(1))
+
+ val gt1Int = ParamValidators.gt[Int](1)
+ assert(!gt1Int(1) && gt1Int(2))
+ val gt1Double = ParamValidators.gt[Double](1)
+ assert(!gt1Double(1.0) && gt1Double(1.1))
+
+ val gtEq1Int = ParamValidators.gtEq[Int](1)
+ assert(!gtEq1Int(0) && gtEq1Int(1))
+ val gtEq1Double = ParamValidators.gtEq[Double](1)
+ assert(!gtEq1Double(0.9) && gtEq1Double(1.0))
+
+ val lt1Int = ParamValidators.lt[Int](1)
+ assert(lt1Int(0) && !lt1Int(1))
+ val lt1Double = ParamValidators.lt[Double](1)
+ assert(lt1Double(0.9) && !lt1Double(1.0))
+
+ val ltEq1Int = ParamValidators.ltEq[Int](1)
+ assert(ltEq1Int(1) && !ltEq1Int(2))
+ val ltEq1Double = ParamValidators.ltEq[Double](1)
+ assert(ltEq1Double(1.0) && !ltEq1Double(1.1))
+
+ val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2)
+ assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) &&
+ !inRange02IntInclusive(-1) && !inRange02IntInclusive(3))
+ val inRange02IntExclusive =
+ ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false)
+ assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2))
+
+ val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2)
+ assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) &&
+ inRange02DoubleInclusive(2) &&
+ !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1))
+ val inRange02DoubleExclusive =
+ ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false)
+ assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) &&
+ !inRange02DoubleExclusive(2))
+
+ val inArray = ParamValidators.inArray[Int](Array(1, 2))
+ assert(inArray(1) && inArray(2) && !inArray(0))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 641b64b42a..6f9c9cb516 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -29,7 +29,7 @@ class TestParams extends Params with HasMaxIter with HasInputCol {
override def validate(paramMap: ParamMap): Unit = {
val m = extractParamMap(paramMap)
- require(m(maxIter) >= 0)
+ // Note: maxIter is validated when it is set.
require(m.contains(inputCol))
}