aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/mllib/sample_lda_libsvm_data.txt12
-rw-r--r--docs/ml-clustering.md7
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java7
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java67
-rw-r--r--examples/src/main/python/ml/aft_survival_regression.py19
-rw-r--r--examples/src/main/python/ml/lda_example.py64
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala41
8 files changed, 143 insertions, 80 deletions
diff --git a/data/mllib/sample_lda_libsvm_data.txt b/data/mllib/sample_lda_libsvm_data.txt
new file mode 100644
index 0000000000..bf118d7d5b
--- /dev/null
+++ b/data/mllib/sample_lda_libsvm_data.txt
@@ -0,0 +1,12 @@
+0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3
+1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1
+2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0
+3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9
+4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3
+5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0
+6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9
+7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3
+8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0
+9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2
+10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3
+11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0
diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md
index 876a280c4c..0d69bf67df 100644
--- a/docs/ml-clustering.md
+++ b/docs/ml-clustering.md
@@ -109,8 +109,13 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) f
{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %}
</div>
-</div>
+<div data-lang="python" markdown="1">
+
+Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.LDA) for more details.
+{% include_example python/ml/lda_example.py %}
+</div>
+</div>
## Bisecting k-means
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
index 2c2aa6df47..b0115756cf 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
@@ -31,6 +31,13 @@ import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
// $example off$
+/**
+ * An example demonstrating AFTSurvivalRegression.
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaAFTSurvivalRegressionExample
+ * </pre>
+ */
public class JavaAFTSurvivalRegressionExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
index 1c52f37867..7102ddd801 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
@@ -17,26 +17,15 @@
package org.apache.spark.examples.ml;
// $example on$
-import java.util.regex.Pattern;
-
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
// $example off$
/**
- * An example demonstrating LDA
+ * An example demonstrating LDA.
* Run with
* <pre>
* bin/run-example ml.JavaLDAExample
@@ -44,53 +33,37 @@ import org.apache.spark.sql.types.StructType;
*/
public class JavaLDAExample {
- // $example on$
- private static class ParseVector implements Function<String, Row> {
- private static final Pattern separator = Pattern.compile(" ");
-
- @Override
- public Row call(String line) {
- String[] tok = separator.split(line);
- double[] point = new double[tok.length];
- for (int i = 0; i < tok.length; ++i) {
- point[i] = Double.parseDouble(tok[i]);
- }
- Vector[] points = {Vectors.dense(point)};
- return new GenericRow(points);
- }
- }
-
public static void main(String[] args) {
-
- String inputFile = "data/mllib/sample_lda_data.txt";
-
- // Parses the arguments
+ // Creates a SparkSession
SparkSession spark = SparkSession
.builder()
.appName("JavaLDAExample")
.getOrCreate();
- // Loads data
- JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParseVector());
- StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
- StructType schema = new StructType(fields);
- Dataset<Row> dataset = spark.createDataFrame(points, schema);
+ // $example on$
+ // Loads data.
+ Dataset<Row> dataset = spark.read().format("libsvm")
+ .load("data/mllib/sample_lda_libsvm_data.txt");
- // Trains a LDA model
- LDA lda = new LDA()
- .setK(10)
- .setMaxIter(10);
+ // Trains a LDA model.
+ LDA lda = new LDA().setK(10).setMaxIter(10);
LDAModel model = lda.fit(dataset);
- System.out.println(model.logLikelihood(dataset));
- System.out.println(model.logPerplexity(dataset));
-
- // Shows the result
+ double ll = model.logLikelihood(dataset);
+ double lp = model.logPerplexity(dataset);
+ System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
+ System.out.println("The upper bound bound on perplexity: " + lp);
+
+ // Describe topics.
Dataset<Row> topics = model.describeTopics(3);
+ System.out.println("The topics described by their top-weighted terms:");
topics.show(false);
- model.transform(dataset).show(false);
+
+ // Shows the result.
+ Dataset<Row> transformed = model.transform(dataset);
+ transformed.show(false);
+ // $example off$
spark.stop();
}
- // $example off$
}
diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py
index 0ee01fd825..9879679829 100644
--- a/examples/src/main/python/ml/aft_survival_regression.py
+++ b/examples/src/main/python/ml/aft_survival_regression.py
@@ -17,19 +17,26 @@
from __future__ import print_function
-from pyspark import SparkContext
-from pyspark.sql import SQLContext
# $example on$
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.mllib.linalg import Vectors
# $example off$
+from pyspark.sql import SparkSession
+
+"""
+An example demonstrating aft survival regression.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py
+"""
if __name__ == "__main__":
- sc = SparkContext(appName="AFTSurvivalRegressionExample")
- sqlContext = SQLContext(sc)
+ spark = SparkSession \
+ .builder \
+ .appName("PythonAFTSurvivalRegressionExample") \
+ .getOrCreate()
# $example on$
- training = sqlContext.createDataFrame([
+ training = spark.createDataFrame([
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
@@ -48,4 +55,4 @@ if __name__ == "__main__":
model.transform(training).show(truncate=False)
# $example off$
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py
new file mode 100644
index 0000000000..6ca56adf3c
--- /dev/null
+++ b/examples/src/main/python/ml/lda_example.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from __future__ import print_function
+
+# $example on$
+from pyspark.ml.clustering import LDA
+# $example off$
+from pyspark.sql import SparkSession
+
+
+"""
+An example demonstrating LDA.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/lda_example.py
+"""
+
+
+if __name__ == "__main__":
+ # Creates a SparkSession
+ spark = SparkSession \
+ .builder \
+ .appName("PythonKMeansExample") \
+ .getOrCreate()
+
+ # $example on$
+ # Loads data.
+ dataset = spark.read.format("libsvm").load("data/mllib/sample_lda_libsvm_data.txt")
+
+ # Trains a LDA model.
+ lda = LDA(k=10, maxIter=10)
+ model = lda.fit(dataset)
+
+ ll = model.logLikelihood(dataset)
+ lp = model.logPerplexity(dataset)
+ print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
+ print("The upper bound bound on perplexity: " + str(lp))
+
+ # Describe topics.
+ topics = model.describeTopics(3)
+ print("The topics described by their top-weighted terms:")
+ topics.show(truncate=False)
+
+ # Shows the result
+ transformed = model.transform(dataset)
+ transformed.show(truncate=False)
+ # $example off$
+
+ spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
index 2b224d50a0..b44304d810 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
@@ -25,7 +25,11 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.SparkSession
/**
- * An example for AFTSurvivalRegression.
+ * An example demonstrating AFTSurvivalRegression.
+ * Run with
+ * {{{
+ * bin/run-example ml.AFTSurvivalRegressionExample
+ * }}}
*/
object AFTSurvivalRegressionExample {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
index c2920f6a5d..22b3b0e3ad 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala
@@ -20,57 +20,48 @@ package org.apache.spark.examples.ml
// scalastyle:off println
// $example on$
import org.apache.spark.ml.clustering.LDA
-import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
-import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.sql.types.{StructField, StructType}
// $example off$
+import org.apache.spark.sql.SparkSession
/**
- * An example demonstrating a LDA of ML pipeline.
+ * An example demonstrating LDA.
* Run with
* {{{
* bin/run-example ml.LDAExample
* }}}
*/
object LDAExample {
-
- final val FEATURES_COL = "features"
-
def main(args: Array[String]): Unit = {
-
- val input = "data/mllib/sample_lda_data.txt"
- // Creates a Spark context and a SQL context
+ // Creates a SparkSession
val spark = SparkSession
.builder
.appName(s"${this.getClass.getSimpleName}")
.getOrCreate()
// $example on$
- // Loads data
- val rowRDD = spark.read.text(input).rdd.filter(_.nonEmpty)
- .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
- val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
- val dataset = spark.createDataFrame(rowRDD, schema)
+ // Loads data.
+ val dataset = spark.read.format("libsvm")
+ .load("data/mllib/sample_lda_libsvm_data.txt")
- // Trains a LDA model
- val lda = new LDA()
- .setK(10)
- .setMaxIter(10)
- .setFeaturesCol(FEATURES_COL)
+ // Trains a LDA model.
+ val lda = new LDA().setK(10).setMaxIter(10)
val model = lda.fit(dataset)
- val transformed = model.transform(dataset)
val ll = model.logLikelihood(dataset)
val lp = model.logPerplexity(dataset)
+ println(s"The lower bound on the log likelihood of the entire corpus: $ll")
+ println(s"The upper bound bound on perplexity: $lp")
- // describeTopics
+ // Describe topics.
val topics = model.describeTopics(3)
-
- // Shows the result
+ println("The topics described by their top-weighted terms:")
topics.show(false)
- transformed.show(false)
+ // Shows the result.
+ val transformed = model.transform(dataset)
+ transformed.show(false)
// $example off$
+
spark.stop()
}
}