<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>Spark ML Programming Guide - Spark 1.4.1 Documentation</title>
<link rel="stylesheet" href="css/bootstrap.min.css">
<style>
body {
padding-top: 60px;
padding-bottom: 40px;
}
</style>
<meta name="viewport" content="width=device-width">
<link rel="stylesheet" href="css/bootstrap-responsive.min.css">
<link rel="stylesheet" href="css/main.css">
<script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script>
<link rel="stylesheet" href="css/pygments-default.css">
<!-- Google analytics script -->
<script type="text/javascript">
var _gaq = _gaq || [];
_gaq.push(['_setAccount', 'UA-32518208-2']);
_gaq.push(['_trackPageview']);
(function() {
var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true;
ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js';
var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s);
})();
</script>
</head>
<body>
<!--[if lt IE 7]>
<p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p>
<![endif]-->
<!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html -->
<div class="navbar navbar-fixed-top" id="topbar">
<div class="navbar-inner">
<div class="container">
<div class="brand"><a href="index.html">
<img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">1.4.1</span>
</div>
<ul class="nav">
<!--TODO(andyk): Add class="active" attribute to li some how.-->
<li><a href="index.html">Overview</a></li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="programming-guide.html">Spark Programming Guide</a></li>
<li class="divider"></li>
<li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
<li><a href="sql-programming-guide.html">DataFrames and SQL</a></li>
<li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li>
<li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li>
<li><a href="bagel-programming-guide.html">Bagel (Pregel on Spark)</a></li>
<li><a href="sparkr.html">SparkR (R on Spark)</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li>
<li><a href="api/java/index.html">Java</a></li>
<li><a href="api/python/index.html">Python</a></li>
<li><a href="api/R/index.html">R</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="cluster-overview.html">Overview</a></li>
<li><a href="submitting-applications.html">Submitting Applications</a></li>
<li class="divider"></li>
<li><a href="spark-standalone.html">Spark Standalone</a></li>
<li><a href="running-on-mesos.html">Mesos</a></li>
<li><a href="running-on-yarn.html">YARN</a></li>
<li class="divider"></li>
<li><a href="ec2-scripts.html">Amazon EC2</a></li>
</ul>
</li>
<li class="dropdown">
<a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="configuration.html">Configuration</a></li>
<li><a href="monitoring.html">Monitoring</a></li>
<li><a href="tuning.html">Tuning Guide</a></li>
<li><a href="job-scheduling.html">Job Scheduling</a></li>
<li><a href="security.html">Security</a></li>
<li><a href="hardware-provisioning.html">Hardware Provisioning</a></li>
<li><a href="hadoop-third-party-distributions.html">3<sup>rd</sup>-Party Hadoop Distros</a></li>
<li class="divider"></li>
<li><a href="building-spark.html">Building Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark">Contributing to Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects">Supplemental Projects</a></li>
</ul>
</li>
</ul>
<!--<p class="navbar-text pull-right"><span class="version-text">v1.4.1</span></p>-->
</div>
</div>
</div>
<div class="container" id="content">
<h1 class="title">Spark ML Programming Guide</h1>
<p>Spark 1.2 introduced a new package called <code>spark.ml</code>, which aims to provide a uniform set of
high-level APIs that help users create and tune practical machine learning pipelines.</p>
<p><em>Graduated from Alpha!</em> The Pipelines API is no longer an alpha component, although many elements of it are still <code>Experimental</code> or <code>DeveloperApi</code>.</p>
<p>Note that we will keep supporting and adding features to <code>spark.mllib</code> along with the
development of <code>spark.ml</code>.
Users should be comfortable using <code>spark.mllib</code> features and expect more features coming.
Developers should contribute new algorithms to <code>spark.mllib</code> and can optionally contribute
to <code>spark.ml</code>.</p>
<p>Guides for sub-packages of <code>spark.ml</code> include:</p>
<ul>
<li><a href="ml-features.html">Feature Extraction, Transformation, and Selection</a>: Details on transformers supported in the Pipelines API, including a few not in the lower-level <code>spark.mllib</code> API</li>
<li><a href="ml-ensembles.html">Ensembles</a>: Details on ensemble learning methods in the Pipelines API</li>
</ul>
<p><strong>Table of Contents</strong></p>
<ul id="markdown-toc">
<li><a href="#main-concepts" id="markdown-toc-main-concepts">Main Concepts</a> <ul>
<li><a href="#ml-dataset" id="markdown-toc-ml-dataset">ML Dataset</a></li>
<li><a href="#ml-algorithms" id="markdown-toc-ml-algorithms">ML Algorithms</a> <ul>
<li><a href="#transformers" id="markdown-toc-transformers">Transformers</a></li>
<li><a href="#estimators" id="markdown-toc-estimators">Estimators</a></li>
<li><a href="#properties-of-ml-algorithms" id="markdown-toc-properties-of-ml-algorithms">Properties of ML Algorithms</a></li>
</ul>
</li>
<li><a href="#pipeline" id="markdown-toc-pipeline">Pipeline</a> <ul>
<li><a href="#how-it-works" id="markdown-toc-how-it-works">How It Works</a></li>
<li><a href="#details" id="markdown-toc-details">Details</a></li>
</ul>
</li>
<li><a href="#parameters" id="markdown-toc-parameters">Parameters</a></li>
</ul>
</li>
<li><a href="#code-examples" id="markdown-toc-code-examples">Code Examples</a> <ul>
<li><a href="#example-estimator-transformer-and-param" id="markdown-toc-example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</a></li>
<li><a href="#example-pipeline" id="markdown-toc-example-pipeline">Example: Pipeline</a></li>
<li><a href="#example-model-selection-via-cross-validation" id="markdown-toc-example-model-selection-via-cross-validation">Example: Model Selection via Cross-Validation</a></li>
</ul>
</li>
<li><a href="#dependencies" id="markdown-toc-dependencies">Dependencies</a></li>
<li><a href="#migration-guide" id="markdown-toc-migration-guide">Migration Guide</a> <ul>
<li><a href="#from-13-to-14" id="markdown-toc-from-13-to-14">From 1.3 to 1.4</a></li>
<li><a href="#from-12-to-13" id="markdown-toc-from-12-to-13">From 1.2 to 1.3</a></li>
</ul>
</li>
</ul>
<h1 id="main-concepts">Main Concepts</h1>
<p>Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API.</p>
<ul>
<li>
<p><strong><a href="ml-guide.html#ml-dataset">ML Dataset</a></strong>: Spark ML uses the <a href="api/scala/index.html#org.apache.spark.sql.DataFrame"><code>DataFrame</code></a> from Spark SQL as a dataset which can hold a variety of data types.
E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions.</p>
</li>
<li>
<p><strong><a href="ml-guide.html#transformers"><code>Transformer</code></a></strong>: A <code>Transformer</code> is an algorithm which can transform one <code>DataFrame</code> into another <code>DataFrame</code>.
E.g., an ML model is a <code>Transformer</code> which transforms an RDD with features into an RDD with predictions.</p>
</li>
<li>
<p><strong><a href="ml-guide.html#estimators"><code>Estimator</code></a></strong>: An <code>Estimator</code> is an algorithm which can be fit on a <code>DataFrame</code> to produce a <code>Transformer</code>.
E.g., a learning algorithm is an <code>Estimator</code> which trains on a dataset and produces a model.</p>
</li>
<li>
<p><strong><a href="ml-guide.html#pipeline"><code>Pipeline</code></a></strong>: A <code>Pipeline</code> chains multiple <code>Transformer</code>s and <code>Estimator</code>s together to specify an ML workflow.</p>
</li>
<li>
<p><strong><a href="ml-guide.html#parameters"><code>Param</code></a></strong>: All <code>Transformer</code>s and <code>Estimator</code>s now share a common API for specifying parameters.</p>
</li>
</ul>
<h2 id="ml-dataset">ML Dataset</h2>
<p>Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data.
Spark ML adopts the <a href="api/scala/index.html#org.apache.spark.sql.DataFrame"><code>DataFrame</code></a> from Spark SQL in order to support a variety of data types under a unified Dataset concept.</p>
<p><code>DataFrame</code> supports many basic and structured types; see the <a href="sql-programming-guide.html#spark-sql-datatype-reference">Spark SQL datatype reference</a> for a list of supported types.
In addition to the types listed in the Spark SQL guide, <code>DataFrame</code> can use ML <a href="api/scala/index.html#org.apache.spark.mllib.linalg.Vector"><code>Vector</code></a> types.</p>
<p>A <code>DataFrame</code> can be created either implicitly or explicitly from a regular <code>RDD</code>. See the code examples below and the <a href="sql-programming-guide.html">Spark SQL programming guide</a> for examples.</p>
<p>Columns in a <code>DataFrame</code> are named. The code examples below use names such as “text,” “features,” and “label.”</p>
<h2 id="ml-algorithms">ML Algorithms</h2>
<h3 id="transformers">Transformers</h3>
<p>A <a href="api/scala/index.html#org.apache.spark.ml.Transformer"><code>Transformer</code></a> is an abstraction which includes feature transformers and learned models. Technically, a <code>Transformer</code> implements a method <code>transform()</code> which converts one <code>DataFrame</code> into another, generally by appending one or more columns.
For example:</p>
<ul>
<li>A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset.</li>
<li>A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset.</li>
</ul>
<h3 id="estimators">Estimators</h3>
<p>An <a href="api/scala/index.html#org.apache.spark.ml.Estimator"><code>Estimator</code></a> abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an <code>Estimator</code> implements a method <code>fit()</code> which accepts a <code>DataFrame</code> and produces a <code>Transformer</code>.
For example, a learning algorithm such as <code>LogisticRegression</code> is an <code>Estimator</code>, and calling <code>fit()</code> trains a <code>LogisticRegressionModel</code>, which is a <code>Transformer</code>.</p>
<h3 id="properties-of-ml-algorithms">Properties of ML Algorithms</h3>
<p><code>Transformer</code>s and <code>Estimator</code>s are both stateless. In the future, stateful algorithms may be supported via alternative concepts.</p>
<p>Each instance of a <code>Transformer</code> or <code>Estimator</code> has a unique ID, which is useful in specifying parameters (discussed below).</p>
<h2 id="pipeline">Pipeline</h2>
<p>In machine learning, it is common to run a sequence of algorithms to process and learn from data.
E.g., a simple text document processing workflow might include several stages:</p>
<ul>
<li>Split each document’s text into words.</li>
<li>Convert each document’s words into a numerical feature vector.</li>
<li>Learn a prediction model using the feature vectors and labels.</li>
</ul>
<p>Spark ML represents such a workflow as a <a href="api/scala/index.html#org.apache.spark.ml.Pipeline"><code>Pipeline</code></a>,
which consists of a sequence of <a href="api/scala/index.html#org.apache.spark.ml.PipelineStage"><code>PipelineStage</code>s</a> (<code>Transformer</code>s and <code>Estimator</code>s) to be run in a specific order. We will use this simple workflow as a running example in this section.</p>
<h3 id="how-it-works">How It Works</h3>
<p>A <code>Pipeline</code> is specified as a sequence of stages, and each stage is either a <code>Transformer</code> or an <code>Estimator</code>.
These stages are run in order, and the input dataset is modified as it passes through each stage.
For <code>Transformer</code> stages, the <code>transform()</code> method is called on the dataset.
For <code>Estimator</code> stages, the <code>fit()</code> method is called to produce a <code>Transformer</code> (which becomes part of the <code>PipelineModel</code>, or fitted <code>Pipeline</code>), and that <code>Transformer</code>’s <code>transform()</code> method is called on the dataset.</p>
<p>We illustrate this for the simple text document workflow. The figure below is for the <em>training time</em> usage of a <code>Pipeline</code>.</p>
<p style="text-align: center;">
<img src="img/ml-Pipeline.png" title="Spark ML Pipeline Example" alt="Spark ML Pipeline Example" width="80%" />
</p>
<p>Above, the top row represents a <code>Pipeline</code> with three stages.
The first two (<code>Tokenizer</code> and <code>HashingTF</code>) are <code>Transformer</code>s (blue), and the third (<code>LogisticRegression</code>) is an <code>Estimator</code> (red).
The bottom row represents data flowing through the pipeline, where cylinders indicate <code>DataFrame</code>s.
The <code>Pipeline.fit()</code> method is called on the original dataset which has raw text documents and labels.
The <code>Tokenizer.transform()</code> method splits the raw text documents into words, adding a new column with words into the dataset.
The <code>HashingTF.transform()</code> method converts the words column into feature vectors, adding a new column with those vectors to the dataset.
Now, since <code>LogisticRegression</code> is an <code>Estimator</code>, the <code>Pipeline</code> first calls <code>LogisticRegression.fit()</code> to produce a <code>LogisticRegressionModel</code>.
If the <code>Pipeline</code> had more stages, it would call the <code>LogisticRegressionModel</code>’s <code>transform()</code> method on the dataset before passing the dataset to the next stage.</p>
<p>A <code>Pipeline</code> is an <code>Estimator</code>.
Thus, after a <code>Pipeline</code>’s <code>fit()</code> method runs, it produces a <code>PipelineModel</code> which is a <code>Transformer</code>. This <code>PipelineModel</code> is used at <em>test time</em>; the figure below illustrates this usage.</p>
<p style="text-align: center;">
<img src="img/ml-PipelineModel.png" title="Spark ML PipelineModel Example" alt="Spark ML PipelineModel Example" width="80%" />
</p>
<p>In the figure above, the <code>PipelineModel</code> has the same number of stages as the original <code>Pipeline</code>, but all <code>Estimator</code>s in the original <code>Pipeline</code> have become <code>Transformer</code>s.
When the <code>PipelineModel</code>’s <code>transform()</code> method is called on a test dataset, the data are passed through the <code>Pipeline</code> in order.
Each stage’s <code>transform()</code> method updates the dataset and passes it to the next stage.</p>
<p><code>Pipeline</code>s and <code>PipelineModel</code>s help to ensure that training and test data go through identical feature processing steps.</p>
<h3 id="details">Details</h3>
<p><em>DAG <code>Pipeline</code>s</em>: A <code>Pipeline</code>’s stages are specified as an ordered array. The examples given here are all for linear <code>Pipeline</code>s, i.e., <code>Pipeline</code>s in which each stage uses data produced by the previous stage. It is possible to create non-linear <code>Pipeline</code>s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the <code>Pipeline</code> forms a DAG, then the stages must be specified in topological order.</p>
<p><em>Runtime checking</em>: Since <code>Pipeline</code>s can operate on datasets with varied types, they cannot use compile-time type checking. <code>Pipeline</code>s and <code>PipelineModel</code>s instead do runtime checking before actually running the <code>Pipeline</code>. This type checking is done using the dataset <em>schema</em>, a description of the data types of columns in the <code>DataFrame</code>.</p>
<h2 id="parameters">Parameters</h2>
<p>Spark ML <code>Estimator</code>s and <code>Transformer</code>s use a uniform API for specifying parameters.</p>
<p>A <a href="api/scala/index.html#org.apache.spark.ml.param.Param"><code>Param</code></a> is a named parameter with self-contained documentation.
A <a href="api/scala/index.html#org.apache.spark.ml.param.ParamMap"><code>ParamMap</code></a> is a set of (parameter, value) pairs.</p>
<p>There are two main ways to pass parameters to an algorithm:</p>
<ol>
<li>Set parameters for an instance. E.g., if <code>lr</code> is an instance of <code>LogisticRegression</code>, one could call <code>lr.setMaxIter(10)</code> to make <code>lr.fit()</code> use at most 10 iterations. This API resembles the API used in MLlib.</li>
<li>Pass a <code>ParamMap</code> to <code>fit()</code> or <code>transform()</code>. Any parameters in the <code>ParamMap</code> will override parameters previously specified via setter methods.</li>
</ol>
<p>Parameters belong to specific instances of <code>Estimator</code>s and <code>Transformer</code>s.
For example, if we have two <code>LogisticRegression</code> instances <code>lr1</code> and <code>lr2</code>, then we can build a <code>ParamMap</code> with both <code>maxIter</code> parameters specified: <code>ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)</code>.
This is useful if there are two algorithms with the <code>maxIter</code> parameter in a <code>Pipeline</code>.</p>
<h1 id="code-examples">Code Examples</h1>
<p>This section gives code examples illustrating the functionality discussed above.
There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the <a href="api/scala/index.html#org.apache.spark.ml.package">API Documentation</a>. Spark ML algorithms are currently wrappers for MLlib algorithms, and the <a href="mllib-guide.html">MLlib programming guide</a> has details on specific algorithms.</p>
<h2 id="example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</h2>
<p>This example covers the concepts of <code>Estimator</code>, <code>Transformer</code>, and <code>Param</code>.</p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.</span><span class="o">{</span><span class="nc">Vector</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span>
<span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"SimpleParamsExample"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span>
<span class="k">import</span> <span class="nn">sqlContext.implicits._</span>
<span class="c1">// Prepare training data.</span>
<span class="c1">// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes</span>
<span class="c1">// into DataFrames, where it uses the case class metadata to infer the schema.</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))))</span>
<span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span>
<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="c1">// Print out the parameters, documentation, and any default values.</span>
<span class="n">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="n">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">)</span>
<span class="c1">// We may set parameters using setter methods.</span>
<span class="n">lr</span><span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span>
<span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span>
<span class="k">val</span> <span class="n">model1</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span>
<span class="c1">// we can view the parameters it used during fit().</span>
<span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span>
<span class="c1">// LogisticRegression instance.</span>
<span class="n">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">extractParamMap</span><span class="o">)</span>
<span class="c1">// We may alternatively specify parameters using a ParamMap,</span>
<span class="c1">// which supports several methods for specifying parameters.</span>
<span class="k">val</span> <span class="n">paramMap</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span> <span class="o">-></span> <span class="mi">20</span><span class="o">)</span>
<span class="n">paramMap</span><span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span><span class="o">,</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// Specify 1 Param. This overwrites the original maxIter.</span>
<span class="n">paramMap</span><span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span> <span class="o">-></span> <span class="mf">0.1</span><span class="o">,</span> <span class="n">lr</span><span class="o">.</span><span class="n">threshold</span> <span class="o">-></span> <span class="mf">0.55</span><span class="o">)</span> <span class="c1">// Specify multiple Params.</span>
<span class="c1">// One can also combine ParamMaps.</span>
<span class="k">val</span> <span class="n">paramMap2</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">probabilityCol</span> <span class="o">-></span> <span class="s">"myProbability"</span><span class="o">)</span> <span class="c1">// Change output column name</span>
<span class="k">val</span> <span class="n">paramMapCombined</span> <span class="k">=</span> <span class="n">paramMap</span> <span class="o">++</span> <span class="n">paramMap2</span>
<span class="c1">// Now learn a new model using the paramMapCombined parameters.</span>
<span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span>
<span class="k">val</span> <span class="n">model2</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="n">toDF</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">extractParamMap</span><span class="o">)</span>
<span class="c1">// Prepare test data.</span>
<span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span>
<span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))))</span>
<span class="c1">// Make predictions on test data using the Transformer.transform() method.</span>
<span class="c1">// LogisticRegression.transform will only use the 'features' column.</span>
<span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span>
<span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span>
<span class="n">model2</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span>
<span class="o">.</span><span class="n">collect</span><span class="o">()</span>
<span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">features</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($features, $label) -> prob=$prob, prediction=$prediction"</span><span class="o">)</span>
<span class="o">}</span>
<span class="n">sc</span><span class="o">.</span><span class="n">stop</span><span class="o">()</span></code></pre></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaSimpleParamsExample"</span><span class="o">);</span>
<span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span>
<span class="n">SQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span>
<span class="c1">// Prepare training data.</span>
<span class="c1">// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans</span>
<span class="c1">// into DataFrames, where it uses the bean metadata to infer the schema.</span>
<span class="n">List</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">)));</span>
<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span>
<span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">();</span>
<span class="c1">// Print out the parameters, documentation, and any default values.</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="na">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span>
<span class="c1">// We may set parameters using setter methods.</span>
<span class="n">lr</span><span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span>
<span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span>
<span class="n">LogisticRegressionModel</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span>
<span class="c1">// we can view the parameters it used during fit().</span>
<span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span>
<span class="c1">// LogisticRegression instance.</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span>
<span class="c1">// We may alternatively specify parameters using a ParamMap.</span>
<span class="n">ParamMap</span> <span class="n">paramMap</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">();</span>
<span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mi">20</span><span class="o">));</span> <span class="c1">// Specify 1 Param.</span>
<span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">(),</span> <span class="mi">30</span><span class="o">);</span> <span class="c1">// This overwrites the original maxIter.</span>
<span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.1</span><span class="o">),</span> <span class="n">lr</span><span class="o">.</span><span class="na">threshold</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.55</span><span class="o">));</span> <span class="c1">// Specify multiple Params.</span>
<span class="c1">// One can also combine ParamMaps.</span>
<span class="n">ParamMap</span> <span class="n">paramMap2</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">();</span>
<span class="n">paramMap2</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">probabilityCol</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="s">"myProbability"</span><span class="o">));</span> <span class="c1">// Change output column name</span>
<span class="n">ParamMap</span> <span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="o">.</span><span class="na">$plus$plus</span><span class="o">(</span><span class="n">paramMap2</span><span class="o">);</span>
<span class="c1">// Now learn a new model using the paramMapCombined parameters.</span>
<span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span>
<span class="n">LogisticRegressionModel</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span>
<span class="c1">// Prepare test documents.</span>
<span class="n">List</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span>
<span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">)));</span>
<span class="n">DataFrame</span> <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Make predictions on test documents using the Transformer.transform() method.</span>
<span class="c1">// LogisticRegression.transform will only use the 'features' column.</span>
<span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span>
<span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span>
<span class="n">DataFrame</span> <span class="n">results</span> <span class="o">=</span> <span class="n">model2</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">results</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">).</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") -> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span>
<span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span>
<span class="o">}</span>
<span class="n">jsc</span><span class="o">.</span><span class="na">stop</span><span class="o">();</span></code></pre></div>
</div>
</div>
<h2 id="example-pipeline">Example: Pipeline</h2>
<p>This example follows the simple text document <code>Pipeline</code> illustrated in the figures above.</p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vector</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span>
<span class="c1">// Labeled and unlabeled instance types.</span>
<span class="c1">// Spark SQL can infer schema from case classes.</span>
<span class="k">case</span> <span class="k">class</span> <span class="nc">LabeledDocument</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span>
<span class="k">case</span> <span class="k">class</span> <span class="nc">Document</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">)</span>
<span class="c1">// Set up contexts. Import implicit conversions to DataFrame from sqlContext.</span>
<span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"SimpleTextClassificationPipeline"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span>
<span class="k">import</span> <span class="nn">sqlContext.implicits._</span>
<span class="c1">// Prepare training documents, which are labeled.</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)))</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="k">val</span> <span class="n">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="n">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span>
<span class="c1">// Fit the pipeline to training documents.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="c1">// Prepare test documents, which are unlabeled.</span>
<span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)))</span>
<span class="c1">// Make predictions on test documents.</span>
<span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span>
<span class="o">.</span><span class="n">collect</span><span class="o">()</span>
<span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($id, $text) --> prob=$prob, prediction=$prediction"</span><span class="o">)</span>
<span class="o">}</span>
<span class="n">sc</span><span class="o">.</span><span class="n">stop</span><span class="o">()</span></code></pre></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span>
<span class="c1">// Labeled and unlabeled instance types.</span>
<span class="c1">// Spark SQL can infer schema from Java Beans.</span>
<span class="kd">public</span> <span class="kd">class</span> <span class="nc">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span>
<span class="kd">private</span> <span class="kt">long</span> <span class="n">id</span><span class="o">;</span>
<span class="kd">private</span> <span class="n">String</span> <span class="n">text</span><span class="o">;</span>
<span class="kd">public</span> <span class="nf">Document</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span>
<span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span>
<span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kt">long</span> <span class="nf">getId</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setId</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="n">String</span> <span class="nf">getText</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setText</span><span class="o">(</span><span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span> <span class="o">}</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kd">class</span> <span class="nc">LabeledDocument</span> <span class="kd">extends</span> <span class="n">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span>
<span class="kd">private</span> <span class="kt">double</span> <span class="n">label</span><span class="o">;</span>
<span class="kd">public</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">,</span> <span class="kt">double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span>
<span class="kd">super</span><span class="o">(</span><span class="n">id</span><span class="o">,</span> <span class="n">text</span><span class="o">);</span>
<span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kt">double</span> <span class="nf">getLabel</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setLabel</span><span class="o">(</span><span class="kt">double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span> <span class="o">}</span>
<span class="o">}</span>
<span class="c1">// Set up contexts.</span>
<span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaSimpleTextClassificationPipeline"</span><span class="o">);</span>
<span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span>
<span class="n">SQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span>
<span class="c1">// Prepare training documents, which are labeled.</span>
<span class="n">List</span><span class="o"><</span><span class="n">LabeledDocument</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">));</span>
<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="n">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span>
<span class="n">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span>
<span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span>
<span class="c1">// Fit the pipeline to training documents.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Prepare test documents, which are unlabeled.</span>
<span class="n">List</span><span class="o"><</span><span class="n">Document</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">));</span>
<span class="n">DataFrame</span> <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">Document</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Make predictions on test documents.</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">).</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span>
<span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span>
<span class="o">}</span>
<span class="n">jsc</span><span class="o">.</span><span class="na">stop</span><span class="o">();</span></code></pre></div>
</div>
<div data-lang="python">
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">SparkContext</span>
<span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">HashingTF</span><span class="p">,</span> <span class="n">Tokenizer</span>
<span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">Row</span><span class="p">,</span> <span class="n">SQLContext</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="p">(</span><span class="n">appName</span><span class="o">=</span><span class="s">"SimpleTextClassificationPipeline"</span><span class="p">)</span>
<span class="n">sqlContext</span> <span class="o">=</span> <span class="n">SQLContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
<span class="c"># Prepare training documents, which are labeled.</span>
<span class="n">LabeledDocument</span> <span class="o">=</span> <span class="n">Row</span><span class="p">(</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">)</span>
<span class="n">training</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="p">([(</span><span class="il">0L</span><span class="p">,</span> <span class="s">"a b c d e spark"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span>
<span class="p">(</span><span class="il">1L</span><span class="p">,</span> <span class="s">"b d"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">),</span>
<span class="p">(</span><span class="il">2L</span><span class="p">,</span> <span class="s">"spark f g h"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span>
<span class="p">(</span><span class="il">3L</span><span class="p">,</span> <span class="s">"hadoop mapreduce"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)])</span> \
<span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">LabeledDocument</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.</span>
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">Tokenizer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"text"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"words"</span><span class="p">)</span>
<span class="n">hashingTF</span> <span class="o">=</span> <span class="n">HashingTF</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="p">(),</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">)</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">hashingTF</span><span class="p">,</span> <span class="n">lr</span><span class="p">])</span>
<span class="c"># Fit the pipeline to training documents.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c"># Prepare test documents, which are unlabeled.</span>
<span class="n">Document</span> <span class="o">=</span> <span class="n">Row</span><span class="p">(</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">)</span>
<span class="n">test</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="p">([(</span><span class="il">4L</span><span class="p">,</span> <span class="s">"spark i j k"</span><span class="p">),</span>
<span class="p">(</span><span class="il">5L</span><span class="p">,</span> <span class="s">"l m n"</span><span class="p">),</span>
<span class="p">(</span><span class="il">6L</span><span class="p">,</span> <span class="s">"mapreduce spark"</span><span class="p">),</span>
<span class="p">(</span><span class="il">7L</span><span class="p">,</span> <span class="s">"apache hadoop"</span><span class="p">)])</span> \
<span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">Document</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Make predictions on test documents and print columns of interest.</span>
<span class="n">prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="n">selected</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"prediction"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">selected</span><span class="o">.</span><span class="n">collect</span><span class="p">():</span>
<span class="k">print</span> <span class="n">row</span>
<span class="n">sc</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span></code></pre></div>
</div>
</div>
<h2 id="example-model-selection-via-cross-validation">Example: Model Selection via Cross-Validation</h2>
<p>An important task in ML is <em>model selection</em>, or using data to find the best model or parameters for a given task. This is also called <em>tuning</em>.
<code>Pipeline</code>s facilitate model selection by making it easy to tune an entire <code>Pipeline</code> at once, rather than tuning each element in the <code>Pipeline</code> separately.</p>
<p>Currently, <code>spark.ml</code> supports model selection using the <a href="api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator"><code>CrossValidator</code></a> class, which takes an <code>Estimator</code>, a set of <code>ParamMap</code>s, and an <a href="api/scala/index.html#org.apache.spark.ml.Evaluator"><code>Evaluator</code></a>.
<code>CrossValidator</code> begins by splitting the dataset into a set of <em>folds</em> which are used as separate training and test datasets; e.g., with <code>$k=3$</code> folds, <code>CrossValidator</code> will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.
<code>CrossValidator</code> iterates through the set of <code>ParamMap</code>s. For each <code>ParamMap</code>, it trains the given <code>Estimator</code> and evaluates it using the given <code>Evaluator</code>.
The <code>ParamMap</code> which produces the best evaluation metric (averaged over the <code>$k$</code> folds) is selected as the best model.
<code>CrossValidator</code> finally fits the <code>Estimator</code> using the best <code>ParamMap</code> and the entire dataset.</p>
<p>The following example demonstrates using <code>CrossValidator</code> to select from a grid of parameters.
To help construct the parameter grid, we use the <a href="api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder"><code>ParamGridBuilder</code></a> utility.</p>
<p>Note that cross-validation over a grid of parameters is expensive.
E.g., in the example below, the parameter grid has 3 values for <code>hashingTF.numFeatures</code> and 2 values for <code>lr.regParam</code>, and <code>CrossValidator</code> uses 2 folds. This multiplies out to <code>$(3 \times 2) \times 2 = 12$</code> different models being trained.
In realistic settings, it can be common to try many more parameters and use more folds (<code>$k=3$</code> and <code>$k=10$</code> are common).
In other words, using <code>CrossValidator</code> can be very expensive.
However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.</p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.BinaryClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.tuning.</span><span class="o">{</span><span class="nc">ParamGridBuilder</span><span class="o">,</span> <span class="nc">CrossValidator</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vector</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span>
<span class="c1">// Labeled and unlabeled instance types.</span>
<span class="c1">// Spark SQL can infer schema from case classes.</span>
<span class="k">case</span> <span class="k">class</span> <span class="nc">LabeledDocument</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span>
<span class="k">case</span> <span class="k">class</span> <span class="nc">Document</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">)</span>
<span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"CrossValidatorExample"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span>
<span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span>
<span class="k">import</span> <span class="nn">sqlContext.implicits._</span>
<span class="c1">// Prepare training documents, which are labeled.</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"b spark who"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"g d a y"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark fly"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"was mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">8L</span><span class="o">,</span> <span class="s">"e spark program"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">9L</span><span class="o">,</span> <span class="s">"a e c l"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">10L</span><span class="o">,</span> <span class="s">"spark compile"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">11L</span><span class="o">,</span> <span class="s">"hadoop software"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)))</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="k">val</span> <span class="n">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span>
<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span>
<span class="c1">// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.</span>
<span class="c1">// This will allow us to jointly choose parameters for all Pipeline stages.</span>
<span class="c1">// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.</span>
<span class="k">val</span> <span class="n">crossval</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">CrossValidator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setEstimator</span><span class="o">(</span><span class="n">pipeline</span><span class="o">)</span>
<span class="o">.</span><span class="n">setEvaluator</span><span class="o">(</span><span class="k">new</span> <span class="nc">BinaryClassificationEvaluator</span><span class="o">)</span>
<span class="c1">// We use a ParamGridBuilder to construct a grid of parameters to search over.</span>
<span class="c1">// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,</span>
<span class="c1">// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.</span>
<span class="k">val</span> <span class="n">paramGrid</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">ParamGridBuilder</span><span class="o">()</span>
<span class="o">.</span><span class="n">addGrid</span><span class="o">(</span><span class="n">hashingTF</span><span class="o">.</span><span class="n">numFeatures</span><span class="o">,</span> <span class="nc">Array</span><span class="o">(</span><span class="mi">10</span><span class="o">,</span> <span class="mi">100</span><span class="o">,</span> <span class="mi">1000</span><span class="o">))</span>
<span class="o">.</span><span class="n">addGrid</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span><span class="o">,</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.1</span><span class="o">,</span> <span class="mf">0.01</span><span class="o">))</span>
<span class="o">.</span><span class="n">build</span><span class="o">()</span>
<span class="n">crossval</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="o">(</span><span class="n">paramGrid</span><span class="o">)</span>
<span class="n">crossval</span><span class="o">.</span><span class="n">setNumFolds</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> <span class="c1">// Use 3+ in practice</span>
<span class="c1">// Run cross-validation, and choose the best set of parameters.</span>
<span class="k">val</span> <span class="n">cvModel</span> <span class="k">=</span> <span class="n">crossval</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="c1">// Prepare test documents, which are unlabeled.</span>
<span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span>
<span class="nc">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)))</span>
<span class="c1">// Make predictions on test documents. cvModel uses the best model found (lrModel).</span>
<span class="n">cvModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">.</span><span class="n">toDF</span><span class="o">)</span>
<span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span>
<span class="o">.</span><span class="n">collect</span><span class="o">()</span>
<span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($id, $text) --> prob=$prob, prediction=$prediction"</span><span class="o">)</span>
<span class="o">}</span>
<span class="n">sc</span><span class="o">.</span><span class="n">stop</span><span class="o">()</span></code></pre></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.BinaryClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.CrossValidator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.CrossValidatorModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.ParamGridBuilder</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span>
<span class="c1">// Labeled and unlabeled instance types.</span>
<span class="c1">// Spark SQL can infer schema from Java Beans.</span>
<span class="kd">public</span> <span class="kd">class</span> <span class="nc">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span>
<span class="kd">private</span> <span class="kt">long</span> <span class="n">id</span><span class="o">;</span>
<span class="kd">private</span> <span class="n">String</span> <span class="n">text</span><span class="o">;</span>
<span class="kd">public</span> <span class="nf">Document</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span>
<span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span>
<span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kt">long</span> <span class="nf">getId</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setId</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="n">String</span> <span class="nf">getText</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setText</span><span class="o">(</span><span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span> <span class="o">}</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kd">class</span> <span class="nc">LabeledDocument</span> <span class="kd">extends</span> <span class="n">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span>
<span class="kd">private</span> <span class="kt">double</span> <span class="n">label</span><span class="o">;</span>
<span class="kd">public</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="kt">long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">,</span> <span class="kt">double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span>
<span class="kd">super</span><span class="o">(</span><span class="n">id</span><span class="o">,</span> <span class="n">text</span><span class="o">);</span>
<span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span>
<span class="o">}</span>
<span class="kd">public</span> <span class="kt">double</span> <span class="nf">getLabel</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span><span class="o">;</span> <span class="o">}</span>
<span class="kd">public</span> <span class="kt">void</span> <span class="nf">setLabel</span><span class="o">(</span><span class="kt">double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span> <span class="o">}</span>
<span class="o">}</span>
<span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaCrossValidatorExample"</span><span class="o">);</span>
<span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span>
<span class="n">SQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span>
<span class="c1">// Prepare training documents, which are labeled.</span>
<span class="n">List</span><span class="o"><</span><span class="n">LabeledDocument</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"b spark who"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"g d a y"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark fly"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"was mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">8L</span><span class="o">,</span> <span class="s">"e spark program"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">9L</span><span class="o">,</span> <span class="s">"a e c l"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">10L</span><span class="o">,</span> <span class="s">"spark compile"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">11L</span><span class="o">,</span> <span class="s">"hadoop software"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">));</span>
<span class="n">DataFrame</span> <span class="n">training</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="n">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span>
<span class="n">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span>
<span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span>
<span class="c1">// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.</span>
<span class="c1">// This will allow us to jointly choose parameters for all Pipeline stages.</span>
<span class="c1">// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.</span>
<span class="n">CrossValidator</span> <span class="n">crossval</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">CrossValidator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setEstimator</span><span class="o">(</span><span class="n">pipeline</span><span class="o">)</span>
<span class="o">.</span><span class="na">setEvaluator</span><span class="o">(</span><span class="k">new</span> <span class="nf">BinaryClassificationEvaluator</span><span class="o">());</span>
<span class="c1">// We use a ParamGridBuilder to construct a grid of parameters to search over.</span>
<span class="c1">// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,</span>
<span class="c1">// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.</span>
<span class="n">ParamMap</span><span class="o">[]</span> <span class="n">paramGrid</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamGridBuilder</span><span class="o">()</span>
<span class="o">.</span><span class="na">addGrid</span><span class="o">(</span><span class="n">hashingTF</span><span class="o">.</span><span class="na">numFeatures</span><span class="o">(),</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]{</span><span class="mi">10</span><span class="o">,</span> <span class="mi">100</span><span class="o">,</span> <span class="mi">1000</span><span class="o">})</span>
<span class="o">.</span><span class="na">addGrid</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">(),</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.1</span><span class="o">,</span> <span class="mf">0.01</span><span class="o">})</span>
<span class="o">.</span><span class="na">build</span><span class="o">();</span>
<span class="n">crossval</span><span class="o">.</span><span class="na">setEstimatorParamMaps</span><span class="o">(</span><span class="n">paramGrid</span><span class="o">);</span>
<span class="n">crossval</span><span class="o">.</span><span class="na">setNumFolds</span><span class="o">(</span><span class="mi">2</span><span class="o">);</span> <span class="c1">// Use 3+ in practice</span>
<span class="c1">// Run cross-validation, and choose the best set of parameters.</span>
<span class="n">CrossValidatorModel</span> <span class="n">cvModel</span> <span class="o">=</span> <span class="n">crossval</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Prepare test documents, which are unlabeled.</span>
<span class="n">List</span><span class="o"><</span><span class="n">Document</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">));</span>
<span class="n">DataFrame</span> <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">Document</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Make predictions on test documents. cvModel uses the best model found (lrModel).</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">cvModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">).</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span>
<span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span>
<span class="o">}</span>
<span class="n">jsc</span><span class="o">.</span><span class="na">stop</span><span class="o">();</span></code></pre></div>
</div>
</div>
<h1 id="dependencies">Dependencies</h1>
<p>Spark ML currently depends on MLlib and has the same dependencies.
Please see the <a href="mllib-guide.html#dependencies">MLlib Dependencies guide</a> for more info.</p>
<p>Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies.</p>
<h1 id="migration-guide">Migration Guide</h1>
<h2 id="from-13-to-14">From 1.3 to 1.4</h2>
<p>Several major API changes occurred, including:
* <code>Param</code> and other APIs for specifying parameters
* <code>uid</code> unique IDs for Pipeline components
* Reorganization of certain classes
Since the <code>spark.ml</code> API was an Alpha Component in Spark 1.3, we do not list all changes here.</p>
<p>However, now that <code>spark.ml</code> is no longer an Alpha Component, we will provide details on any API changes for future releases.</p>
<h2 id="from-12-to-13">From 1.2 to 1.3</h2>
<p>The main API changes are from Spark SQL. We list the most important changes here:</p>
<ul>
<li>The old <a href="http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD">SchemaRDD</a> has been replaced with <a href="api/scala/index.html#org.apache.spark.sql.DataFrame">DataFrame</a> with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame.</li>
<li>In Spark 1.2, we used implicit conversions from <code>RDD</code>s of <code>LabeledPoint</code> into <code>SchemaRDD</code>s by calling <code>import sqlContext._</code> where <code>sqlContext</code> was an instance of <code>SQLContext</code>. These implicits have been moved, so we now call <code>import sqlContext.implicits._</code>.</li>
<li>Java APIs for SQL have also changed accordingly. Please see the examples above and the <a href="sql-programming-guide.html">Spark SQL Programming Guide</a> for details.</li>
</ul>
<p>Other changes were in <code>LogisticRegression</code>:</p>
<ul>
<li>The <code>scoreCol</code> output column (with default value “score”) was renamed to be <code>probabilityCol</code> (with default value “probability”). The type was originally <code>Double</code> (for the probability of class 1.0), but it is now <code>Vector</code> (for the probability of each class, to support multiclass classification in the future).</li>
<li>In Spark 1.2, <code>LogisticRegressionModel</code> did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for <a href="api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS">spark.mllib.LogisticRegressionWithLBFGS</a>. The option to use an intercept will be added in the future.</li>
</ul>
</div> <!-- /container -->
<script src="js/vendor/jquery-1.8.0.min.js"></script>
<script src="js/vendor/bootstrap.min.js"></script>
<script src="js/vendor/anchor.min.js"></script>
<script src="js/main.js"></script>
<!-- MathJax Section -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
TeX: { equationNumbers: { autoNumber: "AMS" } }
});
</script>
<script>
// Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS.
// We could use "//cdn.mathjax...", but that won't support "file://".
(function(d, script) {
script = d.createElement('script');
script.type = 'text/javascript';
script.async = true;
script.onload = function(){
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ],
displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre']
}
});
};
script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') +
'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
d.getElementsByTagName('head')[0].appendChild(script);
}(document));
</script>
</body>
</html>