diff --git a/.gitignore b/.gitignore index c382db2..a11a7cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .dart_tool target turso +*.g.dart \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8ee338a..9245f12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,6 +72,15 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -115,6 +124,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -124,6 +142,32 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -228,6 +272,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -248,6 +303,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -256,6 +312,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -604,12 +670,14 @@ checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" name = "libsqlite3_turso" version = "0.1.0" dependencies = [ + "futures-util", "num_cpus", "regex", "reqwest", "serde", "serde_json", "tokio", + "tokio-tungstenite", ] [[package]] @@ -781,6 +849,15 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -805,6 +882,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.3", +] + [[package]] name = "regex" version = "1.11.1" @@ -1032,6 +1138,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1137,6 +1254,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "thiserror" +version = "2.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -1184,6 +1321,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "489a59b6730eda1b0171fcfda8b121f4bee2b35cba8645ca35c5f7ba3eb736c1" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.16" @@ -1267,6 +1418,30 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadc29d668c91fcc564941132e17b28a7ceb2f3ebf0b9dae3e03fd7a6748eb0d" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror", + "utf-8", +] + +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -1290,6 +1465,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -1302,6 +1483,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "want" version = "0.3.1" @@ -1637,6 +1824,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 2623fe4..28279fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,5 @@ serde_json = "1.0.134" tokio = { version = "1.42.0", features = ["rt-multi-thread"] } reqwest = { version = "0.12.9", features = ["json", "blocking", "gzip"] } num_cpus = "1.17.0" +tokio-tungstenite = { version = "0.27.0", features = ["native-tls"] } +futures-util = "0.3.31" diff --git a/README.md b/README.md new file mode 100644 index 0000000..646c12f --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# libsqlite-turso + +A rewrite for `libsqlite3.so` dynamic library that allows any SQLite client to seamlessly connect to [Turso](https://turso.tech/) or LibSQL databases — with zero client-side changes. + +This project provides drop-in `libsqlite3.so` support with automatic strategy selection depending on runtime context. + +## ✨ Features + +- ✅ Works with **any SQLite client** that uses `libsqlite3.so` +- 🔁 Supports both `Http` & `Websocket` protocol for LibSQL +- 🔌 No custom SQLite client logic or [Hrana](https://github.com/tursodatabase/libsql/blob/main/docs/HRANA_3_SPEC.md) knowledge required + +--- + +## 🔧 Setup + +### 1. Build the custom `libsqlite3.so` + +```bash +cargo build --release +``` + +### 2. Place `libsqlite3.so` in your system + +This project assumes `libsqlite3.so` is available at runtime. + +Place it in a standard library path (e.g., `/usr/lib`, or use `/usr/local/lib/`). + +--- + +## 🚀 Usage + +Use **any standard SQLite library** in your language/runtime — this project handles the dynamic strategy and connection logic under the hood. diff --git a/src/auth.rs b/src/auth.rs index 0093b28..18696c3 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,6 +1,6 @@ use std::{future::Future, pin::Pin}; -use crate::utils::TursoConfig; +use crate::transport::TursoConfig; pub trait DbAuthStrategy { fn resolve<'a>( diff --git a/src/lib.rs b/src/lib.rs index eb307de..a355e85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ use sqlite::{ }; use crate::{ - auth::{DbAuthStrategy, GlobeStrategy}, + auth::{DbAuthStrategy, EnvVarStrategy, GlobeStrategy}, utils::{ count_parameters, execute_async_task, get_tokio, is_aligned, sql_is_begin_transaction, sql_is_commit, sql_is_pragma, sql_is_rollback, @@ -22,8 +22,8 @@ use crate::{ }; mod auth; -mod proxy; mod sqlite; +mod transport; mod utils; #[no_mangle] @@ -59,25 +59,30 @@ pub unsafe extern "C" fn sqlite3_open_v2( return SQLITE_ERROR; } - let filename = CStr::from_ptr(filename).to_str().unwrap(); - if filename.contains(":memory") { + let db_name = CStr::from_ptr(filename).to_str().unwrap(); + if db_name.contains(":memory") { return SQLITE_CANTOPEN; } - let reqwest_client = reqwest::Client::builder() - .user_agent("libsqlite3_turso/1.0.0") - .timeout(std::time::Duration::from_secs(30)) - .build() - .unwrap(); - - let auth_strategy = Box::new(GlobeStrategy); - let turso_config = get_tokio().block_on(auth_strategy.resolve(filename, &reqwest_client)); - if turso_config.is_err() { + // Check if running in Globe environment + let auth_strategy: Box = { + let is_globe_env = std::env::var("GLOBE") + .and_then(|v| Ok(v == "1")) + .unwrap_or(false); + if is_globe_env { + Box::new(GlobeStrategy) + } else { + Box::new(EnvVarStrategy) + } + }; + let connection = + get_tokio().block_on(transport::DatabaseConnection::open(db_name, auth_strategy)); + if connection.is_err() { return SQLITE_CANTOPEN; } let mock_db = Box::into_raw(Box::new(SQLite3 { - client: reqwest_client, + connection: connection.unwrap(), error_stack: Mutex::new(vec![]), transaction_baton: Mutex::new(None), last_insert_rowid: Mutex::new(None), @@ -85,7 +90,6 @@ pub unsafe extern "C" fn sqlite3_open_v2( delete_hook: Mutex::new(None), insert_hook: Mutex::new(None), update_hook: Mutex::new(None), - turso_config: turso_config.unwrap(), })); *db = mock_db; diff --git a/src/proxy.rs b/src/proxy.rs deleted file mode 100644 index 7583460..0000000 --- a/src/proxy.rs +++ /dev/null @@ -1,306 +0,0 @@ -use reqwest::Client; -use serde::Deserialize; -use std::{collections::HashMap, error::Error, time::Duration}; - -use crate::{ - sqlite::{SQLite3, SqliteError, Value, SQLITE_ERROR}, - utils::TursoConfig, -}; - -#[derive(Debug, Deserialize)] -pub struct RemoteSqliteResponse { - pub baton: Option, - pub results: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct RemoteSQliteResultType { - pub response: RemoteSQLiteResult, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum RemoteSQLiteResult { - Execute { result: QueryResult }, - Error { message: String, code: String }, - Close, -} - -#[derive(Debug, Deserialize, Clone)] -pub struct RemoteCol { - pub name: String, -} - -#[derive(Debug, Deserialize, Clone)] -pub struct RemoteRow { - pub r#type: String, - pub value: Option, -} - -#[derive(Debug, Deserialize, Clone)] -pub struct QueryResult { - pub cols: Vec, - pub rows: Vec>, - pub last_insert_rowid: Option, -} - -pub async fn execute_sql_and_params( - db: &SQLite3, - sql: &str, - params: Vec, - baton: Option<&String>, -) -> Result { - let mut query_request = serde_json::Map::new(); - - let mut json_array: Vec = Vec::new(); - - json_array.push(serde_json::json!({ - "type": "execute", - "stmt": { - "sql": sql, - "args": params - } - })); - - if db.has_began_transaction() { - query_request.insert("baton".to_string(), serde_json::json!(baton)); - } else { - json_array.push(serde_json::json!({ - "type": "close" - })); - } - - query_request.insert("requests".to_string(), json_array.into()); - - let result = send_sql_request( - &db.client, - &db.turso_config, - serde_json::Value::from(query_request), - ) - .await; - - if let Err(e) = result { - return Err(SqliteError::new(e.to_string(), Some(SQLITE_ERROR))); - } - - Ok(result.unwrap()) -} - -pub async fn get_transaction_baton( - client: &Client, - config: &TursoConfig, - sql: &str, -) -> Result { - let request = serde_json::json!({ - "requests": [ - { - "type": "execute", - "stmt": { - "sql": sql - } - } - ] - }); - - let result = send_sql_request(client, config, request).await; - if let Err(e) = result { - return Err(SqliteError::new( - format!("Failed to get transaction baton: {}", e), - Some(SQLITE_ERROR), - )); - } - let result = result.unwrap(); - let baton = result.baton.ok_or(SqliteError::new( - "Failed to get transaction baton", - Some(SQLITE_ERROR), - ))?; - - Ok(baton) -} - -async fn send_sql_request( - client: &Client, - config: &TursoConfig, - request: serde_json::Value, -) -> Result> { - if cfg!(debug_assertions) { - println!("Sending SQL Request: {}", request); - } - - let response: serde_json::Value = - send_remote_request(client, config, "v2/pipeline", request).await?; - - let parsed: RemoteSqliteResponse = serde_json::from_value(response)?; - Ok(parsed) -} - -pub async fn send_remote_request( - client: &Client, - turso_config: &TursoConfig, - path: &str, - request: serde_json::Value, -) -> Result> { - const MAX_ATTEMPTS: usize = 5; - let mut last_error = String::new(); - - for attempt in 1..=MAX_ATTEMPTS { - if cfg!(debug_assertions) { - println!( - "Attempt {}: Sending request to {}", - attempt, turso_config.db_url - ); - } - - let resp = client - .post(format!("https://{}/{}", turso_config.db_url, path)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", turso_config.db_token)) - .json(&request) - .send() - .await; - - let resp = match resp { - Ok(r) => r, - Err(e) => { - last_error = format!("Request failed: {}", e); - if attempt < MAX_ATTEMPTS { - tokio::time::sleep(Duration::from_millis(100)).await; - continue; - } else { - return Err(last_error.into()); - } - } - }; - - let status = resp.status(); - let text = match resp.text().await { - Ok(t) => t, - Err(e) => { - last_error = format!("Failed to read response body: {}", e); - if attempt < MAX_ATTEMPTS { - tokio::time::sleep(Duration::from_millis(100)).await; - continue; - } else { - return Err(last_error.into()); - } - } - }; - - if cfg!(debug_assertions) { - println!("Response received, status: {} : {}", status, text); - } - - if !status.is_success() { - if let Ok(err_json) = serde_json::from_str::(&text) { - if let Some(msg) = err_json.get("error").and_then(|v| v.as_str()) { - last_error = format!("API error: {}", msg); - } else { - last_error = format!("HTTP error {}: {}", status, text); - } - } else { - last_error = format!("HTTP error {} with invalid JSON: {}", status, text); - } - - if attempt < MAX_ATTEMPTS { - tokio::time::sleep(Duration::from_millis(100)).await; - continue; - } else { - return Err(last_error.into()); - } - } - - let parsed: serde_json::Value = match serde_json::from_str(&text) { - Ok(v) => v, - Err(e) => return Err(format!("Failed to parse JSON: {}", e).into()), - }; - - // Check for embedded DB errors - if let Some(results) = parsed.get("results").and_then(|r| r.as_array()) { - for result in results { - if let Some(msg) = result - .get("error") - .and_then(|e| e.get("message")) - .and_then(|m| m.as_str()) - { - return Err(msg.to_string().into()); - } - } - } - - return Ok(parsed); - } - - Err(format!( - "Failed to get successful response after {} attempts: {}", - MAX_ATTEMPTS, last_error - ) - .into()) -} - -pub fn convert_params_to_json(params: &HashMap) -> Vec { - let mut index_value_pairs: Vec<_> = params.iter().collect(); - // Sort by parameter index - index_value_pairs.sort_by_key(|&(k, _)| *k); - - // Map sorted values to JSON - index_value_pairs - .into_iter() - .map(|(_, value)| match value { - Value::Integer(i) => serde_json::json!({ - "type": "integer", - "value": *i.to_string() - }), - - Value::Real(f) => serde_json::json!({ - "type": "float", - "value": *f.to_string() - }), - Value::Text(s) => serde_json::json!({ - "type": "text", - "value": s - }), - Value::Null => serde_json::json!({ - "type": "null", - "value": null - }), - }) - .collect() -} - -pub fn get_execution_result<'a>( - db: &SQLite3, - result: &'a RemoteSqliteResponse, -) -> Result<&'a QueryResult, SqliteError> { - let mut baton = db.transaction_baton.lock().unwrap(); - - if let Some(new_baton) = &result.baton { - baton.replace(new_baton.into()); - } - - let first_execution_result = match result.results.get(0) { - Some(inner) => match &inner.response { - RemoteSQLiteResult::Error { message, code } => { - return Err(SqliteError::new( - format!("Remote SQLite error (code {}): {}", code, message), - Some(SQLITE_ERROR), - )); - } - RemoteSQLiteResult::Execute { result } => Ok(result), - RemoteSQLiteResult::Close => Err::<&'a QueryResult, SqliteError>(SqliteError::new( - "Remote SQLite closed the connection unexpectedly", - None, - )), - }, - None => Err::<&'a QueryResult, SqliteError>(SqliteError::new( - "No results returned from remote SQLite", - None, - )), - }?; - - if let Some(last_insert_rowid) = &first_execution_result.last_insert_rowid { - let mut last_insert_rowid_lock = db.last_insert_rowid.lock().unwrap(); - *last_insert_rowid_lock = Some(last_insert_rowid.parse::().unwrap_or(0)); - } - - Ok(first_execution_result) -} diff --git a/src/sqlite.rs b/src/sqlite.rs index 0d07fd1..fb4d0b4 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -7,15 +7,12 @@ use std::{ }; use crate::{ - proxy::{ - convert_params_to_json, execute_sql_and_params, get_execution_result, get_transaction_baton, - }, - utils::TursoConfig, + transport::{self, RemoteSqliteResponse}, + utils::{convert_params_to_json, get_execution_result}, }; pub const SQLITE_OK: c_int = 0; pub const SQLITE_ERROR: c_int = 1; -pub const SQLITE_INTERNAL: c_int = 2; pub const SQLITE_MISUSE: c_int = 21; pub const SQLITE_ROW: c_int = 100; pub const SQLITE_DONE: c_int = 101; @@ -81,15 +78,14 @@ impl fmt::Display for SqliteError { #[repr(C)] pub struct SQLite3 { - pub client: reqwest::Client, // HTTP client for making requests - pub last_insert_rowid: Mutex>, // Last inserted row ID - pub error_stack: Mutex>, // Stack to store error messages - pub transaction_baton: Mutex>, // Baton for transaction management - pub transaction_has_began: Mutex, // Flag to check if a transaction has started + pub connection: transport::DatabaseConnection, // Connection to the database + pub last_insert_rowid: Mutex>, // Last inserted row ID + pub error_stack: Mutex>, // Stack to store error messages + pub transaction_baton: Mutex>, // Baton for transaction management + pub transaction_has_began: Mutex, // Flag to check if a transaction has started pub update_hook: Mutex>, // Update hook callback pub insert_hook: Mutex>, // Insert hook callback pub delete_hook: Mutex>, // Delete hook callback - pub turso_config: TursoConfig, // Configuration for Turso } impl SQLite3 { @@ -282,7 +278,7 @@ pub async fn begin_tnx_on_db(db: *mut SQLite3, sql: &str) -> Result Result Result Result { - let db: &SQLite3 = unsafe { &*stmt.db }; - let baton_str = { - let baton = db.transaction_baton.lock().unwrap(); - baton.as_ref().map(|s| s.as_str()).map(|s| s.to_owned()) - }; + let db: &mut SQLite3 = unsafe { &mut *stmt.db }; let params = convert_params_to_json(&stmt.params); - let response = execute_sql_and_params(db, &stmt.sql, params, baton_str.as_ref()).await?; + let response = execute_sql_and_params(db, &stmt.sql, params).await?; let response = get_execution_result(db, &response)?; stmt.column_names = response.cols.iter().map(|col| col.name.clone()).collect(); @@ -359,3 +349,31 @@ pub async fn execute_stmt(stmt: &mut SQLite3PreparedStmt) -> Result, +) -> Result { + if let transport::ActiveStrategy::Websocket = db.connection.strategy { + let mut request = db.connection.get_json_request(db, sql, ¶ms); + match db.connection.send(&mut request).await { + Ok(response) => return Ok(response), + Err(err) => { + db.connection.strategy = transport::ActiveStrategy::Http; + if cfg!(debug_assertions) { + eprintln!("WebSocket failed, retrying with HTTP... {}", err); + } + } + } + } + + let request = &mut db.connection.get_json_request(db, sql, ¶ms); + let result = db.connection.send(request).await; + + if let Err(e) = result { + return Err(SqliteError::new(e.to_string(), Some(SQLITE_ERROR))); + } + + Ok(result.unwrap()) +} diff --git a/src/transport/http.rs b/src/transport/http.rs new file mode 100644 index 0000000..0890730 --- /dev/null +++ b/src/transport/http.rs @@ -0,0 +1,194 @@ +use std::{sync::Arc, time::Duration}; + +use crate::{ + sqlite::{SqliteError, SQLITE_ERROR}, + transport::{LibsqlInterface, RemoteSqliteResponse, TursoConfig}, +}; + +pub struct HttpStrategy { + client: reqwest::Client, + turso_config: Arc, +} + +impl HttpStrategy { + pub fn new(client: reqwest::Client, turso_config: Arc) -> Self { + Self { + client, + turso_config, + } + } +} + +impl LibsqlInterface for HttpStrategy { + async fn get_transaction_baton(&mut self, sql: &str) -> Result { + let mut request = serde_json::json!({ + "requests": [ + { + "type": "execute", + "stmt": { + "sql": sql + } + } + ] + }); + + let result = self.send(&mut request).await; + if let Err(e) = result { + return Err(SqliteError::new( + format!("Failed to get transaction baton: {}", e), + Some(SQLITE_ERROR), + )); + } + let result = result.unwrap(); + let baton = result.baton.ok_or(SqliteError::new( + "Failed to get transaction baton", + Some(SQLITE_ERROR), + ))?; + + Ok(baton) + } + + async fn send( + &mut self, + request: &mut serde_json::Value, + ) -> Result { + const MAX_ATTEMPTS: usize = 5; + let mut last_error = String::new(); + + for attempt in 1..=MAX_ATTEMPTS { + if cfg!(debug_assertions) { + println!( + "Attempt {}: Sending request to {}", + attempt, self.turso_config.db_url + ); + } + + let resp = self + .client + .post(format!("https://{}/v2/pipeline", self.turso_config.db_url)) + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("Bearer {}", self.turso_config.db_token), + ) + .json(&request) + .send() + .await; + + let resp = match resp { + Ok(r) => r, + Err(e) => { + last_error = format!("Request failed: {}", e); + if attempt < MAX_ATTEMPTS { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } else { + return Err(SqliteError::new(last_error, Some(SQLITE_ERROR))); + } + } + }; + + let status = resp.status(); + let text = match resp.text().await { + Ok(t) => t, + Err(e) => { + last_error = format!("Failed to read response body: {}", e); + if attempt < MAX_ATTEMPTS { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } else { + return Err(SqliteError::new(last_error, Some(SQLITE_ERROR))); + } + } + }; + + if cfg!(debug_assertions) { + println!("Response received, status: {} : {}", status, text); + } + + if !status.is_success() { + if let Ok(err_json) = serde_json::from_str::(&text) { + if let Some(msg) = err_json.get("error").and_then(|v| v.as_str()) { + last_error = format!("API error: {}", msg); + } else { + last_error = format!("HTTP error {}: {}", status, text); + } + } else { + last_error = format!("HTTP error {} with invalid JSON: {}", status, text); + } + + if attempt < MAX_ATTEMPTS { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } else { + return Err(SqliteError::new(last_error, Some(SQLITE_ERROR))); + } + } + + let parsed: serde_json::Value = match serde_json::from_str(&text) { + Ok(v) => v, + Err(e) => { + return Err(SqliteError::new( + format!("Failed to parse JSON: {}", e), + Some(SQLITE_ERROR), + )) + } + }; + + // Check for embedded DB errors + if let Some(results) = parsed.get("results").and_then(|r| r.as_array()) { + for result in results { + if let Some(msg) = result + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + { + return Err(SqliteError::new(msg.to_string(), Some(SQLITE_ERROR))); + } + } + } + + let parsed: RemoteSqliteResponse = serde_json::from_value(parsed).map_err(|e| { + SqliteError::new( + format!("Failed to parse response: {}", e), + Some(SQLITE_ERROR), + ) + })?; + return Ok(parsed); + } + + Err(SqliteError::new(last_error, Some(SQLITE_ERROR))) + } + + fn get_json_request( + &self, + sql: &str, + params: &Vec, + baton: Option<&String>, + is_transacting: bool, + ) -> serde_json::Value { + let mut query_request = serde_json::Map::new(); + + let mut json_array: Vec = Vec::new(); + + json_array.push(serde_json::json!({ + "type": "execute", + "stmt": { + "sql": sql, + "args": params + } + })); + + if is_transacting { + query_request.insert("baton".to_string(), serde_json::json!(baton)); + } else { + json_array.push(serde_json::json!({ + "type": "close" + })); + } + + query_request.insert("requests".to_string(), json_array.into()); + + serde_json::Value::from(query_request) + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..42a3a3d --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,160 @@ +use std::sync::Arc; + +use serde::Deserialize; + +use crate::{ + auth::DbAuthStrategy, + sqlite::{SQLite3, SqliteError, SQLITE_CANTOPEN}, + transport::{http::HttpStrategy, wss::WebSocketStrategy}, +}; + +mod http; +mod wss; + +#[derive(Debug, Deserialize, Clone)] +pub struct TursoConfig { + pub db_url: String, + pub db_token: String, +} + +#[derive(Debug, Deserialize)] +pub struct RemoteSqliteResponse { + pub baton: Option, + pub results: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct RemoteSQliteResultType { + pub response: RemoteSQLiteResult, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum RemoteSQLiteResult { + Execute { result: QueryResult }, + Error { message: String, code: String }, + Close, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct RemoteCol { + pub name: String, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct RemoteRow { + pub r#type: String, + pub value: Option, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct QueryResult { + pub cols: Vec, + pub rows: Vec>, + pub last_insert_rowid: Option, +} + +pub trait LibsqlInterface { + fn get_json_request( + &self, + sql: &str, + params: &Vec, + baton: Option<&String>, + is_transacting: bool, + ) -> serde_json::Value; + + async fn get_transaction_baton(&mut self, sql: &str) -> Result; + + async fn send( + &mut self, + request: &mut serde_json::Value, + ) -> Result; +} + +#[derive(PartialEq)] +pub enum ActiveStrategy { + Http, + Websocket, +} + +pub struct DatabaseConnection { + pub http: HttpStrategy, + pub websocket: WebSocketStrategy, + pub strategy: ActiveStrategy, +} + +impl DatabaseConnection { + pub async fn open(db_name: &str, auth: Box) -> Result { + let reqwest_client = reqwest::Client::builder() + .user_agent("libsqlite3_turso/1.0.0") + .timeout(std::time::Duration::from_secs(30)) + .build() + .unwrap(); + + let turso_config = auth.resolve(db_name, &reqwest_client).await; + if turso_config.is_err() { + let error = turso_config.unwrap_err(); + return Err(SqliteError::new(error.to_string(), Some(SQLITE_CANTOPEN))); + } + + let turso_config = Arc::new(turso_config.unwrap()); + + let http = HttpStrategy::new(reqwest_client, turso_config.clone()); + let mut websocket = WebSocketStrategy::new(turso_config.clone()); + + websocket.connect().await?; + + if cfg!(debug_assertions) { + println!("WebSocket connection established for {}", db_name); + } + + Ok(Self { + http, + websocket, + strategy: ActiveStrategy::Websocket, + }) + } + + pub async fn get_transaction_baton(&mut self, sql: &str) -> Result { + match self.strategy { + ActiveStrategy::Http => self.http.get_transaction_baton(sql).await, + ActiveStrategy::Websocket => self.websocket.get_transaction_baton(sql).await, + } + } + + pub async fn send( + &mut self, + mut request: &mut serde_json::Value, + ) -> Result { + match self.strategy { + ActiveStrategy::Http => self.http.send(&mut request).await, + ActiveStrategy::Websocket => self.websocket.send(&mut request).await, + } + } + + pub fn get_json_request( + &self, + db: &SQLite3, + sql: &str, + params: &Vec, + ) -> serde_json::Value { + let baton_str = { + let baton = db.transaction_baton.lock().unwrap(); + baton.as_ref().map(|s| s.as_str()).map(|s| s.to_owned()) + }; + let has_begun_transaction = db.has_began_transaction(); + + match self.strategy { + ActiveStrategy::Http => { + self.http + .get_json_request(sql, params, baton_str.as_ref(), has_begun_transaction) + } + ActiveStrategy::Websocket => self.websocket.get_json_request( + sql, + params, + baton_str.as_ref(), + has_begun_transaction, + ), + } + } +} diff --git a/src/transport/wss.rs b/src/transport/wss.rs new file mode 100644 index 0000000..fe6a857 --- /dev/null +++ b/src/transport/wss.rs @@ -0,0 +1,390 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use crate::{ + sqlite::{SqliteError, SQLITE_ERROR}, + transport::{ + LibsqlInterface, RemoteSQLiteResult, RemoteSQliteResultType, RemoteSqliteResponse, + TursoConfig, + }, + utils::get_tokio, +}; +use futures_util::{sink::SinkExt, stream::SplitSink, StreamExt}; +use tokio_tungstenite::{ + tungstenite::{Message, Utf8Bytes}, + MaybeTlsStream, WebSocketStream, +}; + +use serde_json::Value; +use tokio::{ + net::TcpStream, + sync::{oneshot, Mutex}, +}; + +static REQUEST_ID: AtomicU64 = AtomicU64::new(1); +static STREAM_ID: AtomicU64 = AtomicU64::new(1); + +#[derive(PartialEq)] +enum WebSocketConnState { + Connected, + Disconnected, +} + +pub struct WebSocketStrategy { + turso_config: Arc, + bus: ResponseBus, + websocket_handle: Option< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + websocket_state: Arc>, +} + +impl WebSocketStrategy { + pub fn new(turso_config: Arc) -> Self { + Self { + turso_config, + bus: ResponseBus::new(), + websocket_handle: None, + websocket_state: Arc::new(Mutex::new(WebSocketConnState::Disconnected)), + } + } + + fn next_request_id() -> i32 { + REQUEST_ID.fetch_add(1, Ordering::Relaxed) as i32 + } + + fn next_stream_id() -> i32 { + STREAM_ID.fetch_add(1, Ordering::Relaxed) as i32 + } + + async fn open_stream(&mut self) -> Result<(i32, ResponseBus), SqliteError> { + let (writer, bus) = self.get_client().await?; + let request_id = WebSocketStrategy::next_request_id(); + let stream_id = WebSocketStrategy::next_stream_id(); + + let request = serde_json::json!({ + "type": "request", + "request_id": request_id, + "request": { + "type": "open_stream", + "stream_id": stream_id, + } + }); + + writer + .send(Message::Text(Utf8Bytes::from(request.to_string()))) + .await + .map_err(|e| { + SqliteError::new( + format!("Failed to send open_stream message: {}", e), + Some(SQLITE_ERROR), + ) + })?; + + let id = format!("request_id:{}", request_id); + bus.wait_for(id.as_str()).await?; + + Ok((stream_id, bus)) + } + + pub async fn connect(&mut self) -> Result<(), SqliteError> { + let url = format!("wss://{}", self.turso_config.db_url); + if cfg!(debug_assertions) { + println!("Connecting to WebSocket at {}", url); + } + + let (socket, _) = tokio_tungstenite::connect_async(url).await.map_err(|e| { + SqliteError::new( + format!("Failed to connect to WebSocket: {}", e), + Some(SQLITE_ERROR), + ) + })?; + let (mut writer, mut reader) = socket.split(); + + let bus = self.bus.clone(); + let websocket_state = self.websocket_state.clone(); + + get_tokio().spawn(async move { + while let Some(message) = reader.next().await { + match message { + Err(_) | Ok(Message::Close(_)) => { + let mut state = websocket_state.lock().await; + *state = WebSocketConnState::Disconnected; + break; + } + _ => (), + } + + let message = message.unwrap(); + let value: Value = match message { + Message::Text(text) => match serde_json::from_str(&text) { + Ok(value) => value, + Err(_) => { + continue; + } + }, + Message::Binary(binary) => { + match serde_json::from_slice(&binary) { + Ok(value) => value, + Err(e) => { + eprintln!( + "Failed to parse WebSocket binary message as JSON: {}", + e + ); + continue; + } + } + continue; + } + _ => { + eprintln!("Received unsupported WebSocket message type: {:?}", message); + continue; + } + }; + + // store the key and value as {key:value} + + let key = { + if value.get("request_id").is_some() { + format!( + "request_id:{}", + value.get("request_id").unwrap().as_i64().unwrap() + ) + } else if value.get("id").is_some() { + format!("id:{}", value.get("id").unwrap().as_i64().unwrap()) + } else { + format!("type:{}", value.get("type").unwrap().as_str().unwrap()) + } + }; + + bus.respond(key.as_str(), value.clone()).await; + } + }); + + let json = serde_json::json!({ + "type": "hello", + "jwt": self.turso_config.db_token, + }); + + writer + .send(tokio_tungstenite::tungstenite::Message::Text( + Utf8Bytes::from(json.to_string()), + )) + .await + .map_err(|e| { + SqliteError::new( + format!("Failed to send initial message over WebSocket: {}", e), + Some(SQLITE_ERROR), + ) + })?; + + let result = self.bus.wait_for("type:hello_ok").await; + if result.is_err() { + return Err(SqliteError::new( + "Failed to validate database URL & Token. Try again".to_string(), + Some(SQLITE_ERROR), + )); + } + + self.websocket_handle = Some(writer); + self.websocket_state = Arc::new(Mutex::new(WebSocketConnState::Connected)); + Ok(()) + } + + async fn get_client( + &mut self, + ) -> Result< + ( + &mut SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + ResponseBus, + ), + SqliteError, + > { + let state = self.websocket_state.lock().await; + let need_connect = + self.websocket_handle.is_none() || *state == WebSocketConnState::Disconnected; + drop(state); + + if need_connect { + self.connect().await?; + } + + Ok((self.websocket_handle.as_mut().unwrap(), self.bus.clone())) + } +} + +impl LibsqlInterface for WebSocketStrategy { + async fn get_transaction_baton(&mut self, sql: &str) -> Result { + let (stream_id, _) = self.open_stream().await?; + let mut request = serde_json::json!({ + "type": "execute", + "stream_id": stream_id, + "stmt": { + "sql": sql + } + }); + + let result = self.send(&mut request).await; + if let Err(e) = result { + return Err(SqliteError::new( + format!("Failed to get transaction baton: {}", e), + Some(SQLITE_ERROR), + )); + } + + Ok(stream_id.to_string()) + } + + async fn send( + &mut self, + request: &mut serde_json::Value, + ) -> Result { + if let WebSocketConnState::Disconnected = *self.websocket_state.lock().await { + return Err(SqliteError::new( + "WebSocket connection is disconnected".to_string(), + Some(SQLITE_ERROR), + )); + } + + let bus: ResponseBus; + + if request.get("stream_id").is_none() { + let (stream_id, actual_bus) = self.open_stream().await?; + request["stream_id"] = serde_json::Value::from(stream_id); + bus = actual_bus; + } else { + bus = self.bus.clone(); + } + + let request_id = WebSocketStrategy::next_request_id(); + let request = serde_json::json!({ + "type": "request", + "request_id": request_id, + "request": request + }); + + let writer = self.websocket_handle.as_mut().unwrap(); + + if cfg!(debug_assertions) { + println!("Sending request over WebSocket: {:?}", request); + } + + writer + .send(tokio_tungstenite::tungstenite::Message::Text( + Utf8Bytes::from(request.to_string()), + )) + .await + .unwrap(); + + let result = bus + .wait_for(format!("request_id:{}", request_id).as_str()) + .await?; + + let parsed: RemoteSQliteResultType = serde_json::from_value(result).map_err(|e| { + SqliteError::new( + format!("Failed to parse response: {}", e), + Some(SQLITE_ERROR), + ) + })?; + let result = parsed.response; + if let RemoteSQLiteResult::Error { message, code } = result { + return Err(SqliteError::new( + format!("Remote SQLite error (code {}): {}", code, message), + Some(SQLITE_ERROR), + )); + } + if let RemoteSQLiteResult::Close = result { + return Err(SqliteError::new( + "Remote SQLite closed the connection unexpectedly".to_string(), + None, + )); + } + + if let RemoteSQLiteResult::Execute { result } = result { + return Ok(RemoteSqliteResponse { + baton: None, + results: vec![RemoteSQliteResultType { + response: RemoteSQLiteResult::Execute { result }, + }], + }); + } + + Ok(RemoteSqliteResponse { + baton: None, + results: vec![], + }) + } + + fn get_json_request( + &self, + sql: &str, + params: &Vec, + stream_id: Option<&String>, + is_transacting: bool, + ) -> serde_json::Value { + let mut request = serde_json::json!({ + "type": "execute", + "stmt": { + "sql": sql, + "args": params + } + }); + + if is_transacting { + let stream_id: i32 = stream_id.and_then(|s| s.parse::().ok()).unwrap(); + request["stream_id"] = serde_json::json!(stream_id); + } + + request + } +} + +#[derive(Clone)] +struct ResponseBus { + map: Arc>>>, +} + +impl ResponseBus { + pub fn new() -> Self { + Self { + map: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn wait_for(&self, id: &str) -> Result { + let (tx, rx) = oneshot::channel(); + self.map.lock().await.insert(id.to_string(), tx); + + match tokio::time::timeout(Duration::from_secs(10), rx).await { + Ok(result) => match result { + Ok(value) => Ok(value), + Err(_) => Err(SqliteError::new( + "Failed to receive response".to_string(), + Some(SQLITE_ERROR), + )), + }, + Err(_) => Err(SqliteError::new( + "Response timed out".to_string(), + Some(SQLITE_ERROR), + )), + } + } + + pub async fn respond(&self, id: &str, value: Value) { + if let Some(sender) = self.map.lock().await.remove(id) { + let _ = sender.send(value); + } + } +} diff --git a/src/utils.rs b/src/utils.rs index 5cfcd7d..6f86658 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,10 +1,12 @@ -use std::{ffi::c_int, sync::OnceLock}; +use std::{collections::HashMap, ffi::c_int, sync::OnceLock}; use regex::Regex; -use serde::Deserialize; use tokio::runtime::{self, Runtime}; -use crate::sqlite::{push_error, SQLite3, SqliteError, SQLITE_ERROR}; +use crate::{ + sqlite::{push_error, SQLite3, SqliteError, Value, SQLITE_ERROR}, + transport::{QueryResult, RemoteSQLiteResult, RemoteSqliteResponse}, +}; static RUNTIME: OnceLock = OnceLock::new(); @@ -52,12 +54,6 @@ where } } -#[derive(Debug, Deserialize)] -pub struct TursoConfig { - pub db_url: String, - pub db_token: String, -} - #[inline] pub fn sql_is_begin_transaction(sql: &String) -> bool { sql.starts_with("BEGIN") @@ -82,3 +78,71 @@ pub fn sql_is_commit(sql: &String) -> bool { pub fn is_aligned(ptr: *const T) -> bool { !ptr.is_null() && (ptr as usize) % std::mem::align_of::() == 0 } + +pub fn convert_params_to_json(params: &HashMap) -> Vec { + let mut index_value_pairs: Vec<_> = params.iter().collect(); + // Sort by parameter index + index_value_pairs.sort_by_key(|&(k, _)| *k); + + // Map sorted values to JSON + index_value_pairs + .into_iter() + .map(|(_, value)| match value { + Value::Integer(i) => serde_json::json!({ + "type": "integer", + "value": *i.to_string() + }), + + Value::Real(f) => serde_json::json!({ + "type": "float", + "value": *f.to_string() + }), + Value::Text(s) => serde_json::json!({ + "type": "text", + "value": s + }), + Value::Null => serde_json::json!({ + "type": "null", + "value": null + }), + }) + .collect() +} + +pub fn get_execution_result<'a>( + db: &SQLite3, + result: &'a RemoteSqliteResponse, +) -> Result<&'a QueryResult, SqliteError> { + let mut baton = db.transaction_baton.lock().unwrap(); + + if let Some(new_baton) = &result.baton { + baton.replace(new_baton.into()); + } + + let first_execution_result = match result.results.get(0) { + Some(inner) => match &inner.response { + RemoteSQLiteResult::Error { message, code } => { + return Err(SqliteError::new( + format!("Remote SQLite error (code {}): {}", code, message), + Some(SQLITE_ERROR), + )); + } + RemoteSQLiteResult::Execute { result } => Ok(result), + RemoteSQLiteResult::Close => Err::<&'a QueryResult, SqliteError>(SqliteError::new( + "Remote SQLite closed the connection unexpectedly", + None, + )), + }, + None => Err::<&'a QueryResult, SqliteError>(SqliteError::new( + "No results returned from remote SQLite", + None, + )), + }?; + + if let Some(last_insert_rowid) = &first_execution_result.last_insert_rowid { + let mut last_insert_rowid_lock = db.last_insert_rowid.lock().unwrap(); + *last_insert_rowid_lock = Some(last_insert_rowid.parse::().unwrap_or(0)); + } + + Ok(first_execution_result) +}