aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-12 10:38:57 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-12 10:38:57 -0800
commit4b736dbab3e177e5265439d37063bb501657d830 (patch)
treedb811e6310ee6035e4d96cb4ace17066572ede6c /mllib
parent84324fbcb987db6e10e435f463eacace1bae43e2 (diff)
downloadspark-4b736dbab3e177e5265439d37063bb501657d830.tar.gz
spark-4b736dbab3e177e5265439d37063bb501657d830.tar.bz2
spark-4b736dbab3e177e5265439d37063bb501657d830.zip
[SPARK-3530][MLLIB] pipeline and parameters with examples
This PR adds package "org.apache.spark.ml" with pipeline and parameters, as discussed on the JIRA. This is a joint work of jkbradley etrain shivaram and many others who helped on the design, also with help from marmbrus and liancheng on the Spark SQL side. The design doc can be found at: https://docs.google.com/document/d/1rVwXRjWKfIb-7PI6b86ipytwbUH7irSNLF1_6dLmh8o/edit?usp=sharing **org.apache.spark.ml** This is a new package with new set of ML APIs that address practical machine learning pipelines. (Sorry for taking so long!) It will be an alpha component, so this is definitely not something set in stone. The new set of APIs, inspired by the MLI project from AMPLab and scikit-learn, takes leverage on Spark SQL's schema support and execution plan optimization. It introduces the following components that help build a practical pipeline: 1. Transformer, which transforms a dataset into another 2. Estimator, which fits models to data, where models are transformers 3. Evaluator, which evaluates model output and returns a scalar metric 4. Pipeline, a simple pipeline that consists of transformers and estimators Parameters could be supplied at fit/transform or embedded with components. 1. Param: a strong-typed parameter key with self-contained doc 2. ParamMap: a param -> value map 3. Params: trait for components with parameters For any component that implements `Params`, user can easily check the doc by calling `explainParams`: ~~~ > val lr = new LogisticRegression > lr.explainParams maxIter: max number of iterations (default: 100) regParam: regularization constant (default: 0.1) labelCol: label column name (default: label) featuresCol: features column name (default: features) ~~~ or user can check individual param: ~~~ > lr.maxIter maxIter: max number of iterations (default: 100) ~~~ **Please start with the example code in test suites and under `org.apache.spark.examples.ml`, where I put several examples:** 1. run a simple logistic regression job ~~~ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold .select('label, 'score, 'prediction).collect() .foreach(println) ~~~ 2. run logistic regression with cross-validation and grid search using areaUnderROC (default) as the metric ~~~ val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.1, 100.0)) .addGrid(lr.maxIter, Array(0, 5)) .build() val eval = new BinaryClassificationEvaluator val cv = new CrossValidator() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) val bestModel = cv.fit(dataset) ~~~ 3. run a pipeline that consists of a standard scaler and a logistic regression component ~~~ val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") val lr = new LogisticRegression() .setFeaturesCol(scaler.getOutputCol) val pipeline = new Pipeline() .setStages(Array(scaler, lr)) val model = pipeline.fit(dataset) val predictions = model.transform(dataset) .select('label, 'score, 'prediction) .collect() .foreach(println) ~~~ 4. a simple text classification pipeline, which recognizes "spark": ~~~ val training = sparkContext.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), LabeledDocument(3L, "hadoop mapreduce", 0.0))) val tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("words") val hashingTF = new HashingTF() .setInputCol(tokenizer.getOutputCol) .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) val model = pipeline.fit(training) val test = sparkContext.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) model.transform(test) .select('id, 'text, 'prediction, 'score) .collect() .foreach(println) ~~~ Java examples are very similar. I put example code that creates a simple text classification pipeline in Scala and Java, where a simple tokenizer is defined as a transformer outside `org.apache.spark.ml`. **What are missing now and will be added soon:** 1. ~~Runtime check of schemas. So before we touch the data, we will go through the schema and make sure column names and types match the input parameters.~~ 2. ~~Java examples.~~ 3. ~~Store training parameters in trained models.~~ 4. (later) Serialization and Python API. Author: Xiangrui Meng <meng@databricks.com> Closes #3099 from mengxr/SPARK-3530 and squashes the following commits: 2cc93fd [Xiangrui Meng] hide APIs as much as I can 34319ba [Xiangrui Meng] use local instead local[2] for unit tests 2524251 [Xiangrui Meng] rename PipelineStage.transform to transformSchema c9daab4 [Xiangrui Meng] remove mockito version 1397ab5 [Xiangrui Meng] use sqlContext from LocalSparkContext instead of TestSQLContext 6ffc389 [Xiangrui Meng] try to fix unit test a59d8b7 [Xiangrui Meng] doc updates 977fd9d [Xiangrui Meng] add scala ml package object 6d97fe6 [Xiangrui Meng] add AlphaComponent annotation 731f0e4 [Xiangrui Meng] update package doc 0435076 [Xiangrui Meng] remove ;this from setters fa21d9b [Xiangrui Meng] update extends indentation f1091b3 [Xiangrui Meng] typo 228a9f4 [Xiangrui Meng] do not persist before calling binary classification metrics f51cd27 [Xiangrui Meng] rename default to defaultValue b3be094 [Xiangrui Meng] refactor schema transform in lr 8791e8e [Xiangrui Meng] rename copyValues to inheritValues and make it do the right thing 51f1c06 [Xiangrui Meng] remove leftover code in Transformer 494b632 [Xiangrui Meng] compure score once ad678e9 [Xiangrui Meng] more doc for Transformer 4306ed4 [Xiangrui Meng] org imports in text pipeline 6e7c1c7 [Xiangrui Meng] update pipeline 4f9e34f [Xiangrui Meng] more doc for pipeline aa5dbd4 [Xiangrui Meng] fix typo 11be383 [Xiangrui Meng] fix unit tests 3df7952 [Xiangrui Meng] clean up 986593e [Xiangrui Meng] re-org java test suites 2b11211 [Xiangrui Meng] remove external data deps 9fd4933 [Xiangrui Meng] add unit test for pipeline 2a0df46 [Xiangrui Meng] update tests 2d52e4d [Xiangrui Meng] add @AlphaComponent to package-info 27582a4 [Xiangrui Meng] doc changes 73a000b [Xiangrui Meng] add schema transformation layer 6736e87 [Xiangrui Meng] more doc / remove HasMetricName trait 80a8b5e [Xiangrui Meng] rename SimpleTransformer to UnaryTransformer 62ca2bb [Xiangrui Meng] check param parent in set/get 1622349 [Xiangrui Meng] add getModel to PipelineModel a0e0054 [Xiangrui Meng] update StandardScaler to use SimpleTransformer d0faa04 [Xiangrui Meng] remove implicit mapping from ParamMap c7f6921 [Xiangrui Meng] move ParamGridBuilder test to ParamGridBuilderSuite e246f29 [Xiangrui Meng] re-org: 7772430 [Xiangrui Meng] remove modelParams add a simple text classification pipeline b95c408 [Xiangrui Meng] remove implicits add unit tests to params bab3e5b [Xiangrui Meng] update params fe0ee92 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3530 6e86d98 [Xiangrui Meng] some code clean-up 2d040b3 [Xiangrui Meng] implement setters inside each class, add Params.copyValues [ci skip] fd751fc [Xiangrui Meng] add java-friendly versions of fit and tranform 3f810cd [Xiangrui Meng] use multi-model training api in cv 5b8f413 [Xiangrui Meng] rename model to modelParams 9d2d35d [Xiangrui Meng] test varargs and chain model params f46e927 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3530 1ef26e0 [Xiangrui Meng] specialize methods/types for Java df293ed [Xiangrui Meng] switch to setter/getter 376db0a [Xiangrui Meng] pipeline and parameters
Diffstat (limited to 'mllib')
-rw-r--r--mllib/pom.xml5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala172
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala127
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala148
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala71
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package-info.java25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala321
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala74
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala126
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala112
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala3
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java72
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java80
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java76
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala82
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala57
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala63
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala21
31 files changed, 2246 insertions, 16 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 87a7ddaba9..dd68b27a78 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -101,6 +101,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
new file mode 100644
index 0000000000..fdbee743e8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for estimators that fit models to data.
+ */
+@AlphaComponent
+abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ val map = new ParamMap().put(paramPairs: _*)
+ fit(dataset, map)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ * The default implementation uses a for loop on each parameter map.
+ * Subclasses could overwrite this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ paramMaps.map(fit(dataset, _))
+ }
+
+ // Java-friendly versions of fit.
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
+ fit(dataset.schemaRDD, paramPairs: _*)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
+ fit(dataset.schemaRDD, paramMap)
+ }
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
+ fit(dataset.schemaRDD, paramMaps).asJava
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
new file mode 100644
index 0000000000..db563dd550
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for evaluators that compute metrics from predictions.
+ */
+@AlphaComponent
+abstract class Evaluator extends Identifiable {
+
+ /**
+ * Evaluates the output.
+ *
+ * @param dataset a dataset that contains labels/observations and predictions.
+ * @param paramMap parameter map that specifies the input columns and output metrics
+ * @return metric
+ */
+ def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
new file mode 100644
index 0000000000..cd84b05bfb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import java.util.UUID
+
+/**
+ * Object with a unique id.
+ */
+private[ml] trait Identifiable extends Serializable {
+
+ /**
+ * A unique id for the object. The default implementation concatenates the class name, "-", and 8
+ * random hex chars.
+ */
+ private[ml] val uid: String =
+ this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
new file mode 100644
index 0000000000..cae5082b51
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+
+/**
+ * :: AlphaComponent ::
+ * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
+ *
+ * @tparam M model type
+ */
+@AlphaComponent
+abstract class Model[M <: Model[M]] extends Transformer {
+ /**
+ * The parent estimator that produced this model.
+ */
+ val parent: Estimator[M]
+
+ /**
+ * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+ */
+ val fittingParamMap: ParamMap
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
new file mode 100644
index 0000000000..e545df1e37
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, Param, ParamMap}
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
+ */
+@AlphaComponent
+abstract class PipelineStage extends Serializable with Logging {
+
+ /**
+ * Derives the output schema from the input schema and parameters.
+ */
+ private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+
+ /**
+ * Derives the output schema from the input schema and parameters, optionally with logging.
+ */
+ protected def transformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ logging: Boolean): StructType = {
+ if (logging) {
+ logDebug(s"Input schema: ${schema.json}")
+ }
+ val outputSchema = transformSchema(schema, paramMap)
+ if (logging) {
+ logDebug(s"Expected output schema: ${outputSchema.json}")
+ }
+ outputSchema
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
+ * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the
+ * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will
+ * be called on the input dataset to fit a model. Then the model, which is a transformer, will be
+ * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]],
+ * its [[Transformer.transform]] method will be called to produce the dataset for the next stage.
+ * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and
+ * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
+ * an identity transformer.
+ */
+@AlphaComponent
+class Pipeline extends Estimator[PipelineModel] {
+
+ /** param for pipeline stages */
+ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
+ def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
+ def getStages: Array[PipelineStage] = get(stages)
+
+ /**
+ * Fits the pipeline to the input dataset with additional parameters. If a stage is an
+ * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model.
+ * Then the model, which is a transformer, will be used to transform the dataset as the input to
+ * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be
+ * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an
+ * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the
+ * pipeline stages. If there are no stages, the output model acts as an identity transformer.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted pipeline
+ */
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ // Search for the last estimator.
+ var indexOfLastEstimator = -1
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ stage match {
+ case _: Estimator[_] =>
+ indexOfLastEstimator = index
+ case _ =>
+ }
+ }
+ var curDataset = dataset
+ val transformers = ListBuffer.empty[Transformer]
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ if (index <= indexOfLastEstimator) {
+ val transformer = stage match {
+ case estimator: Estimator[_] =>
+ estimator.fit(curDataset, paramMap)
+ case t: Transformer =>
+ t
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Do not support stage $stage of type ${stage.getClass}")
+ }
+ curDataset = transformer.transform(curDataset, paramMap)
+ transformers += transformer
+ } else {
+ transformers += stage.asInstanceOf[Transformer]
+ }
+ }
+
+ new PipelineModel(this, map, transformers.toArray)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ require(theStages.toSet.size == theStages.size,
+ "Cannot have duplicate components in a pipeline.")
+ theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Represents a compiled pipeline.
+ */
+@AlphaComponent
+class PipelineModel private[ml] (
+ override val parent: Pipeline,
+ override val fittingParamMap: ParamMap,
+ private[ml] val stages: Array[Transformer])
+ extends Model[PipelineModel] with Logging {
+
+ /**
+ * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
+ * estimator does not exist in the pipeline.
+ */
+ def getModel[M <: Model[M]](stage: Estimator[M]): M = {
+ val matched = stages.filter {
+ case m: Model[_] => m.parent.eq(stage)
+ case _ => false
+ }
+ if (matched.isEmpty) {
+ throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
+ } else if (matched.size > 1) {
+ throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
+ } else {
+ matched.head.asInstanceOf[M]
+ }
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
new file mode 100644
index 0000000000..490e6609ad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for transformers that transform one dataset into another.
+ */
+@AlphaComponent
+abstract class Transformer extends PipelineStage with Params {
+
+ /**
+ * Transforms the dataset with optional parameters
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ val map = new ParamMap()
+ paramPairs.foreach(map.put(_))
+ transform(dataset, map)
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+
+ // Java-friendly versions of transform.
+
+ /**
+ * Transforms the dataset with optional parameters.
+ * @param dataset input datset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
+ }
+}
+
+/**
+ * Abstract class for transformers that take one input column, apply transformation, and output the
+ * result as a new column.
+ */
+private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+ extends Transformer with HasInputCol with HasOutputCol with Logging {
+
+ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
+
+ /**
+ * Creates the transform function using the given param map. The input param map already takes
+ * account of the embedded param map. So the param values should be determined solely by the input
+ * param map.
+ */
+ protected def createTransformFunc(paramMap: ParamMap): IN => OUT
+
+ /**
+ * Validates the input type. Throw an exception if it is invalid.
+ */
+ protected def validateInputType(inputType: DataType): Unit = {}
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ validateInputType(inputType)
+ if (schema.fieldNames.contains(map(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
+ }
+ val output = ScalaReflection.schemaFor[OUT]
+ val outputFields = schema.fields :+
+ StructField(map(outputCol), output.dataType, output.nullable)
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val udf = this.createTransformFunc(map)
+ dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
new file mode 100644
index 0000000000..85b8899636
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: AlphaComponent ::
+ * Params for logistic regression.
+ */
+@AlphaComponent
+private[classification] trait LogisticRegressionParams extends Params
+ with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
+ with HasScoreCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param paramMap additional parameters
+ * @param fitting whether this is in fitting
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean): StructType = {
+ val map = this.paramMap ++ paramMap
+ val featuresType = schema(map(featuresCol)).dataType
+ // TODO: Support casting Array[Double] and Array[Float] to Vector.
+ require(featuresType.isInstanceOf[VectorUDT],
+ s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
+ if (fitting) {
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
+ }
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
+ require(!fieldNames.contains(map(predictionCol)),
+ s"Prediction column ${map(predictionCol)} already exists.")
+ val outputFields = schema.fields ++ Seq(
+ StructField(map(scoreCol), DoubleType, false),
+ StructField(map(predictionCol), DoubleType, false))
+ StructType(outputFields)
+ }
+}
+
+/**
+ * Logistic regression.
+ */
+class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+
+ setRegParam(0.1)
+ setMaxIter(100)
+ setThreshold(0.5)
+
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ val lr = new LogisticRegressionWithLBFGS
+ lr.optimizer
+ .setRegParam(map(regParam))
+ .setNumIterations(map(maxIter))
+ val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
+ instances.unpersist()
+ // copy model params
+ Params.inheritValues(map, this, lrm)
+ lrm
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = true)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by [[LogisticRegression]].
+ */
+@AlphaComponent
+class LogisticRegressionModel private[ml] (
+ override val parent: LogisticRegression,
+ override val fittingParamMap: ParamMap,
+ weights: Vector)
+ extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = false)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val score: Vector => Double = (v) => {
+ val margin = BLAS.dot(v, weights)
+ 1.0 / (1.0 + math.exp(-margin))
+ }
+ val t = map(threshold)
+ val predict: Double => Double = (score) => {
+ if (score > t) 1.0 else 0.0
+ }
+ dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
+ .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
new file mode 100644
index 0000000000..0b0504e036
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.sql.{DoubleType, Row, SchemaRDD}
+
+/**
+ * :: AlphaComponent ::
+ * Evaluator for binary classification, which expects two input columns: score and label.
+ */
+@AlphaComponent
+class BinaryClassificationEvaluator extends Evaluator with Params
+ with HasScoreCol with HasLabelCol {
+
+ /** param for metric name in evaluation */
+ val metricName: Param[String] = new Param(this, "metricName",
+ "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+ def getMetricName: String = get(metricName)
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ val map = this.paramMap ++ paramMap
+
+ val schema = dataset.schema
+ val scoreType = schema(map(scoreCol)).dataType
+ require(scoreType == DoubleType,
+ s"Score column ${map(scoreCol)} must be double type but found $scoreType")
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Label column ${map(labelCol)} must be double type but found $labelType")
+
+ import dataset.sqlContext._
+ val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ .map { case Row(score: Double, label: Double) =>
+ (score, label)
+ }
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+ val metric = map(metricName) match {
+ case "areaUnderROC" =>
+ metrics.areaUnderROC()
+ case "areaUnderPR" =>
+ metrics.areaUnderPR()
+ case other =>
+ throw new IllegalArgumentException(s"Does not support metric $other.")
+ }
+ metrics.unpersist()
+ metric
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
new file mode 100644
index 0000000000..b98b1755a3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * :: AlphaComponent ::
+ * Maps a sequence of terms to their term frequencies using the hashing trick.
+ */
+@AlphaComponent
+class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+
+ /** number of features */
+ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
+ def setNumFeatures(value: Int) = set(numFeatures, value)
+ def getNumFeatures: Int = get(numFeatures)
+
+ override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
+ val hashingTF = new feature.HashingTF(paramMap(numFeatures))
+ hashingTF.transform
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
new file mode 100644
index 0000000000..896a6b83b6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+
+/**
+ * Params for [[StandardScaler]] and [[StandardScalerModel]].
+ */
+private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
+
+/**
+ * :: AlphaComponent ::
+ * Standardizes features by removing the mean and scaling to unit variance using column summary
+ * statistics on the samples in the training set.
+ */
+@AlphaComponent
+class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val input = dataset.select(map(inputCol).attr)
+ .map { case Row(v: Vector) =>
+ v
+ }
+ val scaler = new feature.StandardScaler().fit(input)
+ val model = new StandardScalerModel(this, map, scaler)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[StandardScaler]].
+ */
+@AlphaComponent
+class StandardScalerModel private[ml] (
+ override val parent: StandardScaler,
+ override val fittingParamMap: ParamMap,
+ scaler: feature.StandardScalerModel)
+ extends Model[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val scale: (Vector) => Vector = (v) => {
+ scaler.transform(v)
+ }
+ dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
new file mode 100644
index 0000000000..0a6599b64c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataType, StringType}
+
+/**
+ * :: AlphaComponent ::
+ * A tokenizer that converts the input string to lowercase and then splits it by white spaces.
+ */
+@AlphaComponent
+class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
+
+ protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ _.toLowerCase.split("\\s")
+ }
+
+ protected override def validateInputType(inputType: DataType): Unit = {
+ require(inputType == StringType, s"Input type must be string type but got $inputType.")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
new file mode 100644
index 0000000000..00d9c802e9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+@AlphaComponent
+package org.apache.spark.ml;
+
+import org.apache.spark.annotation.AlphaComponent;
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
new file mode 100644
index 0000000000..51cd48c904
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+package object ml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
new file mode 100644
index 0000000000..8fd46aef4b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import java.lang.reflect.Modifier
+
+import org.apache.spark.annotation.AlphaComponent
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.ml.Identifiable
+
+/**
+ * :: AlphaComponent ::
+ * A param with self-contained documentation and optionally default value. Primitive-typed param
+ * should use the specialized versions, which are more friendly to Java users.
+ *
+ * @param parent parent object
+ * @param name param name
+ * @param doc documentation
+ * @tparam T param value type
+ */
+@AlphaComponent
+class Param[T] (
+ val parent: Params,
+ val name: String,
+ val doc: String,
+ val defaultValue: Option[T] = None)
+ extends Serializable {
+
+ /**
+ * Creates a param pair with the given value (for Java).
+ */
+ def w(value: T): ParamPair[T] = this -> value
+
+ /**
+ * Creates a param pair with the given value (for Scala).
+ */
+ def ->(value: T): ParamPair[T] = ParamPair(this, value)
+
+ override def toString: String = {
+ if (defaultValue.isDefined) {
+ s"$name: $doc (default: ${defaultValue.get})"
+ } else {
+ s"$name: $doc"
+ }
+ }
+}
+
+// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
+
+/** Specialized version of [[Param[Double]]] for Java. */
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+ extends Param[Double](parent, name, doc, defaultValue) {
+
+ override def w(value: Double): ParamPair[Double] = super.w(value)
+}
+
+/** Specialized version of [[Param[Int]]] for Java. */
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+ extends Param[Int](parent, name, doc, defaultValue) {
+
+ override def w(value: Int): ParamPair[Int] = super.w(value)
+}
+
+/** Specialized version of [[Param[Float]]] for Java. */
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+ extends Param[Float](parent, name, doc, defaultValue) {
+
+ override def w(value: Float): ParamPair[Float] = super.w(value)
+}
+
+/** Specialized version of [[Param[Long]]] for Java. */
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+ extends Param[Long](parent, name, doc, defaultValue) {
+
+ override def w(value: Long): ParamPair[Long] = super.w(value)
+}
+
+/** Specialized version of [[Param[Boolean]]] for Java. */
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+ extends Param[Boolean](parent, name, doc, defaultValue) {
+
+ override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
+}
+
+/**
+ * A param amd its value.
+ */
+case class ParamPair[T](param: Param[T], value: T)
+
+/**
+ * :: AlphaComponent ::
+ * Trait for components that take parameters. This also provides an internal param map to store
+ * parameter values attached to the instance.
+ */
+@AlphaComponent
+trait Params extends Identifiable with Serializable {
+
+ /** Returns all params. */
+ def params: Array[Param[_]] = {
+ val methods = this.getClass.getMethods
+ methods.filter { m =>
+ Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty
+ }.sortBy(_.getName)
+ .map(m => m.invoke(this).asInstanceOf[Param[_]])
+ }
+
+ /**
+ * Validates parameter values stored internally plus the input parameter map.
+ * Raises an exception if any parameter is invalid.
+ */
+ def validate(paramMap: ParamMap): Unit = {}
+
+ /**
+ * Validates parameter values stored internally.
+ * Raise an exception if any parameter value is invalid.
+ */
+ def validate(): Unit = validate(ParamMap.empty)
+
+ /**
+ * Returns the documentation of all params.
+ */
+ def explainParams(): String = params.mkString("\n")
+
+ /** Checks whether a param is explicitly set. */
+ def isSet(param: Param[_]): Boolean = {
+ require(param.parent.eq(this))
+ paramMap.contains(param)
+ }
+
+ /** Gets a param by its name. */
+ private[ml] def getParam(paramName: String): Param[Any] = {
+ val m = this.getClass.getMethod(paramName)
+ assert(Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty)
+ m.invoke(this).asInstanceOf[Param[Any]]
+ }
+
+ /**
+ * Sets a parameter in the embedded param map.
+ */
+ private[ml] def set[T](param: Param[T], value: T): this.type = {
+ require(param.parent.eq(this))
+ paramMap.put(param.asInstanceOf[Param[Any]], value)
+ this
+ }
+
+ /**
+ * Gets the value of a parameter in the embedded param map.
+ */
+ private[ml] def get[T](param: Param[T]): T = {
+ require(param.parent.eq(this))
+ paramMap(param)
+ }
+
+ /**
+ * Internal param map.
+ */
+ protected val paramMap: ParamMap = ParamMap.empty
+}
+
+private[ml] object Params {
+
+ /**
+ * Copies parameter values from the parent estimator to the child model it produced.
+ * @param paramMap the param map that holds parameters of the parent
+ * @param parent the parent estimator
+ * @param child the child model
+ */
+ def inheritValues[E <: Params, M <: E](
+ paramMap: ParamMap,
+ parent: E,
+ child: M): Unit = {
+ parent.params.foreach { param =>
+ if (paramMap.contains(param)) {
+ child.set(child.getParam(param.name), paramMap(param))
+ }
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A param to value map.
+ */
+@AlphaComponent
+class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable {
+
+ /**
+ * Creates an empty param map.
+ */
+ def this() = this(mutable.Map.empty[Param[Any], Any])
+
+ /**
+ * Puts a (param, value) pair (overwrites if the input param exists).
+ */
+ def put[T](param: Param[T], value: T): this.type = {
+ map(param.asInstanceOf[Param[Any]]) = value
+ this
+ }
+
+ /**
+ * Puts a list of param pairs (overwrites if the input params exists).
+ */
+ def put(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ put(p.param.asInstanceOf[Param[Any]], p.value)
+ }
+ this
+ }
+
+ /**
+ * Optionally returns the value associated with a param or its default.
+ */
+ def get[T](param: Param[T]): Option[T] = {
+ map.get(param.asInstanceOf[Param[Any]])
+ .orElse(param.defaultValue)
+ .asInstanceOf[Option[T]]
+ }
+
+ /**
+ * Gets the value of the input param or its default value if it does not exist.
+ * Raises a NoSuchElementException if there is no value associated with the input param.
+ */
+ def apply[T](param: Param[T]): T = {
+ val value = get(param)
+ if (value.isDefined) {
+ value.get
+ } else {
+ throw new NoSuchElementException(s"Cannot find param ${param.name}.")
+ }
+ }
+
+ /**
+ * Checks whether a parameter is explicitly specified.
+ */
+ def contains(param: Param[_]): Boolean = {
+ map.contains(param.asInstanceOf[Param[Any]])
+ }
+
+ /**
+ * Filters this param map for the given parent.
+ */
+ def filter(parent: Params): ParamMap = {
+ val filtered = map.filterKeys(_.parent == parent)
+ new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
+ }
+
+ /**
+ * Make a copy of this param map.
+ */
+ def copy: ParamMap = new ParamMap(map.clone())
+
+ override def toString: String = {
+ map.map { case (param, value) =>
+ s"\t${param.parent.uid}-${param.name}: $value"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+
+ /**
+ * Returns a new param map that contains parameters in this map and the given map,
+ * where the latter overwrites this if there exists conflicts.
+ */
+ def ++(other: ParamMap): ParamMap = {
+ new ParamMap(this.map ++ other.map)
+ }
+
+
+ /**
+ * Adds all parameters from the input param map into this param map.
+ */
+ def ++=(other: ParamMap): this.type = {
+ this.map ++= other.map
+ this
+ }
+
+ /**
+ * Converts this param map to a sequence of param pairs.
+ */
+ def toSeq: Seq[ParamPair[_]] = {
+ map.toSeq.map { case (param, value) =>
+ ParamPair(param, value)
+ }
+ }
+}
+
+object ParamMap {
+
+ /**
+ * Returns an empty param map.
+ */
+ def empty: ParamMap = new ParamMap()
+
+ /**
+ * Constructs a param map by specifying its entries.
+ */
+ @varargs
+ def apply(paramPairs: ParamPair[_]*): ParamMap = {
+ new ParamMap().put(paramPairs: _*)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
new file mode 100644
index 0000000000..ef141d3eb2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+private[ml] trait HasRegParam extends Params {
+ /** param for regularization parameter */
+ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+ def getRegParam: Double = get(regParam)
+}
+
+private[ml] trait HasMaxIter extends Params {
+ /** param for max number of iterations */
+ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ def getMaxIter: Int = get(maxIter)
+}
+
+private[ml] trait HasFeaturesCol extends Params {
+ /** param for features column name */
+ val featuresCol: Param[String] =
+ new Param(this, "featuresCol", "features column name", Some("features"))
+ def getFeaturesCol: String = get(featuresCol)
+}
+
+private[ml] trait HasLabelCol extends Params {
+ /** param for label column name */
+ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
+ def getLabelCol: String = get(labelCol)
+}
+
+private[ml] trait HasScoreCol extends Params {
+ /** param for score column name */
+ val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
+ def getScoreCol: String = get(scoreCol)
+}
+
+private[ml] trait HasPredictionCol extends Params {
+ /** param for prediction column name */
+ val predictionCol: Param[String] =
+ new Param(this, "predictionCol", "prediction column name", Some("prediction"))
+ def getPredictionCol: String = get(predictionCol)
+}
+
+private[ml] trait HasThreshold extends Params {
+ /** param for threshold in (binary) prediction */
+ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
+ def getThreshold: Double = get(threshold)
+}
+
+private[ml] trait HasInputCol extends Params {
+ /** param for input column name */
+ val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
+ def getInputCol: String = get(inputCol)
+}
+
+private[ml] trait HasOutputCol extends Params {
+ /** param for output column name */
+ val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
+ def getOutputCol: String = get(outputCol)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
new file mode 100644
index 0000000000..194b9bfd9a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import com.github.fommil.netlib.F2jBLAS
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * Params for [[CrossValidator]] and [[CrossValidatorModel]].
+ */
+private[ml] trait CrossValidatorParams extends Params {
+ /** param for the estimator to be cross-validated */
+ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+ def getEstimator: Estimator[_] = get(estimator)
+
+ /** param for estimator param maps */
+ val estimatorParamMaps: Param[Array[ParamMap]] =
+ new Param(this, "estimatorParamMaps", "param maps for the estimator")
+ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
+
+ /** param for the evaluator for selection */
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+ def getEvaluator: Evaluator = get(evaluator)
+
+ /** param for number of folds for cross validation */
+ val numFolds: IntParam =
+ new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+ def getNumFolds: Int = get(numFolds)
+}
+
+/**
+ * :: AlphaComponent ::
+ * K-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {
+
+ private val f2jBLAS = new F2jBLAS
+
+ def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+ def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+ def setNumFolds(value: Int): this.type = set(numFolds, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ val map = this.paramMap ++ paramMap
+ val schema = dataset.schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val sqlCtx = dataset.sqlContext
+ val est = map(estimator)
+ val eval = map(evaluator)
+ val epm = map(estimatorParamMaps)
+ val numModels = epm.size
+ val metrics = new Array[Double](epm.size)
+ val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
+ val trainingDataset = sqlCtx.applySchema(training, schema).cache()
+ val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ // multi-model training
+ logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ var i = 0
+ while (i < numModels) {
+ val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
+ logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
+ metrics(i) += metric
+ i += 1
+ }
+ }
+ f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
+ logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+ val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ logInfo(s"Best cross-validation metric: $bestMetric.")
+ val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ val cvModel = new CrossValidatorModel(this, map, bestModel)
+ Params.inheritValues(map, this, cvModel)
+ cvModel
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ map(estimator).transformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model from k-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidatorModel private[ml] (
+ override val parent: CrossValidator,
+ override val fittingParamMap: ParamMap,
+ val bestModel: Model[_])
+ extends Model[CrossValidatorModel] with CrossValidatorParams {
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ bestModel.transform(dataset, paramMap)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ bestModel.transformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
new file mode 100644
index 0000000000..dafe73d82c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+
+/**
+ * :: AlphaComponent ::
+ * Builder for a param grid used in grid search-based model selection.
+ */
+@AlphaComponent
+class ParamGridBuilder {
+
+ private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ def baseOn(paramMap: ParamMap): this.type = {
+ baseOn(paramMap.toSeq: _*)
+ this
+ }
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ @varargs
+ def baseOn(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value))
+ }
+ this
+ }
+
+ /**
+ * Adds a param with multiple values (overwrites if the input param exists).
+ */
+ def addGrid[T](param: Param[T], values: Iterable[T]): this.type = {
+ paramGrid.put(param, values)
+ this
+ }
+
+ // specialized versions of addGrid for Java.
+
+ /**
+ * Adds a double param with multiple values.
+ */
+ def addGrid(param: DoubleParam, values: Array[Double]): this.type = {
+ addGrid[Double](param, values)
+ }
+
+ /**
+ * Adds a int param with multiple values.
+ */
+ def addGrid(param: IntParam, values: Array[Int]): this.type = {
+ addGrid[Int](param, values)
+ }
+
+ /**
+ * Adds a float param with multiple values.
+ */
+ def addGrid(param: FloatParam, values: Array[Float]): this.type = {
+ addGrid[Float](param, values)
+ }
+
+ /**
+ * Adds a long param with multiple values.
+ */
+ def addGrid(param: LongParam, values: Array[Long]): this.type = {
+ addGrid[Long](param, values)
+ }
+
+ /**
+ * Adds a boolean param with true and false.
+ */
+ def addGrid(param: BooleanParam): this.type = {
+ addGrid[Boolean](param, Array(true, false))
+ }
+
+ /**
+ * Builds and returns all combinations of parameters specified by the param grid.
+ */
+ def build(): Array[ParamMap] = {
+ var paramMaps = Array(new ParamMap)
+ paramGrid.foreach { case (param, values) =>
+ val newParamMaps = values.flatMap { v =>
+ paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v))
+ }
+ paramMaps = newParamMaps.toArray
+ }
+ paramMaps
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 54ee930d61..89539e600f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
/**
* BLAS routines for MLlib's vectors and matrices.
*/
-private[mllib] object BLAS extends Serializable with Logging {
+private[spark] object BLAS extends Serializable with Logging {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index ac217edc61..9fccd6341b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -115,6 +115,9 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def deserialize(datum: Any): Vector = {
datum match {
+ // TODO: something wrong with UDT serialization
+ case v: Vector =>
+ v
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 17c753c566..2067b36f24 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.regression
+import scala.beans.BeanInfo
+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
@@ -27,6 +29,7 @@ import org.apache.spark.SparkException
* @param label Label for this data point.
* @param features List of features for this data point.
*/
+@BeanInfo
case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
"(%s,%s)".format(label, features)
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
new file mode 100644
index 0000000000..42846677ed
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.StandardScaler;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+/**
+ * Test Pipeline construction and fitting in Java.
+ */
+public class JavaPipelineSuite {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaPipelineSuite");
+ jsql = new JavaSQLContext(jsc);
+ JavaRDD<LabeledPoint> points =
+ jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
+ dataset = jsql.applySchema(points, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void pipeline() {
+ StandardScaler scaler = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("scaledFeatures");
+ LogisticRegression lr = new LogisticRegression()
+ .setFeaturesCol("scaledFeatures");
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {scaler, lr});
+ PipelineModel model = pipeline.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
new file mode 100644
index 0000000000..76eb7f0032
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaLogisticRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new JavaSQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void logisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionWithSetters() {
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
+ .registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionFitWithVarargs() {
+ LogisticRegression lr = new LogisticRegression();
+ lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
new file mode 100644
index 0000000000..a266ebd207
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaCrossValidatorSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
+ jsql = new JavaSQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void crossValidationWithLogisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ ParamMap[] lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
+ .addGrid(lr.maxIter(), new int[] {0, 10})
+ .build();
+ BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
+ CrossValidator cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3);
+ CrossValidatorModel cvModel = cv.fit(dataset);
+ ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
+ Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
+ Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
new file mode 100644
index 0000000000..4515084bc7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.when
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar.mock
+
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+class PipelineSuite extends FunSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ test("pipeline") {
+ val estimator0 = mock[Estimator[MyModel]]
+ val model0 = mock[MyModel]
+ val transformer1 = mock[Transformer]
+ val estimator2 = mock[Estimator[MyModel]]
+ val model2 = mock[MyModel]
+ val transformer3 = mock[Transformer]
+ val dataset0 = mock[SchemaRDD]
+ val dataset1 = mock[SchemaRDD]
+ val dataset2 = mock[SchemaRDD]
+ val dataset3 = mock[SchemaRDD]
+ val dataset4 = mock[SchemaRDD]
+
+ when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
+ when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
+ when(model0.parent).thenReturn(estimator0)
+ when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2)
+ when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
+ when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
+ when(model2.parent).thenReturn(estimator2)
+ when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(estimator0, transformer1, estimator2, transformer3))
+ val pipelineModel = pipeline.fit(dataset0)
+
+ assert(pipelineModel.stages.size === 4)
+ assert(pipelineModel.stages(0).eq(model0))
+ assert(pipelineModel.stages(1).eq(transformer1))
+ assert(pipelineModel.stages(2).eq(model2))
+ assert(pipelineModel.stages(3).eq(transformer3))
+
+ assert(pipelineModel.getModel(estimator0).eq(model0))
+ assert(pipelineModel.getModel(estimator2).eq(model2))
+ intercept[NoSuchElementException] {
+ pipelineModel.getModel(mock[Estimator[MyModel]])
+ }
+ val output = pipelineModel.transform(dataset0)
+ assert(output.eq(dataset4))
+ }
+
+ test("pipeline with duplicate stages") {
+ val estimator = mock[Estimator[MyModel]]
+ val pipeline = new Pipeline()
+ .setStages(Array(estimator, estimator))
+ val dataset = mock[SchemaRDD]
+ intercept[IllegalArgumentException] {
+ pipeline.fit(dataset)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
new file mode 100644
index 0000000000..625af299a5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
+
+ import sqlContext._
+
+ val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+ test("logistic regression") {
+ val lr = new LogisticRegression
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select('label, 'prediction)
+ .collect()
+ }
+
+ test("logistic regression with setters") {
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
+ .select('label, 'score, 'prediction)
+ .collect()
+ }
+
+ test("logistic regression fit and transform with varargs") {
+ val lr = new LogisticRegression
+ val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
+ model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
+ .select('label, 'probability, 'prediction)
+ .collect()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
new file mode 100644
index 0000000000..1ce2987612
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import org.scalatest.FunSuite
+
+class ParamsSuite extends FunSuite {
+
+ val solver = new TestParams()
+ import solver.{inputCol, maxIter}
+
+ test("param") {
+ assert(maxIter.name === "maxIter")
+ assert(maxIter.doc === "max number of iterations")
+ assert(maxIter.defaultValue.get === 100)
+ assert(maxIter.parent.eq(solver))
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 100)")
+ assert(inputCol.defaultValue === None)
+ }
+
+ test("param pair") {
+ val pair0 = maxIter -> 5
+ val pair1 = maxIter.w(5)
+ val pair2 = ParamPair(maxIter, 5)
+ for (pair <- Seq(pair0, pair1, pair2)) {
+ assert(pair.param.eq(maxIter))
+ assert(pair.value === 5)
+ }
+ }
+
+ test("param map") {
+ val map0 = ParamMap.empty
+
+ assert(!map0.contains(maxIter))
+ assert(map0(maxIter) === maxIter.defaultValue.get)
+ map0.put(maxIter, 10)
+ assert(map0.contains(maxIter))
+ assert(map0(maxIter) === 10)
+
+ assert(!map0.contains(inputCol))
+ intercept[NoSuchElementException] {
+ map0(inputCol)
+ }
+ map0.put(inputCol -> "input")
+ assert(map0.contains(inputCol))
+ assert(map0(inputCol) === "input")
+
+ val map1 = map0.copy
+ val map2 = ParamMap(maxIter -> 10, inputCol -> "input")
+ val map3 = new ParamMap()
+ .put(maxIter, 10)
+ .put(inputCol, "input")
+ val map4 = ParamMap.empty ++ map0
+ val map5 = ParamMap.empty
+ map5 ++= map0
+
+ for (m <- Seq(map1, map2, map3, map4, map5)) {
+ assert(m.contains(maxIter))
+ assert(m(maxIter) === 10)
+ assert(m.contains(inputCol))
+ assert(m(inputCol) === "input")
+ }
+ }
+
+ test("params") {
+ val params = solver.params
+ assert(params.size === 2)
+ assert(params(0).eq(inputCol), "params must be ordered by name")
+ assert(params(1).eq(maxIter))
+ assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+ assert(solver.getParam("inputCol").eq(inputCol))
+ assert(solver.getParam("maxIter").eq(maxIter))
+ intercept[NoSuchMethodException] {
+ solver.getParam("abc")
+ }
+ assert(!solver.isSet(inputCol))
+ intercept[IllegalArgumentException] {
+ solver.validate()
+ }
+ solver.validate(ParamMap(inputCol -> "input"))
+ solver.setInputCol("input")
+ assert(solver.isSet(inputCol))
+ assert(solver.getInputCol === "input")
+ solver.validate()
+ intercept[IllegalArgumentException] {
+ solver.validate(ParamMap(maxIter -> -10))
+ }
+ solver.setMaxIter(-10)
+ intercept[IllegalArgumentException] {
+ solver.validate()
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
new file mode 100644
index 0000000000..1a65883d78
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+/** A subclass of Params for testing. */
+class TestParams extends Params {
+
+ val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100))
+ def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
+ def getMaxIter: Int = get(maxIter)
+
+ val inputCol = new Param[String](this, "inputCol", "input column name")
+ def setInputCol(value: String): this.type = { set(inputCol, value); this }
+ def getInputCol: String = get(inputCol)
+
+ override def validate(paramMap: ParamMap) = {
+ val m = this.paramMap ++ paramMap
+ require(m(maxIter) >= 0)
+ require(m.contains(inputCol))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
new file mode 100644
index 0000000000..72a334ae93
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class CrossValidatorSuite extends FunSuite with LocalSparkContext {
+
+ import sqlContext._
+
+ val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+ test("cross validation with logistic regression") {
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 10))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3)
+ val cvModel = cv.fit(dataset)
+ val bestParamMap = cvModel.bestModel.fittingParamMap
+ assert(bestParamMap(lr.regParam) === 0.001)
+ assert(bestParamMap(lr.maxIter) === 10)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
new file mode 100644
index 0000000000..20aa100112
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.param.{ParamMap, TestParams}
+
+class ParamGridBuilderSuite extends FunSuite {
+
+ val solver = new TestParams()
+ import solver.{inputCol, maxIter}
+
+ test("param grid builder") {
+ def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = {
+ assert(maps.size === expected.size)
+ maps.foreach { m =>
+ val tuple = (m(maxIter), m(inputCol))
+ assert(expected.contains(tuple))
+ expected.remove(tuple)
+ }
+ assert(expected.isEmpty)
+ }
+
+ val maps0 = new ParamGridBuilder()
+ .baseOn(maxIter -> 10)
+ .addGrid(inputCol, Array("input0", "input1"))
+ .build()
+ val expected0 = mutable.Set(
+ (10, "input0"),
+ (10, "input1"))
+ validateGrid(maps0, expected0)
+
+ val maps1 = new ParamGridBuilder()
+ .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten
+ .addGrid(maxIter, Array(10, 20))
+ .addGrid(inputCol, Array("input0", "input1"))
+ .build()
+ val expected1 = mutable.Set(
+ (10, "input0"),
+ (20, "input0"),
+ (10, "input1"),
+ (20, "input1"))
+ validateGrid(maps1, expected1)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 7857d9e5ee..4417d66adf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -17,26 +17,17 @@
package org.apache.spark.mllib.util
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
- @transient var sc: SparkContext = _
-
- override def beforeAll() {
- val conf = new SparkConf()
- .setMaster("local")
- .setAppName("test")
- sc = new SparkContext(conf)
- super.beforeAll()
- }
+ @transient val sc = new SparkContext("local", "test")
+ @transient lazy val sqlContext = new SQLContext(sc)
override def afterAll() {
- if (sc != null) {
- sc.stop()
- }
+ sc.stop()
super.afterAll()
}
}