diff --git a/lib/index.ts b/lib/index.ts index 7821adf..4b00a64 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -47,6 +47,12 @@ const addon = bindings<{ databaseInitTokenizer(db: NativeDatabase): void; databaseExec(db: NativeDatabase, query: string): void; databaseClose(db: NativeDatabase): void; + databaseCreateFunction( + db: NativeDatabase, + name: string, + fn: (...args: ReadonlyArray) => void, + bigint: boolean, + ): void; signalTokenize(value: string): Array; @@ -116,6 +122,14 @@ export type RowType = Options extends { ? SqliteValue : Record>; +export type FunctionOptions = Readonly<{ + /** + * If `true` - all integers passed to the fucntion will be big + * integers instead of regular (floating-point) numbers. + */ + bigint?: boolean; +}>; + /** * A compiled SQL statement class. */ @@ -452,6 +466,35 @@ export default class Database { addon.databaseExec(this.#native, sql); } + /** + * Create custom SQL function with a given `name`. + * + * @param name - name of the function + * @param fn - function implementation + * @param options - function options. + */ + public createFunction( + name: string, + fn: (...args: ReadonlyArray) => void, + options: FunctionOptions = {}, + ): void { + if (this.#native === undefined) { + throw new Error('Database closed'); + } + if (typeof name !== 'string') { + throw new TypeError('Invalid name argument'); + } + if (typeof fn !== 'function') { + throw new TypeError('Invalid fn argument'); + } + addon.databaseCreateFunction( + this.#native, + name, + fn, + options.bigint === true, + ); + } + /** * Compile a single SQL statement. * diff --git a/src/addon.cc b/src/addon.cc index 818e562..1f5353e 100644 --- a/src/addon.cc +++ b/src/addon.cc @@ -74,6 +74,90 @@ static Napi::Value SignalTokenize(const Napi::CallbackInfo& info) { return result; } +// Functions + +class FunctionWrap { + public: + FunctionWrap(Napi::Function fn, bool is_bigint) : is_bigint_(is_bigint) { + fn_.Reset(fn, 1); + } + + static void Run(sqlite3_context* ctx, int argc, sqlite3_value** argv) { + auto wrap = static_cast(sqlite3_user_data(ctx)); + + wrap->Call(ctx, argc, argv); + } + + static void Final(void* p_app) { delete static_cast(p_app); } + + protected: + void Call(sqlite3_context* ctx, int argc, sqlite3_value** argv) { + auto env = fn_.Env(); + Napi::HandleScope scope(env); + + assert(argc >= 0); + + auto args = std::vector(static_cast(argc)); + + for (int i = 0; i < argc; i++) { + args[i] = TranslateValue(argv[i]); + } + + auto result = fn_.Value().Call(args); + + // Ignore exceptions + if (result.IsEmpty()) { + auto e = env.GetAndClearPendingException(); + sqlite3_result_error(ctx, e.Message().c_str(), SQLITE_ERROR); + } else if (result.IsUndefined()) { + sqlite3_result_null(ctx); + } else { + sqlite3_result_error(ctx, "Function must not return a value", + SQLITE_ERROR); + } + } + + Napi::Value TranslateValue(sqlite3_value* value) { + auto env = fn_.Env(); + int type = sqlite3_value_type(value); + switch (type) { + case SQLITE_INTEGER: { + auto val = sqlite3_value_int64(value); + if (is_bigint_) { + return Napi::BigInt::New(env, static_cast(val)); + } + if (static_cast(INT32_MIN) <= val && + val <= static_cast(INT32_MAX)) { + napi_value n_value; + NAPI_THROW_IF_FAILED( + env, napi_create_int32(env, static_cast(val), &n_value), + Napi::Value()); + return Napi::Value(env, n_value); + } else { + return Napi::Number::New(env, val); + } + } + case SQLITE_TEXT: + return Napi::String::New( + env, reinterpret_cast(sqlite3_value_text(value)), + sqlite3_value_bytes(value)); + case SQLITE_FLOAT: + return Napi::Number::New(env, sqlite3_value_double(value)); + case SQLITE_BLOB: + return Napi::Buffer::Copy( + env, reinterpret_cast(sqlite3_value_blob(value)), + sqlite3_value_bytes(value)); + case SQLITE_NULL: + return env.Null(); + } + return Napi::Value(); + } + + private: + Napi::Reference fn_; + bool is_bigint_; +}; + // Global Settings thread_local Napi::Reference logger_fn_; @@ -150,6 +234,8 @@ Napi::Object Database::Init(Napi::Env env, Napi::Object exports) { Napi::Function::New(env, &Database::InitTokenizer); exports["databaseClose"] = Napi::Function::New(env, &Database::Close); exports["databaseExec"] = Napi::Function::New(env, &Database::Exec); + exports["databaseCreateFunction"] = + Napi::Function::New(env, &Database::CreateFunction); return exports; } @@ -289,6 +375,42 @@ Napi::Value Database::Exec(const Napi::CallbackInfo& info) { return Napi::Value(); } +Napi::Value Database::CreateFunction(const Napi::CallbackInfo& info) { + auto env = info.Env(); + + auto db = FromExternal(info[0]); + auto name = info[1].As(); + auto fn = info[2].As(); + auto is_bigint = info[3].As(); + + assert(name.IsString()); + assert(fn.IsFunction()); + assert(is_bigint.IsBoolean()); + + if (db == nullptr) { + return Napi::Value(); + } + + if (db->handle_ == nullptr) { + NAPI_THROW(Napi::Error::New(env, "Database closed"), Napi::Value()); + } + + auto name_utf8 = name.Utf8Value(); + + auto fn_wrap = new FunctionWrap(fn, is_bigint); + + int r = sqlite3_create_function_v2(db->handle_, name_utf8.c_str(), -1, + SQLITE_UTF8, // TODO(indutny): or UTF16? + fn_wrap, FunctionWrap::Run, nullptr, + nullptr, FunctionWrap::Final); + + if (r != SQLITE_OK) { + delete fn_wrap; + return db->ThrowSqliteError(env, r); + } + return Napi::Value(); +} + Napi::Value Database::ThrowSqliteError(Napi::Env env, int error) { assert(handle_ != nullptr); const char* msg = sqlite3_errmsg(handle_); diff --git a/src/addon.h b/src/addon.h index 2d76f4b..800aff5 100644 --- a/src/addon.h +++ b/src/addon.h @@ -30,6 +30,7 @@ class Database { static Napi::Value InitTokenizer(const Napi::CallbackInfo& info); static Napi::Value Close(const Napi::CallbackInfo& info); static Napi::Value Exec(const Napi::CallbackInfo& info); + static Napi::Value CreateFunction(const Napi::CallbackInfo& info); fts5_api* GetFTS5API(Napi::Env env); diff --git a/test/memory.test.ts b/test/memory.test.ts index 4d4e83c..531c251 100644 --- a/test/memory.test.ts +++ b/test/memory.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, test, beforeEach, afterEach } from 'vitest'; +import { describe, expect, test, beforeEach, afterEach, vi } from 'vitest'; import Database, { setLogger } from '../lib/index.js'; @@ -497,3 +497,51 @@ describe('statement cache', () => { ); }); }); + +describe('custom function', () => { + let fnDb: Database; + let fn: ReturnType; + let bigFn: ReturnType; + beforeEach(() => { + fnDb = new Database(':memory:'); + + fn = vi.fn(); + fnDb.createFunction('fn', fn); + + bigFn = vi.fn(); + fnDb.createFunction('bigFn', bigFn, { + bigint: true, + }); + }); + + afterEach(() => { + fnDb.close(); + }); + + test('it calls the function without args', () => { + fnDb.exec(`SELECT fn()`); + expect(fn).toHaveBeenCalledWith(); + }); + + test('it calls the function with multiple args', () => { + fnDb.exec(`SELECT fn(1, '123', NULL)`); + expect(fn).toHaveBeenCalledWith(1, '123', null); + }); + + test('it calls the function with blob', () => { + fnDb.exec(`SELECT fn(x'abba')`); + expect(fn).toHaveBeenCalledWith(Buffer.from('abba', 'hex')); + }); + + test('it uses bigints when configured', () => { + fnDb.exec(`SELECT bigFn(123456)`); + expect(bigFn).toHaveBeenCalledWith(123456n); + }); + + test('it throws when function returns a value', () => { + fnDb.createFunction('intFn', () => { + return 1; + }); + expect(() => fnDb.exec(`SELECT intFn()`)).toThrowError('SQLITE_ERROR'); + }); +});