aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java93
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala86
-rw-r--r--mllib/pom.xml5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala172
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala127
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala148
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala71
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package-info.java25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala321
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala74
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala126
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala112
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala3
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java72
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java80
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java76
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala82
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala57
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala63
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala21
33 files changed, 2425 insertions, 16 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
new file mode 100644
index 0000000000..22ba68d8c3
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineModel;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.Tokenizer;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.SparkConf;
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
+ * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
+ * this example {@link SimpleTextClassificationPipeline}. Run with
+ * <pre>
+ * bin/run-example ml.JavaSimpleTextClassificationPipeline
+ * </pre>
+ */
+public class JavaSimpleTextClassificationPipeline {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ JavaSQLContext jsql = new JavaSQLContext(jsc);
+
+ // Prepare training documents, which are labeled.
+ List<LabeledDocument> localTraining = Lists.newArrayList(
+ new LabeledDocument(0L, "a b c d e spark", 1.0),
+ new LabeledDocument(1L, "b d", 0.0),
+ new LabeledDocument(2L, "spark f g h", 1.0),
+ new LabeledDocument(3L, "hadoop mapreduce", 0.0));
+ JavaSchemaRDD training =
+ jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words");
+ HashingTF hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol())
+ .setOutputCol("features");
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01);
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
+
+ // Fit the pipeline to training documents.
+ PipelineModel model = pipeline.fit(training);
+
+ // Prepare test documents, which are unlabeled.
+ List<Document> localTest = Lists.newArrayList(
+ new Document(4L, "spark i j k"),
+ new Document(5L, "l m n"),
+ new Document(6L, "mapreduce spark"),
+ new Document(7L, "apache hadoop"));
+ JavaSchemaRDD test =
+ jsql.applySchema(jsc.parallelize(localTest), Document.class);
+
+ // Make predictions on test documents.
+ model.transform(test).registerAsTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ for (Row r: predictions.collect()) {
+ System.out.println(r);
+ }
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
new file mode 100644
index 0000000000..ee7897d906
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.sql.SQLContext
+
+@BeanInfo
+case class LabeledDocument(id: Long, text: String, label: Double)
+
+@BeanInfo
+case class Document(id: Long, text: String)
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. This is to show
+ * how to create and configure an ML pipeline. Run with
+ * {{{
+ * bin/run-example ml.SimpleTextClassificationPipeline
+ * }}}
+ */
+object SimpleTextClassificationPipeline {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Prepare training documents, which are labeled.
+ val training = sparkContext.parallelize(Seq(
+ LabeledDocument(0L, "a b c d e spark", 1.0),
+ LabeledDocument(1L, "b d", 0.0),
+ LabeledDocument(2L, "spark f g h", 1.0),
+ LabeledDocument(3L, "hadoop mapreduce", 0.0)))
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ val tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words")
+ val hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol)
+ .setOutputCol("features")
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01)
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, hashingTF, lr))
+
+ // Fit the pipeline to training documents.
+ val model = pipeline.fit(training)
+
+ // Prepare test documents, which are unlabeled.
+ val test = sparkContext.parallelize(Seq(
+ Document(4L, "spark i j k"),
+ Document(5L, "l m n"),
+ Document(6L, "mapreduce spark"),
+ Document(7L, "apache hadoop")))
+
+ // Make predictions on test documents.
+ model.transform(test)
+ .select('id, 'text, 'score, 'prediction)
+ .collect()
+ .foreach(println)
+ }
+}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 87a7ddaba9..dd68b27a78 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -101,6 +101,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
new file mode 100644
index 0000000000..fdbee743e8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for estimators that fit models to data.
+ */
+@AlphaComponent
+abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ val map = new ParamMap().put(paramPairs: _*)
+ fit(dataset, map)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ * The default implementation uses a for loop on each parameter map.
+ * Subclasses could overwrite this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ paramMaps.map(fit(dataset, _))
+ }
+
+ // Java-friendly versions of fit.
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
+ fit(dataset.schemaRDD, paramPairs: _*)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
+ fit(dataset.schemaRDD, paramMap)
+ }
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
+ fit(dataset.schemaRDD, paramMaps).asJava
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
new file mode 100644
index 0000000000..db563dd550
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for evaluators that compute metrics from predictions.
+ */
+@AlphaComponent
+abstract class Evaluator extends Identifiable {
+
+ /**
+ * Evaluates the output.
+ *
+ * @param dataset a dataset that contains labels/observations and predictions.
+ * @param paramMap parameter map that specifies the input columns and output metrics
+ * @return metric
+ */
+ def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
new file mode 100644
index 0000000000..cd84b05bfb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import java.util.UUID
+
+/**
+ * Object with a unique id.
+ */
+private[ml] trait Identifiable extends Serializable {
+
+ /**
+ * A unique id for the object. The default implementation concatenates the class name, "-", and 8
+ * random hex chars.
+ */
+ private[ml] val uid: String =
+ this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
new file mode 100644
index 0000000000..cae5082b51
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+
+/**
+ * :: AlphaComponent ::
+ * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
+ *
+ * @tparam M model type
+ */
+@AlphaComponent
+abstract class Model[M <: Model[M]] extends Transformer {
+ /**
+ * The parent estimator that produced this model.
+ */
+ val parent: Estimator[M]
+
+ /**
+ * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+ */
+ val fittingParamMap: ParamMap
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
new file mode 100644
index 0000000000..e545df1e37
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, Param, ParamMap}
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
+ */
+@AlphaComponent
+abstract class PipelineStage extends Serializable with Logging {
+
+ /**
+ * Derives the output schema from the input schema and parameters.
+ */
+ private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+
+ /**
+ * Derives the output schema from the input schema and parameters, optionally with logging.
+ */
+ protected def transformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ logging: Boolean): StructType = {
+ if (logging) {
+ logDebug(s"Input schema: ${schema.json}")
+ }
+ val outputSchema = transformSchema(schema, paramMap)
+ if (logging) {
+ logDebug(s"Expected output schema: ${outputSchema.json}")
+ }
+ outputSchema
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
+ * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the
+ * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will
+ * be called on the input dataset to fit a model. Then the model, which is a transformer, will be
+ * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]],
+ * its [[Transformer.transform]] method will be called to produce the dataset for the next stage.
+ * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and
+ * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
+ * an identity transformer.
+ */
+@AlphaComponent
+class Pipeline extends Estimator[PipelineModel] {
+
+ /** param for pipeline stages */
+ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
+ def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
+ def getStages: Array[PipelineStage] = get(stages)
+
+ /**
+ * Fits the pipeline to the input dataset with additional parameters. If a stage is an
+ * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model.
+ * Then the model, which is a transformer, will be used to transform the dataset as the input to
+ * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be
+ * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an
+ * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the
+ * pipeline stages. If there are no stages, the output model acts as an identity transformer.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted pipeline
+ */
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ // Search for the last estimator.
+ var indexOfLastEstimator = -1
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ stage match {
+ case _: Estimator[_] =>
+ indexOfLastEstimator = index
+ case _ =>
+ }
+ }
+ var curDataset = dataset
+ val transformers = ListBuffer.empty[Transformer]
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ if (index <= indexOfLastEstimator) {
+ val transformer = stage match {
+ case estimator: Estimator[_] =>
+ estimator.fit(curDataset, paramMap)
+ case t: Transformer =>
+ t
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Do not support stage $stage of type ${stage.getClass}")
+ }
+ curDataset = transformer.transform(curDataset, paramMap)
+ transformers += transformer
+ } else {
+ transformers += stage.asInstanceOf[Transformer]
+ }
+ }
+
+ new PipelineModel(this, map, transformers.toArray)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ require(theStages.toSet.size == theStages.size,
+ "Cannot have duplicate components in a pipeline.")
+ theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Represents a compiled pipeline.
+ */
+@AlphaComponent
+class PipelineModel private[ml] (
+ override val parent: Pipeline,
+ override val fittingParamMap: ParamMap,
+ private[ml] val stages: Array[Transformer])
+ extends Model[PipelineModel] with Logging {
+
+ /**
+ * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
+ * estimator does not exist in the pipeline.
+ */
+ def getModel[M <: Model[M]](stage: Estimator[M]): M = {
+ val matched = stages.filter {
+ case m: Model[_] => m.parent.eq(stage)
+ case _ => false
+ }
+ if (matched.isEmpty) {
+ throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
+ } else if (matched.size > 1) {
+ throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
+ } else {
+ matched.head.asInstanceOf[M]
+ }
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
new file mode 100644
index 0000000000..490e6609ad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for transformers that transform one dataset into another.
+ */
+@AlphaComponent
+abstract class Transformer extends PipelineStage with Params {
+
+ /**
+ * Transforms the dataset with optional parameters
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ val map = new ParamMap()
+ paramPairs.foreach(map.put(_))
+ transform(dataset, map)
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+
+ // Java-friendly versions of transform.
+
+ /**
+ * Transforms the dataset with optional parameters.
+ * @param dataset input datset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
+ }
+}
+
+/**
+ * Abstract class for transformers that take one input column, apply transformation, and output the
+ * result as a new column.
+ */
+private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+ extends Transformer with HasInputCol with HasOutputCol with Logging {
+
+ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
+
+ /**
+ * Creates the transform function using the given param map. The input param map already takes
+ * account of the embedded param map. So the param values should be determined solely by the input
+ * param map.
+ */
+ protected def createTransformFunc(paramMap: ParamMap): IN => OUT
+
+ /**
+ * Validates the input type. Throw an exception if it is invalid.
+ */
+ protected def validateInputType(inputType: DataType): Unit = {}
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ validateInputType(inputType)
+ if (schema.fieldNames.contains(map(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
+ }
+ val output = ScalaReflection.schemaFor[OUT]
+ val outputFields = schema.fields :+
+ StructField(map(outputCol), output.dataType, output.nullable)
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val udf = this.createTransformFunc(map)
+ dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
new file mode 100644
index 0000000000..85b8899636
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: AlphaComponent ::
+ * Params for logistic regression.
+ */
+@AlphaComponent
+private[classification] trait LogisticRegressionParams extends Params
+ with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
+ with HasScoreCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param paramMap additional parameters
+ * @param fitting whether this is in fitting
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean): StructType = {
+ val map = this.paramMap ++ paramMap
+ val featuresType = schema(map(featuresCol)).dataType
+ // TODO: Support casting Array[Double] and Array[Float] to Vector.
+ require(featuresType.isInstanceOf[VectorUDT],
+ s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
+ if (fitting) {
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
+ }
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
+ require(!fieldNames.contains(map(predictionCol)),
+ s"Prediction column ${map(predictionCol)} already exists.")
+ val outputFields = schema.fields ++ Seq(
+ StructField(map(scoreCol), DoubleType, false),
+ StructField(map(predictionCol), DoubleType, false))
+ StructType(outputFields)
+ }
+}
+
+/**
+ * Logistic regression.
+ */
+class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+
+ setRegParam(0.1)
+ setMaxIter(100)
+ setThreshold(0.5)
+
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ val lr = new LogisticRegressionWithLBFGS
+ lr.optimizer
+ .setRegParam(map(regParam))
+ .setNumIterations(map(maxIter))
+ val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
+ instances.unpersist()
+ // copy model params
+ Params.inheritValues(map, this, lrm)
+ lrm
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = true)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by [[LogisticRegression]].
+ */
+@AlphaComponent
+class LogisticRegressionModel private[ml] (
+ override val parent: LogisticRegression,
+ override val fittingParamMap: ParamMap,
+ weights: Vector)
+ extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = false)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val score: Vector => Double = (v) => {
+ val margin = BLAS.dot(v, weights)
+ 1.0 / (1.0 + math.exp(-margin))
+ }
+ val t = map(threshold)
+ val predict: Double => Double = (score) => {
+ if (score > t) 1.0 else 0.0
+ }
+ dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
+ .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
new file mode 100644
index 0000000000..0b0504e036
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.sql.{DoubleType, Row, SchemaRDD}
+
+/**
+ * :: AlphaComponent ::
+ * Evaluator for binary classification, which expects two input columns: score and label.
+ */
+@AlphaComponent
+class BinaryClassificationEvaluator extends Evaluator with Params
+ with HasScoreCol with HasLabelCol {
+
+ /** param for metric name in evaluation */
+ val metricName: Param[String] = new Param(this, "metricName",
+ "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+ def getMetricName: String = get(metricName)
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ val map = this.paramMap ++ paramMap
+
+ val schema = dataset.schema
+ val scoreType = schema(map(scoreCol)).dataType
+ require(scoreType == DoubleType,
+ s"Score column ${map(scoreCol)} must be double type but found $scoreType")
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Label column ${map(labelCol)} must be double type but found $labelType")
+
+ import dataset.sqlContext._
+ val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ .map { case Row(score: Double, label: Double) =>
+ (score, label)
+ }
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+ val metric = map(metricName) match {
+ case "areaUnderROC" =>
+ metrics.areaUnderROC()
+ case "areaUnderPR" =>
+ metrics.areaUnderPR()
+ case other =>
+ throw new IllegalArgumentException(s"Does not support metric $other.")
+ }
+ metrics.unpersist()
+ metric
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
new file mode 100644
index 0000000000..b98b1755a3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * :: AlphaComponent ::
+ * Maps a sequence of terms to their term frequencies using the hashing trick.
+ */
+@AlphaComponent
+class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+
+ /** number of features */
+ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
+ def setNumFeatures(value: Int) = set(numFeatures, value)
+ def getNumFeatures: Int = get(numFeatures)
+
+ override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
+ val hashingTF = new feature.HashingTF(paramMap(numFeatures))
+ hashingTF.transform
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
new file mode 100644
index 0000000000..896a6b83b6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+
+/**
+ * Params for [[StandardScaler]] and [[StandardScalerModel]].
+ */
+private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
+
+/**
+ * :: AlphaComponent ::
+ * Standardizes features by removing the mean and scaling to unit variance using column summary
+ * statistics on the samples in the training set.
+ */
+@AlphaComponent
+class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val input = dataset.select(map(inputCol).attr)
+ .map { case Row(v: Vector) =>
+ v
+ }
+ val scaler = new feature.StandardScaler().fit(input)
+ val model = new StandardScalerModel(this, map, scaler)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[StandardScaler]].
+ */
+@AlphaComponent
+class StandardScalerModel private[ml] (
+ override val parent: StandardScaler,
+ override val fittingParamMap: ParamMap,
+ scaler: feature.StandardScalerModel)
+ extends Model[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val scale: (Vector) => Vector = (v) => {
+ scaler.transform(v)
+ }
+ dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
new file mode 100644
index 0000000000..0a6599b64c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataType, StringType}
+
+/**
+ * :: AlphaComponent ::
+ * A tokenizer that converts the input string to lowercase and then splits it by white spaces.
+ */
+@AlphaComponent
+class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
+
+ protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ _.toLowerCase.split("\\s")
+ }
+
+ protected override def validateInputType(inputType: DataType): Unit = {
+ require(inputType == StringType, s"Input type must be string type but got $inputType.")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
new file mode 100644
index 0000000000..00d9c802e9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+@AlphaComponent
+package org.apache.spark.ml;
+
+import org.apache.spark.annotation.AlphaComponent;
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
new file mode 100644
index 0000000000..51cd48c904
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+package object ml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
new file mode 100644
index 0000000000..8fd46aef4b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import java.lang.reflect.Modifier
+
+import org.apache.spark.annotation.AlphaComponent
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.ml.Identifiable
+
+/**
+ * :: AlphaComponent ::
+ * A param with self-contained documentation and optionally default value. Primitive-typed param
+ * should use the specialized versions, which are more friendly to Java users.
+ *
+ * @param parent parent object
+ * @param name param name
+ * @param doc documentation
+ * @tparam T param value type
+ */
+@AlphaComponent
+class Param[T] (
+ val parent: Params,
+ val name: String,
+ val doc: String,
+ val defaultValue: Option[T] = None)
+ extends Serializable {
+
+ /**
+ * Creates a param pair with the given value (for Java).
+ */
+ def w(value: T): ParamPair[T] = this -> value
+
+ /**
+ * Creates a param pair with the given value (for Scala).
+ */
+ def ->(value: T): ParamPair[T] = ParamPair(this, value)
+
+ override def toString: String = {
+ if (defaultValue.isDefined) {
+ s"$name: $doc (default: ${defaultValue.get})"
+ } else {
+ s"$name: $doc"
+ }
+ }
+}
+
+// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
+
+/** Specialized version of [[Param[Double]]] for Java. */
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+ extends Param[Double](parent, name, doc, defaultValue) {
+
+ override def w(value: Double): ParamPair[Double] = super.w(value)
+}
+
+/** Specialized version of [[Param[Int]]] for Java. */
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+ extends Param[Int](parent, name, doc, defaultValue) {
+
+ override def w(value: Int): ParamPair[Int] = super.w(value)
+}
+
+/** Specialized version of [[Param[Float]]] for Java. */
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+ extends Param[Float](parent, name, doc, defaultValue) {
+
+ override def w(value: Float): ParamPair[Float] = super.w(value)
+}
+
+/** Specialized version of [[Param[Long]]] for Java. */
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+ extends Param[Long](parent, name, doc, defaultValue) {
+
+ override def w(value: Long): ParamPair[Long] = super.w(value)
+}
+
+/** Specialized version of [[Param[Boolean]]] for Java. */
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+ extends Param[Boolean](parent, name, doc, defaultValue) {
+
+ override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
+}
+
+/**
+ * A param amd its value.
+ */
+case class ParamPair[T](param: Param[T], value: T)
+
+/**
+ * :: AlphaComponent ::
+ * Trait for components that take parameters. This also provides an internal param map to store
+ * parameter values attached to the instance.
+ */
+@AlphaComponent
+trait Params extends Identifiable with Serializable {
+
+ /** Returns all params. */
+ def params: Array[Param[_]] = {
+ val methods = this.getClass.getMethods
+ methods.filter { m =>
+ Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty
+ }.sortBy(_.getName)
+ .map(m => m.invoke(this).asInstanceOf[Param[_]])
+ }
+
+ /**
+ * Validates parameter values stored internally plus the input parameter map.
+ * Raises an exception if any parameter is invalid.
+ */
+ def validate(paramMap: ParamMap): Unit = {}
+
+ /**
+ * Validates parameter values stored internally.
+ * Raise an exception if any parameter value is invalid.
+ */
+ def validate(): Unit = validate(ParamMap.empty)
+
+ /**
+ * Returns the documentation of all params.
+ */
+ def explainParams(): String = params.mkString("\n")
+
+ /** Checks whether a param is explicitly set. */
+ def isSet(param: Param[_]): Boolean = {
+ require(param.parent.eq(this))
+ paramMap.contains(param)
+ }
+
+ /** Gets a param by its name. */
+ private[ml] def getParam(paramName: String): Param[Any] = {
+ val m = this.getClass.getMethod(paramName)
+ assert(Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty)
+ m.invoke(this).asInstanceOf[Param[Any]]
+ }
+
+ /**
+ * Sets a parameter in the embedded param map.
+ */
+ private[ml] def set[T](param: Param[T], value: T): this.type = {
+ require(param.parent.eq(this))
+ paramMap.put(param.asInstanceOf[Param[Any]], value)
+ this
+ }
+
+ /**
+ * Gets the value of a parameter in the embedded param map.
+ */
+ private[ml] def get[T](param: Param[T]): T = {
+ require(param.parent.eq(this))
+ paramMap(param)
+ }
+
+ /**
+ * Internal param map.
+ */
+ protected val paramMap: ParamMap = ParamMap.empty
+}
+
+private[ml] object Params {
+
+ /**
+ * Copies parameter values from the parent estimator to the child model it produced.
+ * @param paramMap the param map that holds parameters of the parent
+ * @param parent the parent estimator
+ * @param child the child model
+ */
+ def inheritValues[E <: Params, M <: E](
+ paramMap: ParamMap,
+ parent: E,
+ child: M): Unit = {
+ parent.params.foreach { param =>
+ if (paramMap.contains(param)) {
+ child.set(child.getParam(param.name), paramMap(param))
+ }
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A param to value map.
+ */
+@AlphaComponent
+class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable {
+
+ /**
+ * Creates an empty param map.
+ */
+ def this() = this(mutable.Map.empty[Param[Any], Any])
+
+ /**
+ * Puts a (param, value) pair (overwrites if the input param exists).
+ */
+ def put[T](param: Param[T], value: T): this.type = {
+ map(param.asInstanceOf[Param[Any]]) = value
+ this
+ }
+
+ /**
+ * Puts a list of param pairs (overwrites if the input params exists).
+ */
+ def put(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ put(p.param.asInstanceOf[Param[Any]], p.value)
+ }
+ this
+ }
+
+ /**
+ * Optionally returns the value associated with a param or its default.
+ */
+ def get[T](param: Param[T]): Option[T] = {
+ map.get(param.asInstanceOf[Param[Any]])
+ .orElse(param.defaultValue)
+ .asInstanceOf[Option[T]]
+ }
+
+ /**
+ * Gets the value of the input param or its default value if it does not exist.
+ * Raises a NoSuchElementException if there is no value associated with the input param.
+ */
+ def apply[T](param: Param[T]): T = {
+ val value = get(param)
+ if (value.isDefined) {
+ value.get
+ } else {
+ throw new NoSuchElementException(s"Cannot find param ${param.name}.")
+ }
+ }
+
+ /**
+ * Checks whether a parameter is explicitly specified.
+ */
+ def contains(param: Param[_]): Boolean = {
+ map.contains(param.asInstanceOf[Param[Any]])
+ }
+
+ /**
+ * Filters this param map for the given parent.
+ */
+ def filter(parent: Params): ParamMap = {
+ val filtered = map.filterKeys(_.parent == parent)
+ new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
+ }
+
+ /**
+ * Make a copy of this param map.
+ */
+ def copy: ParamMap = new ParamMap(map.clone())
+
+ override def toString: String = {
+ map.map { case (param, value) =>
+ s"\t${param.parent.uid}-${param.name}: $value"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+
+ /**
+ * Returns a new param map that contains parameters in this map and the given map,
+ * where the latter overwrites this if there exists conflicts.
+ */
+ def ++(other: ParamMap): ParamMap = {
+ new ParamMap(this.map ++ other.map)
+ }
+
+
+ /**
+ * Adds all parameters from the input param map into this param map.
+ */
+ def ++=(other: ParamMap): this.type = {
+ this.map ++= other.map
+ this
+ }
+
+ /**
+ * Converts this param map to a sequence of param pairs.
+ */
+ def toSeq: Seq[ParamPair[_]] = {
+ map.toSeq.map { case (param, value) =>
+ ParamPair(param, value)
+ }
+ }
+}
+
+object ParamMap {
+
+ /**
+ * Returns an empty param map.
+ */
+ def empty: ParamMap = new ParamMap()
+
+ /**
+ * Constructs a param map by specifying its entries.
+ */
+ @varargs
+ def apply(paramPairs: ParamPair[_]*): ParamMap = {
+ new ParamMap().put(paramPairs: _*)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
new file mode 100644
index 0000000000..ef141d3eb2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+private[ml] trait HasRegParam extends Params {
+ /** param for regularization parameter */
+ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+ def getRegParam: Double = get(regParam)
+}
+
+private[ml] trait HasMaxIter extends Params {
+ /** param for max number of iterations */
+ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ def getMaxIter: Int = get(maxIter)
+}
+
+private[ml] trait HasFeaturesCol extends Params {
+ /** param for features column name */
+ val featuresCol: Param[String] =
+ new Param(this, "featuresCol", "features column name", Some("features"))
+ def getFeaturesCol: String = get(featuresCol)
+}
+
+private[ml] trait HasLabelCol extends Params {
+ /** param for label column name */
+ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
+ def getLabelCol: String = get(labelCol)
+}
+
+private[ml] trait HasScoreCol extends Params {
+ /** param for score column name */
+ val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
+ def getScoreCol: String = get(scoreCol)
+}
+
+private[ml] trait HasPredictionCol extends Params {
+ /** param for prediction column name */
+ val predictionCol: Param[String] =
+ new Param(this, "predictionCol", "prediction column name", Some("prediction"))
+ def getPredictionCol: String = get(predictionCol)
+}
+
+private[ml] trait HasThreshold extends Params {
+ /** param for threshold in (binary) prediction */
+ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
+ def getThreshold: Double = get(threshold)
+}
+
+private[ml] trait HasInputCol extends Params {
+ /** param for input column name */
+ val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
+ def getInputCol: String = get(inputCol)
+}
+
+private[ml] trait HasOutputCol extends Params {
+ /** param for output column name */
+ val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
+ def getOutputCol: String = get(outputCol)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
new file mode 100644
index 0000000000..194b9bfd9a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import com.github.fommil.netlib.F2jBLAS
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * Params for [[CrossValidator]] and [[CrossValidatorModel]].
+ */
+private[ml] trait CrossValidatorParams extends Params {
+ /** param for the estimator to be cross-validated */
+ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+ def getEstimator: Estimator[_] = get(estimator)
+
+ /** param for estimator param maps */
+ val estimatorParamMaps: Param[Array[ParamMap]] =
+ new Param(this, "estimatorParamMaps", "param maps for the estimator")
+ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
+
+ /** param for the evaluator for selection */
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+ def getEvaluator: Evaluator = get(evaluator)
+
+ /** param for number of folds for cross validation */
+ val numFolds: IntParam =
+ new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+ def getNumFolds: Int = get(numFolds)
+}
+
+/**
+ * :: AlphaComponent ::
+ * K-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {
+
+ private val f2jBLAS = new F2jBLAS
+
+ def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+ def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+ def setNumFolds(value: Int): this.type = set(numFolds, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ val map = this.paramMap ++ paramMap
+ val schema = dataset.schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val sqlCtx = dataset.sqlContext
+ val est = map(estimator)
+ val eval = map(evaluator)
+ val epm = map(estimatorParamMaps)
+ val numModels = epm.size
+ val metrics = new Array[Double](epm.size)
+ val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
+ val trainingDataset = sqlCtx.applySchema(training, schema).cache()
+ val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ // multi-model training
+ logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ var i = 0
+ while (i < numModels) {
+ val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
+ logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
+ metrics(i) += metric
+ i += 1
+ }
+ }
+ f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
+ logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+ val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ logInfo(s"Best cross-validation metric: $bestMetric.")
+ val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ val cvModel = new CrossValidatorModel(this, map, bestModel)
+ Params.inheritValues(map, this, cvModel)
+ cvModel
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ map(estimator).transformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model from k-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidatorModel private[ml] (
+ override val parent: CrossValidator,
+ override val fittingParamMap: ParamMap,
+ val bestModel: Model[_])
+ extends Model[CrossValidatorModel] with CrossValidatorParams {
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ bestModel.transform(dataset, paramMap)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ bestModel.transformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
new file mode 100644
index 0000000000..dafe73d82c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+
+/**
+ * :: AlphaComponent ::
+ * Builder for a param grid used in grid search-based model selection.
+ */
+@AlphaComponent
+class ParamGridBuilder {
+
+ private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ def baseOn(paramMap: ParamMap): this.type = {
+ baseOn(paramMap.toSeq: _*)
+ this
+ }
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ @varargs
+ def baseOn(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value))
+ }
+ this
+ }
+
+ /**
+ * Adds a param with multiple values (overwrites if the input param exists).
+ */
+ def addGrid[T](param: Param[T], values: Iterable[T]): this.type = {
+ paramGrid.put(param, values)
+ this
+ }
+
+ // specialized versions of addGrid for Java.
+
+ /**
+ * Adds a double param with multiple values.
+ */
+ def addGrid(param: DoubleParam, values: Array[Double]): this.type = {
+ addGrid[Double](param, values)
+ }
+
+ /**
+ * Adds a int param with multiple values.
+ */
+ def addGrid(param: IntParam, values: Array[Int]): this.type = {
+ addGrid[Int](param, values)
+ }
+
+ /**
+ * Adds a float param with multiple values.
+ */
+ def addGrid(param: FloatParam, values: Array[Float]): this.type = {
+ addGrid[Float](param, values)
+ }
+
+ /**
+ * Adds a long param with multiple values.
+ */
+ def addGrid(param: LongParam, values: Array[Long]): this.type = {
+ addGrid[Long](param, values)
+ }
+
+ /**
+ * Adds a boolean param with true and false.
+ */
+ def addGrid(param: BooleanParam): this.type = {
+ addGrid[Boolean](param, Array(true, false))
+ }
+
+ /**
+ * Builds and returns all combinations of parameters specified by the param grid.
+ */
+ def build(): Array[ParamMap] = {
+ var paramMaps = Array(new ParamMap)
+ paramGrid.foreach { case (param, values) =>
+ val newParamMaps = values.flatMap { v =>
+ paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v))
+ }
+ paramMaps = newParamMaps.toArray
+ }
+ paramMaps
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 54ee930d61..89539e600f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
/**
* BLAS routines for MLlib's vectors and matrices.
*/
-private[mllib] object BLAS extends Serializable with Logging {
+private[spark] object BLAS extends Serializable with Logging {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index ac217edc61..9fccd6341b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -115,6 +115,9 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def deserialize(datum: Any): Vector = {
datum match {
+ // TODO: something wrong with UDT serialization
+ case v: Vector =>
+ v
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 17c753c566..2067b36f24 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.regression
+import scala.beans.BeanInfo
+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
@@ -27,6 +29,7 @@ import org.apache.spark.SparkException
* @param label Label for this data point.
* @param features List of features for this data point.
*/
+@BeanInfo
case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
"(%s,%s)".format(label, features)
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
new file mode 100644
index 0000000000..42846677ed
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.StandardScaler;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+/**
+ * Test Pipeline construction and fitting in Java.
+ */
+public class JavaPipelineSuite {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaPipelineSuite");
+ jsql = new JavaSQLContext(jsc);
+ JavaRDD<LabeledPoint> points =
+ jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
+ dataset = jsql.applySchema(points, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void pipeline() {
+ StandardScaler scaler = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("scaledFeatures");
+ LogisticRegression lr = new LogisticRegression()
+ .setFeaturesCol("scaledFeatures");
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {scaler, lr});
+ PipelineModel model = pipeline.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
new file mode 100644
index 0000000000..76eb7f0032
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaLogisticRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new JavaSQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void logisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionWithSetters() {
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
+ .registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionFitWithVarargs() {
+ LogisticRegression lr = new LogisticRegression();
+ lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
new file mode 100644
index 0000000000..a266ebd207
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaCrossValidatorSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
+ jsql = new JavaSQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void crossValidationWithLogisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ ParamMap[] lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
+ .addGrid(lr.maxIter(), new int[] {0, 10})
+ .build();
+ BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
+ CrossValidator cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3);
+ CrossValidatorModel cvModel = cv.fit(dataset);
+ ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
+ Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
+ Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
new file mode 100644
index 0000000000..4515084bc7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.when
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar.mock
+
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+class PipelineSuite extends FunSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ test("pipeline") {
+ val estimator0 = mock[Estimator[MyModel]]
+ val model0 = mock[MyModel]
+ val transformer1 = mock[Transformer]
+ val estimator2 = mock[Estimator[MyModel]]
+ val model2 = mock[MyModel]
+ val transformer3 = mock[Transformer]
+ val dataset0 = mock[SchemaRDD]
+ val dataset1 = mock[SchemaRDD]
+ val dataset2 = mock[SchemaRDD]
+ val dataset3 = mock[SchemaRDD]
+ val dataset4 = mock[SchemaRDD]
+
+ when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
+ when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
+ when(model0.parent).thenReturn(estimator0)
+ when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2)
+ when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
+ when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
+ when(model2.parent).thenReturn(estimator2)
+ when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(estimator0, transformer1, estimator2, transformer3))
+ val pipelineModel = pipeline.fit(dataset0)
+
+ assert(pipelineModel.stages.size === 4)
+ assert(pipelineModel.stages(0).eq(model0))
+ assert(pipelineModel.stages(1).eq(transformer1))
+ assert(pipelineModel.stages(2).eq(model2))
+ assert(pipelineModel.stages(3).eq(transformer3))
+
+ assert(pipelineModel.getModel(estimator0).eq(model0))
+ assert(pipelineModel.getModel(estimator2).eq(model2))
+ intercept[NoSuchElementException] {
+ pipelineModel.getModel(mock[Estimator[MyModel]])
+ }
+ val output = pipelineModel.transform(dataset0)
+ assert(output.eq(dataset4))
+ }
+
+ test("pipeline with duplicate stages") {
+ val estimator = mock[Estimator[MyModel]]
+ val pipeline = new Pipeline()
+ .setStages(Array(estimator, estimator))
+ val dataset = mock[SchemaRDD]
+ intercept[IllegalArgumentException] {
+ pipeline.fit(dataset)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
new file mode 100644
index 0000000000..625af299a5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
+
+ import sqlContext._
+
+ val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+ test("logistic regression") {
+ val lr = new LogisticRegression
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select('label, 'prediction)
+ .collect()
+ }
+
+ test("logistic regression with setters") {
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
+ .select('label, 'score, 'prediction)
+ .collect()
+ }
+
+ test("logistic regression fit and transform with varargs") {
+ val lr = new LogisticRegression
+ val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
+ model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
+ .select('label, 'probability, 'prediction)
+ .collect()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
new file mode 100644
index 0000000000..1ce2987612
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import org.scalatest.FunSuite
+
+class ParamsSuite extends FunSuite {
+
+ val solver = new TestParams()
+ import solver.{inputCol, maxIter}
+
+ test("param") {
+ assert(maxIter.name === "maxIter")
+ assert(maxIter.doc === "max number of iterations")
+ assert(maxIter.defaultValue.get === 100)
+ assert(maxIter.parent.eq(solver))
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 100)")
+ assert(inputCol.defaultValue === None)
+ }
+
+ test("param pair") {
+ val pair0 = maxIter -> 5
+ val pair1 = maxIter.w(5)
+ val pair2 = ParamPair(maxIter, 5)
+ for (pair <- Seq(pair0, pair1, pair2)) {
+ assert(pair.param.eq(maxIter))
+ assert(pair.value === 5)
+ }
+ }
+
+ test("param map") {
+ val map0 = ParamMap.empty
+
+ assert(!map0.contains(maxIter))
+ assert(map0(maxIter) === maxIter.defaultValue.get)
+ map0.put(maxIter, 10)
+ assert(map0.contains(maxIter))
+ assert(map0(maxIter) === 10)
+
+ assert(!map0.contains(inputCol))
+ intercept[NoSuchElementException] {
+ map0(inputCol)
+ }
+ map0.put(inputCol -> "input")
+ assert(map0.contains(inputCol))
+ assert(map0(inputCol) === "input")
+
+ val map1 = map0.copy
+ val map2 = ParamMap(maxIter -> 10, inputCol -> "input")
+ val map3 = new ParamMap()
+ .put(maxIter, 10)
+ .put(inputCol, "input")
+ val map4 = ParamMap.empty ++ map0
+ val map5 = ParamMap.empty
+ map5 ++= map0
+
+ for (m <- Seq(map1, map2, map3, map4, map5)) {
+ assert(m.contains(maxIter))
+ assert(m(maxIter) === 10)
+ assert(m.contains(inputCol))
+ assert(m(inputCol) === "input")
+ }
+ }
+
+ test("params") {
+ val params = solver.params
+ assert(params.size === 2)
+ assert(params(0).eq(inputCol), "params must be ordered by name")
+ assert(params(1).eq(maxIter))
+ assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+ assert(solver.getParam("inputCol").eq(inputCol))
+ assert(solver.getParam("maxIter").eq(maxIter))
+ intercept[NoSuchMethodException] {
+ solver.getParam("abc")
+ }
+ assert(!solver.isSet(inputCol))
+ intercept[IllegalArgumentException] {
+ solver.validate()
+ }
+ solver.validate(ParamMap(inputCol -> "input"))
+ solver.setInputCol("input")
+ assert(solver.isSet(inputCol))
+ assert(solver.getInputCol === "input")
+ solver.validate()
+ intercept[IllegalArgumentException] {
+ solver.validate(ParamMap(maxIter -> -10))
+ }
+ solver.setMaxIter(-10)
+ intercept[IllegalArgumentException] {
+ solver.validate()
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
new file mode 100644
index 0000000000..1a65883d78
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+/** A subclass of Params for testing. */
+class TestParams extends Params {
+
+ val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100))
+ def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
+ def getMaxIter: Int = get(maxIter)
+
+ val inputCol = new Param[String](this, "inputCol", "input column name")
+ def setInputCol(value: String): this.type = { set(inputCol, value); this }
+ def getInputCol: String = get(inputCol)
+
+ override def validate(paramMap: ParamMap) = {
+ val m = this.paramMap ++ paramMap
+ require(m(maxIter) >= 0)
+ require(m.contains(inputCol))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
new file mode 100644
index 0000000000..72a334ae93
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class CrossValidatorSuite extends FunSuite with LocalSparkContext {
+
+ import sqlContext._
+
+ val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+ test("cross validation with logistic regression") {
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 10))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3)
+ val cvModel = cv.fit(dataset)
+ val bestParamMap = cvModel.bestModel.fittingParamMap
+ assert(bestParamMap(lr.regParam) === 0.001)
+ assert(bestParamMap(lr.maxIter) === 10)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
new file mode 100644
index 0000000000..20aa100112
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.param.{ParamMap, TestParams}
+
+class ParamGridBuilderSuite extends FunSuite {
+
+ val solver = new TestParams()
+ import solver.{inputCol, maxIter}
+
+ test("param grid builder") {
+ def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = {
+ assert(maps.size === expected.size)
+ maps.foreach { m =>
+ val tuple = (m(maxIter), m(inputCol))
+ assert(expected.contains(tuple))
+ expected.remove(tuple)
+ }
+ assert(expected.isEmpty)
+ }
+
+ val maps0 = new ParamGridBuilder()
+ .baseOn(maxIter -> 10)
+ .addGrid(inputCol, Array("input0", "input1"))
+ .build()
+ val expected0 = mutable.Set(
+ (10, "input0"),
+ (10, "input1"))
+ validateGrid(maps0, expected0)
+
+ val maps1 = new ParamGridBuilder()
+ .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten
+ .addGrid(maxIter, Array(10, 20))
+ .addGrid(inputCol, Array("input0", "input1"))
+ .build()
+ val expected1 = mutable.Set(
+ (10, "input0"),
+ (20, "input0"),
+ (10, "input1"),
+ (20, "input1"))
+ validateGrid(maps1, expected1)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 7857d9e5ee..4417d66adf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -17,26 +17,17 @@
package org.apache.spark.mllib.util
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
- @transient var sc: SparkContext = _
-
- override def beforeAll() {
- val conf = new SparkConf()
- .setMaster("local")
- .setAppName("test")
- sc = new SparkContext(conf)
- super.beforeAll()
- }
+ @transient val sc = new SparkContext("local", "test")
+ @transient lazy val sqlContext = new SQLContext(sc)
override def afterAll() {
- if (sc != null) {
- sc.stop()
- }
+ sc.stop()
super.afterAll()
}
}