diff --git a/Cargo.toml b/Cargo.toml index 837d1dd..eb1ad2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,7 @@ tokio = { version = "1.19", features = ["rt-multi-thread", "net", "macros", "tim rustls-pki-types = { version = "1.10" } rusqlite = { version = "0.36.0", features = ["column_decltype"] } ## for duckdb example -duckdb = { version = "1.0.0" } +duckdb = { version = "1"} ## for loading custom cert files rustls-pemfile = "2.0" diff --git a/examples/duckdb.rs b/examples/duckdb.rs index 514c88b..b8a4626 100644 --- a/examples/duckdb.rs +++ b/examples/duckdb.rs @@ -51,7 +51,7 @@ impl SimpleQueryHandler for DuckDBBackend { C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); - if query.to_uppercase().starts_with("SELECT") { + if is_result_query(query) { let mut stmt = conn .prepare(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -74,6 +74,15 @@ impl SimpleQueryHandler for DuckDBBackend { } } +fn is_result_query(query: &str) -> bool { + let query_upper = query.trim().to_uppercase(); + query_upper.starts_with("SELECT") + || query_upper.starts_with("WITH") + || query_upper.starts_with("EXPLAIN") + || query_upper.starts_with("DESCRIBE") + || query_upper.starts_with("FROM") +} + fn into_pg_type(df_type: &DataType) -> PgWireResult { Ok(match df_type { DataType::Null => Type::UNKNOWN, @@ -261,7 +270,7 @@ impl ExtendedQueryHandler for DuckDBBackend { .map(|f| f.as_ref()) .collect::>(); - if query.to_uppercase().starts_with("SELECT") { + if is_result_query(query) { let rows: Rows<'_> = stmt .query::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -288,9 +297,21 @@ impl ExtendedQueryHandler for DuckDBBackend { { let conn = self.conn.lock().unwrap(); let param_types = stmt.parameter_types.clone(); - let stmt = conn - .prepare_cached(&stmt.statement) + let stmt_sql = &stmt.statement; + let mut stmt = conn + .prepare_cached(stmt_sql) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + if is_result_query(stmt_sql) { + let _ = stmt + .query([]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + } else { + let _ = stmt + .execute([]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + } + row_desc_from_stmt(&stmt, &Format::UnifiedBinary) .map(|fields| DescribeStatementResponse::new(param_types, fields)) } @@ -304,9 +325,26 @@ impl ExtendedQueryHandler for DuckDBBackend { C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); - let stmt = conn + let mut stmt = conn .prepare_cached(&portal.statement.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let params = get_params(portal); + let params_ref = params + .iter() + .map(|f| f.as_ref()) + .collect::>(); + + if is_result_query(&portal.statement.statement) { + let _ = stmt + .query::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + } else { + let _ = stmt + .execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + } + row_desc_from_stmt(&stmt, &portal.result_column_format) .map(|fields| DescribePortalResponse::new(fields)) }