aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-12 21:29:43 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-12 21:29:43 -0800
commitea5ae2705afa4eaadd4192c37d74c97364378cf9 (patch)
tree7ebd53577fba8330c0a636b7de85f3a95f4ea7bb /examples/src/main/java
parent2035ed392e0a9c18ff9c176a7b0f0097ed1276df (diff)
downloadspark-ea5ae2705afa4eaadd4192c37d74c97364378cf9.tar.gz
spark-ea5ae2705afa4eaadd4192c37d74c97364378cf9.tar.bz2
spark-ea5ae2705afa4eaadd4192c37d74c97364378cf9.zip
[SPARK-11629][ML][PYSPARK][DOC] Python example code for Multilayer Perceptron Classification
Add Python example code for Multilayer Perceptron Classification, and make example code in user guide document testable. mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #9594 from yanboliang/spark-11629.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java74
1 files changed, 74 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
new file mode 100644
index 0000000000..f48e1339c5
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+// $example on$
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
+import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+// $example off$
+
+/**
+ * An example for Multilayer Perceptron Classification.
+ */
+public class JavaMultilayerPerceptronClassifierExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // $example on$
+ // Load training data
+ String path = "data/mllib/sample_multiclass_classification_data.txt";
+ JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
+ DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
+ // Split the data into train and test
+ DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
+ DataFrame train = splits[0];
+ DataFrame test = splits[1];
+ // specify layers for the neural network:
+ // input layer of size 4 (features), two intermediate of size 5 and 4
+ // and output of size 3 (classes)
+ int[] layers = new int[] {4, 5, 4, 3};
+ // create the trainer and set its parameters
+ MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(128)
+ .setSeed(1234L)
+ .setMaxIter(100);
+ // train the model
+ MultilayerPerceptronClassificationModel model = trainer.fit(train);
+ // compute precision on the test set
+ DataFrame result = model.transform(test);
+ DataFrame predictionAndLabels = result.select("prediction", "label");
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision");
+ System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
+ // $example off$
+
+ jsc.stop();
+ }
+}