aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-22 21:45:46 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:45:46 -0800
commitfe89c1817d668e46adf70d0896c42c22a547c76a (patch)
tree2c57d80a52027c08afd219615f2ad9d18fd8b107 /examples
parent426004a9c9a864f90494d08601e6974709091a56 (diff)
downloadspark-fe89c1817d668e46adf70d0896c42c22a547c76a.tar.gz
spark-fe89c1817d668e46adf70d0896c42c22a547c76a.tar.bz2
spark-fe89c1817d668e46adf70d0896c42c22a547c76a.zip
[SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples
We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng <meng@databricks.com> Closes #9873 from mengxr/SPARK-11895.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala (renamed from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala)71
1 files changed, 26 insertions, 45 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
index dc13f82488..424f00158c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
@@ -16,7 +16,7 @@
*/
// scalastyle:off println
-package org.apache.spark.examples.mllib
+package org.apache.spark.examples.ml
import java.io.File
@@ -24,25 +24,22 @@ import com.google.common.io.Files
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
-import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
/**
- * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
+ * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with
* {{{
- * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * ./bin/run-example ml.DataFrameExample [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
-object DatasetExample {
+object DataFrameExample {
- case class Params(
- input: String = "data/mllib/sample_libsvm_data.txt",
- dataFormat: String = "libsvm") extends AbstractParams[Params]
+ case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -52,9 +49,6 @@ object DatasetExample {
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
- opt[String]("dataFormat")
- .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
- .action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
@@ -69,55 +63,42 @@ object DatasetExample {
def run(params: Params) {
- val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._ // for implicit conversions
// Load input data
- val origData: RDD[LabeledPoint] = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
- }
- println(s"Loaded ${origData.count()} instances from file: ${params.input}")
-
- // Convert input data to DataFrame explicitly.
- val df: DataFrame = origData.toDF()
- println(s"Inferred schema:\n${df.schema.prettyJson}")
- println(s"Converted to DataFrame with ${df.count()} records")
-
- // Select columns
- val labelsDf: DataFrame = df.select("label")
- val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
- val numLabels = labels.count()
- val meanLabel = labels.fold(0.0)(_ + _) / numLabels
- println(s"Selected label column with average value $meanLabel")
-
- val featuresDf: DataFrame = df.select("features")
- val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
+ println(s"Loading LIBSVM file with UDT from ${params.input}.")
+ val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
+ println("Schema from LIBSVM:")
+ df.printSchema()
+ println(s"Loaded training data as a DataFrame with ${df.count()} records.")
+
+ // Show statistical summary of labels.
+ val labelSummary = df.describe("label")
+ labelSummary.show()
+
+ // Convert features column to an RDD of vectors.
+ val features = df.select("features").map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+ // Save the records in a parquet file.
val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
df.write.parquet(outputDir)
+ // Load the records back.
println(s"Loading Parquet file with UDT from $outputDir.")
- val newDataset = sqlContext.read.parquet(outputDir)
-
- println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
- val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
- val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
- (summary, feat) => summary.add(feat),
- (sum1, sum2) => sum1.merge(sum2))
- println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+ val newDF = sqlContext.read.parquet(outputDir)
+ println(s"Schema from Parquet:")
+ newDF.printSchema()
sc.stop()
}
-
}
// scalastyle:on println