Parcourir la source

[feat-4496][server] Add to! {} is used to mark the custom parameters to be output as-is in sql (#4497)

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496
liuxuedongcn il y a 4 ans
Parent
commit
43586da376

+ 103 - 78
dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java

@@ -18,7 +18,9 @@ package org.apache.dolphinscheduler.server.worker.task.sql;
 
 import com.fasterxml.jackson.databind.node.ArrayNode;
 import com.fasterxml.jackson.databind.node.ObjectNode;
+
 import org.apache.commons.lang.StringUtils;
+
 import org.apache.dolphinscheduler.alert.utils.MailUtils;
 import org.apache.dolphinscheduler.common.Constants;
 import org.apache.dolphinscheduler.common.enums.*;
@@ -41,6 +43,7 @@ import org.apache.dolphinscheduler.server.utils.ParamUtils;
 import org.apache.dolphinscheduler.server.utils.UDFUtils;
 import org.apache.dolphinscheduler.server.worker.task.AbstractTask;
 import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
+
 import org.slf4j.Logger;
 
 import java.sql.*;
@@ -51,17 +54,18 @@ import java.util.stream.Collectors;
 
 import static org.apache.dolphinscheduler.common.Constants.*;
 import static org.apache.dolphinscheduler.common.enums.DbType.HIVE;
+
 /**
  * sql task
  */
 public class SqlTask extends AbstractTask {
 
     /**
-     *  sql parameters
+     * sql parameters
      */
     private SqlParameters sqlParameters;
     /**
-     *  alert dao
+     * alert dao
      */
     private AlertDao alertDao;
     /**
@@ -148,10 +152,11 @@ public class SqlTask extends AbstractTask {
 
     /**
      * ready to execute SQL and parameter entity Map
+     *
      * @return SqlBinds
      */
     private SqlBinds getSqlAndSqlParamsMap(String sql) {
-        Map<Integer,Property> sqlParamsMap =  new HashMap<>();
+        Map<Integer, Property> sqlParamsMap = new HashMap<>();
         StringBuilder sqlBuilder = new StringBuilder();
 
         // find process instance by task id
@@ -164,25 +169,27 @@ public class SqlTask extends AbstractTask {
                 taskExecutionContext.getScheduleTime());
 
         // spell SQL according to the final user-defined variable
-        if(paramsMap == null){
+        if (paramsMap == null) {
             sqlBuilder.append(sql);
             return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
         }
 
-        if (StringUtils.isNotEmpty(sqlParameters.getTitle())){
+        if (StringUtils.isNotEmpty(sqlParameters.getTitle())) {
             String title = ParameterUtils.convertParameterPlaceholders(sqlParameters.getTitle(),
                     ParamUtils.convert(paramsMap));
-            logger.info("SQL title : {}",title);
+            logger.info("SQL title : {}", title);
             sqlParameters.setTitle(title);
         }
-        
+
         //new
         //replace variable TIME with $[YYYYmmddd...] in sql when history run job and batch complement job
         sql = ParameterUtils.replaceScheduleTime(sql, taskExecutionContext.getScheduleTime());
         // special characters need to be escaped, ${} needs to be escaped
         String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*";
         setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap);
-
+        //Replace the original value in sql !{...} ,Does not participate in precompilation
+        String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*";
+        sql = replaceOriginalValue(sql, rgexo, paramsMap);
         // replace the ${} of the SQL statement with the Placeholder
         String formatSql = sql.replaceAll(rgex, "?");
         sqlBuilder.append(formatSql);
@@ -192,6 +199,20 @@ public class SqlTask extends AbstractTask {
         return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
     }
 
+    public String replaceOriginalValue(String content, String rgex, Map<String, Property> sqlParamsMap) {
+        Pattern pattern = Pattern.compile(rgex);
+        while (true) {
+            Matcher m = pattern.matcher(content);
+            if (!m.find()) {
+                break;
+            }
+            String paramName = m.group(1);
+            String paramValue = sqlParamsMap.get(paramName).getValue();
+            content = m.replaceFirst(paramValue);
+        }
+        return content;
+    }
+
     @Override
     public AbstractParameters getParameters() {
         return this.sqlParameters;
@@ -199,15 +220,16 @@ public class SqlTask extends AbstractTask {
 
     /**
      * execute function and sql
-     * @param mainSqlBinds          main sql binds
-     * @param preStatementsBinds    pre statements binds
-     * @param postStatementsBinds   post statements binds
-     * @param createFuncs           create functions
+     *
+     * @param mainSqlBinds main sql binds
+     * @param preStatementsBinds pre statements binds
+     * @param postStatementsBinds post statements binds
+     * @param createFuncs create functions
      */
     public void executeFuncAndSql(SqlBinds mainSqlBinds,
-                                        List<SqlBinds> preStatementsBinds,
-                                        List<SqlBinds> postStatementsBinds,
-                                        List<String> createFuncs){
+                                  List<SqlBinds> preStatementsBinds,
+                                  List<SqlBinds> postStatementsBinds,
+                                  List<String> createFuncs) {
         Connection connection = null;
         PreparedStatement stmt = null;
         ResultSet resultSet = null;
@@ -218,11 +240,11 @@ public class SqlTask extends AbstractTask {
             connection = createConnection();
             // create temp function
             if (CollectionUtils.isNotEmpty(createFuncs)) {
-                createTempFunction(connection,createFuncs);
+                createTempFunction(connection, createFuncs);
             }
 
             // pre sql
-            preSql(connection,preStatementsBinds);
+            preSql(connection, preStatementsBinds);
             stmt = prepareStatementAndBind(connection, mainSqlBinds);
 
             // decide whether to executeQuery or executeUpdate based on sqlType
@@ -236,13 +258,13 @@ public class SqlTask extends AbstractTask {
                 stmt.executeUpdate();
             }
 
-            postSql(connection,postStatementsBinds);
+            postSql(connection, postStatementsBinds);
 
         } catch (Exception e) {
-            logger.error("execute sql error",e);
+            logger.error("execute sql error", e);
             throw new RuntimeException("execute sql error");
         } finally {
-            close(resultSet,stmt,connection);
+            close(resultSet, stmt, connection);
         }
     }
 
@@ -252,7 +274,7 @@ public class SqlTask extends AbstractTask {
      * @param resultSet resultSet
      * @throws Exception Exception
      */
-    private void resultProcess(ResultSet resultSet) throws Exception{
+    private void resultProcess(ResultSet resultSet) throws Exception {
         ArrayNode resultJSONArray = JSONUtils.createArrayNode();
         ResultSetMetaData md = resultSet.getMetaData();
         int num = md.getColumnCount();
@@ -271,22 +293,22 @@ public class SqlTask extends AbstractTask {
         logger.debug("execute sql : {}", result);
 
         sendAttachment(StringUtils.isNotEmpty(sqlParameters.getTitle()) ?
-                        sqlParameters.getTitle(): taskExecutionContext.getTaskName() + " query result sets",
+                        sqlParameters.getTitle() : taskExecutionContext.getTaskName() + " query result sets",
                 JSONUtils.toJsonString(resultJSONArray));
     }
 
     /**
-     *  pre sql
+     * pre sql
      *
      * @param connection connection
      * @param preStatementsBinds preStatementsBinds
      */
     private void preSql(Connection connection,
-                        List<SqlBinds> preStatementsBinds) throws Exception{
-        for (SqlBinds sqlBind: preStatementsBinds) {
-            try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){
+                        List<SqlBinds> preStatementsBinds) throws Exception {
+        for (SqlBinds sqlBind : preStatementsBinds) {
+            try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
                 int result = pstmt.executeUpdate();
-                logger.info("pre statement execute result: {}, for sql: {}",result,sqlBind.getSql());
+                logger.info("pre statement execute result: {}, for sql: {}", result, sqlBind.getSql());
 
             }
         }
@@ -297,26 +319,25 @@ public class SqlTask extends AbstractTask {
      *
      * @param connection connection
      * @param postStatementsBinds postStatementsBinds
-     * @throws Exception
      */
     private void postSql(Connection connection,
-                         List<SqlBinds> postStatementsBinds) throws Exception{
-        for (SqlBinds sqlBind: postStatementsBinds) {
-            try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){
+                         List<SqlBinds> postStatementsBinds) throws Exception {
+        for (SqlBinds sqlBind : postStatementsBinds) {
+            try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
                 int result = pstmt.executeUpdate();
-                logger.info("post statement execute result: {},for sql: {}",result,sqlBind.getSql());
+                logger.info("post statement execute result: {},for sql: {}", result, sqlBind.getSql());
             }
         }
     }
+
     /**
      * create temp function
      *
      * @param connection connection
      * @param createFuncs createFuncs
-     * @throws Exception
      */
     private void createTempFunction(Connection connection,
-                                    List<String> createFuncs) throws Exception{
+                                    List<String> createFuncs) throws Exception {
         try (Statement funcStmt = connection.createStatement()) {
             for (String createFunc : createFuncs) {
                 logger.info("hive create function sql: {}", createFunc);
@@ -324,14 +345,14 @@ public class SqlTask extends AbstractTask {
             }
         }
     }
-    
+
     /**
      * create connection
      *
      * @return connection
      * @throws Exception Exception
      */
-    private Connection createConnection() throws Exception{
+    private Connection createConnection() throws Exception {
         // if hive , load connection params if exists
         Connection connection = null;
         if (HIVE == DbType.valueOf(sqlParameters.getType())) {
@@ -345,7 +366,7 @@ public class SqlTask extends AbstractTask {
 
             connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
                     paramProp);
-        }else{
+        } else {
             connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
                     baseDataSource.getUser(),
                     baseDataSource.getPassword());
@@ -354,7 +375,7 @@ public class SqlTask extends AbstractTask {
     }
 
     /**
-     *  close jdbc resource
+     * close jdbc resource
      *
      * @param resultSet resultSet
      * @param pstmt pstmt
@@ -362,36 +383,37 @@ public class SqlTask extends AbstractTask {
      */
     private void close(ResultSet resultSet,
                        PreparedStatement pstmt,
-                       Connection connection){
-        if (resultSet != null){
+                       Connection connection) {
+        if (resultSet != null) {
             try {
                 resultSet.close();
             } catch (SQLException e) {
-                logger.error("close result set error : {}",e.getMessage(),e);
+                logger.error("close result set error : {}", e.getMessage(), e);
             }
         }
 
-        if (pstmt != null){
+        if (pstmt != null) {
             try {
                 pstmt.close();
             } catch (SQLException e) {
-                logger.error("close prepared statement error : {}",e.getMessage(),e);
+                logger.error("close prepared statement error : {}", e.getMessage(), e);
             }
         }
 
-        if (connection != null){
+        if (connection != null) {
             try {
                 connection.close();
             } catch (SQLException e) {
-                logger.error("close connection error : {}",e.getMessage(),e);
+                logger.error("close connection error : {}", e.getMessage(), e);
             }
         }
     }
 
     /**
      * preparedStatement bind
+     *
      * @param connection connection
-     * @param sqlBinds  sqlBinds
+     * @param sqlBinds sqlBinds
      * @return PreparedStatement
      * @throws Exception Exception
      */
@@ -400,11 +422,11 @@ public class SqlTask extends AbstractTask {
         boolean timeoutFlag = TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED ||
                 TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED;
         PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql());
-        if(timeoutFlag){
+        if (timeoutFlag) {
             stmt.setQueryTimeout(taskExecutionContext.getTaskTimeout());
         }
         Map<Integer, Property> params = sqlBinds.getParamsMap();
-        if(params != null) {
+        if (params != null) {
             for (Map.Entry<Integer, Property> entry : params.entrySet()) {
                 Property prop = entry.getValue();
                 ParameterUtils.setInParameter(entry.getKey(), stmt, prop.getType(), prop.getValue());
@@ -416,23 +438,24 @@ public class SqlTask extends AbstractTask {
 
     /**
      * send mail as an attachment
-     * @param title     title
-     * @param content   content
+     *
+     * @param title title
+     * @param content content
      */
-    public void sendAttachment(String title,String content){
+    public void sendAttachment(String title, String content) {
 
         List<User> users = alertDao.queryUserByAlertGroupId(taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId());
 
         // receiving group list
         List<String> receiversList = new ArrayList<>();
-        for(User user:users){
+        for (User user : users) {
             receiversList.add(user.getEmail().trim());
         }
         // custom receiver
         String receivers = sqlParameters.getReceivers();
-        if (StringUtils.isNotEmpty(receivers)){
+        if (StringUtils.isNotEmpty(receivers)) {
             String[] splits = receivers.split(COMMA);
-            for (String receiver : splits){
+            for (String receiver : splits) {
                 receiversList.add(receiver.trim());
             }
         }
@@ -441,60 +464,62 @@ public class SqlTask extends AbstractTask {
         List<String> receiversCcList = new ArrayList<>();
         // Custom Copier
         String receiversCc = sqlParameters.getReceiversCc();
-        if (StringUtils.isNotEmpty(receiversCc)){
+        if (StringUtils.isNotEmpty(receiversCc)) {
             String[] splits = receiversCc.split(COMMA);
-            for (String receiverCc : splits){
+            for (String receiverCc : splits) {
                 receiversCcList.add(receiverCc.trim());
             }
         }
 
-        String showTypeName = sqlParameters.getShowType().replace(COMMA,"").trim();
-        if(EnumUtils.isValidEnum(ShowType.class,showTypeName)){
+        String showTypeName = sqlParameters.getShowType().replace(COMMA, "").trim();
+        if (EnumUtils.isValidEnum(ShowType.class, showTypeName)) {
             Map<String, Object> mailResult = MailUtils.sendMails(receiversList,
                     receiversCcList, title, content, ShowType.valueOf(showTypeName).getDescp());
-            if(!(boolean) mailResult.get(STATUS)){
+            if (!(boolean) mailResult.get(STATUS)) {
                 throw new RuntimeException("send mail failed!");
             }
-        }else{
-            logger.error("showType: {} is not valid "  ,showTypeName);
-            throw new RuntimeException(String.format("showType: %s is not valid ",showTypeName));
+        } else {
+            logger.error("showType: {} is not valid ", showTypeName);
+            throw new RuntimeException(String.format("showType: %s is not valid ", showTypeName));
         }
     }
 
     /**
      * regular expressions match the contents between two specified strings
-     * @param content           content
-     * @param rgex              rgex
-     * @param sqlParamsMap      sql params map
-     * @param paramsPropsMap    params props map
+     *
+     * @param content content
+     * @param rgex rgex
+     * @param sqlParamsMap sql params map
+     * @param paramsPropsMap params props map
      */
-    public void setSqlParamsMap(String content, String rgex, Map<Integer,Property> sqlParamsMap, Map<String,Property> paramsPropsMap){
+    public void setSqlParamsMap(String content, String rgex, Map<Integer, Property> sqlParamsMap, Map<String, Property> paramsPropsMap) {
         Pattern pattern = Pattern.compile(rgex);
         Matcher m = pattern.matcher(content);
         int index = 1;
         while (m.find()) {
 
             String paramName = m.group(1);
-            Property prop =  paramsPropsMap.get(paramName);
+            Property prop = paramsPropsMap.get(paramName);
 
-            sqlParamsMap.put(index,prop);
-            index ++;
+            sqlParamsMap.put(index, prop);
+            index++;
         }
     }
 
     /**
      * print replace sql
-     * @param content       content
-     * @param formatSql     format sql
-     * @param rgex          rgex
-     * @param sqlParamsMap  sql params map
+     *
+     * @param content content
+     * @param formatSql format sql
+     * @param rgex rgex
+     * @param sqlParamsMap sql params map
      */
-    public void printReplacedSql(String content, String formatSql,String rgex, Map<Integer,Property> sqlParamsMap){
+    public void printReplacedSql(String content, String formatSql, String rgex, Map<Integer, Property> sqlParamsMap) {
         //parameter print style
-        logger.info("after replace sql , preparing : {}" , formatSql);
+        logger.info("after replace sql , preparing : {}", formatSql);
         StringBuilder logPrint = new StringBuilder("replaced sql , parameters:");
-        for(int i=1;i<=sqlParamsMap.size();i++){
-            logPrint.append(sqlParamsMap.get(i).getValue()+"("+sqlParamsMap.get(i).getType()+")");
+        for (int i = 1; i <= sqlParamsMap.size(); i++) {
+            logPrint.append(sqlParamsMap.get(i).getValue() + "(" + sqlParamsMap.get(i).getType() + ")");
         }
         logger.info("Sql Params are {}", logPrint);
     }