diff --git a/examples/server/main.cpp b/examples/server/main.cpp index b0ac7eef9..c9a4c31bb 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -732,6 +732,327 @@ int main(int argc, const char** argv) { } }); + // sdapi endpoints (AUTOMATIC1111 / Forge) + + auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + + json j = json::parse(req.body); + + std::string prompt = j.value("prompt", ""); + std::string negative_prompt = j.value("negative_prompt", ""); + int width = j.value("width", 512); + int height = j.value("height", 512); + int steps = j.value("steps", -1); + float cfg_scale = j.value("cfg_scale", 7.f); + int64_t seed = j.value("seed", -1); + int batch_size = j.value("batch_size", 1); + int clip_skip = j.value("clip_skip", -1); + std::string sampler_name = j.value("sampler_name", ""); + std::string scheduler_name = j.value("scheduler", ""); + + auto bad = [&](const std::string& msg) { + res.status = 400; + res.set_content("{\"error\":\"" + msg + "\"}", "application/json"); + return; + }; + + if (width <= 0 || height <= 0) { + return bad("width and height must be positive"); + } + + if (steps < 1 || steps > 150) { + return bad("steps must be in range [1, 150]"); + } + + if (batch_size < 1 || batch_size > 8) { + return bad("batch_size must be in range [1, 8]"); + } + + if (cfg_scale < 0.f) { + return bad("cfg_scale must be positive"); + } + + if (prompt.empty()) { + return bad("prompt required"); + } + + auto get_sample_method = [](std::string name) -> enum sample_method_t { + enum sample_method_t result = str_to_sample_method(name.c_str()); + if (result != SAMPLE_METHOD_COUNT) return result; + // some applications use a hardcoded sampler list + std::transform(name.begin(), name.end(), name.begin(), + [](unsigned char c) { return std::tolower(c); }); + static const std::unordered_map hardcoded{ + {"euler a", EULER_A_SAMPLE_METHOD}, + {"k_euler_a", EULER_A_SAMPLE_METHOD}, + {"euler", EULER_SAMPLE_METHOD}, + {"k_euler", EULER_SAMPLE_METHOD}, + {"heun", HEUN_SAMPLE_METHOD}, + {"k_heun", HEUN_SAMPLE_METHOD}, + {"dpm2", DPM2_SAMPLE_METHOD}, + {"k_dpm_2", DPM2_SAMPLE_METHOD}, + {"lcm", LCM_SAMPLE_METHOD}, + {"ddim", DDIM_TRAILING_SAMPLE_METHOD}, + {"dpm++ 2m", DPMPP2M_SAMPLE_METHOD}, + {"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}}; + auto it = hardcoded.find(name); + if (it != hardcoded.end()) return it->second; + return SAMPLE_METHOD_COUNT; + }; + + enum sample_method_t sample_method = get_sample_method(sampler_name); + + enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); + + // avoid excessive resource usage + + SDGenerationParams gen_params = default_gen_params; + gen_params.prompt = prompt; + gen_params.negative_prompt = negative_prompt; + gen_params.width = width; + gen_params.height = height; + gen_params.seed = seed; + gen_params.sample_params.sample_steps = steps; + gen_params.batch_count = batch_size; + + if (clip_skip > 0) { + gen_params.clip_skip = clip_skip; + } + + if (sample_method != SAMPLE_METHOD_COUNT) { + gen_params.sample_params.sample_method = sample_method; + } + + if (scheduler != SCHEDULER_COUNT) { + gen_params.sample_params.scheduler = scheduler; + } + + LOG_DEBUG("%s\n", gen_params.to_string().c_str()); + + sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; + std::vector mask_data; + std::vector pmid_images; + std::vector ref_images; + + if (img2img) { + auto decode_image = [](sd_image_t& image, std::string encoded) -> bool { + // remove data URI prefix if present ("data:image/png;base64,") + auto comma_pos = encoded.find(','); + if (comma_pos != std::string::npos) { + encoded = encoded.substr(comma_pos + 1); + } + std::vector img_data = base64_decode(encoded); + if (!img_data.empty()) { + int img_w = image.width; + int img_h = image.height; + uint8_t* raw_data = load_image_from_memory( + (const char*)img_data.data(), (int)img_data.size(), + img_w, img_h, + image.width, image.height, image.channel); + if (raw_data) { + image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data}; + return true; + } + } + return false; + }; + + if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { + std::string encoded = j["init_images"][0].get(); + decode_image(init_image, encoded); + } + + if (j.contains("mask") && j["mask"].is_string()) { + std::string encoded = j["mask"].get(); + decode_image(mask_image, encoded); + bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; + if (inpainting_mask_invert && mask_image.data != nullptr) { + for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) { + mask_image.data[i] = 255 - mask_image.data[i]; + } + } + } else { + mask_data = std::vector(width * height, 255); + mask_image.width = width; + mask_image.height = height; + mask_image.channel = 1; + mask_image.data = mask_data.data(); + } + + if (j.contains("extra_images") && j["extra_images"].is_array()) { + for (auto extra_image : j["extra_images"]) { + std::string encoded = extra_image.get(); + sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + if (decode_image(tmp_image, encoded)) { + ref_images.push_back(tmp_image); + } + } + } + + float denoising_strength = j.value("denoising_strength", -1.f); + if (denoising_strength >= 0.f) { + denoising_strength = std::min(denoising_strength, 1.0f); + gen_params.strength = denoising_strength; + } + } + + sd_img_gen_params_t img_gen_params = { + gen_params.lora_vec.data(), + static_cast(gen_params.lora_vec.size()), + gen_params.prompt.c_str(), + gen_params.negative_prompt.c_str(), + gen_params.clip_skip, + init_image, + ref_images.data(), + (int)ref_images.size(), + gen_params.auto_resize_ref_image, + gen_params.increase_ref_index, + mask_image, + gen_params.width, + gen_params.height, + gen_params.sample_params, + gen_params.strength, + gen_params.seed, + gen_params.batch_count, + control_image, + gen_params.control_strength, + { + pmid_images.data(), + (int)pmid_images.size(), + gen_params.pm_id_embed_path.c_str(), + gen_params.pm_style_strength, + }, // pm_params + ctx_params.vae_tiling_params, + gen_params.cache_params, + }; + + sd_image_t* results = nullptr; + int num_results = 0; + + { + std::lock_guard lock(sd_ctx_mutex); + results = generate_image(sd_ctx, &img_gen_params); + num_results = gen_params.batch_count; + } + + json out; + out["images"] = json::array(); + out["parameters"] = j; // TODO should return changed defaults + out["info"] = ""; + + for (int i = 0; i < num_results; i++) { + if (results[i].data == nullptr) { + continue; + } + + auto image_bytes = write_image_to_vector(ImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel); + + if (image_bytes.empty()) { + LOG_ERROR("write image to mem failed"); + continue; + } + + std::string b64 = base64_encode(image_bytes); + out["images"].push_back(b64); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + if (init_image.data) { + stbi_image_free(init_image.data); + } + if (mask_image.data && mask_data.empty()) { + stbi_image_free(mask_image.data); + } + for (auto ref_image : ref_images) { + stbi_image_free(ref_image.data); + } + + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }; + + svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, false); + }); + + svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, true); + }); + + svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { + std::vector sampler_names; + sampler_names.push_back("default"); + for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { + sampler_names.push_back(sd_sample_method_name((sample_method_t)i)); + } + json r = json::array(); + for (auto name : sampler_names) { + json entry; + entry["name"] = name; + entry["aliases"] = json::array({name}); + entry["options"] = json::object(); + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) { + std::vector scheduler_names; + scheduler_names.push_back("default"); + for (int i = 0; i < SCHEDULER_COUNT; i++) { + scheduler_names.push_back(sd_scheduler_name((scheduler_t)i)); + } + json r = json::array(); + for (auto name : scheduler_names) { + json entry; + entry["name"] = name; + entry["label"] = name; + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) { + fs::path model_path = ctx_params.model_path; + json entry; + entry["title"] = model_path.stem(); + entry["model_name"] = model_path.stem(); + entry["filename"] = model_path.filename(); + entry["hash"] = "8888888888"; + entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888"; + entry["config"] = nullptr; + json r = json::array(); + r.push_back(entry); + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) { + fs::path model_path = ctx_params.model_path; + json r; + r["samples_format"] = "png"; + r["sd_model_checkpoint"] = model_path.stem(); + res.set_content(r.dump(), "application/json"); + }); + LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);