Forráskód Böngészése

[python] Add task sql (#6968)

* [python] Add task sql

* Add java gateway function doc
Jiajie Zhong 3 éve
szülő
commit
41e8836c91

+ 1 - 0
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py

@@ -70,6 +70,7 @@ class TaskType(str):
     SHELL = "SHELL"
     HTTP = "HTTP"
     PYTHON = "PYTHON"
+    SQL = "SQL"
 
 
 class DefaultTaskCodeNum(str):

+ 128 - 0
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/sql.py

@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Task sql."""
+
+import re
+from typing import Dict, Optional
+
+from pydolphinscheduler.constants import TaskType
+from pydolphinscheduler.core.task import Task, TaskParams
+from pydolphinscheduler.java_gateway import launch_gateway
+
+
+class SqlType:
+    """SQL type, for now it just contain `SELECT` and `NO_SELECT`."""
+
+    SELECT = 0
+    NOT_SELECT = 1
+
+
+class SqlTaskParams(TaskParams):
+    """Parameter only for Sql task type."""
+
+    def __init__(
+        self,
+        type: str,
+        datasource: str,
+        sql: str,
+        sql_type: Optional[int] = SqlType.NOT_SELECT,
+        display_rows: Optional[int] = 10,
+        pre_statements: Optional[str] = None,
+        post_statements: Optional[str] = None,
+        *args,
+        **kwargs
+    ):
+        super().__init__(*args, **kwargs)
+        self.type = type
+        self.datasource = datasource
+        self.sql = sql
+        self.sql_type = sql_type
+        self.display_rows = display_rows
+        self.pre_statements = pre_statements or []
+        self.post_statements = post_statements or []
+
+
+class Sql(Task):
+    """Task SQL object, declare behavior for SQL task to dolphinscheduler.
+
+    It should run sql job in multiply sql lik engine, such as:
+    - ClickHouse
+    - DB2
+    - HIVE
+    - MySQL
+    - Oracle
+    - Postgresql
+    - Presto
+    - SQLServer
+    You provider datasource_name contain connection information, it decisions which
+    database type and database instance would run this sql.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        datasource_name: str,
+        sql: str,
+        pre_sql: Optional[str] = None,
+        post_sql: Optional[str] = None,
+        display_rows: Optional[int] = 10,
+        *args,
+        **kwargs
+    ):
+        self._sql = sql
+        self._datasource_name = datasource_name
+        self._datasource = {}
+        task_params = SqlTaskParams(
+            type=self.get_datasource_type(),
+            datasource=self.get_datasource_id(),
+            sql=sql,
+            sql_type=self.sql_type,
+            display_rows=display_rows,
+            pre_statements=pre_sql,
+            post_statements=post_sql,
+        )
+        super().__init__(name, TaskType.SQL, task_params, *args, **kwargs)
+
+    def get_datasource_type(self) -> str:
+        """Get datasource type from java gateway, a wrapper for :func:`get_datasource_info`."""
+        return self.get_datasource_info(self._datasource_name).get("type")
+
+    def get_datasource_id(self) -> str:
+        """Get datasource id from java gateway, a wrapper for :func:`get_datasource_info`."""
+        return self.get_datasource_info(self._datasource_name).get("id")
+
+    def get_datasource_info(self, name) -> Dict:
+        """Get datasource info from java gateway, contains datasource id, type, name."""
+        if self._datasource:
+            return self._datasource
+        else:
+            gateway = launch_gateway()
+            self._datasource = gateway.entry_point.getDatasourceInfo(name)
+            return self._datasource
+
+    @property
+    def sql_type(self) -> int:
+        """Judgement sql type, use regexp to check which type of the sql is."""
+        pattern_select_str = (
+            "^(?!(.* |)insert |(.* |)delete |(.* |)drop |(.* |)update |(.* |)alter ).*"
+        )
+        pattern_select = re.compile(pattern_select_str, re.IGNORECASE)
+        if pattern_select.match(self._sql) is None:
+            return SqlType.NOT_SELECT
+        else:
+            return SqlType.SELECT

+ 131 - 0
dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_sql.py

@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test Task Sql."""
+
+
+from unittest.mock import patch
+
+import pytest
+
+from pydolphinscheduler.tasks.sql import Sql, SqlType
+
+
+@patch(
+    "pydolphinscheduler.core.task.Task.gen_code_and_version",
+    return_value=(123, 1),
+)
+@patch(
+    "pydolphinscheduler.tasks.sql.Sql.get_datasource_info",
+    return_value=({"id": 1, "type": "mock_type"}),
+)
+def test_get_datasource_detail(mock_datasource, mock_code_version):
+    """Test :func:`get_datasource_type` and :func:`get_datasource_id` can return expect value."""
+    name = "test_get_sql_type"
+    datasource_name = "test_datasource"
+    sql = "select 1"
+    task = Sql(name, datasource_name, sql)
+    assert 1 == task.get_datasource_id()
+    assert "mock_type" == task.get_datasource_type()
+
+
+@pytest.mark.parametrize(
+    "sql, sql_type",
+    [
+        ("select 1", SqlType.SELECT),
+        (" select 1", SqlType.SELECT),
+        (" select 1 ", SqlType.SELECT),
+        (" select 'insert' ", SqlType.SELECT),
+        (" select 'insert ' ", SqlType.SELECT),
+        ("with tmp as (select 1) select * from tmp ", SqlType.SELECT),
+        ("insert into table_name(col1, col2) value (val1, val2)", SqlType.NOT_SELECT),
+        (
+            "insert into table_name(select, col2) value ('select', val2)",
+            SqlType.NOT_SELECT,
+        ),
+        ("update table_name SET col1=val1 where col1=val2", SqlType.NOT_SELECT),
+        ("update table_name SET col1='select' where col1=val2", SqlType.NOT_SELECT),
+        ("delete from table_name where id < 10", SqlType.NOT_SELECT),
+        ("delete from table_name where id < 10", SqlType.NOT_SELECT),
+        ("alter table table_name add column col1 int", SqlType.NOT_SELECT),
+    ],
+)
+@patch(
+    "pydolphinscheduler.core.task.Task.gen_code_and_version",
+    return_value=(123, 1),
+)
+@patch(
+    "pydolphinscheduler.tasks.sql.Sql.get_datasource_info",
+    return_value=({"id": 1, "type": "mock_type"}),
+)
+def test_get_sql_type(mock_datasource, mock_code_version, sql, sql_type):
+    """Test property sql_type could return correct type."""
+    name = "test_get_sql_type"
+    datasource_name = "test_datasource"
+    task = Sql(name, datasource_name, sql)
+    assert (
+        sql_type == task.sql_type
+    ), f"Sql {sql} expect sql type is {sql_type} but got {task.sql_type}"
+
+
+@patch(
+    "pydolphinscheduler.tasks.sql.Sql.get_datasource_info",
+    return_value=({"id": 1, "type": "MYSQL"}),
+)
+def test_sql_to_dict(mock_datasource):
+    """Test task sql function to_dict."""
+    code = 123
+    version = 1
+    name = "test_sql_dict"
+    command = "select 1"
+    datasource_name = "test_datasource"
+    expect = {
+        "code": code,
+        "name": name,
+        "version": 1,
+        "description": None,
+        "delayTime": 0,
+        "taskType": "SQL",
+        "taskParams": {
+            "type": "MYSQL",
+            "datasource": 1,
+            "sql": command,
+            "sqlType": SqlType.SELECT,
+            "displayRows": 10,
+            "preStatements": [],
+            "postStatements": [],
+            "localParams": [],
+            "resourceList": [],
+            "dependence": {},
+            "conditionResult": {"successNode": [""], "failedNode": [""]},
+            "waitStartTimeout": {},
+        },
+        "flag": "YES",
+        "taskPriority": "MEDIUM",
+        "workerGroup": "default",
+        "failRetryTimes": 0,
+        "failRetryInterval": 1,
+        "timeoutFlag": "CLOSE",
+        "timeoutNotifyStrategy": None,
+        "timeout": 0,
+    }
+    with patch(
+        "pydolphinscheduler.core.task.Task.gen_code_and_version",
+        return_value=(code, version),
+    ):
+        task = Sql(name, datasource_name, command)
+        assert task.to_dict() == expect

+ 31 - 0
dolphinscheduler-python/src/main/java/org/apache/dolphinscheduler/server/PythonGatewayServer.java

@@ -37,6 +37,7 @@ import org.apache.dolphinscheduler.common.enums.TaskDependType;
 import org.apache.dolphinscheduler.common.enums.UserType;
 import org.apache.dolphinscheduler.common.enums.WarningType;
 import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils;
+import org.apache.dolphinscheduler.dao.entity.DataSource;
 import org.apache.dolphinscheduler.dao.entity.ProcessDefinition;
 import org.apache.dolphinscheduler.dao.entity.Project;
 import org.apache.dolphinscheduler.dao.entity.Queue;
@@ -44,6 +45,7 @@ import org.apache.dolphinscheduler.dao.entity.Schedule;
 import org.apache.dolphinscheduler.dao.entity.TaskDefinition;
 import org.apache.dolphinscheduler.dao.entity.Tenant;
 import org.apache.dolphinscheduler.dao.entity.User;
+import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;
 import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionMapper;
 import org.apache.dolphinscheduler.dao.mapper.ProjectMapper;
 import org.apache.dolphinscheduler.dao.mapper.ScheduleMapper;
@@ -124,6 +126,9 @@ public class PythonGatewayServer extends SpringBootServletInitializer {
     @Autowired
     private ScheduleMapper scheduleMapper;
 
+    @Autowired
+    private DataSourceMapper dataSourceMapper;
+
     // TODO replace this user to build in admin user if we make sure build in one could not be change
     private final User dummyAdminUser = new User() {
         {
@@ -360,6 +365,32 @@ public class PythonGatewayServer extends SpringBootServletInitializer {
         }
     }
 
+    /**
+     * Get datasource by given datasource name. It return map contain datasource id, type, name.
+     * Useful in Python API create sql task which need datasource information.
+     *
+     * @param datasourceName   user who create or update schedule
+     */
+    public Map<String, Object> getDatasourceInfo(String datasourceName) {
+        Map<String, Object> result = new HashMap<>();
+        List<DataSource> dataSourceList = dataSourceMapper.queryDataSourceByName(datasourceName);
+        if (dataSourceList.size() > 1) {
+            String msg = String.format("Get more than one datasource by name %s", datasourceName);
+            logger.error(msg);
+            throw new IllegalArgumentException(msg);
+        } else if (dataSourceList.size() == 0) {
+            String msg = String.format("Can not find any datasource by name %s", datasourceName);
+            logger.error(msg);
+            throw new IllegalArgumentException(msg);
+        } else {
+            DataSource dataSource = dataSourceList.get(0);
+            result.put("id", dataSource.getId());
+            result.put("type", dataSource.getType().name());
+            result.put("name", dataSource.getName());
+        }
+        return result;
+    }
+
     @PostConstruct
     public void run() {
         GatewayServer server = new GatewayServer(this);