Compare commits

..

No commits in common. "636d3cb6ff25d1ffa7267e5f6dac9f2925945606" and "32965450907c730ae8d205a53b9574a99906da31" have entirely different histories.

56 changed files with 468 additions and 527 deletions

View File

@ -535,30 +535,31 @@ jobs:
# Add ROCm to PATH for current session # Add ROCm to PATH for current session
echo "/opt/rocm/bin" >> $GITHUB_PATH echo "/opt/rocm/bin" >> $GITHUB_PATH
# Build regex pattern from ${{ env.GPU_TARGETS }} (match target as substring) # Build case pattern from GPU_TARGETS
TARGET_REGEX="($(printf '%s' "${{ env.GPU_TARGETS }}" | sed 's/;/|/g'))" PATTERN=$(printf '%s' "$GPU_TARGETS" | sed 's/;/\*|\*/g')
PATTERN="*${PATTERN}*"
# Remove library files for architectures we're not building for to save disk space # Remove library files for architectures we're not building for to save disk space
echo "Cleaning up unneeded architecture files..." echo "Cleaning up unneeded architecture files..."
cd /opt/rocm/lib/rocblas/library cd /opt/rocm/lib/rocblas/library
# Keep only our target architectures # Keep only our target architectures
for file in *; do for file in *; do
if printf '%s' "$file" | grep -q 'gfx'; then case "$file" in
if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then $PATTERN)
echo "Removing $file" && ;;
sudo rm -f "$file"; *)
fi sudo rm -f "$file" ;;
fi esac;
done done
cd /opt/rocm/lib/hipblaslt/library cd /opt/rocm/lib/hipblaslt/library
for file in *; do for file in *; do
if printf '%s' "$file" | grep -q 'gfx'; then case "$file" in
if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then $PATTERN)
echo "Removing $file" && ;;
sudo rm -f "$file"; *)
fi sudo rm -f "$file" ;;
fi esac;
done done
- name: Build - name: Build
@ -591,15 +592,21 @@ jobs:
cp ggml/LICENSE ./build/bin/ggml.txt cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt cp LICENSE ./build/bin/stable-diffusion.cpp.txt
# Move ROCm runtime libraries (to avoid double space consumption) # Create directories for ROCm libraries
sudo mv /opt/rocm/lib/librocsparse.so* ./build/bin/ mkdir -p ./build/bin/rocblas/library
sudo mv /opt/rocm/lib/libhsa-runtime64.so* ./build/bin/ mkdir -p ./build/bin/hipblaslt/library
sudo mv /opt/rocm/lib/libamdhip64.so* ./build/bin/
sudo mv /opt/rocm/lib/libhipblas.so* ./build/bin/ # Copy ROCm runtime libraries (use || true to continue if files don't exist)
sudo mv /opt/rocm/lib/libhipblaslt.so* ./build/bin/ cp /opt/rocm/lib/librocsparse.so* ./build/bin/ || true
sudo mv /opt/rocm/lib/librocblas.so* ./build/bin/ cp /opt/rocm/lib/libhsa-runtime64.so* ./build/bin/ || true
sudo mv /opt/rocm/lib/rocblas/ ./build/bin/ cp /opt/rocm/lib/libamdhip64.so* ./build/bin/ || true
sudo mv /opt/rocm/lib/hipblaslt/ ./build/bin/ cp /opt/rocm/lib/libhipblas.so* ./build/bin/ || true
cp /opt/rocm/lib/libhipblaslt.so* ./build/bin/ || true
cp /opt/rocm/lib/librocblas.so* ./build/bin/ || true
# Copy library files (already filtered to target architectures)
cp /opt/rocm/lib/rocblas/library/* ./build/bin/rocblas/library/ || true
cp /opt/rocm/lib/hipblaslt/library/* ./build/bin/hipblaslt/library/ || true
- name: Fetch system info - name: Fetch system info
id: system-info id: system-info
@ -615,7 +622,7 @@ jobs:
run: | run: |
cp ggml/LICENSE ./build/bin/ggml.txt cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt cp LICENSE ./build/bin/stable-diffusion.cpp.txt
zip -y -r sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip ./build/bin zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip ./build/bin/*
- name: Upload artifacts - name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}

View File

@ -87,11 +87,9 @@ endif()
set(SD_LIB stable-diffusion) set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES file(GLOB SD_LIB_SOURCES
"src/*.h" "*.h"
"src/*.cpp" "*.cpp"
"src/*.hpp" "*.hpp"
"src/vocab/*.h"
"src/vocab/*.cpp"
) )
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -121,7 +119,7 @@ endif()
message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}") message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}")
set_property( set_property(
SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/src/version.cpp SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
APPEND PROPERTY COMPILE_DEFINITIONS APPEND PROPERTY COMPILE_DEFINITIONS
SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION} SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION}
) )
@ -184,7 +182,6 @@ endif()
add_subdirectory(thirdparty) add_subdirectory(thirdparty)
target_link_libraries(${SD_LIB} PUBLIC ggml zip) target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . include)
target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
@ -193,7 +190,7 @@ if (SD_BUILD_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
endif() endif()
set(SD_PUBLIC_HEADERS include/stable-diffusion.h) set(SD_PUBLIC_HEADERS stable-diffusion.h)
set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}") set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}")
install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER) install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER)

View File

@ -4,7 +4,6 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "model.h" #include "model.h"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
/*================================================== CLIPTokenizer ===================================================*/ /*================================================== CLIPTokenizer ===================================================*/
@ -111,7 +110,7 @@ public:
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(load_clip_merges()); load_from_merges(ModelLoader::load_merges());
} }
add_special_token("<|startoftext|>"); add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>"); add_special_token("<|endoftext|>");

View File

@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", ""); std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png"); std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100); int output_compression = j.value("output_compression", 100);
int width = default_gen_params.width > 0 ? default_gen_params.width : 512; int width = 512;
int height = default_gen_params.width > 0 ? default_gen_params.height : 512; int height = 512;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
n = std::clamp(n, 1, 8); n = std::clamp(n, 1, 8);
std::string size = req.form.get_field("size"); std::string size = req.form.get_field("size");
int width = -1, height = -1; int width = 512, height = 512;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -650,31 +650,15 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {0, 0, 3, nullptr}; sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr}; sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size()); ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) { for (auto& bytes : images_bytes) {
int img_w; int img_w = width;
int img_h; int img_h = height;
uint8_t* raw_pixels = load_image_from_memory( uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()), reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()), static_cast<int>(bytes.size()),
@ -686,31 +670,22 @@ int main(int argc, const char** argv) {
} }
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(img); ref_images.push_back(img);
} }
sd_image_t mask_image = {0}; sd_image_t mask_image = {0};
if (!mask_bytes.empty()) { if (!mask_bytes.empty()) {
int expected_width = 0; int mask_w = width;
int expected_height = 0; int mask_h = height;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int mask_w;
int mask_h;
uint8_t* mask_raw = load_image_from_memory( uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()), reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
expected_width, expected_height, 1); width, height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else { } else {
mask_image.width = get_resolved_width(); mask_image.width = width;
mask_image.height = get_resolved_height(); mask_image.height = height;
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = nullptr; mask_image.data = nullptr;
} }
@ -727,8 +702,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
get_resolved_width(), gen_params.width,
get_resolved_height(), gen_params.height,
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -911,6 +886,8 @@ int main(int argc, const char** argv) {
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt; gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed; gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps; gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size; gen_params.batch_count = batch_size;
@ -928,66 +905,38 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler; gen_params.sample_params.scheduler = scheduler;
} }
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {0, 0, 3, nullptr}; sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr}; sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr}; sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
std::vector<uint8_t> mask_data; std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int img_w;
int img_h;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
expected_width, expected_height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
}
}
return false;
};
if (img2img) { if (img2img) {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int img_w = image.width;
int img_h = image.height;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
image.width, image.height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
return true;
}
}
return false;
};
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>(); std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded); decode_image(init_image, encoded);
@ -1003,15 +952,23 @@ int main(int argc, const char** argv) {
} }
} }
} else { } else {
int m_width = get_resolved_width(); mask_data = std::vector<uint8_t>(width * height, 255);
int m_height = get_resolved_height(); mask_image.width = width;
mask_data = std::vector<uint8_t>(m_width * m_height, 255); mask_image.height = height;
mask_image.width = m_width;
mask_image.height = m_height;
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = mask_data.data(); mask_image.data = mask_data.data();
} }
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}
float denoising_strength = j.value("denoising_strength", -1.f); float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) { if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f); denoising_strength = std::min(denoising_strength, 1.0f);
@ -1019,16 +976,6 @@ int main(int argc, const char** argv) {
} }
} }
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
sd_loras.data(), sd_loras.data(),
static_cast<uint32_t>(sd_loras.size()), static_cast<uint32_t>(sd_loras.size()),
@ -1041,8 +988,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
get_resolved_width(), gen_params.width,
get_resolved_height(), gen_params.height,
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,

View File

@ -1,88 +1,88 @@
import os import os
import sys import sys
import numpy as np import numpy as np
import torch import torch
from diffusers.utils import load_image from diffusers.utils import load_image
# pip install insightface==0.7.3 # pip install insightface==0.7.3
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis
from insightface.data import get_image as ins_get_image from insightface.data import get_image as ins_get_image
from safetensors.torch import save_file from safetensors.torch import save_file
### ###
# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543 # https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543
### ###
class FaceAnalysis2(FaceAnalysis): class FaceAnalysis2(FaceAnalysis):
# NOTE: allows setting det_size for each detection call. # NOTE: allows setting det_size for each detection call.
# the model allows it but the wrapping code from insightface # the model allows it but the wrapping code from insightface
# doesn't show it, and people end up loading duplicate models # doesn't show it, and people end up loading duplicate models
# for different sizes where there is absolutely no need to # for different sizes where there is absolutely no need to
def get(self, img, max_num=0, det_size=(640, 640)): def get(self, img, max_num=0, det_size=(640, 640)):
if det_size is not None: if det_size is not None:
self.det_model.input_size = det_size self.det_model.input_size = det_size
return super().get(img, max_num) return super().get(img, max_num)
def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)): def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)):
# NOTE: try detect faces, if no faces detected, lower det_size until it does # NOTE: try detect faces, if no faces detected, lower det_size until it does
detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)] detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)]
for size in detection_sizes: for size in detection_sizes:
faces = face_analysis.get(img_data, det_size=size) faces = face_analysis.get(img_data, det_size=size)
if len(faces) > 0: if len(faces) > 0:
return faces return faces
return [] return []
if __name__ == "__main__": if __name__ == "__main__":
#face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) #face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition']) face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition'])
face_detector.prepare(ctx_id=0, det_size=(640, 640)) face_detector.prepare(ctx_id=0, det_size=(640, 640))
#input_folder_name = './scarletthead_woman' #input_folder_name = './scarletthead_woman'
input_folder_name = sys.argv[1] input_folder_name = sys.argv[1]
image_basename_list = os.listdir(input_folder_name) image_basename_list = os.listdir(input_folder_name)
image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list]) image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list])
input_id_images = [] input_id_images = []
for image_path in image_path_list: for image_path in image_path_list:
input_id_images.append(load_image(image_path)) input_id_images.append(load_image(image_path))
id_embed_list = [] id_embed_list = []
for img in input_id_images: for img in input_id_images:
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
faces = analyze_faces(face_detector, img) faces = analyze_faces(face_detector, img)
if len(faces) > 0: if len(faces) > 0:
id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) id_embed_list.append(torch.from_numpy((faces[0]['embedding'])))
if len(id_embed_list) == 0: if len(id_embed_list) == 0:
raise ValueError(f"No face detected in input image pool") raise ValueError(f"No face detected in input image pool")
id_embeds = torch.stack(id_embed_list) id_embeds = torch.stack(id_embed_list)
# for r in id_embeds: # for r in id_embeds:
# print(r) # print(r)
# #torch.save(id_embeds, input_folder_name+'/id_embeds.pt'); # #torch.save(id_embeds, input_folder_name+'/id_embeds.pt');
# weights = dict() # weights = dict()
# weights["id_embeds"] = id_embeds # weights["id_embeds"] = id_embeds
# save_file(weights, input_folder_name+'/id_embeds.safetensors') # save_file(weights, input_folder_name+'/id_embeds.safetensors')
binary_data = id_embeds.numpy().tobytes() binary_data = id_embeds.numpy().tobytes()
two = 4 two = 4
zero = 0 zero = 0
one = 1 one = 1
tensor_name = "id_embeds" tensor_name = "id_embeds"
# Write binary data to a file # Write binary data to a file
with open(input_folder_name+'/id_embeds.bin', "wb") as f: with open(input_folder_name+'/id_embeds.bin', "wb") as f:
f.write(two.to_bytes(4, byteorder='little')) f.write(two.to_bytes(4, byteorder='little'))
f.write((len(tensor_name)).to_bytes(4, byteorder='little')) f.write((len(tensor_name)).to_bytes(4, byteorder='little'))
f.write(zero.to_bytes(4, byteorder='little')) f.write(zero.to_bytes(4, byteorder='little'))
f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little')) f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little'))
f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little')) f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little'))
f.write(one.to_bytes(4, byteorder='little')) f.write(one.to_bytes(4, byteorder='little'))
f.write(one.to_bytes(4, byteorder='little')) f.write(one.to_bytes(4, byteorder='little'))
f.write(tensor_name.encode('ascii')) f.write(tensor_name.encode('ascii'))
f.write(binary_data) f.write(binary_data)

View File

@ -1,4 +1,4 @@
for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
[[ "$f" == vocab* ]] && continue [[ "$f" == vocab* ]] && continue
echo "formatting '$f'" echo "formatting '$f'"
# if [ "$f" != "stable-diffusion.h" ]; then # if [ "$f" != "stable-diffusion.h" ]; then

View File

@ -1,234 +1,234 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include "ggml.h" #include "ggml.h"
const float wan_21_latent_rgb_proj[16][3] = { const float wan_21_latent_rgb_proj[16][3] = {
{0.015123f, -0.148418f, 0.479828f}, {0.015123f, -0.148418f, 0.479828f},
{0.003652f, -0.010680f, -0.037142f}, {0.003652f, -0.010680f, -0.037142f},
{0.212264f, 0.063033f, 0.016779f}, {0.212264f, 0.063033f, 0.016779f},
{0.232999f, 0.406476f, 0.220125f}, {0.232999f, 0.406476f, 0.220125f},
{-0.051864f, -0.082384f, -0.069396f}, {-0.051864f, -0.082384f, -0.069396f},
{0.085005f, -0.161492f, 0.010689f}, {0.085005f, -0.161492f, 0.010689f},
{-0.245369f, -0.506846f, -0.117010f}, {-0.245369f, -0.506846f, -0.117010f},
{-0.151145f, 0.017721f, 0.007207f}, {-0.151145f, 0.017721f, 0.007207f},
{-0.293239f, -0.207936f, -0.421135f}, {-0.293239f, -0.207936f, -0.421135f},
{-0.187721f, 0.050783f, 0.177649f}, {-0.187721f, 0.050783f, 0.177649f},
{-0.013067f, 0.265964f, 0.166578f}, {-0.013067f, 0.265964f, 0.166578f},
{0.028327f, 0.109329f, 0.108642f}, {0.028327f, 0.109329f, 0.108642f},
{-0.205343f, 0.043991f, 0.148914f}, {-0.205343f, 0.043991f, 0.148914f},
{0.014307f, -0.048647f, -0.007219f}, {0.014307f, -0.048647f, -0.007219f},
{0.217150f, 0.053074f, 0.319923f}, {0.217150f, 0.053074f, 0.319923f},
{0.155357f, 0.083156f, 0.064780f}}; {0.155357f, 0.083156f, 0.064780f}};
float wan_21_latent_rgb_bias[3] = {-0.270270f, -0.234976f, -0.456853f}; float wan_21_latent_rgb_bias[3] = {-0.270270f, -0.234976f, -0.456853f};
const float wan_22_latent_rgb_proj[48][3] = { const float wan_22_latent_rgb_proj[48][3] = {
{0.017126f, -0.027230f, -0.019257f}, {0.017126f, -0.027230f, -0.019257f},
{-0.113739f, -0.028715f, -0.022885f}, {-0.113739f, -0.028715f, -0.022885f},
{-0.000106f, 0.021494f, 0.004629f}, {-0.000106f, 0.021494f, 0.004629f},
{-0.013273f, -0.107137f, -0.033638f}, {-0.013273f, -0.107137f, -0.033638f},
{-0.000381f, 0.000279f, 0.025877f}, {-0.000381f, 0.000279f, 0.025877f},
{-0.014216f, -0.003975f, 0.040528f}, {-0.014216f, -0.003975f, 0.040528f},
{0.001638f, -0.000748f, 0.011022f}, {0.001638f, -0.000748f, 0.011022f},
{0.029238f, -0.006697f, 0.035933f}, {0.029238f, -0.006697f, 0.035933f},
{0.021641f, -0.015874f, 0.040531f}, {0.021641f, -0.015874f, 0.040531f},
{-0.101984f, -0.070160f, -0.028855f}, {-0.101984f, -0.070160f, -0.028855f},
{0.033207f, -0.021068f, 0.002663f}, {0.033207f, -0.021068f, 0.002663f},
{-0.104711f, 0.121673f, 0.102981f}, {-0.104711f, 0.121673f, 0.102981f},
{0.082647f, -0.004991f, 0.057237f}, {0.082647f, -0.004991f, 0.057237f},
{-0.027375f, 0.031581f, 0.006868f}, {-0.027375f, 0.031581f, 0.006868f},
{-0.045434f, 0.029444f, 0.019287f}, {-0.045434f, 0.029444f, 0.019287f},
{-0.046572f, -0.012537f, 0.006675f}, {-0.046572f, -0.012537f, 0.006675f},
{0.074709f, 0.033690f, 0.025289f}, {0.074709f, 0.033690f, 0.025289f},
{-0.008251f, -0.002745f, -0.006999f}, {-0.008251f, -0.002745f, -0.006999f},
{0.012685f, -0.061856f, -0.048658f}, {0.012685f, -0.061856f, -0.048658f},
{0.042304f, -0.007039f, 0.000295f}, {0.042304f, -0.007039f, 0.000295f},
{-0.007644f, -0.060843f, -0.033142f}, {-0.007644f, -0.060843f, -0.033142f},
{0.159909f, 0.045628f, 0.367541f}, {0.159909f, 0.045628f, 0.367541f},
{0.095171f, 0.086438f, 0.010271f}, {0.095171f, 0.086438f, 0.010271f},
{0.006812f, 0.019643f, 0.029637f}, {0.006812f, 0.019643f, 0.029637f},
{0.003467f, -0.010705f, 0.014252f}, {0.003467f, -0.010705f, 0.014252f},
{-0.099681f, -0.066272f, -0.006243f}, {-0.099681f, -0.066272f, -0.006243f},
{0.047357f, 0.037040f, 0.000185f}, {0.047357f, 0.037040f, 0.000185f},
{-0.041797f, -0.089225f, -0.032257f}, {-0.041797f, -0.089225f, -0.032257f},
{0.008928f, 0.017028f, 0.018684f}, {0.008928f, 0.017028f, 0.018684f},
{-0.042255f, 0.016045f, 0.006849f}, {-0.042255f, 0.016045f, 0.006849f},
{0.011268f, 0.036462f, 0.037387f}, {0.011268f, 0.036462f, 0.037387f},
{0.011553f, -0.016375f, -0.048589f}, {0.011553f, -0.016375f, -0.048589f},
{0.046266f, -0.027189f, 0.056979f}, {0.046266f, -0.027189f, 0.056979f},
{0.009640f, -0.017576f, 0.030324f}, {0.009640f, -0.017576f, 0.030324f},
{-0.045794f, -0.036083f, -0.010616f}, {-0.045794f, -0.036083f, -0.010616f},
{0.022418f, 0.039783f, -0.032939f}, {0.022418f, 0.039783f, -0.032939f},
{-0.052714f, -0.015525f, 0.007438f}, {-0.052714f, -0.015525f, 0.007438f},
{0.193004f, 0.223541f, 0.264175f}, {0.193004f, 0.223541f, 0.264175f},
{-0.059406f, -0.008188f, 0.022867f}, {-0.059406f, -0.008188f, 0.022867f},
{-0.156742f, -0.263791f, -0.007385f}, {-0.156742f, -0.263791f, -0.007385f},
{-0.015717f, 0.016570f, 0.033969f}, {-0.015717f, 0.016570f, 0.033969f},
{0.037969f, 0.109835f, 0.200449f}, {0.037969f, 0.109835f, 0.200449f},
{-0.000782f, -0.009566f, -0.008058f}, {-0.000782f, -0.009566f, -0.008058f},
{0.010709f, 0.052960f, -0.044195f}, {0.010709f, 0.052960f, -0.044195f},
{0.017271f, 0.045839f, 0.034569f}, {0.017271f, 0.045839f, 0.034569f},
{0.009424f, 0.013088f, -0.001714f}, {0.009424f, 0.013088f, -0.001714f},
{-0.024805f, -0.059378f, -0.033756f}, {-0.024805f, -0.059378f, -0.033756f},
{-0.078293f, 0.029070f, 0.026129f}}; {-0.078293f, 0.029070f, 0.026129f}};
float wan_22_latent_rgb_bias[3] = {0.013160f, -0.096492f, -0.071323f}; float wan_22_latent_rgb_bias[3] = {0.013160f, -0.096492f, -0.071323f};
const float flux_latent_rgb_proj[16][3] = { const float flux_latent_rgb_proj[16][3] = {
{-0.041168f, 0.019917f, 0.097253f}, {-0.041168f, 0.019917f, 0.097253f},
{0.028096f, 0.026730f, 0.129576f}, {0.028096f, 0.026730f, 0.129576f},
{0.065618f, -0.067950f, -0.014651f}, {0.065618f, -0.067950f, -0.014651f},
{-0.012998f, -0.014762f, 0.081251f}, {-0.012998f, -0.014762f, 0.081251f},
{0.078567f, 0.059296f, -0.024687f}, {0.078567f, 0.059296f, -0.024687f},
{-0.015987f, -0.003697f, 0.005012f}, {-0.015987f, -0.003697f, 0.005012f},
{0.033605f, 0.138999f, 0.068517f}, {0.033605f, 0.138999f, 0.068517f},
{-0.024450f, -0.063567f, -0.030101f}, {-0.024450f, -0.063567f, -0.030101f},
{-0.040194f, -0.016710f, 0.127185f}, {-0.040194f, -0.016710f, 0.127185f},
{0.112681f, 0.088764f, -0.041940f}, {0.112681f, 0.088764f, -0.041940f},
{-0.023498f, 0.093664f, 0.025543f}, {-0.023498f, 0.093664f, 0.025543f},
{0.082899f, 0.048320f, 0.007491f}, {0.082899f, 0.048320f, 0.007491f},
{0.075712f, 0.074139f, 0.081965f}, {0.075712f, 0.074139f, 0.081965f},
{-0.143501f, 0.018263f, -0.136138f}, {-0.143501f, 0.018263f, -0.136138f},
{-0.025767f, -0.082035f, -0.040023f}, {-0.025767f, -0.082035f, -0.040023f},
{-0.111849f, -0.055589f, -0.032361f}}; {-0.111849f, -0.055589f, -0.032361f}};
float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f}; float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f};
const float flux2_latent_rgb_proj[32][3] = { const float flux2_latent_rgb_proj[32][3] = {
{0.000736f, -0.008385f, -0.019710f}, {0.000736f, -0.008385f, -0.019710f},
{-0.001352f, -0.016392f, 0.020693f}, {-0.001352f, -0.016392f, 0.020693f},
{-0.006376f, 0.002428f, 0.036736f}, {-0.006376f, 0.002428f, 0.036736f},
{0.039384f, 0.074167f, 0.119789f}, {0.039384f, 0.074167f, 0.119789f},
{0.007464f, -0.005705f, -0.004734f}, {0.007464f, -0.005705f, -0.004734f},
{-0.004086f, 0.005287f, -0.000409f}, {-0.004086f, 0.005287f, -0.000409f},
{-0.032835f, 0.050802f, -0.028120f}, {-0.032835f, 0.050802f, -0.028120f},
{-0.003158f, -0.000835f, 0.000406f}, {-0.003158f, -0.000835f, 0.000406f},
{-0.112840f, -0.084337f, -0.023083f}, {-0.112840f, -0.084337f, -0.023083f},
{0.001462f, -0.006656f, 0.000549f}, {0.001462f, -0.006656f, 0.000549f},
{-0.009980f, -0.007480f, 0.009702f}, {-0.009980f, -0.007480f, 0.009702f},
{0.032540f, 0.000214f, -0.061388f}, {0.032540f, 0.000214f, -0.061388f},
{0.011023f, 0.000694f, 0.007143f}, {0.011023f, 0.000694f, 0.007143f},
{-0.001468f, -0.006723f, -0.001678f}, {-0.001468f, -0.006723f, -0.001678f},
{-0.005921f, -0.010320f, -0.003907f}, {-0.005921f, -0.010320f, -0.003907f},
{-0.028434f, 0.027584f, 0.018457f}, {-0.028434f, 0.027584f, 0.018457f},
{0.014349f, 0.011523f, 0.000441f}, {0.014349f, 0.011523f, 0.000441f},
{0.009874f, 0.003081f, 0.001507f}, {0.009874f, 0.003081f, 0.001507f},
{0.002218f, 0.005712f, 0.001563f}, {0.002218f, 0.005712f, 0.001563f},
{0.053010f, -0.019844f, 0.008683f}, {0.053010f, -0.019844f, 0.008683f},
{-0.002507f, 0.005384f, 0.000938f}, {-0.002507f, 0.005384f, 0.000938f},
{-0.002177f, -0.011366f, 0.003559f}, {-0.002177f, -0.011366f, 0.003559f},
{-0.000261f, 0.015121f, -0.003240f}, {-0.000261f, 0.015121f, -0.003240f},
{-0.003944f, -0.002083f, 0.005043f}, {-0.003944f, -0.002083f, 0.005043f},
{-0.009138f, 0.011336f, 0.003781f}, {-0.009138f, 0.011336f, 0.003781f},
{0.011429f, 0.003985f, -0.003855f}, {0.011429f, 0.003985f, -0.003855f},
{0.010518f, -0.005586f, 0.010131f}, {0.010518f, -0.005586f, 0.010131f},
{0.007883f, 0.002912f, -0.001473f}, {0.007883f, 0.002912f, -0.001473f},
{-0.003318f, -0.003160f, 0.003684f}, {-0.003318f, -0.003160f, 0.003684f},
{-0.034560f, -0.008740f, 0.012996f}, {-0.034560f, -0.008740f, 0.012996f},
{0.000166f, 0.001079f, -0.012153f}, {0.000166f, 0.001079f, -0.012153f},
{0.017772f, 0.000937f, -0.011953f}}; {0.017772f, 0.000937f, -0.011953f}};
float flux2_latent_rgb_bias[3] = {-0.028738f, -0.098463f, -0.107619f}; float flux2_latent_rgb_bias[3] = {-0.028738f, -0.098463f, -0.107619f};
// This one was taken straight from // This one was taken straight from
// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303 // https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
// (MiT Licence) // (MiT Licence)
const float sd3_latent_rgb_proj[16][3] = { const float sd3_latent_rgb_proj[16][3] = {
{-0.0645f, 0.0177f, 0.1052f}, {-0.0645f, 0.0177f, 0.1052f},
{0.0028f, 0.0312f, 0.0650f}, {0.0028f, 0.0312f, 0.0650f},
{0.1848f, 0.0762f, 0.0360f}, {0.1848f, 0.0762f, 0.0360f},
{0.0944f, 0.0360f, 0.0889f}, {0.0944f, 0.0360f, 0.0889f},
{0.0897f, 0.0506f, -0.0364f}, {0.0897f, 0.0506f, -0.0364f},
{-0.0020f, 0.1203f, 0.0284f}, {-0.0020f, 0.1203f, 0.0284f},
{0.0855f, 0.0118f, 0.0283f}, {0.0855f, 0.0118f, 0.0283f},
{-0.0539f, 0.0658f, 0.1047f}, {-0.0539f, 0.0658f, 0.1047f},
{-0.0057f, 0.0116f, 0.0700f}, {-0.0057f, 0.0116f, 0.0700f},
{-0.0412f, 0.0281f, -0.0039f}, {-0.0412f, 0.0281f, -0.0039f},
{0.1106f, 0.1171f, 0.1220f}, {0.1106f, 0.1171f, 0.1220f},
{-0.0248f, 0.0682f, -0.0481f}, {-0.0248f, 0.0682f, -0.0481f},
{0.0815f, 0.0846f, 0.1207f}, {0.0815f, 0.0846f, 0.1207f},
{-0.0120f, -0.0055f, -0.0867f}, {-0.0120f, -0.0055f, -0.0867f},
{-0.0749f, -0.0634f, -0.0456f}, {-0.0749f, -0.0634f, -0.0456f},
{-0.1418f, -0.1457f, -0.1259f}, {-0.1418f, -0.1457f, -0.1259f},
}; };
float sd3_latent_rgb_bias[3] = {0, 0, 0}; float sd3_latent_rgb_bias[3] = {0, 0, 0};
const float sdxl_latent_rgb_proj[4][3] = { const float sdxl_latent_rgb_proj[4][3] = {
{0.258303f, 0.277640f, 0.329699f}, {0.258303f, 0.277640f, 0.329699f},
{-0.299701f, 0.105446f, 0.014194f}, {-0.299701f, 0.105446f, 0.014194f},
{0.050522f, 0.186163f, -0.143257f}, {0.050522f, 0.186163f, -0.143257f},
{-0.211938f, -0.149892f, -0.080036f}}; {-0.211938f, -0.149892f, -0.080036f}};
float sdxl_latent_rgb_bias[3] = {0.144381f, -0.033313f, 0.007061f}; float sdxl_latent_rgb_bias[3] = {0.144381f, -0.033313f, 0.007061f};
const float sd_latent_rgb_proj[4][3] = { const float sd_latent_rgb_proj[4][3] = {
{0.337366f, 0.216344f, 0.257386f}, {0.337366f, 0.216344f, 0.257386f},
{0.165636f, 0.386828f, 0.046994f}, {0.165636f, 0.386828f, 0.046994f},
{-0.267803f, 0.237036f, 0.223517f}, {-0.267803f, 0.237036f, 0.223517f},
{-0.178022f, -0.200862f, -0.678514f}}; {-0.178022f, -0.200862f, -0.678514f}};
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f}; float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) { void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
size_t buffer_head = 0; size_t buffer_head = 0;
uint32_t latent_width = static_cast<uint32_t>(latents->ne[0]); uint32_t latent_width = static_cast<uint32_t>(latents->ne[0]);
uint32_t latent_height = static_cast<uint32_t>(latents->ne[1]); uint32_t latent_height = static_cast<uint32_t>(latents->ne[1]);
uint32_t dim = static_cast<uint32_t>(latents->ne[ggml_n_dims(latents) - 1]); uint32_t dim = static_cast<uint32_t>(latents->ne[ggml_n_dims(latents) - 1]);
uint32_t frames = 1; uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) { if (ggml_n_dims(latents) == 4) {
frames = static_cast<uint32_t>(latents->ne[2]); frames = static_cast<uint32_t>(latents->ne[2]);
} }
uint32_t rgb_width = latent_width * patch_size; uint32_t rgb_width = latent_width * patch_size;
uint32_t rgb_height = latent_height * patch_size; uint32_t rgb_height = latent_height * patch_size;
uint32_t unpatched_dim = dim / (patch_size * patch_size); uint32_t unpatched_dim = dim / (patch_size * patch_size);
for (uint32_t k = 0; k < frames; k++) { for (uint32_t k = 0; k < frames; k++) {
for (uint32_t rgb_x = 0; rgb_x < rgb_width; rgb_x++) { for (uint32_t rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
for (uint32_t rgb_y = 0; rgb_y < rgb_height; rgb_y++) { for (uint32_t rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
int latent_x = rgb_x / patch_size; int latent_x = rgb_x / patch_size;
int latent_y = rgb_y / patch_size; int latent_y = rgb_y / patch_size;
int channel_offset = 0; int channel_offset = 0;
if (patch_size > 1) { if (patch_size > 1) {
channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size)); channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
} }
size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]); size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]);
// should be incremented by 1 for each pixel // should be incremented by 1 for each pixel
size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x; size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
float r = 0, g = 0, b = 0; float r = 0, g = 0, b = 0;
if (latent_rgb_proj != nullptr) { if (latent_rgb_proj != nullptr) {
for (uint32_t d = 0; d < unpatched_dim; d++) { for (uint32_t d = 0; d < unpatched_dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]); float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]);
r += value * latent_rgb_proj[d][0]; r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1]; g += value * latent_rgb_proj[d][1];
b += value * latent_rgb_proj[d][2]; b += value * latent_rgb_proj[d][2];
} }
} else { } else {
// interpret first 3 channels as RGB // interpret first 3 channels as RGB
r = *(float*)((char*)latents->data + latent_id + 0 * latents->nb[ggml_n_dims(latents) - 1]); r = *(float*)((char*)latents->data + latent_id + 0 * latents->nb[ggml_n_dims(latents) - 1]);
g = *(float*)((char*)latents->data + latent_id + 1 * latents->nb[ggml_n_dims(latents) - 1]); g = *(float*)((char*)latents->data + latent_id + 1 * latents->nb[ggml_n_dims(latents) - 1]);
b = *(float*)((char*)latents->data + latent_id + 2 * latents->nb[ggml_n_dims(latents) - 1]); b = *(float*)((char*)latents->data + latent_id + 2 * latents->nb[ggml_n_dims(latents) - 1]);
} }
if (latent_rgb_bias != nullptr) { if (latent_rgb_bias != nullptr) {
// bias // bias
r += latent_rgb_bias[0]; r += latent_rgb_bias[0];
g += latent_rgb_bias[1]; g += latent_rgb_bias[1];
b += latent_rgb_bias[2]; b += latent_rgb_bias[2];
} }
// change range // change range
r = r * .5f + .5f; r = r * .5f + .5f;
g = g * .5f + .5f; g = g * .5f + .5f;
b = b * .5f + .5f; b = b * .5f + .5f;
// clamp rgb values to [0,1] range // clamp rgb values to [0,1] range
r = r >= 0 ? r <= 1 ? r : 1 : 0; r = r >= 0 ? r <= 1 ? r : 1 : 0;
g = g >= 0 ? g <= 1 ? g : 1 : 0; g = g >= 0 ? g <= 1 ? g : 1 : 0;
b = b >= 0 ? b <= 1 ? b : 1 : 0; b = b >= 0 ? b <= 1 ? b : 1 : 0;
buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255); buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255);
buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255); buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255);
buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255); buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255);
} }
} }
} }
} }

View File

@ -19,7 +19,6 @@
#include "json.hpp" #include "json.hpp"
#include "rope.hpp" #include "rope.hpp"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
namespace LLM { namespace LLM {
constexpr int LLM_GRAPH_SIZE = 10240; constexpr int LLM_GRAPH_SIZE = 10240;
@ -366,7 +365,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(load_qwen2_merges()); load_from_merges(ModelLoader::load_qwen2_merges());
} }
} }
}; };
@ -467,7 +466,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str); load_from_merges(merges_utf8_str, vocab_utf8_str);
} else { } else {
load_from_merges(load_mistral_merges(), load_mistral_vocab_json()); load_from_merges(ModelLoader::load_mistral_merges(), ModelLoader::load_mistral_vocab_json());
} }
} }
}; };

View File

@ -16,6 +16,10 @@
#include "model.h" #include "model.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "util.h" #include "util.h"
#include "vocab.hpp"
#include "vocab_mistral.hpp"
#include "vocab_qwen.hpp"
#include "vocab_umt5.hpp"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
@ -1336,6 +1340,36 @@ void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_ru
} }
} }
std::string ModelLoader::load_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string ModelLoader::load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string ModelLoader::load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
int64_t process_time_ms = 0; int64_t process_time_ms = 0;
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);

View File

@ -331,6 +331,13 @@ public:
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default; ~ModelLoader() = default;
static std::string load_merges();
static std::string load_qwen2_merges();
static std::string load_mistral_merges();
static std::string load_mistral_vocab_json();
static std::string load_t5_tokenizer_json();
static std::string load_umt5_tokenizer_json();
}; };
#endif // __MODEL_H__ #endif // __MODEL_H__

View File

@ -1,35 +0,0 @@
#include "vocab.h"
#include "clip_t5.hpp"
#include "mistral.hpp"
#include "qwen.hpp"
#include "umt5.hpp"
std::string load_clip_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(clip_merges_utf8_c_str), sizeof(clip_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}

View File

@ -1,13 +0,0 @@
#ifndef __VOCAB_H__
#define __VOCAB_H__
#include <string>
std::string load_clip_merges();
std::string load_qwen2_merges();
std::string load_mistral_merges();
std::string load_mistral_vocab_json();
std::string load_t5_tokenizer_json();
std::string load_umt5_tokenizer_json();
#endif // __VOCAB_H__

View File

@ -2679,7 +2679,7 @@ public:
}; };
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else { } else {
if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) { if(!first_stage_model->compute(n_threads, x, true, &result, work_ctx)){
LOG_ERROR("Failed to decode latetnts"); LOG_ERROR("Failed to decode latetnts");
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
return nullptr; return nullptr;
@ -2695,7 +2695,7 @@ public:
}; };
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling); sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
} else { } else {
if (!tae_first_stage->compute(n_threads, x, true, &result)) { if(!tae_first_stage->compute(n_threads, x, true, &result)){
LOG_ERROR("Failed to decode latetnts"); LOG_ERROR("Failed to decode latetnts");
tae_first_stage->free_compute_buffer(); tae_first_stage->free_compute_buffer();
return nullptr; return nullptr;

View File

@ -14,7 +14,6 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "json.hpp" #include "json.hpp"
#include "model.h" #include "model.h"
#include "vocab/vocab.h"
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h // Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. // and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
@ -342,9 +341,9 @@ protected:
public: public:
explicit T5UniGramTokenizer(bool is_umt5 = false) { explicit T5UniGramTokenizer(bool is_umt5 = false) {
if (is_umt5) { if (is_umt5) {
InitializePieces(load_umt5_tokenizer_json()); InitializePieces(ModelLoader::load_umt5_tokenizer_json());
} else { } else {
InitializePieces(load_t5_tokenizer_json()); InitializePieces(ModelLoader::load_t5_tokenizer_json());
} }
min_score_ = FLT_MAX; min_score_ = FLT_MAX;

View File

View File

@ -1,4 +1,4 @@
static const unsigned char clip_merges_utf8_c_str[] = { static unsigned char merges_utf8_c_str[] = {
0x23, 0x23,
0x76, 0x76,
0x65, 0x65,
@ -524620,7 +524620,7 @@ static const unsigned char clip_merges_utf8_c_str[] = {
0x0a, 0x0a,
}; };
static const unsigned char t5_tokenizer_json_str[] = { static unsigned char t5_tokenizer_json_str[] = {
0x7b, 0x7b,
0x0a, 0x0a,
0x20, 0x20,

View File

@ -1,4 +1,4 @@
static const unsigned char mistral_merges_utf8_c_str[] = { unsigned char mistral_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65,
0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
@ -260614,7 +260614,7 @@ static const unsigned char mistral_merges_utf8_c_str[] = {
0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4, 0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4,
0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a, 0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a,
}; };
static const unsigned char mistral_vocab_json_utf8_c_str[] = { unsigned char mistral_vocab_json_utf8_c_str[] = {
0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c, 0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c,
0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22, 0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22,
0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b, 0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b,

View File

@ -1,4 +1,4 @@
static const unsigned char qwen2_merges_utf8_c_str[] = { unsigned char qwen2_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4,
0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74,
0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,

View File

@ -1,4 +1,4 @@
static const unsigned char umt5_tokenizer_json_str[] = { unsigned char umt5_tokenizer_json_str[] = {
0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20,
0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e, 0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e,
0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c,