diff options
Diffstat (limited to 'site/docs/1.6.0/ml-classification-regression.html')
-rw-r--r-- | site/docs/1.6.0/ml-classification-regression.html | 2595 |
1 files changed, 2595 insertions, 0 deletions
diff --git a/site/docs/1.6.0/ml-classification-regression.html b/site/docs/1.6.0/ml-classification-regression.html new file mode 100644 index 000000000..0b8a9b418 --- /dev/null +++ b/site/docs/1.6.0/ml-classification-regression.html @@ -0,0 +1,2595 @@ + +<!DOCTYPE html> +<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> +<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> +<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> +<!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> + <head> + <meta charset="utf-8"> + <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> + <title>Classification and regression - spark.ml - Spark 1.6.0 Documentation</title> + + + + + <link rel="stylesheet" href="css/bootstrap.min.css"> + <style> + body { + padding-top: 60px; + padding-bottom: 40px; + } + </style> + <meta name="viewport" content="width=device-width"> + <link rel="stylesheet" href="css/bootstrap-responsive.min.css"> + <link rel="stylesheet" href="css/main.css"> + + <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> + + <link rel="stylesheet" href="css/pygments-default.css"> + + + <!-- Google analytics script --> + <script type="text/javascript"> + var _gaq = _gaq || []; + _gaq.push(['_setAccount', 'UA-32518208-2']); + _gaq.push(['_trackPageview']); + + (function() { + var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true; + ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js'; + var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s); + })(); + </script> + + + </head> + <body> + <!--[if lt IE 7]> + <p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> + <![endif]--> + + <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> + + <div class="navbar navbar-fixed-top" id="topbar"> + <div class="navbar-inner"> + <div class="container"> + <div class="brand"><a href="index.html"> + <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">1.6.0</span> + </div> + <ul class="nav"> + <!--TODO(andyk): Add class="active" attribute to li some how.--> + <li><a href="index.html">Overview</a></li> + + <li class="dropdown"> + <a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a> + <ul class="dropdown-menu"> + <li><a href="quick-start.html">Quick Start</a></li> + <li><a href="programming-guide.html">Spark Programming Guide</a></li> + <li class="divider"></li> + <li><a href="streaming-programming-guide.html">Spark Streaming</a></li> + <li><a href="sql-programming-guide.html">DataFrames, Datasets and SQL</a></li> + <li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li> + <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li> + <li><a href="bagel-programming-guide.html">Bagel (Pregel on Spark)</a></li> + <li><a href="sparkr.html">SparkR (R on Spark)</a></li> + </ul> + </li> + + <li class="dropdown"> + <a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a> + <ul class="dropdown-menu"> + <li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li> + <li><a href="api/java/index.html">Java</a></li> + <li><a href="api/python/index.html">Python</a></li> + <li><a href="api/R/index.html">R</a></li> + </ul> + </li> + + <li class="dropdown"> + <a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a> + <ul class="dropdown-menu"> + <li><a href="cluster-overview.html">Overview</a></li> + <li><a href="submitting-applications.html">Submitting Applications</a></li> + <li class="divider"></li> + <li><a href="spark-standalone.html">Spark Standalone</a></li> + <li><a href="running-on-mesos.html">Mesos</a></li> + <li><a href="running-on-yarn.html">YARN</a></li> + <li class="divider"></li> + <li><a href="ec2-scripts.html">Amazon EC2</a></li> + </ul> + </li> + + <li class="dropdown"> + <a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a> + <ul class="dropdown-menu"> + <li><a href="configuration.html">Configuration</a></li> + <li><a href="monitoring.html">Monitoring</a></li> + <li><a href="tuning.html">Tuning Guide</a></li> + <li><a href="job-scheduling.html">Job Scheduling</a></li> + <li><a href="security.html">Security</a></li> + <li><a href="hardware-provisioning.html">Hardware Provisioning</a></li> + <li class="divider"></li> + <li><a href="building-spark.html">Building Spark</a></li> + <li><a href="https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark">Contributing to Spark</a></li> + <li><a href="https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects">Supplemental Projects</a></li> + </ul> + </li> + </ul> + <!--<p class="navbar-text pull-right"><span class="version-text">v1.6.0</span></p>--> + </div> + </div> + </div> + + <div class="container-wrapper"> + + + <div class="left-menu-wrapper"> + <div class="left-menu"> + <h3><a href="ml-guide.html">spark.ml package</a></h3> + +<ul> + + <li> + <a href="ml-guide.html"> + + Overview: estimators, transformers and pipelines + + </a> + </li> + + + <li> + <a href="ml-features.html"> + + Extracting, transforming and selecting features + + </a> + </li> + + + <li> + <a href="ml-classification-regression.html"> + + <b>Classification and Regression</b> + + </a> + </li> + + + <li> + <a href="ml-clustering.html"> + + Clustering + + </a> + </li> + + + <li> + <a href="ml-advanced.html"> + + Advanced topics + + </a> + </li> + + +</ul> + + <h3><a href="mllib-guide.html">spark.mllib package</a></h3> + +<ul> + + <li> + <a href="mllib-data-types.html"> + + Data types + + </a> + </li> + + + <li> + <a href="mllib-statistics.html"> + + Basic statistics + + </a> + </li> + + + <li> + <a href="mllib-classification-regression.html"> + + Classification and regression + + </a> + </li> + + + <li> + <a href="mllib-collaborative-filtering.html"> + + Collaborative filtering + + </a> + </li> + + + <li> + <a href="mllib-clustering.html"> + + Clustering + + </a> + </li> + + + <li> + <a href="mllib-dimensionality-reduction.html"> + + Dimensionality reduction + + </a> + </li> + + + <li> + <a href="mllib-feature-extraction.html"> + + Feature extraction and transformation + + </a> + </li> + + + <li> + <a href="mllib-frequent-pattern-mining.html"> + + Frequent pattern mining + + </a> + </li> + + + <li> + <a href="mllib-evaluation-metrics.html"> + + Evaluation metrics + + </a> + </li> + + + <li> + <a href="mllib-pmml-model-export.html"> + + PMML model export + + </a> + </li> + + + <li> + <a href="mllib-optimization.html"> + + Optimization (developer) + + </a> + </li> + + +</ul> + + </div> +</div> + <input id="nav-trigger" class="nav-trigger" checked type="checkbox"> + <label for="nav-trigger"></label> + <div class="content-with-sidebar" id="content"> + + <h1 class="title">Classification and regression - spark.ml</h1> + + + <p><code>\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]</code></p> + +<p><strong>Table of Contents</strong></p> + +<ul id="markdown-toc"> + <li><a href="#classification" id="markdown-toc-classification">Classification</a> <ul> + <li><a href="#logistic-regression" id="markdown-toc-logistic-regression">Logistic regression</a></li> + <li><a href="#decision-tree-classifier" id="markdown-toc-decision-tree-classifier">Decision tree classifier</a></li> + <li><a href="#random-forest-classifier" id="markdown-toc-random-forest-classifier">Random forest classifier</a></li> + <li><a href="#gradient-boosted-tree-classifier" id="markdown-toc-gradient-boosted-tree-classifier">Gradient-boosted tree classifier</a></li> + <li><a href="#multilayer-perceptron-classifier" id="markdown-toc-multilayer-perceptron-classifier">Multilayer perceptron classifier</a></li> + <li><a href="#one-vs-rest-classifier-aka-one-vs-all" id="markdown-toc-one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</a></li> + </ul> + </li> + <li><a href="#regression" id="markdown-toc-regression">Regression</a> <ul> + <li><a href="#linear-regression" id="markdown-toc-linear-regression">Linear regression</a></li> + <li><a href="#decision-tree-regression" id="markdown-toc-decision-tree-regression">Decision tree regression</a></li> + <li><a href="#random-forest-regression" id="markdown-toc-random-forest-regression">Random forest regression</a></li> + <li><a href="#gradient-boosted-tree-regression" id="markdown-toc-gradient-boosted-tree-regression">Gradient-boosted tree regression</a></li> + <li><a href="#survival-regression" id="markdown-toc-survival-regression">Survival regression</a></li> + </ul> + </li> + <li><a href="#decision-trees" id="markdown-toc-decision-trees">Decision trees</a> <ul> + <li><a href="#inputs-and-outputs" id="markdown-toc-inputs-and-outputs">Inputs and Outputs</a> <ul> + <li><a href="#input-columns" id="markdown-toc-input-columns">Input Columns</a></li> + <li><a href="#output-columns" id="markdown-toc-output-columns">Output Columns</a></li> + </ul> + </li> + </ul> + </li> + <li><a href="#tree-ensembles" id="markdown-toc-tree-ensembles">Tree Ensembles</a> <ul> + <li><a href="#random-forests" id="markdown-toc-random-forests">Random Forests</a> <ul> + <li><a href="#inputs-and-outputs-1" id="markdown-toc-inputs-and-outputs-1">Inputs and Outputs</a> <ul> + <li><a href="#input-columns-1" id="markdown-toc-input-columns-1">Input Columns</a></li> + <li><a href="#output-columns-predictions" id="markdown-toc-output-columns-predictions">Output Columns (Predictions)</a></li> + </ul> + </li> + </ul> + </li> + <li><a href="#gradient-boosted-trees-gbts" id="markdown-toc-gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</a> <ul> + <li><a href="#inputs-and-outputs-2" id="markdown-toc-inputs-and-outputs-2">Inputs and Outputs</a> <ul> + <li><a href="#input-columns-2" id="markdown-toc-input-columns-2">Input Columns</a></li> + <li><a href="#output-columns-predictions-1" id="markdown-toc-output-columns-predictions-1">Output Columns (Predictions)</a></li> + </ul> + </li> + </ul> + </li> + </ul> + </li> +</ul> + +<p>In <code>spark.ml</code>, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to <a href="mllib-linear-methods.html">the linear methods in mllib</a> for +details about implementation and tuning. We also include a DataFrame API for <a href="http://en.wikipedia.org/wiki/Elastic_net_regularization">Elastic +net</a>, a hybrid +of $L_1$ and $L_2$ regularization proposed in <a href="http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf">Zou et al, Regularization +and variable selection via the elastic +net</a>. +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +<code>\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]</code> +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a <a href="https://en.wikipedia.org/wiki/Linear_regression">linear +regression</a> model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +<a href="http://en.wikipedia.org/wiki/Least_squares#Lasso_method">Lasso</a> model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a <a href="http://en.wikipedia.org/wiki/Tikhonov_regularization">ridge +regression</a> model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization.</p> + +<h1 id="classification">Classification</h1> + +<h2 id="logistic-regression">Logistic regression</h2> + +<p>Logistic regression is a popular method to predict a binary response. It is a special case of <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">Generalized Linear models</a> that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the <a href="mllib-linear-methods.html#logistic-regression">logistic regression in <code>spark.mllib</code></a>.</p> + +<blockquote> + <p>The current implementation of logistic regression in <code>spark.ml</code> only supports binary classes. Support for multiclass regression will be added in the future.</p> +</blockquote> + +<p><strong>Example</strong></p> + +<p>The following example shows how to train a logistic regression model +with elastic net regularization. <code>elasticNetParam</code> corresponds to +$\alpha$ and <code>regParam</code> corresponds to $\lambda$.</p> + +<div class="codetabs"> + +<div data-lang="scala"> + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> + +<span class="c1">// Load training data</span> +<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sqlCtx</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> + <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> + +<span class="c1">// Fit the model</span> +<span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> + +<span class="c1">// Print the coefficients and intercept for logistic regression</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load training data</span> +<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> + <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> + <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> + +<span class="c1">// Fit the model</span> +<span class="n">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> + +<span class="c1">// Print the coefficients and intercept for logistic regression</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> + <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> + +<span class="c"># Load training data</span> +<span class="n">training</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> + +<span class="c"># Fit the model</span> +<span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> + +<span class="c"># Print the coefficients and intercept for logistic regression</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_with_elastic_net.py" in the Spark repo.</small></div> + </div> + +</div> + +<p>The <code>spark.ml</code> implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as <code>DataFrame</code> in +<code>BinaryLogisticRegressionSummary</code> are annotated <code>@transient</code> and hence +only available on the driver.</p> + +<div class="codetabs"> + +<div data-lang="scala"> + + <p><a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary"><code>LogisticRegressionTrainingSummary</code></a> +provides a summary for a +<a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel"><code>LogisticRegressionModel</code></a>. +Currently, only binary classification is supported and the +summary must be explicitly cast to +<a href="api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"><code>BinaryLogisticRegressionTrainingSummary</code></a>. +This will likely change when multiclass classification is supported.</p> + + <p>Continuing the earlier example:</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">BinaryLogisticRegressionSummary</span><span class="o">,</span> <span class="nc">LogisticRegression</span><span class="o">}</span> + +<span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> +<span class="c1">// example</span> +<span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> + +<span class="c1">// Obtain the objective per iteration.</span> +<span class="k">val</span> <span class="n">objectiveHistory</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span> +<span class="n">objectiveHistory</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">loss</span> <span class="k">=></span> <span class="n">println</span><span class="o">(</span><span class="n">loss</span><span class="o">))</span> + +<span class="c1">// Obtain the metrics useful to judge performance on test data.</span> +<span class="c1">// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a</span> +<span class="c1">// binary classification problem.</span> +<span class="k">val</span> <span class="n">binarySummary</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">BinaryLogisticRegressionSummary</span><span class="o">]</span> + +<span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> +<span class="k">val</span> <span class="n">roc</span> <span class="k">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="n">roc</span> +<span class="n">roc</span><span class="o">.</span><span class="n">show</span><span class="o">()</span> +<span class="n">println</span><span class="o">(</span><span class="n">binarySummary</span><span class="o">.</span><span class="n">areaUnderROC</span><span class="o">)</span> + +<span class="c1">// Set the model threshold to maximize F-Measure</span> +<span class="k">val</span> <span class="n">fMeasure</span> <span class="k">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="n">fMeasureByThreshold</span> +<span class="k">val</span> <span class="n">maxFMeasure</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="n">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> +<span class="k">val</span> <span class="n">bestThreshold</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">where</span><span class="o">(</span><span class="n">$</span><span class="s">"F-Measure"</span> <span class="o">===</span> <span class="n">maxFMeasure</span><span class="o">)</span> + <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> +<span class="n">lrModel</span><span class="o">.</span><span class="n">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + <p><a href="api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html"><code>LogisticRegressionTrainingSummary</code></a> +provides a summary for a +<a href="api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html"><code>LogisticRegressionModel</code></a>. +Currently, only binary classification is supported and the +summary must be explicitly cast to +<a href="api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html"><code>BinaryLogisticRegressionTrainingSummary</code></a>. +This will likely change when multiclass classification is supported.</p> + + <p>Continuing the earlier example:</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.BinaryLogisticRegressionSummary</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionTrainingSummary</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.functions</span><span class="o">;</span> + +<span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> +<span class="c1">// example</span> +<span class="n">LogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> + +<span class="c1">// Obtain the loss per iteration.</span> +<span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span> +<span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span> + <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span> +<span class="o">}</span> + +<span class="c1">// Obtain the metrics useful to judge performance on test data.</span> +<span class="c1">// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary</span> +<span class="c1">// classification problem.</span> +<span class="n">BinaryLogisticRegressionSummary</span> <span class="n">binarySummary</span> <span class="o">=</span> + <span class="o">(</span><span class="n">BinaryLogisticRegressionSummary</span><span class="o">)</span> <span class="n">trainingSummary</span><span class="o">;</span> + +<span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> +<span class="n">DataFrame</span> <span class="n">roc</span> <span class="o">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="na">roc</span><span class="o">();</span> +<span class="n">roc</span><span class="o">.</span><span class="na">show</span><span class="o">();</span> +<span class="n">roc</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"FPR"</span><span class="o">).</span><span class="na">show</span><span class="o">();</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">binarySummary</span><span class="o">.</span><span class="na">areaUnderROC</span><span class="o">());</span> + +<span class="c1">// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with</span> +<span class="c1">// this selected threshold.</span> +<span class="n">DataFrame</span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">();</span> +<span class="kt">double</span> <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="n">functions</span><span class="o">.</span><span class="na">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">where</span><span class="o">(</span><span class="n">fMeasure</span><span class="o">.</span><span class="na">col</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">).</span><span class="na">equalTo</span><span class="o">(</span><span class="n">maxFMeasure</span><span class="o">))</span> + <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> +<span class="n">lrModel</span><span class="o">.</span><span class="na">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">);</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java" in the Spark repo.</small></div> + </div> + +<!--- TODO: Add python model summaries once implemented --> +<div data-lang="python"> + <p>Logistic regression model summary is not yet supported in Python.</p> + </div> + +</div> + +<h2 id="decision-tree-classifier">Decision tree classifier</h2> + +<p>Decision trees are a popular family of classification and regression methods. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> + +<p><strong>Example</strong></p> + +<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier">Scala API documentation</a>.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> + +<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a DecisionTree model.</span> +<span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> + +<span class="c1">// Chain indexers and tree in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span> + +<span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeClassificationModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned classification tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala" in the Spark repo.</small></div> + + </div> + +<div data-lang="java"> + + <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html">Java API documentation</a>.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a DecisionTree model.</span> +<span class="n">DecisionTreeClassifier</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">DecisionTreeClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> + +<span class="c1">// Chain indexers and tree in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> + +<span class="n">DecisionTreeClassificationModel</span> <span class="n">treeModel</span> <span class="o">=</span> + <span class="o">(</span><span class="n">DecisionTreeClassificationModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java" in the Spark repo.</small></div> + + </div> + +<div data-lang="python"> + + <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier">Python API documentation</a>.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">SQLContext</span> +<span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">DecisionTreeClassifier</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> + +<span class="c"># Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Index labels, adding metadata to the label column.</span> +<span class="c"># Fit on whole dataset to include all labels in index.</span> +<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># We specify maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a DecisionTree model.</span> +<span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> + +<span class="c"># Chain indexers and tree in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexers.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"precision"</span><span class="p">)</span> +<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Test Error = </span><span class="si">%g</span><span class="s"> "</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> + +<span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> +<span class="c"># summary only</span> +<span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_classification_example.py" in the Spark repo.</small></div> + + </div> + +</div> + +<h2 id="random-forest-classifier">Random forest classifier</h2> + +<p>Random forests are a popular family of classification and regression methods. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> + +<p><strong>Example</strong></p> + +<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier">Scala API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">RandomForestClassificationModel</span><span class="o">,</span> <span class="nc">RandomForestClassifier</span><span class="o">}</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a RandomForest model.</span> +<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setNumTrees</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> + +<span class="c1">// Chain indexers and forest in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span> + +<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestClassificationModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned classification forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/RandomForestClassifier.html">Java API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a RandomForest model.</span> +<span class="n">RandomForestClassifier</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> + +<span class="c1">// Chain indexers and forest in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> + +<span class="n">RandomForestClassificationModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + + <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier">Python API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> + +<span class="c"># Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Index labels, adding metadata to the label column.</span> +<span class="c"># Fit on whole dataset to include all labels in index.</span> +<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a RandomForest model.</span> +<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> + +<span class="c"># Chain indexers and forest in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexers.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"precision"</span><span class="p">)</span> +<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Test Error = </span><span class="si">%g</span><span class="s">"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> + +<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> +<span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c"># summary only</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/random_forest_classifier_example.py" in the Spark repo.</small></div> + </div> +</div> + +<h2 id="gradient-boosted-tree-classifier">Gradient-boosted tree classifier</h2> + +<p>Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> + +<p><strong>Example</strong></p> + +<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier">Scala API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">GBTClassificationModel</span><span class="o">,</span> <span class="nc">GBTClassifier</span><span class="o">}</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a GBT model.</span> +<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> + +<span class="c1">// Chain indexers and GBT in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span> + +<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTClassificationModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned classification GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/GBTClassifier.html">Java API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Index labels, adding metadata to the label column.</span> +<span class="c1">// Fit on whole dataset to include all labels in index.</span> +<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a GBT model.</span> +<span class="n">GBTClassifier</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> + +<span class="c1">// Convert indexed labels back to original labels.</span> +<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> + +<span class="c1">// Chain indexers and GBT in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexers.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> + +<span class="n">GBTClassificationModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + + <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier">Python API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">GBTClassifier</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> + +<span class="c"># Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Index labels, adding metadata to the label column.</span> +<span class="c"># Fit on whole dataset to include all labels in index.</span> +<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a GBT model.</span> +<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> + +<span class="c"># Chain indexers and GBT in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexers.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"precision"</span><span class="p">)</span> +<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Test Error = </span><span class="si">%g</span><span class="s">"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> + +<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> +<span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c"># summary only</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py" in the Spark repo.</small></div> + </div> +</div> + +<h2 id="multilayer-perceptron-classifier">Multilayer perceptron classifier</h2> + +<p>Multilayer perceptron classifier (MLPC) is a classifier based on the <a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">feedforward artificial neural network</a>. +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node’s weights <code>$\wv$</code> and bias <code>$\bv$</code> and applying an activation function. +It can be written in matrix form for MLPC with <code>$K+1$</code> layers as follows: +<code>\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]</code> +Nodes in intermediate layers use sigmoid (logistic) function: +<code>\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]</code> +Nodes in the output layer use softmax function: +<code>\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]</code> +The number of nodes <code>$N$</code> in the output layer corresponds to the number of classes.</p> + +<p>MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine.</p> + +<p><strong>Example</strong></p> + +<div class="codetabs"> + +<div data-lang="scala"> + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> + +<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> + <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> +<span class="c1">// Split the data into train and test</span> +<span class="k">val</span> <span class="n">splits</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span> +<span class="k">val</span> <span class="n">train</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> +<span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> +<span class="c1">// specify layers for the neural network:</span> +<span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> +<span class="c1">// and output of size 3 (classes)</span> +<span class="k">val</span> <span class="n">layers</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">[</span><span class="kt">Int</span><span class="o">](</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">)</span> +<span class="c1">// create the trainer and set its parameters</span> +<span class="k">val</span> <span class="n">trainer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> + <span class="o">.</span><span class="n">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> + <span class="o">.</span><span class="n">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">)</span> +<span class="c1">// train the model</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span> +<span class="c1">// compute precision on the test set</span> +<span class="k">val</span> <span class="n">result</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> +<span class="k">val</span> <span class="n">predictionAndLabels</span> <span class="k">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Precision:"</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">))</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> + +<span class="c1">// Load training data</span> +<span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">;</span> +<span class="n">DataFrame</span> <span class="n">dataFrame</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="n">path</span><span class="o">);</span> +<span class="c1">// Split the data into train and test</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span> +<span class="n">DataFrame</span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> +<span class="c1">// specify layers for the neural network:</span> +<span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> +<span class="c1">// and output of size 3 (classes)</span> +<span class="kt">int</span><span class="o">[]</span> <span class="n">layers</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]</span> <span class="o">{</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">};</span> +<span class="c1">// create the trainer and set its parameters</span> +<span class="n">MultilayerPerceptronClassifier</span> <span class="n">trainer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MultilayerPerceptronClassifier</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> + <span class="o">.</span><span class="na">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> + <span class="o">.</span><span class="na">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">);</span> +<span class="c1">// train the model</span> +<span class="n">MultilayerPerceptronClassificationModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> +<span class="c1">// compute precision on the test set</span> +<span class="n">DataFrame</span> <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> +<span class="n">DataFrame</span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> +<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"precision"</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision = "</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">));</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">MultilayerPerceptronClassifier</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> + +<span class="c"># Load training data</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ + <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> +<span class="c"># Split the data into train and test</span> +<span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span> +<span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> +<span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> +<span class="c"># specify layers for the neural network:</span> +<span class="c"># input layer of size 4 (features), two intermediate of size 5 and 4</span> +<span class="c"># and output of size 3 (classes)</span> +<span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> +<span class="c"># create the trainer and set its parameters</span> +<span class="n">trainer</span> <span class="o">=</span> <span class="n">MultilayerPerceptronClassifier</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">layers</span><span class="o">=</span><span class="n">layers</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">1234</span><span class="p">)</span> +<span class="c"># train the model</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> +<span class="c"># compute precision on the test set</span> +<span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> +<span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">)</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">"precision"</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Precision:"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)))</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/multilayer_perceptron_classification.py" in the Spark repo.</small></div> + </div> + +</div> + +<h2 id="one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</h2> + +<p><a href="http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest">OneVsRest</a> is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as “One-vs-All.”</p> + +<p><code>OneVsRest</code> is implemented as an <code>Estimator</code>. For the base classifier it takes instances of <code>Classifier</code> and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.</p> + +<p>Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.</p> + +<p><strong>Example</strong></p> + +<p>The example below demonstrates how to load the +<a href="http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale">Iris dataset</a>, parse it as a DataFrame and perform multiclass classification using <code>OneVsRest</code>. The test error is calculated to measure the algorithm accuracy.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest">Scala API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.examples.mllib.AbstractParams</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">OneVsRest</span><span class="o">,</span> <span class="nc">LogisticRegression</span><span class="o">}</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.util.MetadataUtils</span> +<span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span> +<span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vector</span> +<span class="k">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span> + +<span class="k">val</span> <span class="n">inputData</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="n">input</span><span class="o">)</span> +<span class="c1">// compute the train/test split: if testInput is not provided use part of input.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">params</span><span class="o">.</span><span class="n">testInput</span> <span class="k">match</span> <span class="o">{</span> + <span class="k">case</span> <span class="nc">Some</span><span class="o">(</span><span class="n">t</span><span class="o">)</span> <span class="k">=></span> <span class="o">{</span> + <span class="c1">// compute the number of features in the training set.</span> + <span class="k">val</span> <span class="n">numFeatures</span> <span class="k">=</span> <span class="n">inputData</span><span class="o">.</span><span class="n">first</span><span class="o">().</span><span class="n">getAs</span><span class="o">[</span><span class="kt">Vector</span><span class="o">](</span><span class="mi">1</span><span class="o">).</span><span class="n">size</span> + <span class="k">val</span> <span class="n">testData</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">option</span><span class="o">(</span><span class="s">"numFeatures"</span><span class="o">,</span> <span class="n">numFeatures</span><span class="o">.</span><span class="n">toString</span><span class="o">)</span> + <span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="n">t</span><span class="o">)</span> + <span class="nc">Array</span><span class="o">[</span><span class="kt">DataFrame</span><span class="o">](</span><span class="n">inputData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> + <span class="o">}</span> + <span class="k">case</span> <span class="nc">None</span> <span class="k">=></span> <span class="o">{</span> + <span class="k">val</span> <span class="n">f</span> <span class="k">=</span> <span class="n">params</span><span class="o">.</span><span class="n">fracTest</span> + <span class="n">inputData</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">f</span><span class="o">,</span> <span class="n">f</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">12345</span><span class="o">)</span> + <span class="o">}</span> +<span class="o">}</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">train</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">cache</span><span class="o">())</span> + +<span class="c1">// instantiate the base classifier</span> +<span class="k">val</span> <span class="n">classifier</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="n">maxIter</span><span class="o">)</span> + <span class="o">.</span><span class="n">setTol</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="n">tol</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFitIntercept</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="n">fitIntercept</span><span class="o">)</span> + +<span class="c1">// Set regParam, elasticNetParam if specified in params</span> +<span class="n">params</span><span class="o">.</span><span class="n">regParam</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">setRegParam</span><span class="o">)</span> +<span class="n">params</span><span class="o">.</span><span class="n">elasticNetParam</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">)</span> + +<span class="c1">// instantiate the One Vs Rest Classifier.</span> + +<span class="k">val</span> <span class="n">ovr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">()</span> +<span class="n">ovr</span><span class="o">.</span><span class="n">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">)</span> + +<span class="c1">// train the multiclass model.</span> +<span class="k">val</span> <span class="o">(</span><span class="n">trainingDuration</span><span class="o">,</span> <span class="n">ovrModel</span><span class="o">)</span> <span class="k">=</span> <span class="n">time</span><span class="o">(</span><span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">))</span> + +<span class="c1">// score the model on test data.</span> +<span class="k">val</span> <span class="o">(</span><span class="n">predictionDuration</span><span class="o">,</span> <span class="n">predictions</span><span class="o">)</span> <span class="k">=</span> <span class="n">time</span><span class="o">(</span><span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">))</span> + +<span class="c1">// evaluate the model</span> +<span class="k">val</span> <span class="n">predictionsAndLabels</span> <span class="k">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">row</span> <span class="k">=></span> <span class="o">(</span><span class="n">row</span><span class="o">.</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">),</span> <span class="n">row</span><span class="o">.</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">1</span><span class="o">)))</span> + +<span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassMetrics</span><span class="o">(</span><span class="n">predictionsAndLabels</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">confusionMatrix</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">confusionMatrix</span> + +<span class="c1">// compute the false positive rate per label</span> +<span class="k">val</span> <span class="n">predictionColSchema</span> <span class="k">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">schema</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">numClasses</span> <span class="k">=</span> <span class="nc">MetadataUtils</span><span class="o">.</span><span class="n">getNumClasses</span><span class="o">(</span><span class="n">predictionColSchema</span><span class="o">).</span><span class="n">get</span> +<span class="k">val</span> <span class="n">fprs</span> <span class="k">=</span> <span class="nc">Range</span><span class="o">(</span><span class="mi">0</span><span class="o">,</span> <span class="n">numClasses</span><span class="o">).</span><span class="n">map</span><span class="o">(</span><span class="n">p</span> <span class="k">=></span> <span class="o">(</span><span class="n">p</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">falsePositiveRate</span><span class="o">(</span><span class="n">p</span><span class="o">.</span><span class="n">toDouble</span><span class="o">)))</span> + +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">" Training Time ${trainingDuration} sec\n"</span><span class="o">)</span> + +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">" Prediction Time ${predictionDuration} sec\n"</span><span class="o">)</span> + +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">" Confusion Matrix\n ${confusionMatrix.toString}\n"</span><span class="o">)</span> + +<span class="n">println</span><span class="o">(</span><span class="s">"label\tfpr"</span><span class="o">)</span> + +<span class="n">println</span><span class="o">(</span><span class="n">fprs</span><span class="o">.</span><span class="n">map</span> <span class="o">{</span><span class="k">case</span> <span class="o">(</span><span class="n">label</span><span class="o">,</span> <span class="n">fpr</span><span class="o">)</span> <span class="k">=></span> <span class="n">label</span> <span class="o">+</span> <span class="s">"\t"</span> <span class="o">+</span> <span class="n">fpr</span><span class="o">}.</span><span class="n">mkString</span><span class="o">(</span><span class="s">"\n"</span><span class="o">))</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/OneVsRest.html">Java API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRest</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRestModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.util.MetadataUtils</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Matrix</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vector</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructField</span><span class="o">;</span> + +<span class="c1">// configure the base classifier</span> +<span class="n">LogisticRegression</span> <span class="n">classifier</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">maxIter</span><span class="o">)</span> + <span class="o">.</span><span class="na">setTol</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">tol</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFitIntercept</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">fitIntercept</span><span class="o">);</span> + +<span class="k">if</span> <span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">regParam</span> <span class="o">!=</span> <span class="kc">null</span><span class="o">)</span> <span class="o">{</span> + <span class="n">classifier</span><span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">regParam</span><span class="o">);</span> +<span class="o">}</span> +<span class="k">if</span> <span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">elasticNetParam</span> <span class="o">!=</span> <span class="kc">null</span><span class="o">)</span> <span class="o">{</span> + <span class="n">classifier</span><span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="n">params</span><span class="o">.</span><span class="na">elasticNetParam</span><span class="o">);</span> +<span class="o">}</span> + +<span class="c1">// instantiate the One Vs Rest Classifier</span> +<span class="n">OneVsRest</span> <span class="n">ovr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">OneVsRest</span><span class="o">().</span><span class="na">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">);</span> + +<span class="n">String</span> <span class="n">input</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="na">input</span><span class="o">;</span> +<span class="n">DataFrame</span> <span class="n">inputData</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="n">input</span><span class="o">);</span> +<span class="n">DataFrame</span> <span class="n">train</span><span class="o">;</span> +<span class="n">DataFrame</span> <span class="n">test</span><span class="o">;</span> + +<span class="c1">// compute the train/ test split: if testInput is not provided use part of input</span> +<span class="n">String</span> <span class="n">testInput</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="na">testInput</span><span class="o">;</span> +<span class="k">if</span> <span class="o">(</span><span class="n">testInput</span> <span class="o">!=</span> <span class="kc">null</span><span class="o">)</span> <span class="o">{</span> + <span class="n">train</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">;</span> + <span class="c1">// compute the number of features in the training set.</span> + <span class="kt">int</span> <span class="n">numFeatures</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="na">first</span><span class="o">().<</span><span class="n">Vector</span><span class="o">></span><span class="n">getAs</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="na">size</span><span class="o">();</span> + <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">option</span><span class="o">(</span><span class="s">"numFeatures"</span><span class="o">,</span> + <span class="n">String</span><span class="o">.</span><span class="na">valueOf</span><span class="o">(</span><span class="n">numFeatures</span><span class="o">)).</span><span class="na">load</span><span class="o">(</span><span class="n">testInput</span><span class="o">);</span> +<span class="o">}</span> <span class="k">else</span> <span class="o">{</span> + <span class="kt">double</span> <span class="n">f</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="na">fracTest</span><span class="o">;</span> + <span class="n">DataFrame</span><span class="o">[]</span> <span class="n">tmp</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mi">1</span> <span class="o">-</span> <span class="n">f</span><span class="o">,</span> <span class="n">f</span><span class="o">},</span> <span class="mi">12345</span><span class="o">);</span> + <span class="n">train</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> + <span class="n">test</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> +<span class="o">}</span> + +<span class="c1">// train the multiclass model</span> +<span class="n">OneVsRestModel</span> <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">.</span><span class="na">cache</span><span class="o">());</span> + +<span class="c1">// score the model on test data</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">.</span><span class="na">cache</span><span class="o">())</span> + <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> + +<span class="c1">// obtain metrics</span> +<span class="n">MulticlassMetrics</span> <span class="n">metrics</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassMetrics</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">StructField</span> <span class="n">predictionColSchema</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="na">schema</span><span class="o">().</span><span class="na">apply</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">);</span> +<span class="n">Integer</span> <span class="n">numClasses</span> <span class="o">=</span> <span class="o">(</span><span class="n">Integer</span><span class="o">)</span> <span class="n">MetadataUtils</span><span class="o">.</span><span class="na">getNumClasses</span><span class="o">(</span><span class="n">predictionColSchema</span><span class="o">).</span><span class="na">get</span><span class="o">();</span> + +<span class="c1">// compute the false positive rate per label</span> +<span class="n">StringBuilder</span> <span class="n">results</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringBuilder</span><span class="o">();</span> +<span class="n">results</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="s">"label\tfpr\n"</span><span class="o">);</span> +<span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">label</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">label</span> <span class="o"><</span> <span class="n">numClasses</span><span class="o">;</span> <span class="n">label</span><span class="o">++)</span> <span class="o">{</span> + <span class="n">results</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="n">label</span><span class="o">);</span> + <span class="n">results</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="s">"\t"</span><span class="o">);</span> + <span class="n">results</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="n">metrics</span><span class="o">.</span><span class="na">falsePositiveRate</span><span class="o">((</span><span class="kt">double</span><span class="o">)</span> <span class="n">label</span><span class="o">));</span> + <span class="n">results</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="s">"\n"</span><span class="o">);</span> +<span class="o">}</span> + +<span class="n">Matrix</span> <span class="n">confusionMatrix</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">confusionMatrix</span><span class="o">();</span> +<span class="c1">// output the Confusion Matrix</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Confusion Matrix"</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">confusionMatrix</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">();</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">results</span><span class="o">);</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java" in the Spark repo.</small></div> + </div> +</div> + +<h1 id="regression">Regression</h1> + +<h2 id="linear-regression">Linear regression</h2> + +<p>The interface for working with linear regression models and model +summaries is similar to the logistic regression case.</p> + +<p><strong>Example</strong></p> + +<p>The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics.</p> + +<div class="codetabs"> + +<div data-lang="scala"> + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span> + +<span class="c1">// Load training data</span> +<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sqlCtx</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> + <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearRegression</span><span class="o">()</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> + <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> + +<span class="c1">// Fit the model</span> +<span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> + +<span class="c1">// Print the coefficients and intercept for linear regression</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"</span><span class="o">)</span> + +<span class="c1">// Summarize the model over the training set and print out some metrics</span> +<span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"numIterations: ${trainingSummary.totalIterations}"</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"objectiveHistory: ${trainingSummary.objectiveHistory.toList}"</span><span class="o">)</span> +<span class="n">trainingSummary</span><span class="o">.</span><span class="n">residuals</span><span class="o">.</span><span class="n">show</span><span class="o">()</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"RMSE: ${trainingSummary.rootMeanSquaredError}"</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"r2: ${trainingSummary.r2}"</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionTrainingSummary</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load training data</span> +<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> + <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">);</span> + +<span class="n">LinearRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LinearRegression</span><span class="o">()</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> + <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> + +<span class="c1">// Fit the model</span> +<span class="n">LinearRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> + +<span class="c1">// Print the coefficients and intercept for linear regression</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> + <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> + +<span class="c1">// Summarize the model over the training set and print out some metrics</span> +<span class="n">LinearRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"numIterations: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">totalIterations</span><span class="o">());</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"objectiveHistory: "</span> <span class="o">+</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">()));</span> +<span class="n">trainingSummary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"RMSE: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"r2: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + <!--- TODO: Add python model summaries once implemented --> + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">LinearRegression</span> + +<span class="c"># Load training data</span> +<span class="n">training</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ + <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> + +<span class="n">lr</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> + +<span class="c"># Fit the model</span> +<span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> + +<span class="c"># Print the coefficients and intercept for linear regression</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/linear_regression_with_elastic_net.py" in the Spark repo.</small></div> + </div> + +</div> + +<h2 id="decision-tree-regression">Decision tree regression</h2> + +<p>Decision trees are a popular family of classification and regression methods. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> + +<p><strong>Example</strong></p> + +<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor">Scala API documentation</a>.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> + +<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Here, we treat features with > 4 distinct values as continuous.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a DecisionTree model.</span> +<span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + +<span class="c1">// Chain indexer and tree in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeRegressionModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned regression tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html">Java API documentation</a>.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> + <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a DecisionTree model.</span> +<span class="n">DecisionTreeRegressor</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">DecisionTreeRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> + +<span class="c1">// Chain indexer and tree in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> + +<span class="n">DecisionTreeRegressionModel</span> <span class="n">treeModel</span> <span class="o">=</span> + <span class="o">(</span><span class="n">DecisionTreeRegressionModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + + <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor">Python API documentation</a>.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">DecisionTreeRegressor</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> + +<span class="c"># Load the data stored in LIBSVM format as a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># We specify maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a DecisionTree model.</span> +<span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> + +<span class="c"># Chain indexer and tree in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexer.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> +<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> + +<span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> +<span class="c"># summary only</span> +<span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_regression_example.py" in the Spark repo.</small></div> + </div> + +</div> + +<h2 id="random-forest-regression">Random forest regression</h2> + +<p>Random forests are a popular family of classification and regression methods. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> + +<p><strong>Example</strong></p> + +<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor">Scala API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">RandomForestRegressionModel</span><span class="o">,</span> <span class="nc">RandomForestRegressor</span><span class="o">}</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a RandomForest model.</span> +<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + +<span class="c1">// Chain indexer and forest in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestRegressionModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned regression forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/RandomForestRegressor.html">Java API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a RandomForest model.</span> +<span class="n">RandomForestRegressor</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> + +<span class="c1">// Chain indexer and forest in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> + +<span class="n">RandomForestRegressionModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + + <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor">Python API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> + +<span class="c"># Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a RandomForest model.</span> +<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> + +<span class="c"># Chain indexer and forest in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexer.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> +<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> + +<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> +<span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c"># summary only</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/random_forest_regressor_example.py" in the Spark repo.</small></div> + </div> +</div> + +<h2 id="gradient-boosted-tree-regression">Gradient-boosted tree regression</h2> + +<p>Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> + +<p><strong>Example</strong></p> + +<p>Note: For this example dataset, <code>GBTRegressor</code> actually only needs 1 iteration, but that will not +be true in general.</p> + +<div class="codetabs"> +<div data-lang="scala"> + + <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor">Scala API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> +<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">GBTRegressionModel</span><span class="o">,</span> <span class="nc">GBTRegressor</span><span class="o">}</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> + +<span class="c1">// Train a GBT model.</span> +<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> + +<span class="c1">// Chain indexer and GBT in a Pipeline</span> +<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> + <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">))</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> + +<span class="c1">// Make predictions.</span> +<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTRegressionModel</span><span class="o">]</span> +<span class="n">println</span><span class="o">(</span><span class="s">"Learned regression GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + + <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GBTRegressor.html">Java API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> + +<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> + +<span class="c1">// Automatically identify categorical features, and index them.</span> +<span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span> + <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> + <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> + +<span class="c1">// Split the data into training and test sets (30% held out for testing)</span> +<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> +<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> + +<span class="c1">// Train a GBT model.</span> +<span class="n">GBTRegressor</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTRegressor</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> + +<span class="c1">// Chain indexer and GBT in a Pipeline</span> +<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">().</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">});</span> + +<span class="c1">// Train model. This also runs the indexer.</span> +<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> + +<span class="c1">// Make predictions.</span> +<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> + +<span class="c1">// Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> + +<span class="c1">// Select (prediction, true label) and compute test error</span> +<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span> + <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> + <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> +<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> + +<span class="n">GBTRegressionModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + + <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor">Python API docs</a> for more details.</p> + + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GBTRegressor</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> +<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> + +<span class="c"># Load and parse the data file, converting it to a DataFrame.</span> +<span class="n">data</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> + +<span class="c"># Automatically identify categorical features, and index them.</span> +<span class="c"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> +<span class="n">featureIndexer</span> <span class="o">=</span>\ + <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + +<span class="c"># Split the data into training and test sets (30% held out for testing)</span> +<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> + +<span class="c"># Train a GBT model.</span> +<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> + +<span class="c"># Chain indexer and GBT in a Pipeline</span> +<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> + +<span class="c"># Train model. This also runs the indexer.</span> +<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> + +<span class="c"># Make predictions.</span> +<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> + +<span class="c"># Select example rows to display.</span> +<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> + +<span class="c"># Select (prediction, true label) and compute test error</span> +<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> + <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> +<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> + +<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> +<span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c"># summary only</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py" in the Spark repo.</small></div> + </div> +</div> + +<h2 id="survival-regression">Survival regression</h2> + +<p>In <code>spark.ml</code>, we implement the <a href="https://en.wikipedia.org/wiki/Accelerated_failure_time_model">Accelerated failure time (AFT)</a> +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it’s often called +log-linear model for survival analysis. Different from +<a href="https://en.wikipedia.org/wiki/Proportional_hazards_model">Proportional hazards</a> model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently.</p> + +<p>Given the values of the covariates $x^{‘}$, for random lifetime $t_{i}$ of +subjects i = 1, …, n, with possible right-censoring, +the likelihood function under the AFT model is given as: +<code>\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]</code> +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{‘}\beta}{\sigma}$, the log-likelihood function +assumes the form: +<code>\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]</code> +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function.</p> + +<p>The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +<code>\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]</code> +the $f_{0}(\epsilon_{i})$ function is: +<code>\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]</code> +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +<code>\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]</code> +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +<code>\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]</code> +<code>\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]</code></p> + +<p>The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R’s survival function +<a href="https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html">survreg</a></p> + +<p><strong>Example</strong></p> + +<div class="codetabs"> + +<div data-lang="scala"> + <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span> +<span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span> + +<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> + <span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> + <span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> + <span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> + <span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> + <span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> +<span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"censor"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span> +<span class="k">val</span> <span class="n">quantileProbabilities</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">)</span> +<span class="k">val</span> <span class="n">aft</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">AFTSurvivalRegression</span><span class="o">()</span> + <span class="o">.</span><span class="n">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> + <span class="o">.</span><span class="n">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">)</span> + +<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> + +<span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> +<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${model.coefficients} Intercept: "</span> <span class="o">+</span> + <span class="n">s</span><span class="s">"${model.intercept} Scale: ${model.scale}"</span><span class="o">)</span> +<span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="kc">false</span><span class="o">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala" in the Spark repo.</small></div> + </div> + +<div data-lang="java"> + <div class="highlight"><pre><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> + +<span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegressionModel</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.*</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span> +<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.*</span><span class="o">;</span> + +<span class="n">List</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> + <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> + <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> + <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> + <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> + <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> +<span class="o">);</span> +<span class="n">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StructType</span><span class="o">(</span><span class="k">new</span> <span class="n">StructField</span><span class="o">[]{</span> + <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> + <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"censor"</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> + <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="k">new</span> <span class="nf">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span> +<span class="o">});</span> +<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span> +<span class="kt">double</span><span class="o">[]</span> <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">};</span> +<span class="n">AFTSurvivalRegression</span> <span class="n">aft</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">AFTSurvivalRegression</span><span class="o">()</span> + <span class="o">.</span><span class="na">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> + <span class="o">.</span><span class="na">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">);</span> + +<span class="n">AFTSurvivalRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> + +<span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> +<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> + <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Scale: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">scale</span><span class="o">());</span> +<span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="kc">false</span><span class="o">);</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java" in the Spark repo.</small></div> + </div> + +<div data-lang="python"> + <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">AFTSurvivalRegression</span> +<span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span> + +<span class="n">training</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> + <span class="p">(</span><span class="mf">1.218</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.560</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="p">)),</span> + <span class="p">(</span><span class="mf">2.949</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.346</span><span class="p">,</span> <span class="mf">2.158</span><span class="p">)),</span> + <span class="p">(</span><span class="mf">3.627</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.380</span><span class="p">,</span> <span class="mf">0.231</span><span class="p">)),</span> + <span class="p">(</span><span class="mf">0.273</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.520</span><span class="p">,</span> <span class="mf">1.151</span><span class="p">)),</span> + <span class="p">(</span><span class="mf">4.199</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.795</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="p">))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"censor"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span> +<span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]</span> +<span class="n">aft</span> <span class="o">=</span> <span class="n">AFTSurvivalRegression</span><span class="p">(</span><span class="n">quantileProbabilities</span><span class="o">=</span><span class="n">quantileProbabilities</span><span class="p">,</span> + <span class="n">quantilesCol</span><span class="o">=</span><span class="s">"quantiles"</span><span class="p">)</span> + +<span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> + +<span class="c"># Print the coefficients, intercept and scale parameter for AFT survival regression</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> +<span class="k">print</span><span class="p">(</span><span class="s">"Scale: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">scale</span><span class="p">))</span> +<span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">training</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="n">truncate</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> +</pre></div> + <div><small>Find full example code at "examples/src/main/python/ml/aft_survival_regression.py" in the Spark repo.</small></div> + </div> + +</div> + +<h1 id="decision-trees">Decision trees</h1> + +<p><a href="http://en.wikipedia.org/wiki/Decision_tree_learning">Decision trees</a> +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks.</p> + +<p>The <code>spark.ml</code> implementation supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances.</p> + +<p>Users can find more information about the decision tree algorithm in the <a href="mllib-decision-tree.html">MLlib Decision Tree guide</a>. +The main differences between this API and the <a href="mllib-decision-tree.html">original MLlib Decision Tree API</a> are:</p> + +<ul> + <li>support for ML Pipelines</li> + <li>separation of Decision Trees for classification vs. regression</li> + <li>use of DataFrame metadata to distinguish continuous and categorical features</li> +</ul> + +<p>The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities).</p> + +<p>Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the <a href="#tree-ensembles">Tree ensembles section</a>.</p> + +<h2 id="inputs-and-outputs">Inputs and Outputs</h2> + +<p>We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> + +<h3 id="input-columns">Input Columns</h3> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + </tr> + </thead> + <tbody> + <tr> + <td>labelCol</td> + <td>Double</td> + <td>"label"</td> + <td>Label to predict</td> + </tr> + <tr> + <td>featuresCol</td> + <td>Vector</td> + <td>"features"</td> + <td>Feature vector</td> + </tr> + </tbody> +</table> + +<h3 id="output-columns">Output Columns</h3> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + <th align="left">Notes</th> + </tr> + </thead> + <tbody> + <tr> + <td>predictionCol</td> + <td>Double</td> + <td>"prediction"</td> + <td>Predicted label</td> + <td></td> + </tr> + <tr> + <td>rawPredictionCol</td> + <td>Vector</td> + <td>"rawPrediction"</td> + <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> + <td>Classification only</td> + </tr> + <tr> + <td>probabilityCol</td> + <td>Vector</td> + <td>"probability"</td> + <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> + <td>Classification only</td> + </tr> + </tbody> +</table> + +<h1 id="tree-ensembles">Tree Ensembles</h1> + +<p>The DataFrame API supports two major tree ensemble algorithms: <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forests</a> and <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>. +Both use <a href="ml-classification-regression.html#decision-trees"><code>spark.ml</code> decision trees</a> as their base models.</p> + +<p>Users can find more information about ensemble algorithms in the <a href="mllib-ensembles.html">MLlib Ensemble guide</a>.<br /> +In this section, we demonstrate the DataFrame API for ensembles.</p> + +<p>The main differences between this API and the <a href="mllib-ensembles.html">original MLlib ensembles API</a> are:</p> + +<ul> + <li>support for DataFrames and ML Pipelines</li> + <li>separation of classification vs. regression</li> + <li>use of DataFrame metadata to distinguish continuous and categorical features</li> + <li>more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.</li> +</ul> + +<h2 id="random-forests">Random Forests</h2> + +<p><a href="http://en.wikipedia.org/wiki/Random_forest">Random forests</a> +are ensembles of <a href="ml-decision-tree.html">decision trees</a>. +Random forests combine many decision trees in order to reduce the risk of overfitting. +The <code>spark.ml</code> implementation supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features.</p> + +<p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html"><code>spark.mllib</code> documentation on random forests</a>.</p> + +<h3 id="inputs-and-outputs-1">Inputs and Outputs</h3> + +<p>We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> + +<h4 id="input-columns-1">Input Columns</h4> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + </tr> + </thead> + <tbody> + <tr> + <td>labelCol</td> + <td>Double</td> + <td>"label"</td> + <td>Label to predict</td> + </tr> + <tr> + <td>featuresCol</td> + <td>Vector</td> + <td>"features"</td> + <td>Feature vector</td> + </tr> + </tbody> +</table> + +<h4 id="output-columns-predictions">Output Columns (Predictions)</h4> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + <th align="left">Notes</th> + </tr> + </thead> + <tbody> + <tr> + <td>predictionCol</td> + <td>Double</td> + <td>"prediction"</td> + <td>Predicted label</td> + <td></td> + </tr> + <tr> + <td>rawPredictionCol</td> + <td>Vector</td> + <td>"rawPrediction"</td> + <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> + <td>Classification only</td> + </tr> + <tr> + <td>probabilityCol</td> + <td>Vector</td> + <td>"probability"</td> + <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> + <td>Classification only</td> + </tr> + </tbody> +</table> + +<h2 id="gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</h2> + +<p><a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a> +are ensembles of <a href="ml-decision-tree.html">decision trees</a>. +GBTs iteratively train decision trees in order to minimize a loss function. +The <code>spark.ml</code> implementation supports GBTs for binary classification and for regression, +using both continuous and categorical features.</p> + +<p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html"><code>spark.mllib</code> documentation on GBTs</a>.</p> + +<h3 id="inputs-and-outputs-2">Inputs and Outputs</h3> + +<p>We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> + +<h4 id="input-columns-2">Input Columns</h4> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + </tr> + </thead> + <tbody> + <tr> + <td>labelCol</td> + <td>Double</td> + <td>"label"</td> + <td>Label to predict</td> + </tr> + <tr> + <td>featuresCol</td> + <td>Vector</td> + <td>"features"</td> + <td>Feature vector</td> + </tr> + </tbody> +</table> + +<p>Note that <code>GBTClassifier</code> currently only supports binary labels.</p> + +<h4 id="output-columns-predictions-1">Output Columns (Predictions)</h4> + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + <th align="left">Notes</th> + </tr> + </thead> + <tbody> + <tr> + <td>predictionCol</td> + <td>Double</td> + <td>"prediction"</td> + <td>Predicted label</td> + <td></td> + </tr> + </tbody> +</table> + +<p>In the future, <code>GBTClassifier</code> will also output columns for <code>rawPrediction</code> and <code>probability</code>, just as <code>RandomForestClassifier</code> does.</p> + + + + </div> + + <!-- /container --> + </div> + + <script src="js/vendor/jquery-1.8.0.min.js"></script> + <script src="js/vendor/bootstrap.min.js"></script> + <script src="js/vendor/anchor.min.js"></script> + <script src="js/main.js"></script> + + <!-- MathJax Section --> + <script type="text/x-mathjax-config"> + MathJax.Hub.Config({ + TeX: { equationNumbers: { autoNumber: "AMS" } } + }); + </script> + <script> + // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. + // We could use "//cdn.mathjax...", but that won't support "file://". + (function(d, script) { + script = d.createElement('script'); + script.type = 'text/javascript'; + script.async = true; + script.onload = function(){ + MathJax.Hub.Config({ + tex2jax: { + inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], + displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], + processEscapes: true, + skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] + } + }); + }; + script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + + 'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML'; + d.getElementsByTagName('head')[0].appendChild(script); + }(document)); + </script> + </body> +</html> |