aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/param/params.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala75
1 files changed, 44 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index df6360dce6..51ce19d29c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -23,7 +23,7 @@ import java.util.NoSuchElementException
import scala.annotation.varargs
import scala.collection.mutable
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
/**
@@ -49,7 +49,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
* Assert that the given value is valid for this parameter.
*
* Note: Parameter checks involving interactions between multiple parameters should be
- * implemented in [[Params.validate()]]. Checks for input/output columns should be
+ * implemented in [[Params.validateParams()]]. Checks for input/output columns should be
* implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
*
* DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
@@ -258,7 +258,9 @@ trait Params extends Identifiable with Serializable {
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
*/
- def validate(paramMap: ParamMap): Unit = { }
+ def validateParams(paramMap: ParamMap): Unit = {
+ copy(paramMap).validateParams()
+ }
/**
* Validates parameter values stored internally.
@@ -269,7 +271,11 @@ trait Params extends Identifiable with Serializable {
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
*/
- def validate(): Unit = validate(ParamMap.empty)
+ def validateParams(): Unit = {
+ params.filter(isDefined _).foreach { param =>
+ param.asInstanceOf[Param[Any]].validate($(param))
+ }
+ }
/**
* Returns the documentation of all params.
@@ -288,6 +294,11 @@ trait Params extends Identifiable with Serializable {
defaultParamMap.contains(param) || paramMap.contains(param)
}
+ /** Tests whether this instance contains a param with a given name. */
+ def hasParam(paramName: String): Boolean = {
+ params.exists(_.name == paramName)
+ }
+
/** Gets a param by its name. */
def getParam(paramName: String): Param[Any] = {
params.find(_.name == paramName).getOrElse {
@@ -337,6 +348,9 @@ trait Params extends Identifiable with Serializable {
get(param).orElse(getDefault(param)).get
}
+ /** An alias for [[getOrDefault()]]. */
+ protected final def $[T](param: Param[T]): T = getOrDefault(param)
+
/**
* Sets a default value for a param.
* @param param param to set the default value. Make sure that this param is initialized before
@@ -383,18 +397,30 @@ trait Params extends Identifiable with Serializable {
}
/**
+ * Creates a copy of this instance with a randomly generated uid and some extra params.
+ * The default implementation calls the default constructor to create a new instance, then
+ * copies the embedded and extra parameters over and returns the new instance.
+ * Subclasses should override this method if the default approach is not sufficient.
+ */
+ def copy(extra: ParamMap): Params = {
+ val that = this.getClass.newInstance()
+ copyValues(that, extra)
+ that
+ }
+
+ /**
* Extracts the embedded default param values and user-supplied values, and then merges them with
* extra values from input into a flat param map, where the latter value is used if there exist
* conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
*/
- protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
+ final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
defaultParamMap ++ paramMap ++ extraParamMap
}
/**
* [[extractParamMap]] with no extra values.
*/
- protected final def extractParamMap(): ParamMap = {
+ final def extractParamMap(): ParamMap = {
extractParamMap(ParamMap.empty)
}
@@ -408,34 +434,21 @@ trait Params extends Identifiable with Serializable {
private def shouldOwn(param: Param[_]): Unit = {
require(param.parent.eq(this), s"Param $param does not belong to $this.")
}
-}
-/**
- * :: DeveloperApi ::
- *
- * Helper functionality for developers.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] object Params {
-
- /**
- * Copies parameter values from the parent estimator to the child model it produced.
- * @param paramMap the param map that holds parameters of the parent
- * @param parent the parent estimator
- * @param child the child model
- */
- def inheritValues[E <: Params, M <: E](
- paramMap: ParamMap,
- parent: E,
- child: M): Unit = {
- val childParams = child.params.map(_.name).toSet
- parent.params.foreach { param =>
- if (paramMap.contains(param) && childParams.contains(param.name)) {
- child.set(child.getParam(param.name), paramMap(param))
+ /**
+ * Copies param values from this instance to another instance for params shared by them.
+ * @param to the target instance
+ * @param extra extra params to be copied
+ * @return the target instance with param values copied
+ */
+ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
+ val map = extractParamMap(extra)
+ params.foreach { param =>
+ if (map.contains(param) && to.hasParam(param.name)) {
+ to.set(param.name, map(param))
}
}
+ to
}
}