From d88afabdfa83be47f36d833105aadd6b818ceeee Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 11 May 2016 12:49:41 +0200 Subject: [SPARK-15150][EXAMPLE][DOC] Update LDA examples ## What changes were proposed in this pull request? 1,create a libsvm-type dataset for lda: `data/mllib/sample_lda_libsvm_data.txt` 2,add python example 3,directly read the datafile in examples 4,BTW, change to `SparkSession` in `aft_survival_regression.py` ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/lda_example.py` Author: Zheng RuiFeng Closes #12927 from zhengruifeng/lda_pe. --- .../ml/JavaAFTSurvivalRegressionExample.java | 7 +++ .../apache/spark/examples/ml/JavaLDAExample.java | 67 +++++++--------------- .../src/main/python/ml/aft_survival_regression.py | 19 ++++-- examples/src/main/python/ml/lda_example.py | 64 +++++++++++++++++++++ .../examples/ml/AFTSurvivalRegressionExample.scala | 6 +- .../org/apache/spark/examples/ml/LDAExample.scala | 41 ++++++------- 6 files changed, 125 insertions(+), 79 deletions(-) create mode 100644 examples/src/main/python/ml/lda_example.py (limited to 'examples') diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java index 2c2aa6df47..b0115756cf 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -31,6 +31,13 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; // $example off$ +/** + * An example demonstrating AFTSurvivalRegression. + * Run with + *
+ * bin/run-example ml.JavaAFTSurvivalRegressionExample
+ * 
+ */ public class JavaAFTSurvivalRegressionExample { public static void main(String[] args) { SparkSession spark = SparkSession diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java index 1c52f37867..7102ddd801 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -17,26 +17,15 @@ package org.apache.spark.examples.ml; // $example on$ -import java.util.regex.Pattern; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.clustering.LDA; import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; // $example off$ /** - * An example demonstrating LDA + * An example demonstrating LDA. * Run with *
  * bin/run-example ml.JavaLDAExample
@@ -44,53 +33,37 @@ import org.apache.spark.sql.types.StructType;
  */
 public class JavaLDAExample {
 
-  // $example on$
-  private static class ParseVector implements Function {
-    private static final Pattern separator = Pattern.compile(" ");
-
-    @Override
-    public Row call(String line) {
-      String[] tok = separator.split(line);
-      double[] point = new double[tok.length];
-      for (int i = 0; i < tok.length; ++i) {
-        point[i] = Double.parseDouble(tok[i]);
-      }
-      Vector[] points = {Vectors.dense(point)};
-      return new GenericRow(points);
-    }
-  }
-
   public static void main(String[] args) {
-
-    String inputFile = "data/mllib/sample_lda_data.txt";
-
-    // Parses the arguments
+    // Creates a SparkSession
     SparkSession spark = SparkSession
       .builder()
       .appName("JavaLDAExample")
       .getOrCreate();
 
-    // Loads data
-    JavaRDD points = spark.read().text(inputFile).javaRDD().map(new ParseVector());
-    StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
-    StructType schema = new StructType(fields);
-    Dataset dataset = spark.createDataFrame(points, schema);
+    // $example on$
+    // Loads data.
+    Dataset dataset = spark.read().format("libsvm")
+      .load("data/mllib/sample_lda_libsvm_data.txt");
 
-    // Trains a LDA model
-    LDA lda = new LDA()
-      .setK(10)
-      .setMaxIter(10);
+    // Trains a LDA model.
+    LDA lda = new LDA().setK(10).setMaxIter(10);
     LDAModel model = lda.fit(dataset);
 
-    System.out.println(model.logLikelihood(dataset));
-    System.out.println(model.logPerplexity(dataset));
-
-    // Shows the result
+    double ll = model.logLikelihood(dataset);
+    double lp = model.logPerplexity(dataset);
+    System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
+    System.out.println("The upper bound bound on perplexity: " + lp);
+    
+    // Describe topics.
     Dataset topics = model.describeTopics(3);
+    System.out.println("The topics described by their top-weighted terms:");
     topics.show(false);
-    model.transform(dataset).show(false);
+
+    // Shows the result.
+    Dataset transformed = model.transform(dataset);
+    transformed.show(false);
+    // $example off$
 
     spark.stop();
   }
-  // $example off$
 }
diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py
index 0ee01fd825..9879679829 100644
--- a/examples/src/main/python/ml/aft_survival_regression.py
+++ b/examples/src/main/python/ml/aft_survival_regression.py
@@ -17,19 +17,26 @@
 
 from __future__ import print_function
 
-from pyspark import SparkContext
-from pyspark.sql import SQLContext
 # $example on$
 from pyspark.ml.regression import AFTSurvivalRegression
 from pyspark.mllib.linalg import Vectors
 # $example off$
+from pyspark.sql import SparkSession
+
+"""
+An example demonstrating aft survival regression.
+Run with:
+  bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py
+"""
 
 if __name__ == "__main__":
-    sc = SparkContext(appName="AFTSurvivalRegressionExample")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession \
+        .builder \
+        .appName("PythonAFTSurvivalRegressionExample") \
+        .getOrCreate()
 
     # $example on$
-    training = sqlContext.createDataFrame([
+    training = spark.createDataFrame([
         (1.218, 1.0, Vectors.dense(1.560, -0.605)),
         (2.949, 0.0, Vectors.dense(0.346, 2.158)),
         (3.627, 0.0, Vectors.dense(1.380, 0.231)),
@@ -48,4 +55,4 @@ if __name__ == "__main__":
     model.transform(training).show(truncate=False)
     # $example off$
 
-    sc.stop()
+    spark.stop()
diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py
new file mode 100644
index 0000000000..6ca56adf3c
--- /dev/null
+++ b/examples/src/main/python/ml/lda_example.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from __future__ import print_function
+
+# $example on$
+from pyspark.ml.clustering import LDA
+# $example off$
+from pyspark.sql import SparkSession
+
+
+"""
+An example demonstrating LDA.
+Run with:
+  bin/spark-submit examples/src/main/python/ml/lda_example.py
+"""
+
+
+if __name__ == "__main__":
+    # Creates a SparkSession
+    spark = SparkSession \
+        .builder \
+        .appName("PythonKMeansExample") \
+        .getOrCreate()
+
+    # $example on$
+    # Loads data.
+    dataset = spark.read.format("libsvm").load("data/mllib/sample_lda_libsvm_data.txt")
+
+    # Trains a LDA model.
+    lda = LDA(k=10, maxIter=10)
+    model = lda.fit(dataset)
+
+    ll = model.logLikelihood(dataset)
+    lp = model.logPerplexity(dataset)
+    print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
+    print("The upper bound bound on perplexity: " + str(lp))
+
+    # Describe topics.
+    topics = model.describeTopics(3)
+    print("The topics described by their top-weighted terms:")
+    topics.show(truncate=False)
+
+    # Shows the result
+    transformed = model.transform(dataset)
+    transformed.show(truncate=False)
+    # $example off$
+
+    spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
index 2b224d50a0..b44304d810 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
@@ -25,7 +25,11 @@ import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.sql.SparkSession
 
 /**
- * An example for AFTSurvivalRegression.
+ * An example demonstrating AFTSurvivalRegression.
+ * Run with
+ * {{{
+ * bin/run-example ml.AFTSurvivalRegressionExample
+ * }}}
  */
 object AFTSurvivalRegressionExample {
 
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
index c2920f6a5d..22b3b0e3ad 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
@@ -20,57 +20,48 @@ package org.apache.spark.examples.ml
 // scalastyle:off println
 // $example on$
 import org.apache.spark.ml.clustering.LDA
-import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
-import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.sql.types.{StructField, StructType}
 // $example off$
+import org.apache.spark.sql.SparkSession
 
 /**
- * An example demonstrating a LDA of ML pipeline.
+ * An example demonstrating LDA.
  * Run with
  * {{{
  * bin/run-example ml.LDAExample
  * }}}
  */
 object LDAExample {
-
-  final val FEATURES_COL = "features"
-
   def main(args: Array[String]): Unit = {
-
-    val input = "data/mllib/sample_lda_data.txt"
-    // Creates a Spark context and a SQL context
+    // Creates a SparkSession
     val spark = SparkSession
       .builder
       .appName(s"${this.getClass.getSimpleName}")
       .getOrCreate()
 
     // $example on$
-    // Loads data
-    val rowRDD = spark.read.text(input).rdd.filter(_.nonEmpty)
-      .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
-    val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
-    val dataset = spark.createDataFrame(rowRDD, schema)
+    // Loads data.
+    val dataset = spark.read.format("libsvm")
+      .load("data/mllib/sample_lda_libsvm_data.txt")
 
-    // Trains a LDA model
-    val lda = new LDA()
-      .setK(10)
-      .setMaxIter(10)
-      .setFeaturesCol(FEATURES_COL)
+    // Trains a LDA model.
+    val lda = new LDA().setK(10).setMaxIter(10)
     val model = lda.fit(dataset)
-    val transformed = model.transform(dataset)
 
     val ll = model.logLikelihood(dataset)
     val lp = model.logPerplexity(dataset)
+    println(s"The lower bound on the log likelihood of the entire corpus: $ll")
+    println(s"The upper bound bound on perplexity: $lp")
 
-    // describeTopics
+    // Describe topics.
     val topics = model.describeTopics(3)
-
-    // Shows the result
+    println("The topics described by their top-weighted terms:")
     topics.show(false)
-    transformed.show(false)
 
+    // Shows the result.
+    val transformed = model.transform(dataset)
+    transformed.show(false)
     // $example off$
+
     spark.stop()
   }
 }
-- 
cgit v1.2.3