From 114bad606e7a17f980ea6c99e31c8ab0179fec2e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Apr 2015 17:26:46 -0700 Subject: [SPARK-7176] [ML] Add validation functionality to Param Main change: Added isValid field to Param. Modified all usages to use isValid when relevant. Added helper methods in ParamValidate. Also overrode Params.validate() in: * CrossValidator + model * Pipeline + model I made a few updates for the elastic net patch: * I changed "tol" to "convergenceTol" * I added some documentation This PR is Scala + Java only. Python will be in a follow-up PR. CC: mengxr Author: Joseph K. Bradley Closes #5740 from jkbradley/enforce-validate and squashes the following commits: ad9c6c1 [Joseph K. Bradley] re-generated sharedParams after merging with current master 76415e8 [Joseph K. Bradley] reverted convergenceTol to tol af62f4b [Joseph K. Bradley] Removed changes to SparkBuild, python linalg. Fixed test failures. Renamed ParamValidate to ParamValidators. Removed explicit type from ParamValidators calls where possible. bb2665a [Joseph K. Bradley] merged with elastic net pr ecda302 [Joseph K. Bradley] fix rat tests, plus add a little doc 6895dfc [Joseph K. Bradley] small cleanups 069ac6d [Joseph K. Bradley] many cleanups 928fb84 [Joseph K. Bradley] Maybe done a910ac7 [Joseph K. Bradley] still workin 6d60e2e [Joseph K. Bradley] Still workin b987319 [Joseph K. Bradley] Partly done with adding checks, but blocking on adding checking functionality to Param dbc9fb2 [Joseph K. Bradley] merged with master. enforcing Params.validate --- .../spark/examples/ml/JavaDeveloperApiExample.java | 14 +- .../main/scala/org/apache/spark/ml/Pipeline.scala | 19 ++- .../spark/ml/classification/GBTClassifier.scala | 13 +- .../org/apache/spark/ml/feature/HashingTF.scala | 12 +- .../org/apache/spark/ml/feature/Normalizer.scala | 11 +- .../spark/ml/feature/PolynomialExpansion.scala | 9 +- .../apache/spark/ml/feature/StandardScaler.scala | 10 +- .../org/apache/spark/ml/feature/Tokenizer.scala | 20 +-- .../apache/spark/ml/feature/VectorIndexer.scala | 18 +-- .../org/apache/spark/ml/impl/tree/treeParams.scala | 115 +++++-------- .../scala/org/apache/spark/ml/param/params.scala | 179 +++++++++++++++++++-- .../ml/param/shared/SharedParamsCodeGen.scala | 35 ++-- .../spark/ml/param/shared/sharedParams.scala | 122 +++++--------- .../org/apache/spark/ml/recommendation/ALS.scala | 35 ++-- .../apache/spark/ml/regression/GBTRegressor.scala | 13 +- .../spark/ml/regression/LinearRegression.scala | 16 +- .../apache/spark/ml/tuning/CrossValidator.scala | 22 ++- .../spark/mllib/tree/GradientBoostedTrees.scala | 4 +- .../org/apache/spark/ml/param/JavaParamsSuite.java | 66 ++++++++ .../org/apache/spark/ml/param/JavaTestParams.java | 63 ++++++++ .../org/apache/spark/ml/param/ParamsSuite.scala | 69 +++++++- .../org/apache/spark/ml/param/TestParams.scala | 2 +- 22 files changed, 593 insertions(+), 274 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index eaf00d09f5..46377a99c4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier; import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.param.Params; import org.apache.spark.ml.param.Params$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; @@ -100,11 +99,12 @@ public class JavaDeveloperApiExample { /** * Example of defining a type of {@link Classifier}. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. */ class MyJavaLogisticRegression - extends Classifier - implements Params { + extends Classifier { /** * Param for max number of iterations @@ -145,10 +145,12 @@ class MyJavaLogisticRegression /** * Example of defining a type of {@link ClassificationModel}. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. */ class MyJavaLogisticRegressionModel - extends ClassificationModel implements Params { + extends ClassificationModel { private MyJavaLogisticRegression parent_; public MyJavaLogisticRegression parent() { return parent_; } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 8eddf79cdf..6bfeecd764 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{Params, Param, ParamMap} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -86,6 +86,14 @@ class Pipeline extends Estimator[PipelineModel] { def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } def getStages: Array[PipelineStage] = getOrDefault(stages) + override def validate(paramMap: ParamMap): Unit = { + val map = extractParamMap(paramMap) + getStages.foreach { + case pStage: Params => pStage.validate(map) + case _ => + } + } + /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -140,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] { override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) val theStages = map(stages) - require(theStages.toSet.size == theStages.size, + require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) } @@ -157,6 +165,11 @@ class PipelineModel private[ml] ( private[ml] val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + override def validate(paramMap: ParamMap): Unit = { + val map = fittingParamMap ++ extractParamMap(paramMap) + stages.foreach(_.validate(map)) + } + /** * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input * estimator does not exist in the pipeline. @@ -168,7 +181,7 @@ class PipelineModel private[ml] ( } if (matched.isEmpty) { throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") - } else if (matched.size > 1) { + } else if (matched.length > 1) { throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") } else { matched.head.asInstanceOf[M] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d2e052fbbb..3d849867d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -103,21 +103,16 @@ final class GBTClassifier */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}") + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) setDefault(lossType -> "logistic") /** @group setParam */ - def setLossType(value: String): this.type = { - val lossStr = value.toLowerCase - require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" + - s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}") - set(lossType, lossStr) - this - } + def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType) + def getLossType: String = getOrDefault(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index b20f2fc49a..0b3128f9ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.sql.types.DataType @@ -32,10 +32,14 @@ import org.apache.spark.sql.types.DataType class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { /** - * number of features + * Number of features. Should be > 0. + * (default = 2^18^) * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features") + val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", + ParamValidators.gt(0)) + + setDefault(numFeatures -> (1 << 18)) /** @group getParam */ def getNumFeatures: Int = getOrDefault(numFeatures) @@ -43,8 +47,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - setDefault(numFeatures -> (1 << 18)) - override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index decaeb0da6..bd2b5f6067 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.sql.types.DataType @@ -32,10 +32,13 @@ import org.apache.spark.sql.types.DataType class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { /** - * Normalization in L^p^ space, p = 2 by default. + * Normalization in L^p^ space. Must be >= 1. + * (default: p = 2) * @group param */ - val p = new DoubleParam(this, "p", "the p norm value") + val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1)) + + setDefault(p -> 2.0) /** @group getParam */ def getP: Double = getOrDefault(p) @@ -43,8 +46,6 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { /** @group setParam */ def setP(value: Double): this.type = set(p, value) - setDefault(p -> 2.0) - override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d855f04799..1b7c939c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -37,10 +37,13 @@ import org.apache.spark.sql.types.DataType class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { /** - * The polynomial degree to expand, which should be larger than 1. + * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion. + * Default: 2 * @group param */ - val degree = new IntParam(this, "degree", "the polynomial degree to expand") + val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", + ParamValidators.gt(1)) + setDefault(degree -> 2) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 447851ec03..a0e9ed32e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -31,17 +31,19 @@ import org.apache.spark.sql.types.{StructField, StructType} * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { - + /** - * False by default. Centers the data with mean before scaling. + * Centers the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. + * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") /** - * True by default. Scales the data to unit standard deviation. + * Scales the data to unit standard deviation. + * Default: true * @group param */ val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") @@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { setDefault(withMean -> false, withStd -> true) - + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 376a004858..01752ba482 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param} +import org.apache.spark.ml.param._ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** @@ -43,20 +43,20 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { /** * :: AlphaComponent :: * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - * or using it to split the text (set matching to false). Optional parameters also allow to fold - * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. + * or using it to split the text (set matching to false). Optional parameters also allow filtering + * tokens using a minimal length. * It returns an array of strings that can be empty. - * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, - * lowercase = false, minTokenLength = 1 */ @AlphaComponent class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { /** - * param for minimum token length, default is one to avoid returning empty strings + * Minimum token length, >= 0. + * Default: 1, to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)", + ParamValidators.gtEq(0)) /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) @@ -65,7 +65,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getMinTokenLength: Int = getOrDefault(minTokenLength) /** - * param sets regex as splitting on gaps (true) or matching tokens (false) + * Indicates whether regex splits on gaps (true) or matching tokens (false). + * Default: false * @group param */ val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") @@ -77,7 +78,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getGaps: Boolean = getOrDefault(gaps) /** - * param sets regex pattern used by tokenizer + * Regex pattern used by tokenizer. + * Default: `"\\p{L}+|[^\\p{L}\\s]+"` * @group param */ val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 1e5ffd15af..ed833c63c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, Attribute, AttributeGroup} -import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.{Row, DataFrame} @@ -37,17 +37,19 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu /** * Threshold for the number of values a categorical feature can take. * If a feature is found to have > maxCategories values, then it is declared continuous. + * Must be >= 2. * * (default = 20) */ val maxCategories = new IntParam(this, "maxCategories", - "Threshold for the number of values a categorical feature can take." + - " If a feature is found to have > maxCategories values, then it is declared continuous.") + "Threshold for the number of values a categorical feature can take (>= 2)." + + " If a feature is found to have > maxCategories values, then it is declared continuous.", + ParamValidators.gtEq(2)) + + setDefault(maxCategories -> 20) /** @group getParam */ def getMaxCategories: Int = getOrDefault(maxCategories) - - setDefault(maxCategories -> 20) } /** @@ -90,11 +92,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { /** @group setParam */ - def setMaxCategories(value: Int): this.type = { - require(value > 1, - s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") - set(maxCategories, value) - } + def setMaxCategories(value: Int): this.type = set(maxCategories, value) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index ab6281b9b2..fb770622e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -38,14 +38,15 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} private[ml] trait DecisionTreeParams extends PredictorParams { /** - * Maximum depth of the tree. + * Maximum depth of the tree (>= 0). * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * (default = 5) * @group param */ final val maxDepth: IntParam = - new IntParam(this, "maxDepth", "Maximum depth of the tree." + - " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", + ParamValidators.gtEq(0)) /** * Maximum number of bins used for discretizing continuous features and for choosing how to split @@ -56,7 +57,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { */ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature.") + " categorical feature.", ParamValidators.gtEq(2)) /** * Minimum number of instances each child must have after split. @@ -69,7 +70,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + " number of instances each child must have after split. If a split causes the left or right" + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + - " Should be >= 1.") + " Should be >= 1.", ParamValidators.gtEq(1)) /** * Minimum information gain for a split to be considered at a tree node. @@ -85,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams { * @group expertParam */ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", - "Maximum memory in MB allocated to histogram aggregation.") + "Maximum memory in MB allocated to histogram aggregation.", + ParamValidators.gtEq(0)) /** * If false, the algorithm will pass trees to executors to match instances with nodes. @@ -111,34 +113,26 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.") + " checkpoint directory is set in the SparkContext. Must be >= 1.", + ParamValidators.gtEq(1)) setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** @group setParam */ - def setMaxDepth(value: Int): this.type = { - require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") - set(maxDepth, value) - } + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group getParam */ final def getMaxDepth: Int = getOrDefault(maxDepth) /** @group setParam */ - def setMaxBins(value: Int): this.type = { - require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") - set(maxBins, value) - } + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group getParam */ final def getMaxBins: Int = getOrDefault(maxBins) /** @group setParam */ - def setMinInstancesPerNode(value: Int): this.type = { - require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") - set(minInstancesPerNode, value) - } + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group getParam */ final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) @@ -150,10 +144,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final def getMinInfoGain: Double = getOrDefault(minInfoGain) /** @group expertSetParam */ - def setMaxMemoryInMB(value: Int): this.type = { - require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") - set(maxMemoryInMB, value) - } + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertGetParam */ final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) @@ -165,10 +156,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) /** @group expertSetParam */ - def setCheckpointInterval(value: Int): this.type = { - require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") - set(checkpointInterval, value) - } + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group expertGetParam */ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) @@ -209,21 +197,16 @@ private[ml] trait TreeClassifierParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) setDefault(impurity -> "gini") /** @group setParam */ - def setImpurity(value: String): this.type = { - val impurityStr = value.toLowerCase - require(TreeClassifierParams.supportedImpurities.contains(impurityStr), - s"Tree-based classifier was given unrecognized impurity: $value." + - s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") - set(impurity, impurityStr) - } + def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -256,21 +239,16 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) setDefault(impurity -> "variance") /** @group setParam */ - def setImpurity(value: String): this.type = { - val impurityStr = value.toLowerCase - require(TreeRegressorParams.supportedImpurities.contains(impurityStr), - s"Tree-based regressor was given unrecognized impurity: $value." + - s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") - set(impurity, impurityStr) - } + def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -299,21 +277,18 @@ private[ml] object TreeRegressorParams { private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** - * Fraction of the training data used for learning each decision tree. + * Fraction of the training data used for learning each decision tree, in range (0, 1]. * (default = 1.0) * @group param */ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", - "Fraction of the training data used for learning each decision tree.") + "Fraction of the training data used for learning each decision tree, in range (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) setDefault(subsamplingRate -> 1.0) /** @group setParam */ - def setSubsamplingRate(value: Double): this.type = { - require(value > 0.0 && value <= 1.0, - s"Subsampling rate must be in range (0,1]. Bad rate: $value") - set(subsamplingRate, value) - } + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group getParam */ final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) @@ -350,7 +325,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { * (default = 20) * @group param */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)") + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) /** * The number of features to consider for splits at each tree node. @@ -378,30 +354,23 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + (value: String) => + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") /** @group setParam */ - def setNumTrees(value: Int): this.type = { - require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.") - set(numTrees, value) - } + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ final def getNumTrees: Int = getOrDefault(numTrees) /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = { - val strategyStr = value.toLowerCase - require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr), - s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" + - s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") - set(featureSubsetStrategy, strategyStr) - } + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy) + final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase } private[ml] object RandomForestParams { @@ -426,7 +395,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { * @group param */ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator") + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -442,17 +412,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { setDefault(maxIter -> 20, stepSize -> 0.1) /** @group setParam */ - def setMaxIter(value: Int): this.type = { - require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.") - set(maxIter, value) - } + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setStepSize(value: Double): this.type = { - require(value > 0.0 && value <= 1.0, - s"GBT given invalid step size ($value). Value should be in (0,1].") - set(stepSize, value) - } + def setStepSize(value: Double): this.type = set(stepSize, value) /** @group getParam */ final def getStepSize: Double = getOrDefault(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 014e124e44..df6360dce6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -34,10 +34,35 @@ import org.apache.spark.ml.util.Identifiable * @param parent parent object * @param name param name * @param doc documentation + * @param isValid optional validation method which indicates if a value is valid. + * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ @AlphaComponent -class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean) + extends Serializable { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue[T]) + + /** + * Assert that the given value is valid for this parameter. + * + * Note: Parameter checks involving interactions between multiple parameters should be + * implemented in [[Params.validate()]]. Checks for input/output columns should be + * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * + * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters + * should be specified via [[ParamPair]]. + * + * @throws IllegalArgumentException if the value is invalid + */ + private[param] def validate(value: T): Unit = { + if (!isValid(value)) { + throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." + + s" Parameter description: $toString") + } + } /** * Creates a param pair with the given value (for Java). @@ -65,38 +90,129 @@ class Param[T] (val parent: Params, val name: String, val doc: String) extends S } } +/** + * Factory methods for common validation functions for [[Param.isValid]]. + * The numerical methods only support Int, Long, Float, and Double. + */ +object ParamValidators { + + /** (private[param]) Default validation always return true */ + private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true + + /** + * Private method for checking numerical types and converting to Double. + * This is mainly for the sake of compilation; type checks are really handled + * by [[Params]] setters and the [[ParamPair]] constructor. + */ + private def getDouble[T](value: T): Double = value match { + case x: Int => x.toDouble + case x: Long => x.toDouble + case x: Float => x.toDouble + case x: Double => x.toDouble + case _ => + // The type should be checked before this is ever called. + throw new IllegalArgumentException("Numerical Param validation failed because" + + s" of unexpected input type: ${value.getClass}") + } + + /** Check if value > lowerBound */ + def gt[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) > lowerBound + } + + /** Check if value >= lowerBound */ + def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) >= lowerBound + } + + /** Check if value < upperBound */ + def lt[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) < upperBound + } + + /** Check if value <= upperBound */ + def ltEq[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) <= upperBound + } + + /** + * Check for value in range lowerBound to upperBound. + * @param lowerInclusive If true, check for value >= lowerBound. + * If false, check for value > lowerBound. + * @param upperInclusive If true, check for value <= upperBound. + * If false, check for value < upperBound. + */ + def inRange[T]( + lowerBound: Double, + upperBound: Double, + lowerInclusive: Boolean, + upperInclusive: Boolean): T => Boolean = { (value: T) => + val x: Double = getDouble(value) + val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound + val upperValid = if (upperInclusive) x <= upperBound else x < upperBound + lowerValid && upperValid + } + + /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */ + def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = { + inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } +} + // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String) - extends Param[Double](parent, name, doc) { +class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean) + extends Param[Double](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String) - extends Param[Int](parent, name, doc) { +class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean) + extends Param[Int](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String) - extends Param[Float](parent, name, doc) { +class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean) + extends Param[Float](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String) - extends Param[Long](parent, name, doc) { +class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean) + extends Param[Long](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String) +class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) @@ -105,7 +221,11 @@ class BooleanParam(parent: Params, name: String, doc: String) /** * A param amd its value. */ -case class ParamPair[T](param: Param[T], value: T) +case class ParamPair[T](param: Param[T], value: T) { + // This is *the* place Param.validate is called. Whenever a parameter is specified, we should + // always construct a ParamPair so that validate is called. + param.validate(value) +} /** * :: AlphaComponent :: @@ -132,12 +252,22 @@ trait Params extends Identifiable with Serializable { /** * Validates parameter values stored internally plus the input parameter map. * Raises an exception if any parameter is invalid. + * + * This only needs to check for interactions between parameters. + * Parameter value checks which do not depend on other parameters are handled by + * [[Param.validate()]]. This method does not handle input/output column parameters; + * those are checked during schema validation. */ - def validate(paramMap: ParamMap): Unit = {} + def validate(paramMap: ParamMap): Unit = { } /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. + * + * This only needs to check for interactions between parameters. + * Parameter value checks which do not depend on other parameters are handled by + * [[Param.validate()]]. This method does not handle input/output column parameters; + * those are checked during schema validation. */ def validate(): Unit = validate(ParamMap.empty) @@ -221,6 +351,10 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. + * + * Note: Java developers should use the single-parameter [[setDefault()]]. + * Annotating this with varargs causes compilation failures. + * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. @@ -305,6 +439,14 @@ private[spark] object Params { } } +/** + * Java-friendly wrapper for [[Params]]. + * Java developers who need to extend [[Params]] should use this class instead. + * If you need to extend a abstract class which already extends [[Params]], then that abstract + * class should be Java-friendly as well. + */ +abstract class JavaParams extends Params + /** * :: AlphaComponent :: * A param to value map. @@ -313,6 +455,12 @@ private[spark] object Params { final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + /* DEVELOPERS: About validating parameter values + * This and ParamPair are the only two collections of parameters. + * This class should always create ParamPairs when + * specifying new parameter values. ParamPair will then call Param.validate(). + */ + /** * Creates an empty param map. */ @@ -321,10 +469,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Puts a (param, value) pair (overwrites if the input param exists). */ - def put[T](param: Param[T], value: T): this.type = { - map(param.asInstanceOf[Param[Any]]) = value - this - } + def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value)) /** * Puts a list of param pairs (overwrites if the input params exists). @@ -332,7 +477,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) @varargs def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => - put(p.param.asInstanceOf[Param[Any]], p.value) + map(p.param.asInstanceOf[Param[Any]]) = p.value } this } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 654cd72d53..7da4bb4b4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -21,6 +21,8 @@ import java.io.PrintWriter import scala.reflect.ClassTag +import org.apache.spark.ml.param.ParamValidators + /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with * {{{ @@ -31,8 +33,10 @@ private[shared] object SharedParamsCodeGen { def main(args: Array[String]): Unit = { val params = Seq( - ParamDesc[Double]("regParam", "regularization parameter"), - ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[Double]("regParam", "regularization parameter (>= 0)", + isValid = "ParamValidators.gtEq(0)"), + ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), @@ -40,14 +44,19 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "column name for predicted class conditional probabilities", Some("\"probability\"")), - ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[Double]("threshold", + "threshold in binary classification prediction, in range [0, 1]", + isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), - ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), - ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) @@ -62,7 +71,8 @@ private[shared] object SharedParamsCodeGen { private case class ParamDesc[T: ClassTag]( name: String, doc: String, - defaultValueStr: Option[String] = None) { + defaultValueStr: Option[String] = None, + isValid: String = "") { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -113,20 +123,23 @@ private[shared] object SharedParamsCodeGen { | setDefault($name, $v) |""".stripMargin }.getOrElse("") + val isValid = if (param.isValid != "") { + ", " + param.isValid + } else { + "" + } s""" |/** - | * :: DeveloperApi :: - | * Trait for shared param $name$defaultValueDoc. + | * (private[ml]) Trait for shared param $name$defaultValueDoc. | */ - |@DeveloperApi - |trait Has$Name extends Params { + |private[ml] trait Has$Name extends Params { | | /** | * Param for $doc. | * @group param | */ - | final val $name: $Param = new $Param(this, "$name", "$doc") + | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ | final def get$Name: $T = getOrDefault($name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 96d11ed76f..e1549f46a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -26,45 +26,39 @@ import org.apache.spark.util.Utils // scalastyle:off /** - * :: DeveloperApi :: - * Trait for shared param regParam. + * (private[ml]) Trait for shared param regParam. */ -@DeveloperApi -trait HasRegParam extends Params { +private[ml] trait HasRegParam extends Params { /** - * Param for regularization parameter. + * Param for regularization parameter (>= 0). * @group param */ - final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getRegParam: Double = getOrDefault(regParam) } /** - * :: DeveloperApi :: - * Trait for shared param maxIter. + * (private[ml]) Trait for shared param maxIter. */ -@DeveloperApi -trait HasMaxIter extends Params { +private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations. + * Param for max number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = getOrDefault(maxIter) } /** - * :: DeveloperApi :: - * Trait for shared param featuresCol (default: "features"). + * (private[ml]) Trait for shared param featuresCol (default: "features"). */ -@DeveloperApi -trait HasFeaturesCol extends Params { +private[ml] trait HasFeaturesCol extends Params { /** * Param for features column name. @@ -79,11 +73,9 @@ trait HasFeaturesCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param labelCol (default: "label"). + * (private[ml]) Trait for shared param labelCol (default: "label"). */ -@DeveloperApi -trait HasLabelCol extends Params { +private[ml] trait HasLabelCol extends Params { /** * Param for label column name. @@ -98,11 +90,9 @@ trait HasLabelCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param predictionCol (default: "prediction"). + * (private[ml]) Trait for shared param predictionCol (default: "prediction"). */ -@DeveloperApi -trait HasPredictionCol extends Params { +private[ml] trait HasPredictionCol extends Params { /** * Param for prediction column name. @@ -117,11 +107,9 @@ trait HasPredictionCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param rawPredictionCol (default: "rawPrediction"). + * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction"). */ -@DeveloperApi -trait HasRawPredictionCol extends Params { +private[ml] trait HasRawPredictionCol extends Params { /** * Param for raw prediction (a.k.a. confidence) column name. @@ -136,11 +124,9 @@ trait HasRawPredictionCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param probabilityCol (default: "probability"). + * (private[ml]) Trait for shared param probabilityCol (default: "probability"). */ -@DeveloperApi -trait HasProbabilityCol extends Params { +private[ml] trait HasProbabilityCol extends Params { /** * Param for column name for predicted class conditional probabilities. @@ -155,28 +141,24 @@ trait HasProbabilityCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param threshold. + * (private[ml]) Trait for shared param threshold. */ -@DeveloperApi -trait HasThreshold extends Params { +private[ml] trait HasThreshold extends Params { /** - * Param for threshold in binary classification prediction. + * Param for threshold in binary classification prediction, in range [0, 1]. * @group param */ - final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getThreshold: Double = getOrDefault(threshold) } /** - * :: DeveloperApi :: - * Trait for shared param inputCol. + * (private[ml]) Trait for shared param inputCol. */ -@DeveloperApi -trait HasInputCol extends Params { +private[ml] trait HasInputCol extends Params { /** * Param for input column name. @@ -189,11 +171,9 @@ trait HasInputCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param inputCols. + * (private[ml]) Trait for shared param inputCols. */ -@DeveloperApi -trait HasInputCols extends Params { +private[ml] trait HasInputCols extends Params { /** * Param for input column names. @@ -206,11 +186,9 @@ trait HasInputCols extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param outputCol. + * (private[ml]) Trait for shared param outputCol. */ -@DeveloperApi -trait HasOutputCol extends Params { +private[ml] trait HasOutputCol extends Params { /** * Param for output column name. @@ -223,28 +201,24 @@ trait HasOutputCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param checkpointInterval. + * (private[ml]) Trait for shared param checkpointInterval. */ -@DeveloperApi -trait HasCheckpointInterval extends Params { +private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval. + * Param for checkpoint interval (>= 1). * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) /** @group getParam */ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) } /** - * :: DeveloperApi :: - * Trait for shared param fitIntercept (default: true). + * (private[ml]) Trait for shared param fitIntercept (default: true). */ -@DeveloperApi -trait HasFitIntercept extends Params { +private[ml] trait HasFitIntercept extends Params { /** * Param for whether to fit an intercept term. @@ -259,11 +233,9 @@ trait HasFitIntercept extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param seed (default: Utils.random.nextLong()). + * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()). */ -@DeveloperApi -trait HasSeed extends Params { +private[ml] trait HasSeed extends Params { /** * Param for random seed. @@ -278,28 +250,24 @@ trait HasSeed extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param elasticNetParam. + * (private[ml]) Trait for shared param elasticNetParam. */ -@DeveloperApi -trait HasElasticNetParam extends Params { +private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = getOrDefault(elasticNetParam) } /** - * :: DeveloperApi :: - * Trait for shared param tol. + * (private[ml]) Trait for shared param tol. */ -@DeveloperApi -trait HasTol extends Params { +private[ml] trait HasTol extends Params { /** * Param for the convergence tolerance for iterative algorithms. @@ -312,11 +280,9 @@ trait HasTol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param stepSize. + * (private[ml]) Trait for shared param stepSize. */ -@DeveloperApi -trait HasStepSize extends Params { +private[ml] trait HasStepSize extends Params { /** * Param for Step size to be used for each iteration of optimization.. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index bd793beba3..f9f2b2764d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -52,35 +52,40 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR with HasPredictionCol with HasCheckpointInterval { /** - * Param for rank of the matrix factorization. + * Param for rank of the matrix factorization (>= 1). + * Default: 10 * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization") + val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1)) /** @group getParam */ def getRank: Int = getOrDefault(rank) /** - * Param for number of user blocks. + * Param for number of user blocks (>= 1). + * Default: 10 * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", + ParamValidators.gtEq(1)) /** @group getParam */ def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** - * Param for number of item blocks. + * Param for number of item blocks (>= 1). + * Default: 10 * @group param */ - val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks") + val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", + ParamValidators.gtEq(1)) /** @group getParam */ def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. + * Default: false * @group param */ val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") @@ -89,16 +94,19 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** - * Param for the alpha parameter in the implicit preference formulation. + * Param for the alpha parameter in the implicit preference formulation (>= 0). + * Default: 1.0 * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", + ParamValidators.gtEq(0)) /** @group getParam */ def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. + * Default: "user" * @group param */ val userCol = new Param[String](this, "userCol", "column name for user ids") @@ -108,6 +116,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for the column name for item ids. + * Default: "item" * @group param */ val itemCol = new Param[String](this, "itemCol", "column name for item ids") @@ -117,6 +126,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for the column name for ratings. + * Default: "rating" * @group param */ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") @@ -126,6 +136,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for whether to apply nonnegativity constraints. + * Default: false * @group param */ val nonnegative = new BooleanParam( @@ -136,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false) + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) /** * Validates and transforms the input schema. @@ -281,10 +292,6 @@ class ALS extends Estimator[ALSModel] with ALSParams { this } - setMaxIter(20) - setRegParam(1.0) - setCheckpointInterval(10) - override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = extractParamMap(paramMap) val ratings = dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index c784cf39ed..76c9837693 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -102,21 +102,16 @@ final class GBTRegressor */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}") + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) setDefault(lossType -> "squared") /** @group setParam */ - def setLossType(value: String): this.type = { - val lossStr = value.toLowerCase - require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" + - s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}") - set(lossType, lossStr) - this - } + def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType) + def getLossType: String = getOrDefault(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index cc9ad22cb8..11c6cea0f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,7 +25,8 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction} import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter, + HasRegParam} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -46,6 +47,16 @@ private[regression] trait LinearRegressionParams extends RegressorParams * :: AlphaComponent :: * * Linear regression. + * + * The learning objective is to minimize the squared error, with regularization. + * The specific squared error loss function used is: + * L = 1/2n ||A weights - y||^2^ + * + * This support multiple types of regularization: + * - none (a.k.a. ordinary least squares) + * - L2 (ridge regression) + * - L1 (Lasso) + * - L2 + L1 (elastic net) */ @AlphaComponent class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] @@ -135,7 +146,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) } else { - new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, + paramMap(tol)) } val initialWeights = Vectors.zeros(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4bb4ed813c..d1ad0893cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -22,7 +22,7 @@ import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -61,10 +61,12 @@ private[ml] trait CrossValidatorParams extends Params { def getEvaluator: Evaluator = getOrDefault(evaluator) /** - * param for number of folds for cross validation + * Param for number of folds for cross validation. Must be >= 2. + * Default: 3 * @group param */ - val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") + val numFolds: IntParam = new IntParam(this, "numFolds", + "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) /** @group getParam */ def getNumFolds: Int = getOrDefault(numFolds) @@ -93,6 +95,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) + override def validate(paramMap: ParamMap): Unit = { + getEstimatorParamMaps.foreach { eMap => + getEstimator.validate(eMap ++ paramMap) + } + } + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = extractParamMap(paramMap) val schema = dataset.schema @@ -101,8 +109,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val est = map(estimator) val eval = map(evaluator) val epm = map(estimatorParamMaps) - val numModels = epm.size - val metrics = new Array[Double](epm.size) + val numModels = epm.length + val metrics = new Array[Double](epm.length) val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() @@ -148,6 +156,10 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { + override def validate(paramMap: ParamMap): Unit = { + bestModel.validate(paramMap) + } + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index deac390130..1f779584dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -180,7 +180,9 @@ object GradientBoostedTrees extends Logging { val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) true - } else false + } else { + false + } timer.stop("init") diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java new file mode 100644 index 0000000000..e7df10dfa6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; + +/** + * Test Param and related classes in Java + */ +public class JavaParamsSuite { + + private transient JavaSparkContext jsc; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaParamsSuite"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testParams() { + JavaTestParams testParams = new JavaTestParams(); + Assert.assertEquals(testParams.getMyIntParam(), 1); + testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); + Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); + Assert.assertEquals(testParams.getMyStringParam(), "a"); + } + + @Test + public void testParamValidate() { + ParamValidators.gt(1.0); + ParamValidators.gtEq(1.0); + ParamValidators.lt(1.0); + ParamValidators.ltEq(1.0); + ParamValidators.inRange(0, 1, true, false); + ParamValidators.inRange(0, 1); + ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); + ParamValidators.inArray(Lists.newArrayList("a", "b")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java new file mode 100644 index 0000000000..8abe575610 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import java.util.List; + +import com.google.common.collect.Lists; + +/** + * A subclass of Params for testing. + */ +public class JavaTestParams extends JavaParams { + + public IntParam myIntParam; + + public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); } + + public JavaTestParams setMyIntParam(int value) { + set(myIntParam, value); return this; + } + + public DoubleParam myDoubleParam; + + public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); } + + public JavaTestParams setMyDoubleParam(double value) { + set(myDoubleParam, value); return this; + } + + public Param myStringParam; + + public String getMyStringParam() { return (String)getOrDefault(myStringParam); } + + public JavaTestParams setMyStringParam(String value) { + set(myStringParam, value); return this; + } + + public JavaTestParams() { + myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); + myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param", + ParamValidators.inRange(0.0, 1.0)); + List validStrings = Lists.newArrayList("a", "b"); + myStringParam = new Param(this, "myStringParam", "this is a string param", + ParamValidators.inArray(validStrings)); + setDefault(myIntParam, 1); + setDefault(myDoubleParam, 0.5); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 88ea679eea..f8852606ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -26,14 +26,22 @@ class ParamsSuite extends FunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations") + assert(maxIter.doc === "max number of iterations (>= 0)") assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)") + assert(!maxIter.isValid(-1)) + assert(maxIter.isValid(0)) + assert(maxIter.isValid(1)) solver.setMaxIter(5) - assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + assert(maxIter.toString === + "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === "inputCol: input column name (undefined)") + + intercept[IllegalArgumentException] { + solver.setMaxIter(-1) + } } test("param pair") { @@ -47,6 +55,9 @@ class ParamsSuite extends FunSuite { assert(pair.param.eq(maxIter)) assert(pair.value === 5) } + intercept[IllegalArgumentException] { + val pair = maxIter -> -1 + } } test("param map") { @@ -59,6 +70,9 @@ class ParamsSuite extends FunSuite { map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) + intercept[IllegalArgumentException] { + map0.put(maxIter, -1) + } assert(!map0.contains(inputCol)) intercept[NoSuchElementException] { @@ -122,14 +136,57 @@ class ParamsSuite extends FunSuite { assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { - solver.validate(ParamMap(maxIter -> -10)) + ParamMap(maxIter -> -10) } - solver.setMaxIter(-10) intercept[IllegalArgumentException] { - solver.validate() + solver.setMaxIter(-10) } solver.clearMaxIter() assert(!solver.isSet(maxIter)) } + + test("ParamValidate") { + val alwaysTrue = ParamValidators.alwaysTrue[Int] + assert(alwaysTrue(1)) + + val gt1Int = ParamValidators.gt[Int](1) + assert(!gt1Int(1) && gt1Int(2)) + val gt1Double = ParamValidators.gt[Double](1) + assert(!gt1Double(1.0) && gt1Double(1.1)) + + val gtEq1Int = ParamValidators.gtEq[Int](1) + assert(!gtEq1Int(0) && gtEq1Int(1)) + val gtEq1Double = ParamValidators.gtEq[Double](1) + assert(!gtEq1Double(0.9) && gtEq1Double(1.0)) + + val lt1Int = ParamValidators.lt[Int](1) + assert(lt1Int(0) && !lt1Int(1)) + val lt1Double = ParamValidators.lt[Double](1) + assert(lt1Double(0.9) && !lt1Double(1.0)) + + val ltEq1Int = ParamValidators.ltEq[Int](1) + assert(ltEq1Int(1) && !ltEq1Int(2)) + val ltEq1Double = ParamValidators.ltEq[Double](1) + assert(ltEq1Double(1.0) && !ltEq1Double(1.1)) + + val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2) + assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) && + !inRange02IntInclusive(-1) && !inRange02IntInclusive(3)) + val inRange02IntExclusive = + ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2)) + + val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2) + assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) && + inRange02DoubleInclusive(2) && + !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1)) + val inRange02DoubleExclusive = + ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) && + !inRange02DoubleExclusive(2)) + + val inArray = ParamValidators.inArray[Int](Array(1, 2)) + assert(inArray(1) && inArray(2) && !inArray(0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 641b64b42a..6f9c9cb516 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -29,7 +29,7 @@ class TestParams extends Params with HasMaxIter with HasInputCol { override def validate(paramMap: ParamMap): Unit = { val m = extractParamMap(paramMap) - require(m(maxIter) >= 0) + // Note: maxIter is validated when it is set. require(m.contains(inputCol)) } -- cgit v1.2.3