From e3727c409fe7d1fb6e27a14faddd0602f963745e Mon Sep 17 00:00:00 2001 From: Takahashi Hiroshi Date: Wed, 20 Jan 2016 11:44:04 -0800 Subject: [SPARK-10263][ML] Add @Since annotation to ml.param and ml.* Add Since annotations to ml.param and ml.* Author: Takahashi Hiroshi Author: Hiroshi Takahashi Closes #8935 from taishi-oss/issue10263. --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 21 ++++++++++++++--- .../scala/org/apache/spark/ml/param/params.scala | 26 ++++++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 32570a16e6..cbac7bbf49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -85,25 +85,32 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ +@Since("1.2.0") @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { +class Pipeline @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) /** * param for pipeline stages * @group param */ + @Since("1.2.0") val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") /** @group setParam */ + @Since("1.2.0") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. /** @group getParam */ + @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() $(stages).foreach(_.validateParams()) @@ -121,6 +128,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M * @param dataset input dataset * @return fitted pipeline */ + @Since("1.2.0") override def fit(dataset: DataFrame): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) @@ -158,12 +166,14 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M new PipelineModel(uid, transformers.toArray).setParent(this) } + @Since("1.4.0") override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) new Pipeline().setStages(newStages) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { validateParams() val theStages = $(stages) @@ -275,10 +285,11 @@ object Pipeline extends MLReadable[Pipeline] { * :: Experimental :: * Represents a fitted pipeline. */ +@Since("1.2.0") @Experimental class PipelineModel private[ml] ( - override val uid: String, - val stages: Array[Transformer]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val stages: Array[Transformer]) extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ @@ -286,21 +297,25 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) } + @Since("1.2.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } + @Since("1.4.0") override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c0546695e4..f48923d699 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -504,8 +504,11 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In * :: Experimental :: * A param and its value. */ +@Since("1.2.0") @Experimental -case class ParamPair[T](param: Param[T], value: T) { +case class ParamPair[T] @Since("1.2.0") ( + @Since("1.2.0") param: Param[T], + @Since("1.2.0") value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. param.validate(value) @@ -786,6 +789,7 @@ abstract class JavaParams extends Params * :: Experimental :: * A param to value map. */ +@Since("1.2.0") @Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -799,17 +803,20 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Creates an empty param map. */ + @Since("1.2.0") def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). */ + @Since("1.2.0") def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). */ @varargs + @Since("1.2.0") def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => map(p.param.asInstanceOf[Param[Any]]) = p.value @@ -820,6 +827,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Optionally returns the value associated with a param. */ + @Since("1.2.0") def get[T](param: Param[T]): Option[T] = { map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -827,6 +835,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Returns the value associated with a param or a default value. */ + @Since("1.4.0") def getOrElse[T](param: Param[T], default: T): T = { get(param).getOrElse(default) } @@ -835,6 +844,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Gets the value of the input param or its default value if it does not exist. * Raises a NoSuchElementException if there is no value associated with the input param. */ + @Since("1.2.0") def apply[T](param: Param[T]): T = { get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") @@ -844,6 +854,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Checks whether a parameter is explicitly specified. */ + @Since("1.2.0") def contains(param: Param[_]): Boolean = { map.contains(param.asInstanceOf[Param[Any]]) } @@ -851,6 +862,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Removes a key from this map and returns its value associated previously as an option. */ + @Since("1.4.0") def remove[T](param: Param[T]): Option[T] = { map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -858,6 +870,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Filters this param map for the given parent. */ + @Since("1.2.0") def filter(parent: Params): ParamMap = { // Don't use filterKeys because mutable.Map#filterKeys // returns the instance of collections.Map, not mutable.Map. @@ -870,8 +883,10 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Creates a copy of this param map. */ + @Since("1.2.0") def copy: ParamMap = new ParamMap(map.clone()) + @Since("1.2.0") override def toString: String = { map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent}-${param.name}: $value" @@ -882,6 +897,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Returns a new param map that contains parameters in this map and the given map, * where the latter overwrites this if there exist conflicts. */ + @Since("1.2.0") def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) @@ -890,6 +906,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Adds all parameters from the input param map into this param map. */ + @Since("1.2.0") def ++=(other: ParamMap): this.type = { // TODO: Provide a better method name for Java users. this.map ++= other.map @@ -899,6 +916,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Converts this param map to a sequence of param pairs. */ + @Since("1.2.0") def toSeq: Seq[ParamPair[_]] = { map.toSeq.map { case (param, value) => ParamPair(param, value) @@ -908,21 +926,25 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Number of param pairs in this map. */ + @Since("1.3.0") def size: Int = map.size } +@Since("1.2.0") @Experimental object ParamMap { /** * Returns an empty param map. */ + @Since("1.2.0") def empty: ParamMap = new ParamMap() /** * Constructs a param map by specifying its entries. */ @varargs + @Since("1.2.0") def apply(paramPairs: ParamPair[_]*): ParamMap = { new ParamMap().put(paramPairs: _*) } -- cgit v1.2.3