Revert "feat(stablediffusion): Passthrough more parameters to support z-image and flux2" (#7417)
Revert "feat(stablediffusion): Passthrough more parameters to support z-image…"
This reverts commit 4018e59b2a.
This commit is contained in:
committed by
GitHub
parent
d8c7e90a69
commit
fea9018dc5
@@ -1,5 +1,4 @@
|
||||
#include "stable-diffusion.h"
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#define GGML_MAX_NAME 128
|
||||
|
||||
@@ -22,7 +21,6 @@
|
||||
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||
#define STB_IMAGE_RESIZE_STATIC
|
||||
#include "stb_image_resize.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
@@ -57,68 +55,6 @@ const char* schedulers[] = {
|
||||
|
||||
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
|
||||
|
||||
// New enum string arrays
|
||||
const char* rng_type_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
"cpu",
|
||||
};
|
||||
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
|
||||
|
||||
const char* prediction_str[] = {
|
||||
"default",
|
||||
"epsilon",
|
||||
"v",
|
||||
"edm_v",
|
||||
"sd3_flow",
|
||||
"flux_flow",
|
||||
"flux2_flow",
|
||||
};
|
||||
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
|
||||
|
||||
const char* lora_apply_mode_str[] = {
|
||||
"auto",
|
||||
"immediately",
|
||||
"at_runtime",
|
||||
};
|
||||
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
|
||||
|
||||
const char* sd_type_str[] = {
|
||||
[0] = "f32",
|
||||
[1] = "f16",
|
||||
[2] = "q4_0",
|
||||
[3] = "q4_1",
|
||||
[6] = "q5_0",
|
||||
[7] = "q5_1",
|
||||
[8] = "q8_0",
|
||||
[9] = "q8_1",
|
||||
[10] = "q2_k",
|
||||
[11] = "q3_k",
|
||||
[12] = "q4_k",
|
||||
[13] = "q5_k",
|
||||
[14] = "q6_k",
|
||||
[15] = "q8_k",
|
||||
[16] = "iq2_xxs",
|
||||
[17] = "iq2_xs",
|
||||
[18] = "iq3_xxs",
|
||||
[19] = "iq1_s",
|
||||
[20] = "iq4_nl",
|
||||
[21] = "iq3_s",
|
||||
[22] = "iq2_s",
|
||||
[23] = "iq4_xs",
|
||||
[24] = "i8",
|
||||
[25] = "i16",
|
||||
[26] = "i32",
|
||||
[27] = "i64",
|
||||
[28] = "f64",
|
||||
[29] = "iq1_m",
|
||||
[30] = "bf16",
|
||||
[34] = "tq1_0",
|
||||
[35] = "tq2_0",
|
||||
[39] = "mxfp4",
|
||||
};
|
||||
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
|
||||
|
||||
sd_ctx_t* sd_c;
|
||||
// Moved from the context (load time) to generation time params
|
||||
scheduler_t scheduler = SCHEDULER_COUNT;
|
||||
@@ -174,41 +110,9 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *vae_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
const char *llm_path = "";
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
const char *photo_maker_path = "";
|
||||
const char *tensor_type_rules = "";
|
||||
char *lora_dir = model_path;
|
||||
bool lora_dir_allocated = false;
|
||||
|
||||
bool vae_decode_only = true;
|
||||
bool free_params_immediately = true;
|
||||
int n_threads = threads;
|
||||
enum sd_type_t wtype = SD_TYPE_COUNT;
|
||||
enum rng_type_t rng_type = STD_DEFAULT_RNG;
|
||||
enum rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
|
||||
enum prediction_t prediction = PREDICTION_COUNT;
|
||||
enum lora_apply_mode_t lora_apply_mode = LORA_APPLY_MODE_COUNT;
|
||||
bool offload_params_to_cpu = false;
|
||||
bool keep_clip_on_cpu = false;
|
||||
bool keep_control_net_on_cpu = false;
|
||||
bool keep_vae_on_cpu = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool tae_preview_only = false;
|
||||
bool diffusion_conv_direct = false;
|
||||
bool vae_conv_direct = false;
|
||||
bool force_sdxl_vae_conv_scale = false;
|
||||
bool chroma_use_dit_mask = true;
|
||||
bool chroma_use_t5_mask = false;
|
||||
int chroma_t5_mask_pad = 0;
|
||||
float flow_shift = INFINITY;
|
||||
|
||||
fprintf(stderr, "parsing options: %p\n", options);
|
||||
|
||||
// If options is not NULL, parse options
|
||||
@@ -252,113 +156,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
|
||||
}
|
||||
}
|
||||
|
||||
// New parsing
|
||||
if (!strcmp(optname, "clip_vision_path")) clip_vision_path = optval;
|
||||
if (!strcmp(optname, "llm_path")) llm_path = optval;
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = optval;
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = optval;
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = optval;
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = optval;
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = optval;
|
||||
if (!strcmp(optname, "embedding_dir")) embedding_dir = optval;
|
||||
if (!strcmp(optname, "photo_maker_path")) photo_maker_path = optval;
|
||||
if (!strcmp(optname, "tensor_type_rules")) tensor_type_rules = optval;
|
||||
|
||||
if (!strcmp(optname, "vae_decode_only")) vae_decode_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "free_params_immediately")) free_params_immediately = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "offload_params_to_cpu")) offload_params_to_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "keep_clip_on_cpu")) keep_clip_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "keep_control_net_on_cpu")) keep_control_net_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "keep_vae_on_cpu")) keep_vae_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "diffusion_flash_attn")) diffusion_flash_attn = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "tae_preview_only")) tae_preview_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "diffusion_conv_direct")) diffusion_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "vae_conv_direct")) vae_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "force_sdxl_vae_conv_scale")) force_sdxl_vae_conv_scale = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "chroma_use_dit_mask")) chroma_use_dit_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
if (!strcmp(optname, "chroma_use_t5_mask")) chroma_use_t5_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
|
||||
|
||||
if (!strcmp(optname, "n_threads")) n_threads = atoi(optval);
|
||||
if (!strcmp(optname, "chroma_t5_mask_pad")) chroma_t5_mask_pad = atoi(optval);
|
||||
|
||||
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
|
||||
|
||||
if (!strcmp(optname, "rng_type")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
rng_type = (rng_type_t)found;
|
||||
fprintf(stderr, "Found rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "sampler_rng_type")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
sampler_rng_type = (rng_type_t)found;
|
||||
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "prediction")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < PREDICTION_COUNT; m++) {
|
||||
if (!strcmp(optval, prediction_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
prediction = (prediction_t)found;
|
||||
fprintf(stderr, "Found prediction: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "lora_apply_mode")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
|
||||
if (!strcmp(optval, lora_apply_mode_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
lora_apply_mode = (lora_apply_mode_t)found;
|
||||
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "wtype")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < SD_TYPE_COUNT; m++) {
|
||||
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
wtype = (sd_type_t)found;
|
||||
fprintf(stderr, "Found wtype: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "parsed options\n");
|
||||
@@ -369,40 +166,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.model_path = model;
|
||||
ctx_params.clip_l_path = clip_l_path;
|
||||
ctx_params.clip_g_path = clip_g_path;
|
||||
ctx_params.clip_vision_path = clip_vision_path;
|
||||
ctx_params.t5xxl_path = t5xxl_path;
|
||||
ctx_params.llm_path = llm_path;
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.diffusion_model_path = stableDiffusionModel;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
ctx_params.taesd_path = "";
|
||||
ctx_params.control_net_path = "";
|
||||
ctx_params.lora_model_dir = lora_dir;
|
||||
ctx_params.embedding_dir = embedding_dir;
|
||||
ctx_params.photo_maker_path = photo_maker_path;
|
||||
ctx_params.tensor_type_rules = tensor_type_rules;
|
||||
ctx_params.vae_decode_only = vae_decode_only;
|
||||
ctx_params.free_params_immediately = free_params_immediately;
|
||||
ctx_params.n_threads = n_threads;
|
||||
ctx_params.rng_type = rng_type;
|
||||
ctx_params.keep_clip_on_cpu = keep_clip_on_cpu;
|
||||
if (wtype != SD_TYPE_COUNT) ctx_params.wtype = wtype;
|
||||
if (sampler_rng_type != RNG_TYPE_COUNT) ctx_params.sampler_rng_type = sampler_rng_type;
|
||||
if (prediction != PREDICTION_COUNT) ctx_params.prediction = prediction;
|
||||
if (lora_apply_mode != LORA_APPLY_MODE_COUNT) ctx_params.lora_apply_mode = lora_apply_mode;
|
||||
ctx_params.offload_params_to_cpu = offload_params_to_cpu;
|
||||
ctx_params.keep_control_net_on_cpu = keep_control_net_on_cpu;
|
||||
ctx_params.keep_vae_on_cpu = keep_vae_on_cpu;
|
||||
ctx_params.diffusion_flash_attn = diffusion_flash_attn;
|
||||
ctx_params.tae_preview_only = tae_preview_only;
|
||||
ctx_params.diffusion_conv_direct = diffusion_conv_direct;
|
||||
ctx_params.vae_conv_direct = vae_conv_direct;
|
||||
ctx_params.force_sdxl_vae_conv_scale = force_sdxl_vae_conv_scale;
|
||||
ctx_params.chroma_use_dit_mask = chroma_use_dit_mask;
|
||||
ctx_params.chroma_use_t5_mask = chroma_use_t5_mask;
|
||||
ctx_params.chroma_t5_mask_pad = chroma_t5_mask_pad;
|
||||
ctx_params.flow_shift = flow_shift;
|
||||
ctx_params.embedding_dir = "";
|
||||
ctx_params.vae_decode_only = false;
|
||||
ctx_params.free_params_immediately = false;
|
||||
ctx_params.n_threads = threads;
|
||||
ctx_params.rng_type = STD_DEFAULT_RNG;
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
|
||||
Reference in New Issue
Block a user