aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-06-06 09:49:45 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-06-06 09:49:45 -0700
commit4c74ee8d8e1c3139d3d322ae68977f2ab53295df (patch)
tree8581f91d642ca8ff4f766177197f5487f29805e3
parentfa4bc8ea8bab1277d1482da370dac79947cac719 (diff)
downloadspark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.gz
spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.bz2
spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.zip
[SPARK-15721][ML] Make DefaultParamsReadable, DefaultParamsWritable public
## What changes were proposed in this pull request? Made DefaultParamsReadable, DefaultParamsWritable public. Also added relevant doc and annotations. Added UnaryTransformerExample to demonstrate use of UnaryTransformer and DefaultParamsReadable,Writable. ## How was this patch tested? Wrote example making use of the now-public APIs. Compiled and ran locally Author: Joseph K. Bradley <joseph@databricks.com> Closes #13461 from jkbradley/defaultparamswritable.
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala122
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala44
2 files changed, 163 insertions, 3 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
new file mode 100644
index 0000000000..13c72f88cc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.DoubleParam
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
+import org.apache.spark.sql.functions.col
+// $example off$
+import org.apache.spark.sql.SparkSession
+// $example on$
+import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.util.Utils
+// $example off$
+
+/**
+ * An example demonstrating creating a custom [[org.apache.spark.ml.Transformer]] using
+ * the [[UnaryTransformer]] abstraction.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.UnaryTransformerExample
+ * }}}
+ */
+object UnaryTransformerExample {
+
+ // $example on$
+ /**
+ * Simple Transformer which adds a constant value to input Doubles.
+ *
+ * [[UnaryTransformer]] can be used to create a stage usable within Pipelines.
+ * It defines parameters for specifying input and output columns:
+ * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]].
+ * It can optionally handle schema validation.
+ *
+ * [[DefaultParamsWritable]] provides a default implementation for persisting instances
+ * of this Transformer.
+ */
+ class MyTransformer(override val uid: String)
+ extends UnaryTransformer[Double, Double, MyTransformer] with DefaultParamsWritable {
+
+ final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added to input")
+
+ def getShift: Double = $(shift)
+
+ def setShift(value: Double): this.type = set(shift, value)
+
+ def this() = this(Identifiable.randomUID("myT"))
+
+ override protected def createTransformFunc: Double => Double = (input: Double) => {
+ input + $(shift)
+ }
+
+ override protected def outputDataType: DataType = DataTypes.DoubleType
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. Requires Double.")
+ }
+ }
+
+ /**
+ * Companion object for our simple Transformer.
+ *
+ * [[DefaultParamsReadable]] provides a default implementation for loading instances
+ * of this Transformer which were persisted using [[DefaultParamsWritable]].
+ */
+ object MyTransformer extends DefaultParamsReadable[MyTransformer]
+ // $example off$
+
+ def main(args: Array[String]) {
+ val spark = SparkSession
+ .builder()
+ .appName("UnaryTransformerExample")
+ .getOrCreate()
+
+ // $example on$
+ val myTransformer = new MyTransformer()
+ .setShift(0.5)
+ .setInputCol("input")
+ .setOutputCol("output")
+
+ // Create data, transform, and display it.
+ val data = spark.range(0, 5).toDF("input")
+ .select(col("input").cast("double").as("input"))
+ val result = myTransformer.transform(data)
+ result.show()
+
+ // Save and load the Transformer.
+ val tmpDir = Utils.createTempDir()
+ val dirName = tmpDir.getCanonicalPath
+ myTransformer.write.overwrite().save(dirName)
+ val sameTransformer = MyTransformer.load(dirName)
+
+ // Transform the data to show the results are identical.
+ val sameResult = sameTransformer.transform(data)
+ sameResult.show()
+
+ Utils.deleteRecursively(tmpDir)
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 8ed40c379c..90b8d7df7b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -68,6 +68,8 @@ private[util] sealed trait BaseReadWrite {
}
/**
+ * :: Experimental ::
+ *
* Abstract class for utility classes that can save ML instances.
*/
@Experimental
@@ -120,8 +122,11 @@ abstract class MLWriter extends BaseReadWrite with Logging {
}
/**
+ * :: Experimental ::
+ *
* Trait for classes that provide [[MLWriter]].
*/
+@Experimental
@Since("1.6.0")
trait MLWritable {
@@ -139,12 +144,27 @@ trait MLWritable {
def save(path: String): Unit = write.save(path)
}
-private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
+/**
+ * :: Experimental ::
+ *
+ * Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores
+ * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
+ * a default implementation of writing saved instances of the class.
+ * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
+ * [[org.apache.spark.sql.Dataset]].
+ *
+ * @see [[DefaultParamsReadable]], the counterpart to this trait
+ */
+@Experimental
+@Since("2.0.0")
+trait DefaultParamsWritable extends MLWritable { self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
}
/**
+ * :: Experimental ::
+ *
* Abstract class for utility classes that can load ML instances.
*
* @tparam T ML instance type
@@ -164,6 +184,8 @@ abstract class MLReader[T] extends BaseReadWrite {
}
/**
+ * :: Experimental ::
+ *
* Trait for objects that provide [[MLReader]].
*
* @tparam T ML instance type
@@ -187,9 +209,25 @@ trait MLReadable[T] {
def load(path: String): T = read.load(path)
}
-private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
- override def read: MLReader[T] = new DefaultParamsReader
+/**
+ * :: Experimental ::
+ *
+ * Helper trait for making simple [[Params]] types readable. If a [[Params]] class stores
+ * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
+ * a default implementation of reading saved instances of the class.
+ * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
+ * [[org.apache.spark.sql.Dataset]].
+ *
+ * @tparam T ML instance type
+ *
+ * @see [[DefaultParamsWritable]], the counterpart to this trait
+ */
+@Experimental
+@Since("2.0.0")
+trait DefaultParamsReadable[T] extends MLReadable[T] {
+
+ override def read: MLReader[T] = new DefaultParamsReader[T]
}
/**