diff --git a/src/lua.rs b/src/lua.rs index 67f03007..ce270d87 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -9,7 +9,7 @@ use gc_arena::{ use crate::{ finalizers::Finalizers, stash::{Fetchable, Stashable}, - stdlib::{load_base, load_coroutine, load_io, load_math, load_string, load_table}, + stdlib::{load_base, load_coroutine, load_io, load_math, load_string, load_table, load_utf8}, string::InternedStringSet, thread::BadThreadMode, Error, ExternError, FromMultiValue, FromValue, Fuel, IntoValue, Registry, RuntimeError, @@ -176,6 +176,7 @@ impl Lua { load_math(ctx); load_string(ctx); load_table(ctx); + load_utf8(ctx); }) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index aa766153..d7695ee4 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -4,8 +4,9 @@ mod io; mod math; mod string; mod table; +mod utf8; pub use self::{ base::load_base, coroutine::load_coroutine, io::load_io, math::load_math, string::load_string, - table::load_table, + table::load_table, utf8::load_utf8, }; diff --git a/src/stdlib/utf8.rs b/src/stdlib/utf8.rs new file mode 100644 index 00000000..8e3fe99c --- /dev/null +++ b/src/stdlib/utf8.rs @@ -0,0 +1,286 @@ +use crate::{Callback, CallbackReturn, Context, IntoValue, String as LuaString, Table, Value}; + +fn convert_index(i: i64, len: usize) -> Option { + let val = match i { + 0 => 0, + v @ 1.. => v - 1, + v @ ..=-1 => (len as i64 + v).max(0), + }; + usize::try_from(val).ok() +} + +fn convert_index_end(i: i64, len: usize) -> Option { + let val = match i { + v @ 0.. => v, + v @ ..=-1 => (len as i64 + v + 1).max(0), + }; + usize::try_from(val).ok() +} + +pub fn load_utf8(ctx: Context) { + let utf8 = Table::new(&ctx); + + utf8.set_field( + ctx, + "char", + Callback::from_fn(&ctx, |ctx, _, mut stack| { + let mut bytes = Vec::with_capacity(stack.len() * 4); + let iter = stack.into_iter().enumerate(); + + for (idx, i) in iter { + let code = match i.to_integer() { + Some(code) => code as u32, + None => { + return Err(format!( + "bad argument #{} to 'char' (number expected, got {})", + idx + 1, + i.type_name() + ) + .into_value(ctx) + .into()) + } + }; + + if let Some(c) = char::from_u32(code) { + let mut buf = [0; 4]; + let utf8_bytes = c.encode_utf8(&mut buf).as_bytes(); + bytes.extend_from_slice(utf8_bytes); + } else { + return Err(format!( + "bad argument #{} to 'char' (value out of range)", + idx + 1 + ) + .into_value(ctx) + .into()); + } + } + + let result = ctx.intern(&bytes); + stack.replace(ctx, result); + + Ok(CallbackReturn::Return) + }), + ); + + let _ = utf8.set(ctx, "charpattern", r"[\0-\x7F\xC2-\xF4][\x80-\xBF]*"); + + utf8.set_field( + ctx, + "codes", + Callback::from_fn(&ctx, |ctx, _, mut stack| { + let s = stack.consume::(ctx)?; + + let callback = Callback::from_fn(&ctx, |ctx, _, mut stack| { + let (s, n) = stack.consume::<(LuaString, i64)>(ctx)?; + + let bytes = s.as_bytes(); + let n = n as usize; + + if n >= bytes.len() { + stack.replace(ctx, (Value::Nil, Value::Nil)); + return Ok(CallbackReturn::Return); + } + + let bytes = &bytes[n..]; + + let mut chunks = bytes.utf8_chunks(); + + if let Some(chunk) = chunks.next() { + if !chunk.invalid().is_empty() { + return Err("Invalid UTF-8 byte sequence".into_value(ctx).into()); + } + + if let Some(c) = chunk.valid().chars().next() { + if n == 0 { + stack.replace(ctx, (1, c as i64)); + } else { + let len = c.len_utf8(); + stack.replace(ctx, ((n + len) as i64, c as i64)); + } + Ok(CallbackReturn::Return) + } else { + stack.replace(ctx, (Value::Nil, Value::Nil)); + Ok(CallbackReturn::Return) + } + } else { + stack.replace(ctx, (Value::Nil, Value::Nil)); + Ok(CallbackReturn::Return) + } + }); + + stack.replace(ctx, (callback, s, 0)); + + Ok(CallbackReturn::Return) + }), + ); + + utf8.set_field( + ctx, + "len", + Callback::from_fn(&ctx, |ctx, _, mut stack| { + let (s, i, j) = stack.consume::<(String, Option, Option)>(ctx)?; + + let len = s.len(); + + let start = convert_index(i.unwrap_or(1), len).unwrap_or(usize::MAX); + let end = convert_index_end(j.unwrap_or(len as i64), len) + .unwrap_or(usize::MAX) + .min(len); + + // TODO: we need to check this conditions + if start >= len || (end < start && end != 0) { + stack.replace(ctx, 0); + return Ok(CallbackReturn::Return); + } + + let s = &s[start..=end]; + + let s = match std::str::from_utf8(s.as_bytes()) { + Ok(s) => s, + Err(err) => { + let position = err.error_len().unwrap_or_default(); + stack.replace(ctx, (false, position as i64 + 1)); + return Ok(CallbackReturn::Return); + } + }; + + stack.replace(ctx, s.chars().count() as i64); + + Ok(CallbackReturn::Return) + }), + ); + + utf8.set_field( + ctx, + "codepoint", + Callback::from_fn(&ctx, |ctx, _, mut stack| { + let (s, i, j) = stack.consume::<(String, Option, Option)>(ctx)?; + + let len = s.len(); + + let i = i.unwrap_or(1); + let j = j.unwrap_or(i); + + let start = convert_index(i, len).unwrap_or(usize::MAX); + let end = convert_index_end(j, len).unwrap_or(usize::MAX).min(len); + + if start > len { + stack.replace(ctx, Value::Nil); + return Ok(CallbackReturn::Return); + } + + if start < 1 { + return Err("bad argument #2 (out of range)".into_value(ctx).into()); + } + + if start > end { + return Ok(CallbackReturn::Return); + } + + let s = &s[start..=end]; + + let s = std::str::from_utf8(s.as_bytes()).map_err(|err| { + format!( + "bad argument #1 to 'codepoint' (invalid byte sequence at {})", + err.error_len().unwrap_or_default() + ) + .into_value(ctx) + })?; + + stack.extend(s.chars().map(|c| Value::Integer(c as i64))); + + Ok(CallbackReturn::Return) + }), + ); + + utf8.set_field( + ctx, + "offset", + Callback::from_fn(&ctx, |ctx, _, mut stack| { + let (s, n, i): (String, i64, Option) = stack.consume(ctx)?; + let bytes = s.as_bytes(); + let len = bytes.len(); + + let i = i.unwrap_or(if n >= 0 { 1 } else { len as i64 + 1 }); + + if i == 0 { + return Err("bad argument #3 to 'offset' (position out of bounds)" + .into_value(ctx) + .into()); + } + + let mut position = convert_index(i, len).unwrap_or(usize::MAX); + + if n != 0 && position < len && (bytes[position] & 0xC0) == 0x80 { + return Err("initial position is a continuation byte" + .into_value(ctx) + .into()); + } + + if n == 0 { + if position >= len { + stack.replace(ctx, Value::Nil); + return Ok(CallbackReturn::Return); + } + + while position > 0 && (bytes[position] & 0xC0) == 0x80 { + position -= 1; + } + + stack.replace(ctx, (position as i64) + 1); + return Ok(CallbackReturn::Return); + } + + if n > 0 { + let mut count = 0; + + while count < n && position < len { + if (bytes[position] & 0xC0) != 0x80 { + count += 1; + } + + if count == n { + break; + } + + position += 1; + } + + if count == n { + stack.replace(ctx, (position as i64) + 1); + return Ok(CallbackReturn::Return); + } + + if count == n - 1 && position == len { + stack.replace(ctx, (position as i64) + 1); + return Ok(CallbackReturn::Return); + } else if count < n { + stack.replace(ctx, Value::Nil); + return Ok(CallbackReturn::Return); + } + } else if n < 0 { + let target_count = -n; + let mut count = 0i64; + + let mut current_byte_index = convert_index(i, len).unwrap_or(usize::MAX); + + while count < target_count { + if current_byte_index == 0 { + stack.replace(ctx, Value::Nil); + return Ok(CallbackReturn::Return); + } + current_byte_index -= 1; + if (bytes[current_byte_index] & 0xC0) != 0x80 { + count += 1; + } + } + stack.replace(ctx, (current_byte_index as i64) + 1); + return Ok(CallbackReturn::Return); + } + + Ok(CallbackReturn::Return) + }), + ); + + ctx.set_global("utf8", utf8); +} diff --git a/tests/scripts/utf8.lua b/tests/scripts/utf8.lua new file mode 100644 index 00000000..ef7319be --- /dev/null +++ b/tests/scripts/utf8.lua @@ -0,0 +1,260 @@ +function is_err(f, ...) + local status, err = pcall(f, ...) + return not status, err +end + +function collect_codes(s) + local results = {} + local err_status, err_val = pcall(function() + for p, c in utf8.codes(s) do + table.insert(results, {p, c}) + end + end) + if not err_status then + return false, err_val + end + return results +end + +function collect_codepoints(s, i, j) + local results = {} + local args = {s, i, j} + local err_status, err_val = pcall(function() + local values = {utf8.codepoint(table.unpack(args))} + for _, v in ipairs(values) do + table.insert(results, v) + end + end) + if not err_status then + return false, err_val + end + return results +end + +do + assert(utf8.char() == "") + assert(utf8.char(65) == "A") + assert(utf8.char(65, 66, 67) == "ABC") + assert(utf8.char(0x41, 0x42, 0x43) == "ABC") + assert(utf8.char(1055, 1088, 1080, 1074, 1077, 1090) == "Привет") + assert(utf8.char(72, 1080, 33) == "Hи!") + assert(utf8.char(0xC2, 0xA2) == "\195\130\194\162") + assert(utf8.char(162) == "\194\162") + assert(utf8.char(0xE2, 0x82, 0xAC) == "\195\162\194\130\194\172") + assert(utf8.char(8364) == "\226\130\172") + assert(utf8.char(0xF0, 0x9F, 0x98, 0x80) == "\195\176\194\159\194\152\194\128") + assert(utf8.char(128512) == "\240\159\152\128") + assert(utf8.char(0) == "\0") + assert(utf8.char(65, 0, 66) == "A\0B") + assert(utf8.char(0x7F) == "\127") + assert(utf8.char(0x80) == "\194\128") + assert(utf8.char(0x7FF) == "\223\191") + assert(utf8.char(0x800) == "\224\160\128") + assert(utf8.char(0xFFFF) == "\239\191\191") + assert(utf8.char(0x10000) == "\240\144\128\128") + assert(utf8.char(0x10FFFF) == "\244\143\191\191") + assert(is_err(utf8.char, "A")) + assert(is_err(utf8.char, 65, "B")) + assert(is_err(utf8.char, {})) + assert(is_err(utf8.char, nil)) + assert(is_err(utf8.char, -1)) + assert(is_err(utf8.char, 0x110000)) + assert(is_err(utf8.char, 0xD800)) + assert(is_err(utf8.char, 0xDFFF)) + assert(is_err(utf8.char, 0x110000)) + assert(is_err(utf8.char, "not a number")) +end + +do + assert(utf8.charpattern == "[\\0-\\x7F\\xC2-\\xF4][\\x80-\\xBF]*") +end + +do + local empty_codes = collect_codes("") + assert(type(empty_codes) == "table" and #empty_codes == 0) + + local abc_codes = collect_codes("ABC") + assert(type(abc_codes) == "table" and #abc_codes == 3) + assert(abc_codes[1][1] == 1 and abc_codes[1][2] == 65) + assert(abc_codes[2][1] == 2 and abc_codes[2][2] == 66) + assert(abc_codes[3][1] == 3 and abc_codes[3][2] == 67) + + local ab0c_codes = collect_codes("AB\0C") + assert(type(ab0c_codes) == "table" and #ab0c_codes == 4) + assert(ab0c_codes[1][1] == 1 and ab0c_codes[1][2] == 65) + assert(ab0c_codes[2][1] == 2 and ab0c_codes[2][2] == 66) + assert(ab0c_codes[3][1] == 3 and ab0c_codes[3][2] == 0) + assert(ab0c_codes[4][1] == 4 and ab0c_codes[4][2] == 67) + + local privet = "Привет" + local privet_codes = collect_codes(privet) + assert(#privet_codes == 6) + assert(privet_codes[1][1] == 1 and privet_codes[1][2] == 1055) + assert(privet_codes[2][1] == 3 and privet_codes[2][2] == 1088) + assert(privet_codes[3][1] == 5 and privet_codes[3][2] == 1080) + assert(privet_codes[4][1] == 7 and privet_codes[4][2] == 1074) + assert(privet_codes[5][1] == 9 and privet_codes[5][2] == 1077) + assert(privet_codes[6][1] == 11 and privet_codes[6][2] == 1090) + + local hieuro = "Hi€!" + local hieuro_codes = collect_codes(hieuro) + assert(#hieuro_codes == 4) + assert(hieuro_codes[1][1] == 1 and hieuro_codes[1][2] == 72) + assert(hieuro_codes[2][1] == 2 and hieuro_codes[2][2] == 105) + assert(hieuro_codes[3][1] == 3 and hieuro_codes[3][2] == 8364) + assert(hieuro_codes[4][1] == 6 and hieuro_codes[4][2] == 33) + + local emoji = "😀" + local emoji_codes = collect_codes(emoji) + assert(#emoji_codes == 1) + assert(emoji_codes[1][1] == 1 and emoji_codes[1][2] == 128512) + + assert(collect_codes("abc\xE2\x82") == false) + assert(collect_codes("abc\xE2\x82\xFF") == false) + assert(collect_codes("abc\xFF") == false) + assert(collect_codes("\xC0\x80") == false) +end + +do + local s = "ABC" + assert(table.concat(collect_codepoints(s), ",") == "65") + assert(table.concat(collect_codepoints(s, 1), ",") == "65") + assert(table.concat(collect_codepoints(s, 2), ",") == "66") + assert(table.concat(collect_codepoints(s, 3), ",") == "67") + assert(collect_codepoints(s, 4) == false) + assert(table.concat(collect_codepoints(s, 1, 1), ",") == "65") + assert(table.concat(collect_codepoints(s, 1, 2), ",") == "65,66") + assert(table.concat(collect_codepoints(s, 1, 3), ",") == "65,66,67") + assert(table.concat(collect_codepoints(s, 2, 3), ",") == "66,67") + assert(table.concat(collect_codepoints(s, 3, 3), ",") == "67") + assert(collect_codepoints(s, 1, 10) == false) + assert(table.concat(collect_codepoints(s, 3, 1), ",") == "") + assert(table.concat(collect_codepoints(s, -1), ",") == "67") + assert(table.concat(collect_codepoints(s, -2), ",") == "66") + assert(table.concat(collect_codepoints(s, -3), ",") == "65") + assert(table.concat(collect_codepoints(s, -3, -1), ",") == "65,66,67") + assert(table.concat(collect_codepoints(s, -2, -1), ",") == "66,67") + assert(table.concat(collect_codepoints(s, -1, -1), ",") == "67") + assert(table.concat(collect_codepoints(s, 1, -1), ",") == "65,66,67") + assert(table.concat(collect_codepoints(s, 2, -1), ",") == "66,67") + assert(table.concat(collect_codepoints(s, 1, -2), ",") == "65,66") + assert(table.concat(collect_codepoints(s, -3, 3), ",") == "65,66,67") + assert(table.concat(collect_codepoints(s, -3, 1), ",") == "65") + + local privet = "Привет" + assert(table.concat(collect_codepoints(privet, 1), ",") == "1055") + assert(table.concat(collect_codepoints(privet, 1), ",") == "1055") + assert(table.concat(collect_codepoints(privet, 3), ",") == "1088") + assert(table.concat(collect_codepoints(privet, 1, 2), ",") == "1055") + assert(table.concat(collect_codepoints(privet, 1, 3), ",") == "1055,1088") + assert(table.concat(collect_codepoints(privet, 1, 4), ",") == "1055,1088") + assert(table.concat(collect_codepoints(privet, 1, 12), ",") == "1055,1088,1080,1074,1077,1090") + assert(table.concat(collect_codepoints(privet, 3, 7), ",") == "1088,1080,1074") + assert(table.concat(collect_codepoints(privet, -2, -1), ",") == "1090") + assert(table.concat(collect_codepoints(privet, 11, -1), ",") == "1090") + assert(table.concat(collect_codepoints(privet, 1, -1), ",") == "1055,1088,1080,1074,1077,1090") + assert(collect_codepoints("", 1, 1) == false) + + local emoji = "😀" + assert(table.concat(collect_codepoints(emoji), ",") == "128512") + assert(table.concat(collect_codepoints(emoji, 1), ",") == "128512") + assert(table.concat(collect_codepoints(emoji, 1), ",") == "128512") + assert(table.concat(collect_codepoints(emoji, 1, 4), ",") == "128512") + assert(collect_codepoints("abc\xE2\x82", 1) == false) + assert(collect_codepoints("abc\xE2\x82\xFF", 1) == false) + assert(collect_codepoints("abc\xFF", 1) == false) + assert(collect_codepoints("abc\xFF", 4) == false) + assert(collect_codepoints("abc\xE2\x82", 1, 5) == false) + assert(collect_codepoints("abc\xE2\x20\xAC", 1, 6) == false) +end + +do + assert(utf8.len("") == 0) + assert(utf8.len("ABC") == 3) + assert(utf8.len("При") == 3) + assert(utf8.len("Привет") == 6) + assert(utf8.len("😀") == 1) + assert(utf8.len("A😀B") == 3) + assert(utf8.len("A\0B") == 3) + + local s = "Привет" + assert(utf8.len(s, 1, 1) == 1) + assert(utf8.len(s, 1, 2) == 1) + assert(utf8.len(s, 1, 3) == 2) + assert(utf8.len(s, 1, 4) == 2) + assert(utf8.len(s, 3, 4) == 1) + assert(utf8.len(s, 3, 6) == 2) + assert(utf8.len(s, 1, 12) == 6) + assert(utf8.len(s, 1, -1) == 6) + assert(utf8.len(s, -12, -1) == 6) + assert(utf8.len(s, -2, -1) == 1) + assert(utf8.len(s, 11, 12) == 1) + assert(utf8.len(s, 1, 6) == 3) + assert(utf8.len(s, 7, 12) == 3) + assert(utf8.len(s, 13, 20) == 0) + assert(utf8.len(s, 5, 1) == 0) + assert(utf8.len(s, 1, 11) == 6) +end + +do + local s = "Привет" + assert(utf8.offset(s, 0) == 1) + assert(utf8.offset(s, 1) == 1) + assert(utf8.offset(s, 2) == 3) + assert(utf8.offset(s, 6) == 11) + assert(utf8.offset(s, 7) == 13) + assert(utf8.offset(s, 8) == nil) + assert(utf8.offset(s, -1) == 11) + assert(utf8.offset(s, -2) == 9) + assert(utf8.offset(s, -6) == 1) + assert(utf8.offset(s, -7) == nil) + assert(utf8.offset(s, 1, 1) == 1) + assert(is_err(utf8.offset, s, 1, 2)) + assert(utf8.offset(s, 1, 3) == 3) + assert(utf8.offset(s, 2, 3) == 5) + assert(utf8.offset(s, 1, 11) == 11) + assert(is_err(utf8.offset, s, 1, 12)) + assert(utf8.offset(s, 1, 13) == 13) + assert(utf8.offset(s, 2, 11) == 13) + assert(is_err(utf8.offset, s, 2, 12)) + assert(is_err(utf8.offset, s, -1, 12)) + assert(utf8.offset(s, -1, 11) == 9) + assert(utf8.offset(s, -1, 3) == 1) + assert(is_err(utf8.offset, s, -1, 2)) + assert(utf8.offset(s, -1, 1) == nil) + assert(is_err(utf8.offset, s, -2, 12)) + assert(is_err(utf8.offset, s, -6, 12)) + assert(is_err(utf8.offset, s, -7, 12)) + assert(utf8.offset(s, -1, #s + 1) == 11) + assert(utf8.offset(s, 0, 1) == 1) + assert(utf8.offset(s, 0, 2) == 1) + assert(utf8.offset(s, 0, 3) == 3) + assert(utf8.offset(s, 0, 4) == 3) + assert(utf8.offset(s, 0, 11) == 11) + assert(utf8.offset(s, 0, 12) == 11) + assert(utf8.offset(s, 0, 13) == nil) + assert(is_err(utf8.offset, s, 0, 0)) + assert(utf8.offset(s, 0, -1) == 11) + assert(utf8.offset(s, 0, -12) == 1) + + local ascii = "ABCDEFG" + assert(utf8.offset(ascii, 3, 1) == 3) + assert(utf8.offset(ascii, -3, 7) == 4) + assert(utf8.offset(ascii, 0, 5) == 5) + + local emoji = "A😀B" + assert(utf8.offset(emoji, 1) == 1) + assert(utf8.offset(emoji, 2) == 2) + assert(utf8.offset(emoji, 3) == 6) + assert(utf8.offset(emoji, 4) == 7) + assert(utf8.offset(emoji, -1) == 6) + assert(utf8.offset(emoji, -2) == 2) + assert(utf8.offset(emoji, -3) == 1) + assert(utf8.offset(emoji, 0, 1) == 1) + assert(utf8.offset(emoji, 0, 2) == 2) + assert(utf8.offset(emoji, 0, 3) == 2) + assert(utf8.offset(emoji, 0, 4) == 2) + assert(utf8.offset(emoji, 0, 5) == 2) + assert(utf8.offset(emoji, 0, 6) == 6) + assert(utf8.offset(emoji, 0, 7) == nil) +end