Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ const addon = bindings<{
databaseInitTokenizer(db: NativeDatabase): void;
databaseExec(db: NativeDatabase, query: string): void;
databaseClose(db: NativeDatabase): void;
databaseCreateFunction(
db: NativeDatabase,
name: string,
fn: (...args: ReadonlyArray<unknown>) => void,
bigint: boolean,
): void;

signalTokenize(value: string): Array<string>;

Expand Down Expand Up @@ -116,6 +122,14 @@ export type RowType<Options extends StatementOptions> = Options extends {
? SqliteValue<Options>
: Record<string, SqliteValue<Options>>;

export type FunctionOptions = Readonly<{
/**
* If `true` - all integers passed to the fucntion will be big
* integers instead of regular (floating-point) numbers.
*/
bigint?: boolean;
}>;

/**
* A compiled SQL statement class.
*/
Expand Down Expand Up @@ -452,6 +466,35 @@ export default class Database {
addon.databaseExec(this.#native, sql);
}

/**
* Create custom SQL function with a given `name`.
*
* @param name - name of the function
* @param fn - function implementation
* @param options - function options.
*/
public createFunction(
name: string,
fn: (...args: ReadonlyArray<unknown>) => void,
options: FunctionOptions = {},
): void {
if (this.#native === undefined) {
throw new Error('Database closed');
}
if (typeof name !== 'string') {
throw new TypeError('Invalid name argument');
}
if (typeof fn !== 'function') {
throw new TypeError('Invalid fn argument');
}
addon.databaseCreateFunction(
this.#native,
name,
fn,
options.bigint === true,
);
}

/**
* Compile a single SQL statement.
*
Expand Down
122 changes: 122 additions & 0 deletions src/addon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,90 @@ static Napi::Value SignalTokenize(const Napi::CallbackInfo& info) {
return result;
}

// Functions

class FunctionWrap {
public:
FunctionWrap(Napi::Function fn, bool is_bigint) : is_bigint_(is_bigint) {
fn_.Reset(fn, 1);
}

static void Run(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
auto wrap = static_cast<FunctionWrap*>(sqlite3_user_data(ctx));

wrap->Call(ctx, argc, argv);
}

static void Final(void* p_app) { delete static_cast<FunctionWrap*>(p_app); }

protected:
void Call(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
auto env = fn_.Env();
Napi::HandleScope scope(env);

assert(argc >= 0);

auto args = std::vector<Napi::Value>(static_cast<size_t>(argc));

for (int i = 0; i < argc; i++) {
args[i] = TranslateValue(argv[i]);
}

auto result = fn_.Value().Call(args);

// Ignore exceptions
if (result.IsEmpty()) {
auto e = env.GetAndClearPendingException();
sqlite3_result_error(ctx, e.Message().c_str(), SQLITE_ERROR);
} else if (result.IsUndefined()) {
sqlite3_result_null(ctx);
} else {
sqlite3_result_error(ctx, "Function must not return a value",
SQLITE_ERROR);
}
}

Napi::Value TranslateValue(sqlite3_value* value) {
auto env = fn_.Env();
int type = sqlite3_value_type(value);
switch (type) {
case SQLITE_INTEGER: {
auto val = sqlite3_value_int64(value);
if (is_bigint_) {
return Napi::BigInt::New(env, static_cast<int64_t>(val));
}
if (static_cast<int64_t>(INT32_MIN) <= val &&
val <= static_cast<int64_t>(INT32_MAX)) {
napi_value n_value;
NAPI_THROW_IF_FAILED(
env, napi_create_int32(env, static_cast<int32_t>(val), &n_value),
Napi::Value());
return Napi::Value(env, n_value);
} else {
return Napi::Number::New(env, val);
}
}
case SQLITE_TEXT:
return Napi::String::New(
env, reinterpret_cast<const char*>(sqlite3_value_text(value)),
sqlite3_value_bytes(value));
case SQLITE_FLOAT:
return Napi::Number::New(env, sqlite3_value_double(value));
case SQLITE_BLOB:
return Napi::Buffer<uint8_t>::Copy(
env, reinterpret_cast<const uint8_t*>(sqlite3_value_blob(value)),
sqlite3_value_bytes(value));
case SQLITE_NULL:
return env.Null();
}
return Napi::Value();
}

private:
Napi::Reference<Napi::Function> fn_;
bool is_bigint_;
};

// Global Settings

thread_local Napi::Reference<Napi::Function> logger_fn_;
Expand Down Expand Up @@ -150,6 +234,8 @@ Napi::Object Database::Init(Napi::Env env, Napi::Object exports) {
Napi::Function::New(env, &Database::InitTokenizer);
exports["databaseClose"] = Napi::Function::New(env, &Database::Close);
exports["databaseExec"] = Napi::Function::New(env, &Database::Exec);
exports["databaseCreateFunction"] =
Napi::Function::New(env, &Database::CreateFunction);
return exports;
}

Expand Down Expand Up @@ -289,6 +375,42 @@ Napi::Value Database::Exec(const Napi::CallbackInfo& info) {
return Napi::Value();
}

Napi::Value Database::CreateFunction(const Napi::CallbackInfo& info) {
auto env = info.Env();

auto db = FromExternal(info[0]);
auto name = info[1].As<Napi::String>();
auto fn = info[2].As<Napi::Function>();
auto is_bigint = info[3].As<Napi::Boolean>();

assert(name.IsString());
assert(fn.IsFunction());
assert(is_bigint.IsBoolean());

if (db == nullptr) {
return Napi::Value();
}

if (db->handle_ == nullptr) {
NAPI_THROW(Napi::Error::New(env, "Database closed"), Napi::Value());
}

auto name_utf8 = name.Utf8Value();

auto fn_wrap = new FunctionWrap(fn, is_bigint);

int r = sqlite3_create_function_v2(db->handle_, name_utf8.c_str(), -1,
SQLITE_UTF8, // TODO(indutny): or UTF16?
fn_wrap, FunctionWrap::Run, nullptr,
nullptr, FunctionWrap::Final);

if (r != SQLITE_OK) {
delete fn_wrap;
return db->ThrowSqliteError(env, r);
}
return Napi::Value();
}

Napi::Value Database::ThrowSqliteError(Napi::Env env, int error) {
assert(handle_ != nullptr);
const char* msg = sqlite3_errmsg(handle_);
Expand Down
1 change: 1 addition & 0 deletions src/addon.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Database {
static Napi::Value InitTokenizer(const Napi::CallbackInfo& info);
static Napi::Value Close(const Napi::CallbackInfo& info);
static Napi::Value Exec(const Napi::CallbackInfo& info);
static Napi::Value CreateFunction(const Napi::CallbackInfo& info);

fts5_api* GetFTS5API(Napi::Env env);

Expand Down
50 changes: 49 additions & 1 deletion test/memory.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, expect, test, beforeEach, afterEach } from 'vitest';
import { describe, expect, test, beforeEach, afterEach, vi } from 'vitest';

import Database, { setLogger } from '../lib/index.js';

Expand Down Expand Up @@ -497,3 +497,51 @@ describe('statement cache', () => {
);
});
});

describe('custom function', () => {
let fnDb: Database;
let fn: ReturnType<typeof vi.fn>;
let bigFn: ReturnType<typeof vi.fn>;
beforeEach(() => {
fnDb = new Database(':memory:');

fn = vi.fn();
fnDb.createFunction('fn', fn);

bigFn = vi.fn();
fnDb.createFunction('bigFn', bigFn, {
bigint: true,
});
});

afterEach(() => {
fnDb.close();
});

test('it calls the function without args', () => {
fnDb.exec(`SELECT fn()`);
expect(fn).toHaveBeenCalledWith();
});

test('it calls the function with multiple args', () => {
fnDb.exec(`SELECT fn(1, '123', NULL)`);
expect(fn).toHaveBeenCalledWith(1, '123', null);
});

test('it calls the function with blob', () => {
fnDb.exec(`SELECT fn(x'abba')`);
expect(fn).toHaveBeenCalledWith(Buffer.from('abba', 'hex'));
});

test('it uses bigints when configured', () => {
fnDb.exec(`SELECT bigFn(123456)`);
expect(bigFn).toHaveBeenCalledWith(123456n);
});

test('it throws when function returns a value', () => {
fnDb.createFunction('intFn', () => {
return 1;
});
expect(() => fnDb.exec(`SELECT intFn()`)).toThrowError('SQLITE_ERROR');
});
});
Loading