aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
commit1a9c6cddadebdc53d083ac3e0da276ce979b5d1f (patch)
treeb485818ba52a9287ae7124e57ef55f1d974f3a1f /examples
parent04450d11548cfb25d4fb77d4a33e3a7cd4254183 (diff)
downloadspark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.gz
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.bz2
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.zip
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/mllib/dataset_example.py62
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala121
2 files changed, 183 insertions, 0 deletions
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000..540dae785f
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py <libsvm file>"
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000..f8d83f4ec7
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}