aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-18 13:34:36 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-18 13:34:36 -0700
commit8c62edb70fdeedf0ca5a7fc154698aea96184cc6 (patch)
tree2054a81b026f0b55043286227aeb8d27c692b6e2 /examples
parentf31a62d1b24aea8ddfa40b60378ce065518786e4 (diff)
downloadspark-8c62edb70fdeedf0ca5a7fc154698aea96184cc6.tar.gz
spark-8c62edb70fdeedf0ca5a7fc154698aea96184cc6.tar.bz2
spark-8c62edb70fdeedf0ca5a7fc154698aea96184cc6.zip
[SPARK-14299][EXAMPLES] Remove duplications for scala.examples.ml
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14299 Delete duplications in scala/examples/ml. TrainValidationSplitExample.scala --> ModelSelectionViaTrainValidationSplitExample CrossValidatorExample.scala --> ModelSelectionViaCrossValidationExample ## How was this patch tested? Existing tests passed. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Xusen Yin <yinxusen@gmail.com> Closes #12366 from yinxusen/SPARK-14299-2.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala114
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala9
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala78
4 files changed, 17 insertions, 192 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
deleted file mode 100644
index bca301d412..0000000000
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.ml
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
-import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
-import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{Row, SQLContext}
-
-/**
- * A simple example demonstrating model selection using CrossValidator.
- * This example also demonstrates how Pipelines are Estimators.
- *
- * This example uses the [[LabeledDocument]] and [[Document]] case classes from
- * [[SimpleTextClassificationPipeline]].
- *
- * Run with
- * {{{
- * bin/run-example ml.CrossValidatorExample
- * }}}
- */
-object CrossValidatorExample {
-
- def main(args: Array[String]) {
- val conf = new SparkConf().setAppName("CrossValidatorExample")
- val sc = new SparkContext(conf)
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
-
- // Prepare training documents, which are labeled.
- val training = sc.parallelize(Seq(
- LabeledDocument(0L, "a b c d e spark", 1.0),
- LabeledDocument(1L, "b d", 0.0),
- LabeledDocument(2L, "spark f g h", 1.0),
- LabeledDocument(3L, "hadoop mapreduce", 0.0),
- LabeledDocument(4L, "b spark who", 1.0),
- LabeledDocument(5L, "g d a y", 0.0),
- LabeledDocument(6L, "spark fly", 1.0),
- LabeledDocument(7L, "was mapreduce", 0.0),
- LabeledDocument(8L, "e spark program", 1.0),
- LabeledDocument(9L, "a e c l", 0.0),
- LabeledDocument(10L, "spark compile", 1.0),
- LabeledDocument(11L, "hadoop software", 0.0)))
-
- // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
- val tokenizer = new Tokenizer()
- .setInputCol("text")
- .setOutputCol("words")
- val hashingTF = new HashingTF()
- .setInputCol(tokenizer.getOutputCol)
- .setOutputCol("features")
- val lr = new LogisticRegression()
- .setMaxIter(10)
- val pipeline = new Pipeline()
- .setStages(Array(tokenizer, hashingTF, lr))
-
- // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
- // This will allow us to jointly choose parameters for all Pipeline stages.
- // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
- val crossval = new CrossValidator()
- .setEstimator(pipeline)
- .setEvaluator(new BinaryClassificationEvaluator)
- // We use a ParamGridBuilder to construct a grid of parameters to search over.
- // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
- // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
- val paramGrid = new ParamGridBuilder()
- .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
- .addGrid(lr.regParam, Array(0.1, 0.01))
- .build()
- crossval.setEstimatorParamMaps(paramGrid)
- crossval.setNumFolds(2) // Use 3+ in practice
-
- // Run cross-validation, and choose the best set of parameters.
- val cvModel = crossval.fit(training.toDF())
-
- // Prepare test documents, which are unlabeled.
- val test = sc.parallelize(Seq(
- Document(4L, "spark i j k"),
- Document(5L, "l m n"),
- Document(6L, "mapreduce spark"),
- Document(7L, "apache hadoop")))
-
- // Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test.toDF())
- .select("id", "text", "probability", "prediction")
- .collect()
- .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
- println(s"($id, $text) --> prob=$prob, prediction=$prediction")
- }
-
- sc.stop()
- }
-}
-// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
index 0331d6e7b3..d1441b5497 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
@@ -30,6 +30,15 @@ import org.apache.spark.sql.Row
// $example off$
import org.apache.spark.sql.SQLContext
+/**
+ * A simple example demonstrating model selection using CrossValidator.
+ * This example also demonstrates how Pipelines are Estimators.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.ModelSelectionViaCrossValidationExample
+ * }}}
+ */
object ModelSelectionViaCrossValidationExample {
def main(args: Array[String]): Unit = {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
index 5a95344f22..fcad17a817 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
@@ -25,6 +25,14 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
// $example off$
import org.apache.spark.sql.SQLContext
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.ModelSelectionViaTrainValidationSplitExample
+ * }}}
+ */
object ModelSelectionViaTrainValidationSplitExample {
def main(args: Array[String]): Unit = {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
deleted file mode 100644
index fbba17eba6..0000000000
--- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.ml
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.evaluation.RegressionEvaluator
-import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
-import org.apache.spark.sql.SQLContext
-
-/**
- * A simple example demonstrating model selection using TrainValidationSplit.
- *
- * The example is based on [[SimpleParamsExample]] using linear regression.
- * Run with
- * {{{
- * bin/run-example ml.TrainValidationSplitExample
- * }}}
- */
-object TrainValidationSplitExample {
-
- def main(args: Array[String]): Unit = {
- val conf = new SparkConf().setAppName("TrainValidationSplitExample")
- val sc = new SparkContext(conf)
- val sqlContext = new SQLContext(sc)
-
- // Prepare training and test data.
- val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
- val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
-
- val lr = new LinearRegression()
-
- // We use a ParamGridBuilder to construct a grid of parameters to search over.
- // TrainValidationSplit will try all combinations of values and determine best model using
- // the evaluator.
- val paramGrid = new ParamGridBuilder()
- .addGrid(lr.regParam, Array(0.1, 0.01))
- .addGrid(lr.fitIntercept, Array(true, false))
- .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
- .build()
-
- // In this case the estimator is simply the linear regression.
- // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
- val trainValidationSplit = new TrainValidationSplit()
- .setEstimator(lr)
- .setEvaluator(new RegressionEvaluator)
- .setEstimatorParamMaps(paramGrid)
-
- // 80% of the data will be used for training and the remaining 20% for validation.
- trainValidationSplit.setTrainRatio(0.8)
-
- // Run train validation split, and choose the best set of parameters.
- val model = trainValidationSplit.fit(training)
-
- // Make predictions on test data. model is the model with combination of parameters
- // that performed best.
- model.transform(test)
- .select("features", "label", "prediction")
- .show()
-
- sc.stop()
- }
-}