diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 128679d020c..2f0ffea1c2e 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -216,7 +216,7 @@ int main(int argc, char ** argv) { ctx_cli.ctx_server.start_loop(); }); - auto inf = ctx_cli.ctx_server.get_info(); + auto inf = ctx_cli.ctx_server.get_meta(); std::string modalities = "text"; if (inf.has_inp_image) { modalities += ", vision"; diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ab6b3aa7cec..b02afaefda1 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -115,26 +115,14 @@ bool lora_should_clear_cache( !lora_all_alora(next)); } -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } +std::map parse_lora_request(const json & data) { + std::map lora; // set value for (const auto & entry : data) { int id = json_value(entry, "id", -1); float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } + lora[id] = scale; } return lora; @@ -1435,7 +1423,7 @@ std::string safe_json_to_str(const json & data) { // TODO: reuse llama_detokenize template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { +static std::string tokens_to_str(const llama_vocab * ctx, Iter begin, Iter end) { std::string ret; for (; begin != end; ++begin) { ret += common_token_to_piece(ctx, *begin); @@ -1445,7 +1433,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { } std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { - return tokens_to_str(ctx, tokens.begin(), tokens.end()); + auto model = llama_get_model(ctx); + return tokens_to_str(llama_model_get_vocab(model), tokens.begin(), tokens.end()); +} + +std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens) { + return tokens_to_str(vocab, tokens.begin(), tokens.end()); } // format incomplete utf-8 multibyte character for output diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 0629bb5edd5..152a2a3c46c 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -107,9 +107,7 @@ bool lora_should_clear_cache( const std::vector & current, const std::vector & next); -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data); +std::map parse_lora_request(const json & data); bool are_lora_equal( const std::vector & l1, @@ -325,6 +323,7 @@ std::vector get_token_probabilities(llama_context * ctx, int i std::string safe_json_to_str(const json & data); std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); +std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens); // format incomplete utf-8 multibyte character for output std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index cde34e6533c..a132b87c84d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -507,19 +507,42 @@ struct server_metrics { // struct server_context_impl { + friend struct server_context; + +public: + // only use these pointers outside of this class: + // - when not in sleeping state + // - and, with thread-safe APIs (e.g., tokenizer calls) + llama_model * model = nullptr; + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; + + server_queue queue_tasks; + server_response queue_results; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { + if (!sleeping) { + // destroy() is already called when entering sleeping state + // we don't call it again here to avoid double free + destroy(); + } + } + +private: + // note: accessing these fields outside of this class is not thread-safe + // use server_context methods instead + common_params params_base; // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; common_init_result_ptr llama_init_dft; - llama_model * model = nullptr; llama_context * ctx = nullptr; - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; bool vocab_dft_compatible = true; llama_model * model_dft = nullptr; @@ -537,35 +560,19 @@ struct server_context_impl { int slots_debug = 0; - server_queue queue_tasks; - server_response queue_results; - std::unique_ptr prompt_cache; server_metrics metrics; - // cached responses for HTTP API (read-only from HTTP threads) - json json_server_props = json::object(); - json json_server_model_meta = json::object(); + json json_webui_settings = json::object(); // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; std::string model_name; // name of the loaded model, to be used by API - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - bool sleeping = false; - ~server_context_impl() { - if (!sleeping) { - // destroy() is already called when entering sleeping state - // we don't call it again here to avoid double free - destroy(); - } - } - void destroy() { llama_init.reset(); ctx = nullptr; @@ -871,17 +878,7 @@ struct server_context_impl { metrics.init(); - if (!populate_json_responses()) { - SRV_ERR("%s", "failed to populate JSON responses\n"); - return false; - } - - return true; - } - - bool populate_json_responses() { // populate webui settings - json json_webui_settings = json::object(); { if (!params_base.webui_config_json.empty()) { try { @@ -893,53 +890,6 @@ struct server_context_impl { } } - // populate server properties - { - task_params params; - params.sampling = params_base.sampling; - json default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", get_slot_n_ctx()}, - }; - - json_server_props = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", params_base.n_parallel }, - { "model_alias", model_name }, - { "model_path", params_base.model.path }, - { "modalities", json { - {"vision", oai_parser_opt.allow_image}, - {"audio", oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params_base.endpoint_slots }, - { "endpoint_props", params_base.endpoint_props }, - { "endpoint_metrics", params_base.endpoint_metrics }, - { "webui", params_base.webui }, - { "webui_settings", json_webui_settings }, - { "chat_template", common_chat_templates_source(chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx, llama_vocab_bos(vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx, llama_vocab_eos(vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(chat_templates.get(), "tool_use")) { - json_server_props["chat_template_tool_use"] = tool_use_src; - } - } - } - - // populate model metadata - { - json_server_model_meta = { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } - return true; } @@ -1098,18 +1048,37 @@ struct server_context_impl { return res; } + std::vector construct_lora_list(const std::map & config) { + std::vector output = params_base.lora_adapters; // copy + for (size_t i = 0; i < output.size(); ++i) { + auto it = config.find(i); + if (it != config.end()) { + output[i].scale = it->second; + } else { + output[i].scale = 0.0f; + } + } + return output; + } + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); + // process per-request lora adapters + if (!task.params.lora.empty()) { + auto task_loras = construct_lora_list(task.params.lora); + if (!are_lora_equal(task_loras, slot.lora)) { + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, task_loras)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task_loras.size()); + } + slot.lora = task_loras; } - slot.lora = task.params.lora; + } else { + slot.lora = params_base.lora_adapters; } // if using alora, make sure it's only a single one requested and active @@ -1839,9 +1808,41 @@ struct server_context_impl { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_GET_LORA: + { + // TODO @ngxson : make lora_adapters a dedicated member of server_context + auto & loras = params_base.lora_adapters; + auto res = std::make_unique(); + res->id = task.id; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + llama_tokens alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t j = 0; j < n_alora_tokens; ++j) { + alora_invocation_string += common_token_to_piece(vocab, alora_tokens[j]); + alora_invocation_tokens.push_back(alora_tokens[j]); + } + } + res->loras.push_back(server_task_result_get_lora::lora{ + lora, + alora_invocation_string, + alora_invocation_tokens, + }); + } + queue_results.send(std::move(res)); + } break; case SERVER_TASK_TYPE_SET_LORA: { - params_base.lora_adapters = std::move(task.set_lora); + auto new_loras = construct_lora_list(task.set_lora); + // logging + for (size_t i = 0; i < new_loras.size(); ++i) { + SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale); + } + // TODO @ngxson : make lora_adapters a dedicated member of server_context + params_base.lora_adapters = new_loras; auto res = std::make_unique(); res->id = task.id; queue_results.send(std::move(res)); @@ -2781,12 +2782,34 @@ server_response_reader server_context::get_response_reader() { return impl->get_response_reader(); } -server_context_info server_context::get_info() const { - return server_context_info { - /* build_info */ build_info, - /* model_name */ impl->model_name, - /* has_inp_image */ impl->oai_parser_opt.allow_image, - /* has_inp_audio */ impl->oai_parser_opt.allow_audio, +server_context_meta server_context::get_meta() const { + auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use"); + return server_context_meta { + /* build_info */ build_info, + /* model_name */ impl->model_name, + /* model_path */ impl->params_base.model.path, + /* has_mtmd */ impl->mctx != nullptr, + /* has_inp_image */ impl->oai_parser_opt.allow_image, + /* has_inp_audio */ impl->oai_parser_opt.allow_audio, + /* json_webui_settings */ impl->json_webui_settings, + /* slot_n_ctx */ impl->get_slot_n_ctx(), + /* pooling_type */ llama_pooling_type(impl->ctx), + + /* chat_template */ common_chat_templates_source(impl->chat_templates.get()), + /* chat_template_tool_use */ tool_use_src ? tool_use_src : "", + + /* bos_token_str */ common_token_to_piece(impl->ctx, llama_vocab_bos(impl->vocab), true), + /* eos_token_str */ common_token_to_piece(impl->ctx, llama_vocab_eos(impl->vocab), true), + /* fim_pre_token */ llama_vocab_fim_pre(impl->vocab), + /* fim_sub_token */ llama_vocab_fim_suf(impl->vocab), + /* fim_mid_token */ llama_vocab_fim_mid(impl->vocab), + + /* model_vocab_type */ llama_vocab_type(impl->vocab), + /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), + /* model_n_embd_inp */ llama_model_n_embd(impl->model), + /* model_n_params */ llama_model_n_params(impl->model), + /* model_size */ llama_model_size(impl->model), }; } @@ -2796,12 +2819,12 @@ server_context_info server_context::get_info() const { // may have bypass_sleep = true if the task does not use ctx_server struct server_res_generator : server_http_res { server_response_reader rd; - server_res_generator(server_context_impl & ctx_server, bool bypass_sleep = false) - : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) { + server_res_generator(server_queue & queue_tasks, server_response & queue_results, int sleep_idle_seconds, bool bypass_sleep = false) + : rd(queue_tasks, queue_results, HTTP_POLLING_SECONDS) { // fast path in case sleeping is disabled - bypass_sleep |= ctx_server.params_base.sleep_idle_seconds < 0; + bypass_sleep |= sleep_idle_seconds < 0; if (!bypass_sleep) { - ctx_server.queue_tasks.wait_until_no_sleep(); + queue_tasks.wait_until_no_sleep(); } } void ok(const json & response_data) { @@ -2820,17 +2843,15 @@ struct server_res_generator : server_http_res { // server_routes // -static std::unique_ptr handle_completions_impl( - std::unique_ptr && res_ptr, - server_context_impl & ctx_server, +std::unique_ptr server_routes::handle_completions_impl( + const server_http_req & req, server_task_type type, const json & data, const std::vector & files, - const std::function & should_stop, task_response_type res_type) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - auto res = std::move(res_ptr); + auto res = create_response(); auto completion_id = gen_chatcmplid(); auto & rd = res->rd; @@ -2852,32 +2873,30 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); - int idx = 0; for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = idx++; + task.id = rd.get_new_id(); task.tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, + ctx_server.vocab, + params, + meta->slot_n_ctx, data); task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = ctx_server.model_name; + task.params.oaicompat_model = meta->model_name; if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; for (size_t j = 0; j < task.n_children; j++) { server_task child = task.create_child( task.id, - ctx_server.queue_tasks.get_new_id(), - idx++); + rd.get_new_id()); tasks.push_back(std::move(child)); } } @@ -2895,7 +2914,7 @@ static std::unique_ptr handle_completions_impl( if (!stream) { // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); + auto all_results = rd.wait_for_all(req.should_stop); if (all_results.is_terminated) { return res; // connection is closed } else if (all_results.error) { @@ -2927,7 +2946,7 @@ static std::unique_ptr handle_completions_impl( // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); + server_task_result_ptr first_result = rd.next(req.should_stop); if (first_result == nullptr) { return res; // connection is closed } else if (first_result->is_error()) { @@ -2950,7 +2969,7 @@ static std::unique_ptr handle_completions_impl( } res->status = 200; res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool { static auto format_error = [](task_response_type res_type, const json & res_json) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { return format_anthropic_sse({ @@ -2963,7 +2982,7 @@ static std::unique_ptr handle_completions_impl( }; try { - if (should_stop()) { + if (req.should_stop()) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); return false; // should_stop condition met } @@ -2992,7 +3011,7 @@ static std::unique_ptr handle_completions_impl( } // receive subsequent results - auto result = rd.next(should_stop); + auto result = rd.next(req.should_stop); if (result == nullptr) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); return false; // should_stop condition met @@ -3033,37 +3052,51 @@ static std::unique_ptr handle_completions_impl( return res; } +std::unique_ptr server_routes::create_response(bool bypass_sleep) { + return std::make_unique(queue_tasks, queue_results, params.sleep_idle_seconds, bypass_sleep); +} + +server_routes::server_routes(const common_params & params, server_context & ctx_server) + : params(params), + ctx_server(*ctx_server.impl), + queue_tasks(ctx_server.impl->queue_tasks), + queue_results(ctx_server.impl->queue_results) { + init_routes(); +} + void server_routes::init_routes() { - // IMPORTANT: all lambda functions must start with std::make_unique + // IMPORTANT: all lambda functions must start with create_response() // this is to ensure that the server_res_generator can handle sleeping case correctly this->get_health = [this](const server_http_req &) { // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server, true); + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + res->ok({{"status", "ok"}}); return res; }; - this->get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); + this->get_metrics = [this](const server_http_req & req) { + auto res = create_response(); if (!params.endpoint_metrics) { res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return res; } // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + task.id = res->rd.get_new_id(); + res->rd.post_task(std::move(task), true); // high-priority task } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + auto result = res->rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3149,24 +3182,21 @@ void server_routes::init_routes() { }; this->get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (!params.endpoint_slots) { res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); return res; } // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + task.id = res->rd.get_new_id(); + res->rd.post_task(std::move(task), true); // high-priority task } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + auto result = res->rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3190,7 +3220,7 @@ void server_routes::init_routes() { }; this->post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (params.slot_save_path.empty()) { res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return res; @@ -3221,15 +3251,51 @@ void server_routes::init_routes() { }; this->get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server, true); - auto props = ctx_server.json_server_props; - props["is_sleeping"] = ctx_server.queue_tasks.is_sleeping(); + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + + task_params tparams; + tparams.sampling = params.sampling; + json default_generation_settings_for_props = json { + { "params", tparams.to_json(true) }, + { "n_ctx", meta->slot_n_ctx }, + }; + + json props = { + { "default_generation_settings", default_generation_settings_for_props }, + { "total_slots", params.n_parallel }, + { "model_alias", meta->model_name }, + { "model_path", meta->model_path }, + { "modalities", json { + {"vision", meta->has_inp_image}, + {"audio", meta->has_inp_audio}, + } }, + { "endpoint_slots", params.endpoint_slots }, + { "endpoint_props", params.endpoint_props }, + { "endpoint_metrics", params.endpoint_metrics }, + { "webui", params.webui }, + { "webui_settings", meta->json_webui_settings }, + { "chat_template", meta->chat_template }, + { "bos_token", meta->bos_token_str }, + { "eos_token", meta->eos_token_str }, + { "build_info", meta->build_info }, + { "is_sleeping", queue_tasks.is_sleeping() }, + }; + if (params.use_jinja) { + if (!meta->chat_template_tool_use.empty()) { + props["chat_template_tool_use"] = meta->chat_template_tool_use; + } + } res->ok(props); return res; }; this->post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (!params.endpoint_props) { res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); return res; @@ -3241,20 +3307,16 @@ void server_routes::init_routes() { }; this->get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; + auto res = create_response(); json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, { "model_info", { - { "llama.context_length", ctx_server.get_slot_n_ctx() }, + { "llama.context_length", meta->slot_n_ctx }, } }, {"modelfile", ""}, {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"template", meta->chat_template}, {"details", { {"parent_model", ""}, {"format", "gguf"}, @@ -3264,7 +3326,7 @@ void server_routes::init_routes() { {"quantization_level", ""} }}, {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} + {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} }; res->ok(data); @@ -3272,7 +3334,7 @@ void server_routes::init_routes() { }; this->post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); // check model compatibility std::string err; if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { @@ -3333,54 +3395,48 @@ void server_routes::init_routes() { data.at("input_prefix"), data.at("input_suffix"), data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.get_slot_n_ctx(), - ctx_server.params_base.spm_infill, + params.n_batch, + params.n_predict, + meta->slot_n_ctx, + params.spm_infill, tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. ); std::vector files; // dummy return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_INFILL, data, files, - req.should_stop, TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible }; this->post_completions = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body, files, - req.should_stop, TASK_RESPONSE_TYPE_NONE); }; this->post_completions_oai = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body, files, - req.should_stop, TASK_RESPONSE_TYPE_OAI_CMPL); }; this->post_chat_completions = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = json::parse(req.body); json body_parsed = oaicompat_chat_params_parse( @@ -3388,17 +3444,15 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, - req.should_stop, TASK_RESPONSE_TYPE_OAI_CHAT); }; this->post_anthropic_messages = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( @@ -3406,17 +3460,15 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, - req.should_stop, TASK_RESPONSE_TYPE_ANTHROPIC); }; this->post_anthropic_count_tokens = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( @@ -3426,14 +3478,13 @@ void server_routes::init_routes() { json prompt = body_parsed.at("prompt"); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); - res->ok({{"input_tokens", static_cast(tokens.size())}}); return res; }; // same with handle_chat_completions, but without inference part this->post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy, unused json body = json::parse(req.body); json data = oaicompat_chat_params_parse( @@ -3444,27 +3495,26 @@ void server_routes::init_routes() { return res; }; - // TODO: this endpoint is unsafe to access during model reloading (i.e. wake up from sleeping) - // how to make it work even during load_model()? this->get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json model_meta = nullptr; - if (is_ready()) { - model_meta = ctx_server.json_server_model_meta; - } - bool has_mtmd = ctx_server.mctx != nullptr; + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + json models = { {"models", { { - {"name", ctx_server.model_name}, - {"model", ctx_server.model_name}, + {"name", meta->model_name}, + {"model", meta->model_name}, {"modified_at", ""}, {"size", ""}, {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash {"type", "model"}, {"description", ""}, {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, {"parameters", ""}, {"details", { {"parent_model", ""}, @@ -3479,11 +3529,18 @@ void server_routes::init_routes() { {"object", "list"}, {"data", { { - {"id", ctx_server.model_name}, + {"id", meta->model_name}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta}, + {"meta", { + {"vocab_type", meta->model_vocab_type}, + {"n_vocab", meta->model_vocab_n_tokens}, + {"n_ctx_train", meta->model_n_ctx_train}, + {"n_embd", meta->model_n_embd_inp}, + {"n_params", meta->model_n_params}, + {"size", meta->model_size}, + }}, }, }} }; @@ -3493,7 +3550,7 @@ void server_routes::init_routes() { }; this->post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); json tokens_response = json::array(); if (body.count("content") != 0) { @@ -3505,7 +3562,7 @@ void server_routes::init_routes() { if (with_pieces) { for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); + std::string piece = common_token_to_piece(ctx_server.vocab, token); json piece_json; // Check if the piece is valid UTF-8 @@ -3534,13 +3591,13 @@ void server_routes::init_routes() { }; this->post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) { const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); + content = tokens_to_str(ctx_server.vocab, tokens); } res->ok(json{{"content", std::move(content)}}); @@ -3556,8 +3613,8 @@ void server_routes::init_routes() { }; this->post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + auto res = create_response(); + if (!params.embedding || params.pooling_type != LLAMA_POOLING_TYPE_RANK) { res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return res; } @@ -3592,15 +3649,14 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); + auto & rd = res->rd; { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; + task.id = rd.get_new_id(); task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } @@ -3626,7 +3682,7 @@ void server_routes::init_routes() { // write JSON response json root = format_response_rerank( body, - ctx_server.model_name, + meta->model_name, responses, is_tei_format, documents, @@ -3636,57 +3692,47 @@ void server_routes::init_routes() { return res; }; - this->get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); + this->get_lora_adapters = [this](const server_http_req & req) { + auto res = create_response(); + + auto & rd = res->rd; + { + server_task task(SERVER_TASK_TYPE_GET_LORA); + task.id = rd.get_new_id(); + rd.post_task(std::move(task)); + } + + // get the result + server_task_result_ptr result = rd.next(req.should_stop); + + if (result->is_error()) { + res->error(result->to_json()); + return res; } - res->ok(result); + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); return res; }; this->post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); if (!body.is_array()) { res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return res; } - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + task.id = rd.get_new_id(); + task.set_lora = parse_lora_request(body); + rd.post_task(std::move(task)); } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3700,7 +3746,7 @@ void server_routes::init_routes() { } std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3709,21 +3755,17 @@ std::unique_ptr server_routes::handle_slots_save(const ser } std::string filepath = params.slot_save_path + filename; - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; task.slot_action.filename = filename; task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3735,7 +3777,7 @@ std::unique_ptr server_routes::handle_slots_save(const ser } std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3744,21 +3786,17 @@ std::unique_ptr server_routes::handle_slots_restore(const } std::string filepath = params.slot_save_path + filename; - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; task.slot_action.filename = filename; task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3770,21 +3808,17 @@ std::unique_ptr server_routes::handle_slots_restore(const return res; } -std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); +std::unique_ptr server_routes::handle_slots_erase(const server_http_req & req, int id_slot) { + auto res = create_response(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3797,13 +3831,13 @@ std::unique_ptr server_routes::handle_slots_erase(const se } std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { + auto res = create_response(); + if (!params.embedding) { res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return res; } - if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + if (res_type != TASK_RESPONSE_TYPE_NONE && meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -3824,7 +3858,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons bool use_base64 = false; if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); + const std::string & format = body.at("encoding_format"); if (format == "base64") { use_base64 = true; } else if (format != "float") { @@ -3845,21 +3879,20 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons int embd_normalize = 2; // default to Euclidean/L2 norm if (body.count("embd_normalize") != 0) { embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type); } } // create and queue the task json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); + auto & rd = res->rd; { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; + task.id = rd.get_new_id(); task.tokens = std::move(tokenized_prompts[i]); // OAI-compat @@ -3889,7 +3922,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // write JSON response json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) + ? format_embeddings_response_oaicompat(body, meta->model_name, responses, use_base64) : json(responses); res->ok(root); return res; diff --git a/tools/server/server-context.h b/tools/server/server-context.h index a56be7b8e7e..09bec15ae11 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -9,11 +9,35 @@ struct server_context_impl; // private implementation -struct server_context_info { +struct server_context_meta { std::string build_info; std::string model_name; + std::string model_path; + bool has_mtmd; bool has_inp_image; bool has_inp_audio; + json json_webui_settings; + int slot_n_ctx; + enum llama_pooling_type pooling_type; + + // chat template + std::string chat_template; + std::string chat_template_tool_use; + + // tokens + std::string bos_token_str; + std::string eos_token_str; + llama_token fim_pre_token; + llama_token fim_sub_token; + llama_token fim_mid_token; + + // model meta + enum llama_vocab_type model_vocab_type; + int32_t model_vocab_n_tokens; + int32_t model_n_ctx_train; + int32_t model_n_embd_inp; + uint64_t model_n_params; + uint64_t model_size; }; struct server_context { @@ -33,14 +57,15 @@ struct server_context { void terminate(); // get the underlaying llama_context, can return nullptr if sleeping + // not thread-safe, should only be used from the main thread llama_context * get_llama_context() const; // get a new response reader, used by CLI application server_response_reader get_response_reader(); - // get server info - // used by CLI application - server_context_info get_info() const; + // get server metadata (read-only), can only be called after load_model() + // not thread-safe, should only be used from the main thread + server_context_meta get_meta() const; }; @@ -48,13 +73,17 @@ struct server_context { struct server_res_generator; struct server_routes { - server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) - : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { - init_routes(); - } + server_routes(const common_params & params, server_context & ctx_server); void init_routes(); + + // note: this is not thread-safe and can only when ctx_http.is_ready is false + void update_meta(const server_context & ctx_server) { + this->meta = std::make_unique(ctx_server.get_meta()); + } + // handlers using lambda function, so that they can capture `this` without `std::bind` + // they won't be called until ctx_http.is_ready is set to true server_http_context::handler_t get_health; server_http_context::handler_t get_metrics; server_http_context::handler_t get_slots; @@ -78,13 +107,24 @@ struct server_routes { server_http_context::handler_t get_lora_adapters; server_http_context::handler_t post_lora_adapters; private: - // TODO: move these outside of server_routes? + std::unique_ptr handle_completions_impl( + const server_http_req & req, + server_task_type type, + const json & data, + const std::vector & files, + task_response_type res_type); std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + // using unique_ptr to allow late initialization of const + std::unique_ptr meta; + const common_params & params; - server_context_impl & ctx_server; - std::function is_ready; + const server_context_impl & ctx_server; + + server_queue & queue_tasks; + server_response & queue_results; + std::unique_ptr create_response(bool bypass_sleep = false); }; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 622505714cf..5d67e5722d1 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -177,12 +177,11 @@ bool server_http_context::init(const common_params & params) { if (!ready) { auto tmp = string_split(req.path, '.'); if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); } else { + // no endpoints is allowed to be accessed when the server is not ready + // this is to prevent any data races or inconsistent states res.status = 503; res.set_content( safe_json_to_str(json { @@ -334,12 +333,16 @@ static std::map get_headers(const httplib::Request & r return headers; } -static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { +// using unique_ptr for request to allow safe capturing in lambdas +using server_http_req_ptr = std::unique_ptr; + +static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) { if (response->is_stream()) { res.status = response->status; set_headers(res, response->headers); std::string content_type = response->content_type; // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it + std::shared_ptr q_ptr = std::move(request); std::shared_ptr r_ptr = std::move(response); const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { std::string chunk; @@ -355,8 +358,9 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re } return has_next; }; - const auto on_complete = [response = r_ptr](bool) mutable { + const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable { response.reset(); // trigger the destruction of the response object + request.reset(); // trigger the destruction of the request object }; res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); } else { @@ -368,27 +372,29 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ + server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); - process_handler_response(response, res); + server_http_res_ptr response = handler(*request); + process_handler_response(std::move(request), response, res); }); } void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ + server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); - process_handler_response(response, res); + server_http_res_ptr response = handler(*request); + process_handler_response(std::move(request), response, res); }); } diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 835938bfc25..9a6ba560a36 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -325,23 +325,25 @@ void server_response::terminate() { // server_response_reader // -void server_response_reader::post_task(server_task && task) { +void server_response_reader::post_task(server_task && task, bool front) { GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + task.index = 0; id_tasks.insert(task.id); states.push_back(task.create_state()); queue_results.add_waiting_task_id(task.id); - queue_tasks.post(std::move(task)); + queue_tasks.post(std::move(task), front); } -void server_response_reader::post_tasks(std::vector && tasks) { +void server_response_reader::post_tasks(std::vector && tasks, bool front) { GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); states.reserve(tasks.size()); for (size_t i = 0; i < tasks.size(); i++) { + tasks[i].index = i; states.push_back(tasks[i].create_state()); } queue_results.add_waiting_tasks(tasks); - queue_tasks.post(std::move(tasks)); + queue_tasks.post(std::move(tasks), front); } bool server_response_reader::has_next() const { @@ -367,7 +369,7 @@ server_task_result_ptr server_response_reader::next(const std::function } if (!states.empty()) { // update the generation state if needed - size_t idx = result->get_index(); + const size_t idx = result->index; GGML_ASSERT(idx < states.size()); result->update(states[idx]); } @@ -383,6 +385,7 @@ server_task_result_ptr server_response_reader::next(const std::function server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { batch_response batch_res; + batch_res.results.clear(); batch_res.results.resize(id_tasks.size()); while (has_next()) { auto res = next(should_stop); @@ -394,7 +397,7 @@ server_response_reader::batch_response server_response_reader::wait_for_all(cons batch_res.error = std::move(res); return batch_res; } - const size_t idx = res->get_index(); + const size_t idx = res->index; GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); batch_res.results[idx] = std::move(res); diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 8ac37a20f6b..3798aa299ef 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -5,6 +5,7 @@ #include #include #include +#include #include // struct for managing server tasks @@ -173,8 +174,10 @@ struct server_response_reader { int get_new_id() { return queue_tasks.get_new_id(); } - void post_task(server_task && task); - void post_tasks(std::vector && tasks); + + // if front = true, the task will be posted to the front of the queue (high priority) + void post_task(server_task && task, bool front = false); + void post_tasks(std::vector && tasks, bool front = false); bool has_next() const; // return nullptr if should_stop() is true before receiving a result diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 360826062b1..3ccaff59f46 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -32,8 +32,8 @@ json task_params::to_json(bool only_metrics) const { } json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + for (auto & it : this->lora) { + lora.push_back({{"id", it.first}, {"scale", it.second}}); } if (only_metrics) { @@ -145,12 +145,10 @@ json task_params::to_json(bool only_metrics) const { // task_params server_task::params_from_json_cmpl( - const llama_context * ctx, + const llama_vocab * vocab, const common_params & params_base, + const int n_ctx_slot, const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - task_params params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) @@ -223,12 +221,12 @@ task_params server_task::params_from_json_cmpl( if (data.contains("lora")) { if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + params.lora = parse_lora_request(data.at("lora")); } else { throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); } } else { - params.lora = params_base.lora_adapters; + params.lora = {}; } // TODO: add more sanity checks for the input parameters @@ -243,11 +241,11 @@ task_params server_task::params_from_json_cmpl( if (params.sampling.penalty_last_n == -1) { // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); + params.sampling.penalty_last_n = n_ctx_slot; } if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + params.sampling.dry_penalty_last_n = n_ctx_slot; } if (params.sampling.dry_base < 1.0f) { @@ -1324,6 +1322,30 @@ json server_task_result_slot_erase::to_json() { }; } +// +// server_task_result_get_lora +// + +json server_task_result_get_lora::to_json() { + json result = json::array(); + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.info.path}, + {"scale", lora.info.scale}, + {"task_name", lora.info.task_name}, + {"prompt_prefix", lora.info.prompt_prefix}, + }; + if (!lora.alora_invocation_tokens.empty()) { + entry["alora_invocation_string"] = lora.alora_invocation_string; + entry["alora_invocation_tokens"] = lora.alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + return result; +} + // // server_task_result_apply_lora // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0759094a01d..687770de5e9 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -6,6 +6,7 @@ #include #include #include +#include // TODO: prevent including the whole server-common.h as we only use server_tokens #include "server-common.h" @@ -23,6 +24,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_GET_LORA, SERVER_TASK_TYPE_SET_LORA, }; @@ -60,7 +62,7 @@ struct task_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - std::vector lora; + std::map lora; // mapping adapter ID -> scale std::vector antiprompt; std::vector response_fields; @@ -105,8 +107,10 @@ struct task_result_state { }; struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) + int id = -1; // to be filled by server_queue + + // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader + size_t index = 0; // used when there are multiple prompts (batch request) // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; @@ -138,7 +142,7 @@ struct server_task { bool metrics_reset_bucket = false; // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; + std::map set_lora; // mapping adapter ID -> scale server_task() = default; @@ -149,9 +153,10 @@ struct server_task { } static task_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data); + const llama_vocab * vocab, + const common_params & params_base, + const int n_ctx_slot, + const json & data); // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -162,10 +167,9 @@ struct server_task { return ids; } - server_task create_child(int id_parent, int id_child, int idx) const { + server_task create_child(int id_parent, int id_child) const { server_task copy; copy.id = id_child; - copy.index = idx; copy.id_parent = id_parent; copy.params = params; copy.type = type; @@ -212,6 +216,10 @@ struct result_prompt_progress { struct server_task_result { int id = -1; int id_slot = -1; + + // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader + size_t index = 0; // to be used for batched tasks + virtual bool is_error() { // only used by server_task_result_error return false; @@ -220,9 +228,6 @@ struct server_task_result { // only used by server_task_result_cmpl_* return true; } - virtual int get_index() { - return -1; - } virtual void update(task_result_state &) { // only used by server_task_result_cmpl_* } @@ -255,8 +260,6 @@ struct completion_token_output { }; struct server_task_result_cmpl_final : server_task_result { - int index = 0; - std::string content; llama_tokens tokens; @@ -289,10 +292,6 @@ struct server_task_result_cmpl_final : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; - virtual int get_index() override { - return index; - } - virtual bool is_stop() override { return true; // in stream mode, final responses are considered stop } @@ -318,8 +317,6 @@ struct server_task_result_cmpl_final : server_task_result { }; struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - std::string content; llama_tokens tokens; @@ -340,10 +337,6 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; - virtual int get_index() override { - return index; - } - virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -365,7 +358,6 @@ struct server_task_result_cmpl_partial : server_task_result { }; struct server_task_result_embd : server_task_result { - int index = 0; std::vector> embedding; int32_t n_tokens; @@ -373,10 +365,6 @@ struct server_task_result_embd : server_task_result { // response formatting task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - virtual int get_index() override { - return index; - } - virtual json to_json() override; json to_json_non_oaicompat(); @@ -385,20 +373,14 @@ struct server_task_result_embd : server_task_result { }; struct server_task_result_rerank : server_task_result { - int index = 0; float score = -1e6; int32_t n_tokens; - virtual int get_index() override { - return index; - } - virtual json to_json() override; }; struct server_task_result_error : server_task_result { - int index = 0; error_type err_type = ERROR_TYPE_SERVER; std::string err_msg; @@ -460,6 +442,17 @@ struct server_task_result_slot_erase : server_task_result { virtual json to_json() override; }; +struct server_task_result_get_lora : server_task_result { + struct lora { + common_adapter_lora_info info; + std::string alora_invocation_string; + llama_tokens alora_invocation_tokens; + }; + std::vector loras; + + virtual json to_json() override; +}; + struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ff650ab2ec1..0fbc7b6d354 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -119,7 +119,7 @@ int main(int argc, char ** argv, char ** envp) { // // register API routes - server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); + server_routes routes(params, ctx_server); bool is_router_server = params.model.path.empty(); std::optional models_routes{}; @@ -252,6 +252,7 @@ int main(int argc, char ** argv, char ** envp) { return 1; } + routes.update_meta(ctx_server); ctx_http.is_ready.store(true); LOG_INF("%s: model loaded\n", __func__);