aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincenzo Selvaggio <vselvaggio@hotmail.it>2015-04-29 23:21:21 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-29 23:21:21 -0700
commit254e0509762937acc9c72b432d5d953bf72c3c52 (patch)
tree29ebcc119912d02a58689c14633b3c2b770e2c02
parent4459514497eb76e6f2465d071857854390453805 (diff)
downloadspark-254e0509762937acc9c72b432d5d953bf72c3c52.tar.gz
spark-254e0509762937acc9c72b432d5d953bf72c3c52.tar.bz2
spark-254e0509762937acc9c72b432d5d953bf72c3c52.zip
[SPARK-1406] Mllib pmml model export
See PDF attached to the JIRA issue 1406. The contribution is my original work and I license the work to the project under the project's open source license. Author: Vincenzo Selvaggio <vselvaggio@hotmail.it> Author: Xiangrui Meng <meng@databricks.com> Author: selvinsource <vselvaggio@hotmail.it> Closes #3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits: 852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file 085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style 30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit 7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression 25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model 66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible a0a55f7 [Vincenzo Selvaggio] Merge pull request #2 from mengxr/SPARK-1406 3c22f79 [Xiangrui Meng] more code style e2313df [Vincenzo Selvaggio] Merge pull request #1 from mengxr/SPARK-1406 472d757 [Xiangrui Meng] fix code style 1676e15 [Vincenzo Selvaggio] fixed scala issue e2ffae8 [Vincenzo Selvaggio] fixed scala style b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context 7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models 7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories 03bc3a5 [Vincenzo Selvaggio] added logistic regression da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export 82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style 1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change 3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark a1b4dc3 [Vincenzo Selvaggio] updated imports 834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines 349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format c3ef9b8 [Vincenzo Selvaggio] set it to private 6357b98 [Vincenzo Selvaggio] set it to private e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object aba5ee1 [Vincenzo Selvaggio] fixed cluster export cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406 07a29bf [selvinsource] Update LICENSE 8841439 [Vincenzo Selvaggio] adjust scala style in order to compile 1433b11 [Vincenzo Selvaggio] complete suite tests 8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation 9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait 226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML) a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation
-rw-r--r--LICENSE1
-rw-r--r--mllib/pom.xml15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala74
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala90
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala75
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala83
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala64
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala84
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala84
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala49
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala95
18 files changed, 774 insertions, 6 deletions
diff --git a/LICENSE b/LICENSE
index 9b364a4d00..21c42e9a20 100644
--- a/LICENSE
+++ b/LICENSE
@@ -814,6 +814,7 @@ BSD-style licenses
The following components are provided under a BSD-style license. See project link for details.
(BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
+ (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model)
(BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 5dfab36c76..a3c57ae260 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -109,6 +109,21 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.jpmml</groupId>
+ <artifactId>pmml-model</artifactId>
+ <version>1.1.15</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.xml.fastinfoset</groupId>
+ <artifactId>FastInfoset</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.istack</groupId>
+ <artifactId>istack-commons-runtime</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
<profiles>
<profile>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 057b628c6a..bd2e9079ce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
@@ -46,7 +47,7 @@ class LogisticRegressionModel (
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
- with Saveable {
+ with Saveable with PMMLExportable {
if (numClasses == 2) {
require(weights.size == numFeatures,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 52fb62dcff..33104cf06c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
@@ -36,7 +37,7 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
- with Saveable {
+ with Saveable with PMMLExportable {
private var threshold: Option[Double] = Some(0.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index e4e411a3c8..ba228b11fc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
@@ -34,7 +35,8 @@ import org.apache.spark.sql.Row
/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
-class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
+class KMeansModel (
+ val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {
/** A Java-friendly constructor that takes an Iterable of Vectors. */
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
new file mode 100644
index 0000000000..354e90f3ee
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml
+
+import java.io.{File, OutputStream, StringWriter}
+import javax.xml.transform.stream.StreamResult
+
+import org.jpmml.model.JAXBUtil
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
+
+/**
+ * Export model to the PMML format
+ * Predictive Model Markup Language (PMML) is an XML-based file format
+ * developed by the Data Mining Group (www.dmg.org).
+ */
+trait PMMLExportable {
+
+ /**
+ * Export the model to the stream result in PMML format
+ */
+ private def toPMML(streamResult: StreamResult): Unit = {
+ val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
+ JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
+ }
+
+ /**
+ * Export the model to a local file in PMML format
+ */
+ def toPMML(localPath: String): Unit = {
+ toPMML(new StreamResult(new File(localPath)))
+ }
+
+ /**
+ * Export the model to a directory on a distributed file system in PMML format
+ */
+ def toPMML(sc: SparkContext, path: String): Unit = {
+ val pmml = toPMML()
+ sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
+ }
+
+ /**
+ * Export the model to the OutputStream in PMML format
+ */
+ def toPMML(outputStream: OutputStream): Unit = {
+ toPMML(new StreamResult(outputStream))
+ }
+
+ /**
+ * Export the model to a String in PMML format
+ */
+ def toPMML(): String = {
+ val writer = new StringWriter
+ toPMML(new StreamResult(writer))
+ writer.toString
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
new file mode 100644
index 0000000000..34b447584e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.regression.GeneralizedLinearModel
+
+/**
+ * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
+ */
+private[mllib] class BinaryClassificationPMMLModelExport(
+ model : GeneralizedLinearModel,
+ description : String,
+ normalizationMethod : RegressionNormalizationMethodType,
+ threshold: Double)
+ extends PMMLModelExport {
+
+ populateBinaryClassificationPMML()
+
+ /**
+ * Export the input LogisticRegressionModel or SVMModel to PMML format.
+ */
+ private def populateBinaryClassificationPMML(): Unit = {
+ pmml.getHeader.setDescription(description)
+
+ if (model.weights.size > 0) {
+ val fields = new SArray[FieldName](model.weights.size)
+ val dataDictionary = new DataDictionary
+ val miningSchema = new MiningSchema
+ val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
+ var interceptNO = threshold
+ if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
+ if (threshold <= 0) {
+ interceptNO = Double.MinValue
+ } else if (threshold >= 1) {
+ interceptNO = Double.MaxValue
+ } else {
+ interceptNO = -math.log(1 / threshold - 1)
+ }
+ }
+ val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
+ val regressionModel = new RegressionModel()
+ .withFunctionName(MiningFunctionType.CLASSIFICATION)
+ .withMiningSchema(miningSchema)
+ .withModelName(description)
+ .withNormalizationMethod(normalizationMethod)
+ .withRegressionTables(regressionTableYES, regressionTableNO)
+
+ for (i <- 0 until model.weights.size) {
+ fields(i) = FieldName.create("field_" + i)
+ dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+ miningSchema
+ .withMiningFields(new MiningField(fields(i))
+ .withUsageType(FieldUsageType.ACTIVE))
+ regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
+ }
+
+ // add target field
+ val targetField = FieldName.create("target")
+ dataDictionary
+ .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
+ miningSchema
+ .withMiningFields(new MiningField(targetField)
+ .withUsageType(FieldUsageType.TARGET))
+
+ dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+
+ pmml.setDataDictionary(dataDictionary)
+ pmml.withModels(regressionModel)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
new file mode 100644
index 0000000000..1874786af0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.regression.GeneralizedLinearModel
+
+/**
+ * PMML Model Export for GeneralizedLinearModel abstract class
+ */
+private[mllib] class GeneralizedLinearPMMLModelExport(
+ model: GeneralizedLinearModel,
+ description: String)
+ extends PMMLModelExport {
+
+ populateGeneralizedLinearPMML(model)
+
+ /**
+ * Export the input GeneralizedLinearModel model to PMML format.
+ */
+ private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
+ pmml.getHeader.setDescription(description)
+
+ if (model.weights.size > 0) {
+ val fields = new SArray[FieldName](model.weights.size)
+ val dataDictionary = new DataDictionary
+ val miningSchema = new MiningSchema
+ val regressionTable = new RegressionTable(model.intercept)
+ val regressionModel = new RegressionModel()
+ .withFunctionName(MiningFunctionType.REGRESSION)
+ .withMiningSchema(miningSchema)
+ .withModelName(description)
+ .withRegressionTables(regressionTable)
+
+ for (i <- 0 until model.weights.size) {
+ fields(i) = FieldName.create("field_" + i)
+ dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+ miningSchema
+ .withMiningFields(new MiningField(fields(i))
+ .withUsageType(FieldUsageType.ACTIVE))
+ regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
+ }
+
+ // for completeness add target field
+ val targetField = FieldName.create("target")
+ dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
+ miningSchema
+ .withMiningFields(new MiningField(targetField)
+ .withUsageType(FieldUsageType.TARGET))
+
+ dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+
+ pmml.setDataDictionary(dataDictionary)
+ pmml.withModels(regressionModel)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
new file mode 100644
index 0000000000..069e7afc9f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.clustering.KMeansModel
+
+/**
+ * PMML Model Export for KMeansModel class
+ */
+private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
+
+ populateKMeansPMML(model)
+
+ /**
+ * Export the input KMeansModel model to PMML format.
+ */
+ private def populateKMeansPMML(model : KMeansModel): Unit = {
+ pmml.getHeader.setDescription("k-means clustering")
+
+ if (model.clusterCenters.length > 0) {
+ val clusterCenter = model.clusterCenters(0)
+ val fields = new SArray[FieldName](clusterCenter.size)
+ val dataDictionary = new DataDictionary
+ val miningSchema = new MiningSchema
+ val comparisonMeasure = new ComparisonMeasure()
+ .withKind(ComparisonMeasure.Kind.DISTANCE)
+ .withMeasure(new SquaredEuclidean())
+ val clusteringModel = new ClusteringModel()
+ .withModelName("k-means")
+ .withMiningSchema(miningSchema)
+ .withComparisonMeasure(comparisonMeasure)
+ .withFunctionName(MiningFunctionType.CLUSTERING)
+ .withModelClass(ClusteringModel.ModelClass.CENTER_BASED)
+ .withNumberOfClusters(model.clusterCenters.length)
+
+ for (i <- 0 until clusterCenter.size) {
+ fields(i) = FieldName.create("field_" + i)
+ dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+ miningSchema
+ .withMiningFields(new MiningField(fields(i))
+ .withUsageType(FieldUsageType.ACTIVE))
+ clusteringModel.withClusteringFields(
+ new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
+ }
+
+ dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+
+ for (i <- 0 until model.clusterCenters.length) {
+ val cluster = new Cluster()
+ .withName("cluster_" + i)
+ .withArray(new org.dmg.pmml.Array()
+ .withType(Array.Type.REAL)
+ .withN(clusterCenter.size)
+ .withValue(model.clusterCenters(i).toArray.mkString(" ")))
+ // we don't have the size of the single cluster but only the centroids (withValue)
+ // .withSize(value)
+ clusteringModel.withClusters(cluster)
+ }
+
+ pmml.setDataDictionary(dataDictionary)
+ pmml.withModels(clusteringModel)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
new file mode 100644
index 0000000000..ebdeae50bb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import scala.beans.BeanProperty
+
+import org.dmg.pmml.{Application, Header, PMML, Timestamp}
+
+private[mllib] trait PMMLModelExport {
+
+ /**
+ * Holder of the exported model in PMML format
+ */
+ @BeanProperty
+ val pmml: PMML = new PMML
+
+ setHeader(pmml)
+
+ private def setHeader(pmml: PMML): Unit = {
+ val version = getClass.getPackage.getImplementationVersion
+ val app = new Application().withName("Apache Spark MLlib").withVersion(version)
+ val timestamp = new Timestamp()
+ .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date()))
+ val header = new Header()
+ .withApplication(app)
+ .withTimestamp(timestamp)
+ pmml.setHeader(header)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
new file mode 100644
index 0000000000..c16e83d6a0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionNormalizationMethodType
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+import org.apache.spark.mllib.classification.SVMModel
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.regression.LassoModel
+import org.apache.spark.mllib.regression.LinearRegressionModel
+import org.apache.spark.mllib.regression.RidgeRegressionModel
+
+private[mllib] object PMMLModelExportFactory {
+
+ /**
+ * Factory object to help creating the necessary PMMLModelExport implementation
+ * taking as input the machine learning model (for example KMeansModel).
+ */
+ def createPMMLModelExport(model: Any): PMMLModelExport = {
+ model match {
+ case kmeans: KMeansModel =>
+ new KMeansPMMLModelExport(kmeans)
+ case linear: LinearRegressionModel =>
+ new GeneralizedLinearPMMLModelExport(linear, "linear regression")
+ case ridge: RidgeRegressionModel =>
+ new GeneralizedLinearPMMLModelExport(ridge, "ridge regression")
+ case lasso: LassoModel =>
+ new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
+ case svm: SVMModel =>
+ new BinaryClassificationPMMLModelExport(
+ svm, "linear SVM", RegressionNormalizationMethodType.NONE,
+ svm.getThreshold.getOrElse(0.0))
+ case logistic: LogisticRegressionModel =>
+ if (logistic.numClasses == 2) {
+ new BinaryClassificationPMMLModelExport(
+ logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT,
+ logistic.getThreshold.getOrElse(0.5))
+ } else {
+ throw new IllegalArgumentException(
+ "PMML Export not supported for Multinomial Logistic Regression")
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ "PMML Export not supported for model: " + model.getClass.getName)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index e8b0381657..4f482384f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
@@ -34,7 +35,7 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable with Saveable {
+ with RegressionModel with Serializable with Saveable with PMMLExportable {
override protected def predictPoint(
dataMatrix: Vector,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 6fa7ad52a5..9453c4f66c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
@@ -34,7 +35,7 @@ class LinearRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
- with Saveable {
+ with Saveable with PMMLExportable {
override protected def predictPoint(
dataMatrix: Vector,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 309f9af466..e0c03d8180 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
@@ -35,7 +36,7 @@ class RidgeRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable with Saveable {
+ with RegressionModel with Serializable with Saveable with PMMLExportable {
override protected def predictPoint(
dataMatrix: Vector,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
new file mode 100644
index 0000000000..0b646cf1ce
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionModel
+import org.dmg.pmml.RegressionNormalizationMethodType
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+import org.apache.spark.mllib.classification.SVMModel
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class BinaryClassificationPMMLModelExportSuite extends FunSuite {
+
+ test("logistic regression PMML export") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+ val logisticRegressionModel =
+ new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
+
+ val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
+
+ // assert that the PMML format is as expected
+ assert(logisticModelExport.isInstanceOf[PMMLModelExport])
+ val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml
+ assert(pmml.getHeader.getDescription === "logistic regression")
+ // check that the number of fields match the weights size
+ assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1)
+ // This verify that there is a model attached to the pmml object and the model is a regression
+ // one. It also verifies that the pmml model has a regression table (for target category 1)
+ // with the same number of predictors of the model weights.
+ val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+ === logisticRegressionModel.weights.size)
+ // verify if there is a second table with target category 0 and no predictors
+ assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
+ assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
+ // ensure logistic regression has normalization method set to LOGIT
+ assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
+ }
+
+ test("linear SVM PMML export") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+ val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
+
+ val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
+
+ // assert that the PMML format is as expected
+ assert(svmModelExport.isInstanceOf[PMMLModelExport])
+ val pmml = svmModelExport.getPmml
+ assert(pmml.getHeader.getDescription
+ === "linear SVM")
+ // check that the number of fields match the weights size
+ assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
+ // This verify that there is a model attached to the pmml object and the model is a regression
+ // one. It also verifies that the pmml model has a regression table (for target category 1)
+ // with the same number of predictors of the model weights.
+ val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+ === svmModel.weights.size)
+ // verify if there is a second table with target category 0 and no predictors
+ assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
+ assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
+ // ensure linear SVM has normalization method set to NONE
+ assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
new file mode 100644
index 0000000000..f9afbd888d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionModel
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
+
+ test("linear regression PMML export") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+ val linearRegressionModel =
+ new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
+ val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
+ // assert that the PMML format is as expected
+ assert(linearModelExport.isInstanceOf[PMMLModelExport])
+ val pmml = linearModelExport.getPmml
+ assert(pmml.getHeader.getDescription === "linear regression")
+ // check that the number of fields match the weights size
+ assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1)
+ // This verifies that there is a model attached to the pmml object and the model is a regression
+ // one. It also verifies that the pmml model has a regression table with the same number of
+ // predictors of the model weights.
+ val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+ === linearRegressionModel.weights.size)
+ }
+
+ test("ridge regression PMML export") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+ val ridgeRegressionModel =
+ new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
+ val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
+ // assert that the PMML format is as expected
+ assert(ridgeModelExport.isInstanceOf[PMMLModelExport])
+ val pmml = ridgeModelExport.getPmml
+ assert(pmml.getHeader.getDescription === "ridge regression")
+ // check that the number of fields match the weights size
+ assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1)
+ // This verify that there is a model attached to the pmml object and the model is a regression
+ // one. It also verifies that the pmml model has a regression table with the same number of
+ // predictors of the model weights.
+ val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+ === ridgeRegressionModel.weights.size)
+ }
+
+ test("lasso PMML export") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+ val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
+ val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
+ // assert that the PMML format is as expected
+ assert(lassoModelExport.isInstanceOf[PMMLModelExport])
+ val pmml = lassoModelExport.getPmml
+ assert(pmml.getHeader.getDescription === "lasso regression")
+ // check that the number of fields match the weights size
+ assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1)
+ // This verify that there is a model attached to the pmml object and the model is a regression
+ // one. It also verifies that the pmml model has a regression table with the same number of
+ // predictors of the model weights.
+ val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+ assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+ === lassoModel.weights.size)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
new file mode 100644
index 0000000000..b985d0446d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.ClusteringModel
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.linalg.Vectors
+
+class KMeansPMMLModelExportSuite extends FunSuite {
+
+ test("KMeansPMMLModelExport generate PMML format") {
+ val clusterCenters = Array(
+ Vectors.dense(1.0, 2.0, 6.0),
+ Vectors.dense(1.0, 3.0, 0.0),
+ Vectors.dense(1.0, 4.0, 6.0))
+ val kmeansModel = new KMeansModel(clusterCenters)
+
+ val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
+
+ // assert that the PMML format is as expected
+ assert(modelExport.isInstanceOf[PMMLModelExport])
+ val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml
+ assert(pmml.getHeader.getDescription === "k-means clustering")
+ // check that the number of fields match the single vector size
+ assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
+ // This verify that there is a model attached to the pmml object and the model is a clustering
+ // one. It also verifies that the pmml model has the same number of clusters of the spark model.
+ val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+ assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
new file mode 100644
index 0000000000..f28a4ac8ad
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class PMMLModelExportFactorySuite extends FunSuite {
+
+ test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
+ val clusterCenters = Array(
+ Vectors.dense(1.0, 2.0, 6.0),
+ Vectors.dense(1.0, 3.0, 0.0),
+ Vectors.dense(1.0, 4.0, 6.0))
+ val kmeansModel = new KMeansModel(clusterCenters)
+
+ val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
+
+ assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
+ }
+
+ test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+ + "LinearRegressionModel, RidgeRegressionModel or LassoModel") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+
+ val linearRegressionModel =
+ new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
+ val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
+ assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+
+ val ridgeRegressionModel =
+ new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
+ val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
+ assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+
+ val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
+ val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
+ assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+ }
+
+ test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+ + "when passing a LogisticRegressionModel or SVMModel") {
+ val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+
+ val logisticRegressionModel =
+ new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
+ val logisticRegressionModelExport =
+ PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
+ assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
+
+ val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
+ val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
+ assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
+ }
+
+ test("PMMLModelExportFactory throw IllegalArgumentException "
+ + "when passing a Multinomial Logistic Regression") {
+ /** 3 classes, 2 features */
+ val multiclassLogisticRegressionModel = new LogisticRegressionModel(
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
+ numFeatures = 2, numClasses = 3)
+
+ intercept[IllegalArgumentException] {
+ PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
+ }
+ }
+
+ test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
+ val invalidModel = new Object
+
+ intercept[IllegalArgumentException] {
+ PMMLModelExportFactory.createPMMLModelExport(invalidModel)
+ }
+ }
+}