diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-06-06 09:49:45 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-06-06 09:49:45 -0700 |
commit | 4c74ee8d8e1c3139d3d322ae68977f2ab53295df (patch) | |
tree | 8581f91d642ca8ff4f766177197f5487f29805e3 | |
parent | fa4bc8ea8bab1277d1482da370dac79947cac719 (diff) | |
download | spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.gz spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.bz2 spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.zip |
[SPARK-15721][ML] Make DefaultParamsReadable, DefaultParamsWritable public
## What changes were proposed in this pull request?
Made DefaultParamsReadable, DefaultParamsWritable public. Also added relevant doc and annotations. Added UnaryTransformerExample to demonstrate use of UnaryTransformer and DefaultParamsReadable,Writable.
## How was this patch tested?
Wrote example making use of the now-public APIs. Compiled and ran locally
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #13461 from jkbradley/defaultparamswritable.
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala | 122 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 44 |
2 files changed, 163 insertions, 3 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala new file mode 100644 index 0000000000..13c72f88cc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.DoubleParam +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.functions.col +// $example off$ +import org.apache.spark.sql.SparkSession +// $example on$ +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils +// $example off$ + +/** + * An example demonstrating creating a custom [[org.apache.spark.ml.Transformer]] using + * the [[UnaryTransformer]] abstraction. + * + * Run with + * {{{ + * bin/run-example ml.UnaryTransformerExample + * }}} + */ +object UnaryTransformerExample { + + // $example on$ + /** + * Simple Transformer which adds a constant value to input Doubles. + * + * [[UnaryTransformer]] can be used to create a stage usable within Pipelines. + * It defines parameters for specifying input and output columns: + * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]]. + * It can optionally handle schema validation. + * + * [[DefaultParamsWritable]] provides a default implementation for persisting instances + * of this Transformer. + */ + class MyTransformer(override val uid: String) + extends UnaryTransformer[Double, Double, MyTransformer] with DefaultParamsWritable { + + final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added to input") + + def getShift: Double = $(shift) + + def setShift(value: Double): this.type = set(shift, value) + + def this() = this(Identifiable.randomUID("myT")) + + override protected def createTransformFunc: Double => Double = (input: Double) => { + input + $(shift) + } + + override protected def outputDataType: DataType = DataTypes.DoubleType + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. Requires Double.") + } + } + + /** + * Companion object for our simple Transformer. + * + * [[DefaultParamsReadable]] provides a default implementation for loading instances + * of this Transformer which were persisted using [[DefaultParamsWritable]]. + */ + object MyTransformer extends DefaultParamsReadable[MyTransformer] + // $example off$ + + def main(args: Array[String]) { + val spark = SparkSession + .builder() + .appName("UnaryTransformerExample") + .getOrCreate() + + // $example on$ + val myTransformer = new MyTransformer() + .setShift(0.5) + .setInputCol("input") + .setOutputCol("output") + + // Create data, transform, and display it. + val data = spark.range(0, 5).toDF("input") + .select(col("input").cast("double").as("input")) + val result = myTransformer.transform(data) + result.show() + + // Save and load the Transformer. + val tmpDir = Utils.createTempDir() + val dirName = tmpDir.getCanonicalPath + myTransformer.write.overwrite().save(dirName) + val sameTransformer = MyTransformer.load(dirName) + + // Transform the data to show the results are identical. + val sameResult = sameTransformer.transform(data) + sameResult.show() + + Utils.deleteRecursively(tmpDir) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println + diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8ed40c379c..90b8d7df7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -68,6 +68,8 @@ private[util] sealed trait BaseReadWrite { } /** + * :: Experimental :: + * * Abstract class for utility classes that can save ML instances. */ @Experimental @@ -120,8 +122,11 @@ abstract class MLWriter extends BaseReadWrite with Logging { } /** + * :: Experimental :: + * * Trait for classes that provide [[MLWriter]]. */ +@Experimental @Since("1.6.0") trait MLWritable { @@ -139,12 +144,27 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } -private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => +/** + * :: Experimental :: + * + * Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of writing saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @see [[DefaultParamsReadable]], the counterpart to this trait + */ +@Experimental +@Since("2.0.0") +trait DefaultParamsWritable extends MLWritable { self: Params => override def write: MLWriter = new DefaultParamsWriter(this) } /** + * :: Experimental :: + * * Abstract class for utility classes that can load ML instances. * * @tparam T ML instance type @@ -164,6 +184,8 @@ abstract class MLReader[T] extends BaseReadWrite { } /** + * :: Experimental :: + * * Trait for objects that provide [[MLReader]]. * * @tparam T ML instance type @@ -187,9 +209,25 @@ trait MLReadable[T] { def load(path: String): T = read.load(path) } -private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { - override def read: MLReader[T] = new DefaultParamsReader +/** + * :: Experimental :: + * + * Helper trait for making simple [[Params]] types readable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of reading saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @tparam T ML instance type + * + * @see [[DefaultParamsWritable]], the counterpart to this trait + */ +@Experimental +@Since("2.0.0") +trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader[T] } /** |