diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala index 8de227f9d07c..8b60f309ef6d 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connect.client.jdbc import java.sql.{Array => _, _} +import org.apache.spark.sql.connect.client.SparkResult + class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { private var operationId: String = _ @@ -49,33 +51,51 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { } override def executeQuery(sql: String): ResultSet = { - checkOpen() - - val df = conn.spark.sql(sql) - val sparkResult = df.collectResult() - operationId = sparkResult.operationId - resultSet = new SparkConnectResultSet(sparkResult, this) - resultSet + val hasResultSet = execute(sql) + if (hasResultSet) { + assert(resultSet != null) + resultSet + } else { + throw new SQLException("The query does not produce a ResultSet.") + } } override def executeUpdate(sql: String): Int = { - checkOpen() - - val df = conn.spark.sql(sql) - val sparkResult = df.collectResult() - operationId = sparkResult.operationId - resultSet = null + val hasResultSet = execute(sql) + if (hasResultSet) { + // user are not expected to access the result set in this case, + // we must close it to avoid memory leak. + resultSet.close() + throw new SQLException("The query produces a ResultSet.") + } else { + assert(resultSet == null) + getUpdateCount + } + } - // always return 0 because affected rows is not supported yet - 0 + private def hasResultSet(sparkResult: SparkResult[_]): Boolean = { + // suppose this works in most cases + sparkResult.schema.length > 0 } override def execute(sql: String): Boolean = { checkOpen() - // always perform executeQuery and reture a ResultSet - executeQuery(sql) - true + // stmt can be reused to execute more than one queries, + // reset before executing new query + operationId = null + resultSet = null + + val df = conn.spark.sql(sql) + val sparkResult = df.collectResult() + operationId = sparkResult.operationId + if (hasResultSet(sparkResult)) { + resultSet = new SparkConnectResultSet(sparkResult, this) + true + } else { + sparkResult.close() + false + } } override def getResultSet: ResultSet = { @@ -123,8 +143,15 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { override def setCursorName(name: String): Unit = throw new SQLFeatureNotSupportedException - override def getUpdateCount: Int = - throw new SQLFeatureNotSupportedException + override def getUpdateCount: Int = { + checkOpen() + + if (resultSet != null) { + -1 + } else { + 0 // always return 0 because affected rows is not supported yet + } + } override def getMoreResults: Boolean = throw new SQLFeatureNotSupportedException diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala new file mode 100644 index 000000000000..8e3b616372d8 --- /dev/null +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.{Array => _, _} + +import scala.util.Using + +import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} + +class SparkConnectStatementSuite extends ConnectFunSuite with RemoteSparkSession + with JdbcHelper with SQLHelper { + + override def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort" + + test("returned result set and update count of execute* methods") { + withTable("t1", "t2", "t3") { + withStatement { stmt => + // CREATE TABLE + assert(!stmt.execute("CREATE TABLE t1 (id INT) USING Parquet")) + assert(stmt.getUpdateCount === 0) + assert(stmt.getResultSet === null) + + var se = intercept[SQLException] { + stmt.executeQuery("CREATE TABLE t2 (id INT) USING Parquet") + } + assert(se.getMessage === "The query does not produce a ResultSet.") + + assert(stmt.executeUpdate("CREATE TABLE t3 (id INT) USING Parquet") === 0) + assert(stmt.getResultSet === null) + + // INSERT INTO + assert(!stmt.execute("INSERT INTO t1 VALUES (1)")) + assert(stmt.getUpdateCount === 0) + assert(stmt.getResultSet === null) + + se = intercept[SQLException] { + stmt.executeQuery("INSERT INTO t1 VALUES (1)") + } + assert(se.getMessage === "The query does not produce a ResultSet.") + + assert(stmt.executeUpdate("INSERT INTO t1 VALUES (1)") === 0) + assert(stmt.getResultSet === null) + + // SELECT + assert(stmt.execute("SELECT id FROM t1")) + assert(stmt.getUpdateCount === -1) + Using.resource(stmt.getResultSet) { rs => + assert(rs !== null) + } + + Using.resource(stmt.executeQuery("SELECT id FROM t1")) { rs => + assert(stmt.getUpdateCount === -1) + assert(rs !== null) + } + + se = intercept[SQLException] { + stmt.executeUpdate("SELECT id FROM t1") + } + assert(se.getMessage === "The query produces a ResultSet.") + } + } + } +}