aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-06-06 09:49:45 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-06-06 09:49:45 -0700
commit4c74ee8d8e1c3139d3d322ae68977f2ab53295df (patch)
tree8581f91d642ca8ff4f766177197f5487f29805e3 /examples
parentfa4bc8ea8bab1277d1482da370dac79947cac719 (diff)
downloadspark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.gz
spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.tar.bz2
spark-4c74ee8d8e1c3139d3d322ae68977f2ab53295df.zip
[SPARK-15721][ML] Make DefaultParamsReadable, DefaultParamsWritable public
## What changes were proposed in this pull request? Made DefaultParamsReadable, DefaultParamsWritable public. Also added relevant doc and annotations. Added UnaryTransformerExample to demonstrate use of UnaryTransformer and DefaultParamsReadable,Writable. ## How was this patch tested? Wrote example making use of the now-public APIs. Compiled and ran locally Author: Joseph K. Bradley <joseph@databricks.com> Closes #13461 from jkbradley/defaultparamswritable.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala122
1 files changed, 122 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
new file mode 100644
index 0000000000..13c72f88cc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.DoubleParam
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
+import org.apache.spark.sql.functions.col
+// $example off$
+import org.apache.spark.sql.SparkSession
+// $example on$
+import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.util.Utils
+// $example off$
+
+/**
+ * An example demonstrating creating a custom [[org.apache.spark.ml.Transformer]] using
+ * the [[UnaryTransformer]] abstraction.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.UnaryTransformerExample
+ * }}}
+ */
+object UnaryTransformerExample {
+
+ // $example on$
+ /**
+ * Simple Transformer which adds a constant value to input Doubles.
+ *
+ * [[UnaryTransformer]] can be used to create a stage usable within Pipelines.
+ * It defines parameters for specifying input and output columns:
+ * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]].
+ * It can optionally handle schema validation.
+ *
+ * [[DefaultParamsWritable]] provides a default implementation for persisting instances
+ * of this Transformer.
+ */
+ class MyTransformer(override val uid: String)
+ extends UnaryTransformer[Double, Double, MyTransformer] with DefaultParamsWritable {
+
+ final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added to input")
+
+ def getShift: Double = $(shift)
+
+ def setShift(value: Double): this.type = set(shift, value)
+
+ def this() = this(Identifiable.randomUID("myT"))
+
+ override protected def createTransformFunc: Double => Double = (input: Double) => {
+ input + $(shift)
+ }
+
+ override protected def outputDataType: DataType = DataTypes.DoubleType
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. Requires Double.")
+ }
+ }
+
+ /**
+ * Companion object for our simple Transformer.
+ *
+ * [[DefaultParamsReadable]] provides a default implementation for loading instances
+ * of this Transformer which were persisted using [[DefaultParamsWritable]].
+ */
+ object MyTransformer extends DefaultParamsReadable[MyTransformer]
+ // $example off$
+
+ def main(args: Array[String]) {
+ val spark = SparkSession
+ .builder()
+ .appName("UnaryTransformerExample")
+ .getOrCreate()
+
+ // $example on$
+ val myTransformer = new MyTransformer()
+ .setShift(0.5)
+ .setInputCol("input")
+ .setOutputCol("output")
+
+ // Create data, transform, and display it.
+ val data = spark.range(0, 5).toDF("input")
+ .select(col("input").cast("double").as("input"))
+ val result = myTransformer.transform(data)
+ result.show()
+
+ // Save and load the Transformer.
+ val tmpDir = Utils.createTempDir()
+ val dirName = tmpDir.getCanonicalPath
+ myTransformer.write.overwrite().save(dirName)
+ val sameTransformer = MyTransformer.load(dirName)
+
+ // Transform the data to show the results are identical.
+ val sameResult = sameTransformer.transform(data)
+ sameResult.show()
+
+ Utils.deleteRecursively(tmpDir)
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
+