diff --git a/common/arg.cpp b/common/arg.cpp index 1302065498..4eeeda51bd 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3518,15 +3518,16 @@ void common_params_add_preset_options(std::vector & args) { [](common_params &, const std::string &) { /* unused */ } ).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only()); + args.push_back(common_arg( + {"unsafe-allow-api-override"}, "PARAM1,PARAM2,...", + "allow overriding these params via /models/load endpoint (unsafe)", + [](common_params &, const std::string &) { /* unused */ } + ).set_env(COMMON_ARG_PRESET_UNSAFE_ALLOW_API_OVERRIDE).set_preset_only()); + + // TODO: // args.push_back(common_arg( // {"pin"}, // "in server router mode, do not unload this model if models_max is exceeded", // [](common_params &) { /* unused */ } // ).set_preset_only()); - - // args.push_back(common_arg( - // {"unload-idle-seconds"}, "SECONDS", - // "in server router mode, unload models idle for more than this many seconds", - // [](common_params &, int) { /* unused */ } - // ).set_preset_only()); } diff --git a/common/arg.h b/common/arg.h index f5111c658f..1024bbdc74 100644 --- a/common/arg.h +++ b/common/arg.h @@ -9,7 +9,8 @@ #include // pseudo-env variable to identify preset-only arguments -#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP" +#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP" +#define COMMON_ARG_PRESET_UNSAFE_ALLOW_API_OVERRIDE "__PRESET_UNSAFE_ALLOW_API_OVERRIDE" // // CLI argument parsing diff --git a/common/preset.cpp b/common/preset.cpp index e2fc18c5da..42b42446ea 100644 --- a/common/preset.cpp +++ b/common/preset.cpp @@ -236,32 +236,41 @@ common_preset_context::common_preset_context(llama_example ex) key_to_opt = get_map_key_opt(ctx_params); } +common_preset common_preset_context::load_from_map(const std::map & arg_map) const { + common_preset preset; + preset.name = COMMON_PRESET_DEFAULT_NAME; + + for (const auto & [key, value] : arg_map) { + LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); + if (key_to_opt.find(key) != key_to_opt.end()) { + const auto & opt = key_to_opt.at(key); + if (is_bool_arg(opt)) { + preset.options[opt] = parse_bool_arg(opt, key, value); + } else { + preset.options[opt] = value; + } + LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str()); + } else { + LOG_WRN("ignoring unknown option: %s\n", key.c_str()); + } + } + + return preset; +} + common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const { common_presets out; auto ini_data = parse_ini_from_file(path); for (auto section : ini_data) { - common_preset preset; + common_preset preset = load_from_map(section.second); + if (section.first.empty()) { preset.name = COMMON_PRESET_DEFAULT_NAME; } else { preset.name = section.first; } - LOG_DBG("loading preset: %s\n", preset.name.c_str()); - for (const auto & [key, value] : section.second) { - LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); - if (key_to_opt.find(key) != key_to_opt.end()) { - const auto & opt = key_to_opt.at(key); - if (is_bool_arg(opt)) { - preset.options[opt] = parse_bool_arg(opt, key, value); - } else { - preset.options[opt] = value; - } - LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str()); - } else { - // TODO: maybe warn about unknown key? - } - } + LOG_DBG("loaded preset: %s\n", preset.name.c_str()); if (preset.name == "*") { // handle global preset diff --git a/common/preset.h b/common/preset.h index 3a84d1be29..a55dc91bc2 100644 --- a/common/preset.h +++ b/common/preset.h @@ -65,6 +65,10 @@ struct common_preset_context { // generate one preset from CLI arguments common_preset load_from_args(int argc, char ** argv) const; + // generate one preset from mapping string to string + // key can be either arg name or env variable + common_preset load_from_map(const std::map & arg_map) const; + // cascade multiple presets if exist on both: base < added // if preset does not exist in base, it will be added without modification common_presets cascade(const common_presets & base, const common_presets & added) const; diff --git a/tools/server/README.md b/tools/server/README.md index 1ae5eae4c6..2e56496620 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1486,6 +1486,7 @@ The precedence rule for preset options is as follows: We also offer additional options that are exclusive to presets (these aren't treated as command-line arguments): - `load-on-startup` (boolean): Controls whether the model loads automatically when the server starts +- `unsafe-allow-api-override` (string): Specifies which parameters can be overridden via the `/models/load` API endpoint. Accepts multiple values separated by commas. Example: `n-gpu-layers,jinja`. **Warning:** This feature is **unsafe** and must only be used in trusted environments. ### Routing requests @@ -1571,11 +1572,15 @@ Load a model Payload: - `model`: name of the model to be loaded. +- `overrides`: list of preset parameter override (an object mapping string to string). Parameters must be whitelisted via the `unsafe-allow-api-override` preset parameter. ```json { "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", - "extra_args": ["-n", "128", "--top-k", "4"] + "overrides": { + "c": "1024", + "jinja": "false" + } } ``` diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 08a0da5c87..3763ac6676 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -244,7 +245,7 @@ void server_models::load_models() { } for (const auto & name : models_to_load) { SRV_INF("(startup) loading model %s\n", name.c_str()); - load(name); + load(name, {}); } } @@ -379,7 +380,7 @@ void server_models::unload_lru() { } } -void server_models::load(const std::string & name) { +std::vector server_models::load(const std::string & name, const std::map & override_params) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } @@ -390,7 +391,7 @@ void server_models::load(const std::string & name) { auto meta = mapping[name].meta; if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); - return; + return meta.args; } // prepare new instance info @@ -404,12 +405,38 @@ void server_models::load(const std::string & name) { throw std::runtime_error("failed to get a port number"); } + // prepare arguments + if (override_params.empty()) { + inst.meta.update_args(ctx_preset, bin_path); // render args + } else { + std::unordered_set allowed_keys; + std::string val; + if (inst.meta.preset.get_option(COMMON_ARG_PRESET_UNSAFE_ALLOW_API_OVERRIDE, val)) { + auto keys = string_split(val, ','); + for (auto & key : keys) { + allowed_keys.insert(key); + } + } + common_preset orig_preset = inst.meta.preset; // copy + for (const auto & [key, value] : override_params) { + if (allowed_keys.find(key) != allowed_keys.end()) { + inst.meta.preset.set_option(ctx_preset, key, value); + } else { + throw std::invalid_argument(string_format( + "overriding option '%s' is not allowed for model '%s'", + key.c_str(), + name.c_str() + )); + } + } + inst.meta.update_args(ctx_preset, bin_path); // render args + inst.meta.preset = orig_preset; // restore + } + inst.subproc = std::make_shared(); { SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); - inst.meta.update_args(ctx_preset, bin_path); // render args - std::vector child_args = inst.meta.args; // copy std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); @@ -484,8 +511,11 @@ void server_models::load(const std::string & name) { } } + auto args = inst.meta.args; // save args for return mapping[name] = std::move(inst); cv.notify_all(); + + return args; } static void interrupt_subprocess(FILE * stdin_file) { @@ -565,7 +595,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { } if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name); + load(name, {}); } SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); @@ -743,8 +773,19 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models.load(name); - res_ok(res, {{"success", true}}); + std::map overrides; + if (body.contains("overrides")) { + json overrides_json = body["overrides"]; + for (auto it = overrides_json.begin(); it != overrides_json.end(); ++it) { + if (!it.value().is_string()) { + res_err(res, format_error_response("override values must be strings", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + overrides[it.key()] = it.value().get(); + } + } + auto args = models.load(name, overrides); + res_ok(res, {{"success", true}, {"args", args}}); return res; }; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 3e1868c27c..a1194f1713 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -114,7 +114,8 @@ struct server_models { // load and unload model instances // these functions are thread-safe - void load(const std::string & name); + // load() returns the argument list used to launch the model instance + std::vector load(const std::string & name, const std::map & override_params); void unload(const std::string & name); void unload_all();