diff options
Diffstat (limited to 'sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java')
-rw-r--r-- | sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java | 406 |
1 files changed, 406 insertions, 0 deletions
diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java new file mode 100644 index 0000000000..4b2015e0df --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java @@ -0,0 +1,406 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.antlr.runtime.tree.Tree; +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + +/** + * SemanticAnalyzer. + * + */ +public abstract class SemanticAnalyzer { + public static String charSetString(String charSetName, String charSetString) + throws SemanticException { + try { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), + charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + String res = new String(bArray, charSetName); + return res; + } + } catch (UnsupportedEncodingException e) { + throw new SemanticException(e); + } + } + + /** + * Remove the encapsulating "`" pair from the identifier. We allow users to + * use "`" to escape identifier for table names, column names and aliases, in + * case that coincide with Hive language keywords. + */ + public static String unescapeIdentifier(String val) { + if (val == null) { + return null; + } + if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { + val = val.substring(1, val.length() - 1); + } + return val; + } + + /** + * Converts parsed key/value properties pairs into a map. + * + * @param prop ASTNode parent of the key/value pairs + * + * @param mapProp property map which receives the mappings + */ + public static void readProps( + ASTNode prop, Map<String, String> mapProp) { + + for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { + String key = unescapeSQLString(prop.getChild(propChild).getChild(0) + .getText()); + String value = null; + if (prop.getChild(propChild).getChild(1) != null) { + value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); + } + mapProp.put(key, value); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } + + /** + * Get the list of FieldSchema out of the ASTNode. + */ + public static List<FieldSchema> getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { + List<FieldSchema> colList = new ArrayList<FieldSchema>(); + int numCh = ast.getChildCount(); + for (int i = 0; i < numCh; i++) { + FieldSchema col = new FieldSchema(); + ASTNode child = (ASTNode) ast.getChild(i); + Tree grandChild = child.getChild(0); + if(grandChild != null) { + String name = grandChild.getText(); + if(lowerCase) { + name = name.toLowerCase(); + } + // child 0 is the name of the column + col.setName(unescapeIdentifier(name)); + // child 1 is the type of the column + ASTNode typeChild = (ASTNode) (child.getChild(1)); + col.setType(getTypeStringFromAST(typeChild)); + + // child 2 is the optional comment of the column + if (child.getChildCount() == 3) { + col.setComment(unescapeSQLString(child.getChild(2).getText())); + } + } + colList.add(col); + } + return colList; + } + + protected static String getTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + switch (typeNode.getType()) { + case SparkSqlParser.TOK_LIST: + return serdeConstants.LIST_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; + case SparkSqlParser.TOK_MAP: + return serdeConstants.MAP_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," + + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; + case SparkSqlParser.TOK_STRUCT: + return getStructTypeStringFromAST(typeNode); + case SparkSqlParser.TOK_UNIONTYPE: + return getUnionTypeStringFromAST(typeNode); + default: + return getTypeName(typeNode); + } + } + + private static String getStructTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty struct not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + ASTNode child = (ASTNode) typeNode.getChild(i); + buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); + buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); + if (i < children - 1) { + buffer.append(","); + } + } + + buffer.append(">"); + return buffer.toString(); + } + + private static String getUnionTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty union not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); + if (i < children - 1) { + buffer.append(","); + } + } + buffer.append(">"); + typeStr = buffer.toString(); + return typeStr; + } + + public static String getAstNodeText(ASTNode tree) { + return tree.getChildCount() == 0?tree.getText() : + getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); + } + + public static String generateErrorMessage(ASTNode ast, String message) { + StringBuilder sb = new StringBuilder(); + if (ast == null) { + sb.append(message).append(". Cannot tell the position of null AST."); + return sb.toString(); + } + sb.append(ast.getLine()); + sb.append(":"); + sb.append(ast.getCharPositionInLine()); + sb.append(" "); + sb.append(message); + sb.append(". Error encountered near token '"); + sb.append(getAstNodeText(ast)); + sb.append("'"); + return sb.toString(); + } + + private static final Map<Integer, String> TokenToTypeName = new HashMap<Integer, String>(); + + static { + TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); + } + + public static String getTypeName(ASTNode node) throws SemanticException { + int token = node.getType(); + String typeName; + + // datetime type isn't currently supported + if (token == SparkSqlParser.TOK_DATETIME) { + throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); + } + + switch (token) { + case SparkSqlParser.TOK_CHAR: + CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); + typeName = charTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_VARCHAR: + VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); + typeName = varcharTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_DECIMAL: + DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); + typeName = decTypeInfo.getQualifiedName(); + break; + default: + typeName = TokenToTypeName.get(token); + } + return typeName; + } + + public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { + boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); + if (testMode) { + URI uri = new Path(location).toUri(); + String scheme = uri.getScheme(); + String authority = uri.getAuthority(); + String path = uri.getPath(); + if (!path.startsWith("/")) { + path = (new Path(System.getProperty("test.tmp.dir"), + path)).toUri().getPath(); + } + if (StringUtils.isEmpty(scheme)) { + scheme = "pfile"; + } + try { + uri = new URI(scheme, authority, path, null, null); + } catch (URISyntaxException e) { + throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); + } + return uri.toString(); + } else { + //no-op for non-test mode for now + return location; + } + } +} |