feat(stablediffusion-ggml): add support to ref images (flux Kontext) (#5935)
* feat(stablediffusion-ggml): add support to ref images Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add it to the model gallery Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
4438b4361e
commit
3d22bfc27c
@@ -198,7 +198,7 @@ int load_model(char *model, char* options[], int threads, int diff) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
||||
|
||||
sd_image_t* results;
|
||||
|
||||
@@ -221,15 +221,187 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
p.seed = seed;
|
||||
p.input_id_images_path = "";
|
||||
|
||||
// Handle input image for img2img
|
||||
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
||||
bool has_mask_image = (mask_image != NULL && strlen(mask_image) > 0);
|
||||
|
||||
uint8_t* input_image_buffer = NULL;
|
||||
uint8_t* mask_image_buffer = NULL;
|
||||
std::vector<uint8_t> default_mask_image_vec;
|
||||
|
||||
if (has_input_image) {
|
||||
fprintf(stderr, "Loading input image: %s\n", src_image);
|
||||
|
||||
int c = 0;
|
||||
int img_width = 0;
|
||||
int img_height = 0;
|
||||
input_image_buffer = stbi_load(src_image, &img_width, &img_height, &c, 3);
|
||||
if (input_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load input image from '%s'\n", src_image);
|
||||
return 1;
|
||||
}
|
||||
if (c < 3) {
|
||||
fprintf(stderr, "Input image must have at least 3 channels, got %d\n", c);
|
||||
free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Resize input image if dimensions don't match
|
||||
if (img_width != width || img_height != height) {
|
||||
fprintf(stderr, "Resizing input image from %dx%d to %dx%d\n", img_width, img_height, width, height);
|
||||
|
||||
uint8_t* resized_image_buffer = (uint8_t*)malloc(height * width * 3);
|
||||
if (resized_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized image\n");
|
||||
free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
stbir_resize(input_image_buffer, img_width, img_height, 0,
|
||||
resized_image_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(input_image_buffer);
|
||||
input_image_buffer = resized_image_buffer;
|
||||
}
|
||||
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
||||
p.strength = strength;
|
||||
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
||||
} else {
|
||||
// No input image, use empty image for text-to-image
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
||||
p.strength = 0.0f;
|
||||
}
|
||||
|
||||
// Handle mask image for inpainting
|
||||
if (has_mask_image) {
|
||||
fprintf(stderr, "Loading mask image: %s\n", mask_image);
|
||||
|
||||
int c = 0;
|
||||
int mask_width = 0;
|
||||
int mask_height = 0;
|
||||
mask_image_buffer = stbi_load(mask_image, &mask_width, &mask_height, &c, 1);
|
||||
if (mask_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load mask image from '%s'\n", mask_image);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Resize mask if dimensions don't match
|
||||
if (mask_width != width || mask_height != height) {
|
||||
fprintf(stderr, "Resizing mask image from %dx%d to %dx%d\n", mask_width, mask_height, width, height);
|
||||
|
||||
uint8_t* resized_mask_buffer = (uint8_t*)malloc(height * width);
|
||||
if (resized_mask_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized mask\n");
|
||||
free(mask_image_buffer);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
stbir_resize(mask_image_buffer, mask_width, mask_height, 0,
|
||||
resized_mask_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
1, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(mask_image_buffer);
|
||||
mask_image_buffer = resized_mask_buffer;
|
||||
}
|
||||
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
||||
fprintf(stderr, "Using inpainting with mask\n");
|
||||
} else {
|
||||
// No mask image, create default full mask
|
||||
default_mask_image_vec.resize(width * height, 255);
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
std::vector<sd_image_t> ref_images_vec;
|
||||
std::vector<uint8_t*> ref_image_buffers;
|
||||
|
||||
if (ref_images_count > 0 && ref_images != NULL) {
|
||||
fprintf(stderr, "Loading %d reference images\n", ref_images_count);
|
||||
|
||||
for (int i = 0; i < ref_images_count; i++) {
|
||||
if (ref_images[i] == NULL || strlen(ref_images[i]) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Loading reference image %d: %s\n", i + 1, ref_images[i]);
|
||||
|
||||
int c = 0;
|
||||
int ref_width = 0;
|
||||
int ref_height = 0;
|
||||
uint8_t* ref_image_buffer = stbi_load(ref_images[i], &ref_width, &ref_height, &c, 3);
|
||||
if (ref_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load reference image from '%s'\n", ref_images[i]);
|
||||
continue;
|
||||
}
|
||||
if (c < 3) {
|
||||
fprintf(stderr, "Reference image must have at least 3 channels, got %d\n", c);
|
||||
free(ref_image_buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resize reference image if dimensions don't match
|
||||
if (ref_width != width || ref_height != height) {
|
||||
fprintf(stderr, "Resizing reference image from %dx%d to %dx%d\n", ref_width, ref_height, width, height);
|
||||
|
||||
uint8_t* resized_ref_buffer = (uint8_t*)malloc(height * width * 3);
|
||||
if (resized_ref_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized reference image\n");
|
||||
free(ref_image_buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
stbir_resize(ref_image_buffer, ref_width, ref_height, 0,
|
||||
resized_ref_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(ref_image_buffer);
|
||||
ref_image_buffer = resized_ref_buffer;
|
||||
}
|
||||
|
||||
ref_image_buffers.push_back(ref_image_buffer);
|
||||
ref_images_vec.push_back({(uint32_t)width, (uint32_t)height, 3, ref_image_buffer});
|
||||
}
|
||||
|
||||
if (!ref_images_vec.empty()) {
|
||||
p.ref_images = ref_images_vec.data();
|
||||
p.ref_images_count = ref_images_vec.size();
|
||||
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
||||
}
|
||||
}
|
||||
|
||||
results = generate_image(sd_c, &p);
|
||||
|
||||
if (results == NULL) {
|
||||
fprintf (stderr, "NO results\n");
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (results[0].data == NULL) {
|
||||
fprintf (stderr, "Results with no data\n");
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -245,11 +417,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
results[0].data, 0, NULL);
|
||||
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
||||
|
||||
// TODO: free results. Why does it crash?
|
||||
|
||||
// Clean up
|
||||
free(results[0].data);
|
||||
results[0].data = NULL;
|
||||
free(results);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
fprintf (stderr, "gen_image is done", dst);
|
||||
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user