Compare commits

...

4 Commits

56 changed files with 527 additions and 468 deletions

View File

@ -535,31 +535,30 @@ jobs:
# Add ROCm to PATH for current session # Add ROCm to PATH for current session
echo "/opt/rocm/bin" >> $GITHUB_PATH echo "/opt/rocm/bin" >> $GITHUB_PATH
# Build case pattern from GPU_TARGETS # Build regex pattern from ${{ env.GPU_TARGETS }} (match target as substring)
PATTERN=$(printf '%s' "$GPU_TARGETS" | sed 's/;/\*|\*/g') TARGET_REGEX="($(printf '%s' "${{ env.GPU_TARGETS }}" | sed 's/;/|/g'))"
PATTERN="*${PATTERN}*"
# Remove library files for architectures we're not building for to save disk space # Remove library files for architectures we're not building for to save disk space
echo "Cleaning up unneeded architecture files..." echo "Cleaning up unneeded architecture files..."
cd /opt/rocm/lib/rocblas/library cd /opt/rocm/lib/rocblas/library
# Keep only our target architectures # Keep only our target architectures
for file in *; do for file in *; do
case "$file" in if printf '%s' "$file" | grep -q 'gfx'; then
$PATTERN) if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then
;; echo "Removing $file" &&
*) sudo rm -f "$file";
sudo rm -f "$file" ;; fi
esac; fi
done done
cd /opt/rocm/lib/hipblaslt/library cd /opt/rocm/lib/hipblaslt/library
for file in *; do for file in *; do
case "$file" in if printf '%s' "$file" | grep -q 'gfx'; then
$PATTERN) if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then
;; echo "Removing $file" &&
*) sudo rm -f "$file";
sudo rm -f "$file" ;; fi
esac; fi
done done
- name: Build - name: Build
@ -592,21 +591,15 @@ jobs:
cp ggml/LICENSE ./build/bin/ggml.txt cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt cp LICENSE ./build/bin/stable-diffusion.cpp.txt
# Create directories for ROCm libraries # Move ROCm runtime libraries (to avoid double space consumption)
mkdir -p ./build/bin/rocblas/library sudo mv /opt/rocm/lib/librocsparse.so* ./build/bin/
mkdir -p ./build/bin/hipblaslt/library sudo mv /opt/rocm/lib/libhsa-runtime64.so* ./build/bin/
sudo mv /opt/rocm/lib/libamdhip64.so* ./build/bin/
# Copy ROCm runtime libraries (use || true to continue if files don't exist) sudo mv /opt/rocm/lib/libhipblas.so* ./build/bin/
cp /opt/rocm/lib/librocsparse.so* ./build/bin/ || true sudo mv /opt/rocm/lib/libhipblaslt.so* ./build/bin/
cp /opt/rocm/lib/libhsa-runtime64.so* ./build/bin/ || true sudo mv /opt/rocm/lib/librocblas.so* ./build/bin/
cp /opt/rocm/lib/libamdhip64.so* ./build/bin/ || true sudo mv /opt/rocm/lib/rocblas/ ./build/bin/
cp /opt/rocm/lib/libhipblas.so* ./build/bin/ || true sudo mv /opt/rocm/lib/hipblaslt/ ./build/bin/
cp /opt/rocm/lib/libhipblaslt.so* ./build/bin/ || true
cp /opt/rocm/lib/librocblas.so* ./build/bin/ || true
# Copy library files (already filtered to target architectures)
cp /opt/rocm/lib/rocblas/library/* ./build/bin/rocblas/library/ || true
cp /opt/rocm/lib/hipblaslt/library/* ./build/bin/hipblaslt/library/ || true
- name: Fetch system info - name: Fetch system info
id: system-info id: system-info
@ -622,7 +615,7 @@ jobs:
run: | run: |
cp ggml/LICENSE ./build/bin/ggml.txt cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt cp LICENSE ./build/bin/stable-diffusion.cpp.txt
zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip ./build/bin/* zip -y -r sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip ./build/bin
- name: Upload artifacts - name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}

View File

@ -87,9 +87,11 @@ endif()
set(SD_LIB stable-diffusion) set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES file(GLOB SD_LIB_SOURCES
"*.h" "src/*.h"
"*.cpp" "src/*.cpp"
"*.hpp" "src/*.hpp"
"src/vocab/*.h"
"src/vocab/*.cpp"
) )
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -119,7 +121,7 @@ endif()
message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}") message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}")
set_property( set_property(
SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/src/version.cpp
APPEND PROPERTY COMPILE_DEFINITIONS APPEND PROPERTY COMPILE_DEFINITIONS
SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION} SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION}
) )
@ -182,6 +184,7 @@ endif()
add_subdirectory(thirdparty) add_subdirectory(thirdparty)
target_link_libraries(${SD_LIB} PUBLIC ggml zip) target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . include)
target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
@ -190,7 +193,7 @@ if (SD_BUILD_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
endif() endif()
set(SD_PUBLIC_HEADERS stable-diffusion.h) set(SD_PUBLIC_HEADERS include/stable-diffusion.h)
set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}") set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}")
install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER) install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER)

View File

@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", ""); std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png"); std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100); int output_compression = j.value("output_compression", 100);
int width = 512; int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
int height = 512; int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
n = std::clamp(n, 1, 8); n = std::clamp(n, 1, 8);
std::string size = req.form.get_field("size"); std::string size = req.form.get_field("size");
int width = 512, height = 512; int width = -1, height = -1;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -650,15 +650,31 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size()); ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) { for (auto& bytes : images_bytes) {
int img_w = width; int img_w;
int img_h = height; int img_h;
uint8_t* raw_pixels = load_image_from_memory( uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()), reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()), static_cast<int>(bytes.size()),
@ -670,22 +686,31 @@ int main(int argc, const char** argv) {
} }
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(img); ref_images.push_back(img);
} }
sd_image_t mask_image = {0}; sd_image_t mask_image = {0};
if (!mask_bytes.empty()) { if (!mask_bytes.empty()) {
int mask_w = width; int expected_width = 0;
int mask_h = height; int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int mask_w;
int mask_h;
uint8_t* mask_raw = load_image_from_memory( uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()), reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
width, height, 1); expected_width, expected_height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else { } else {
mask_image.width = width; mask_image.width = get_resolved_width();
mask_image.height = height; mask_image.height = get_resolved_height();
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = nullptr; mask_image.data = nullptr;
} }
@ -702,8 +727,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, get_resolved_width(),
gen_params.height, get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -886,8 +911,6 @@ int main(int argc, const char** argv) {
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt; gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed; gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps; gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size; gen_params.batch_count = batch_size;
@ -905,38 +928,66 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler; gen_params.sample_params.scheduler = scheduler;
} }
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<uint8_t> mask_data; std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
if (img2img) { auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool { if (gen_params.width > 0)
// remove data URI prefix if present ("data:image/png;base64,") return gen_params.width;
auto comma_pos = encoded.find(','); if (default_gen_params.width > 0)
if (comma_pos != std::string::npos) { return default_gen_params.width;
encoded = encoded.substr(comma_pos + 1); return 512;
} };
std::vector<uint8_t> img_data = base64_decode(encoded); auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (!img_data.empty()) { if (gen_params.height > 0)
int img_w = image.width; return gen_params.height;
int img_h = image.height; if (default_gen_params.height > 0)
uint8_t* raw_data = load_image_from_memory( return default_gen_params.height;
(const char*)img_data.data(), (int)img_data.size(), return 512;
img_w, img_h, };
image.width, image.height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
return true;
}
}
return false;
};
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int img_w;
int img_h;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
expected_width, expected_height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
}
}
return false;
};
if (img2img) {
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>(); std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded); decode_image(init_image, encoded);
@ -952,23 +1003,15 @@ int main(int argc, const char** argv) {
} }
} }
} else { } else {
mask_data = std::vector<uint8_t>(width * height, 255); int m_width = get_resolved_width();
mask_image.width = width; int m_height = get_resolved_height();
mask_image.height = height; mask_data = std::vector<uint8_t>(m_width * m_height, 255);
mask_image.width = m_width;
mask_image.height = m_height;
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = mask_data.data(); mask_image.data = mask_data.data();
} }
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}
float denoising_strength = j.value("denoising_strength", -1.f); float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) { if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f); denoising_strength = std::min(denoising_strength, 1.0f);
@ -976,6 +1019,16 @@ int main(int argc, const char** argv) {
} }
} }
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
sd_loras.data(), sd_loras.data(),
static_cast<uint32_t>(sd_loras.size()), static_cast<uint32_t>(sd_loras.size()),
@ -988,8 +1041,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, get_resolved_width(),
gen_params.height, get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,

View File

@ -1,4 +1,4 @@
for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
[[ "$f" == vocab* ]] && continue [[ "$f" == vocab* ]] && continue
echo "formatting '$f'" echo "formatting '$f'"
# if [ "$f" != "stable-diffusion.h" ]; then # if [ "$f" != "stable-diffusion.h" ]; then

View File

@ -4,6 +4,7 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "model.h" #include "model.h"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
/*================================================== CLIPTokenizer ===================================================*/ /*================================================== CLIPTokenizer ===================================================*/
@ -110,7 +111,7 @@ public:
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_merges()); load_from_merges(load_clip_merges());
} }
add_special_token("<|startoftext|>"); add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>"); add_special_token("<|endoftext|>");

View File

@ -19,6 +19,7 @@
#include "json.hpp" #include "json.hpp"
#include "rope.hpp" #include "rope.hpp"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
namespace LLM { namespace LLM {
constexpr int LLM_GRAPH_SIZE = 10240; constexpr int LLM_GRAPH_SIZE = 10240;
@ -365,7 +366,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_qwen2_merges()); load_from_merges(load_qwen2_merges());
} }
} }
}; };
@ -466,7 +467,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str); load_from_merges(merges_utf8_str, vocab_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_mistral_merges(), ModelLoader::load_mistral_vocab_json()); load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
} }
} }
}; };

View File

@ -16,10 +16,6 @@
#include "model.h" #include "model.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "util.h" #include "util.h"
#include "vocab.hpp"
#include "vocab_mistral.hpp"
#include "vocab_qwen.hpp"
#include "vocab_umt5.hpp"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
@ -1340,36 +1336,6 @@ void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_ru
} }
} }
std::string ModelLoader::load_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string ModelLoader::load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string ModelLoader::load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
int64_t process_time_ms = 0; int64_t process_time_ms = 0;
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);

View File

@ -331,13 +331,6 @@ public:
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default; ~ModelLoader() = default;
static std::string load_merges();
static std::string load_qwen2_merges();
static std::string load_mistral_merges();
static std::string load_mistral_vocab_json();
static std::string load_t5_tokenizer_json();
static std::string load_umt5_tokenizer_json();
}; };
#endif // __MODEL_H__ #endif // __MODEL_H__

View File

@ -2679,7 +2679,7 @@ public:
}; };
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else { } else {
if(!first_stage_model->compute(n_threads, x, true, &result, work_ctx)){ if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
LOG_ERROR("Failed to decode latetnts"); LOG_ERROR("Failed to decode latetnts");
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
return nullptr; return nullptr;
@ -2695,7 +2695,7 @@ public:
}; };
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling); sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
} else { } else {
if(!tae_first_stage->compute(n_threads, x, true, &result)){ if (!tae_first_stage->compute(n_threads, x, true, &result)) {
LOG_ERROR("Failed to decode latetnts"); LOG_ERROR("Failed to decode latetnts");
tae_first_stage->free_compute_buffer(); tae_first_stage->free_compute_buffer();
return nullptr; return nullptr;

View File

@ -14,6 +14,7 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "json.hpp" #include "json.hpp"
#include "model.h" #include "model.h"
#include "vocab/vocab.h"
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h // Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. // and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
@ -341,9 +342,9 @@ protected:
public: public:
explicit T5UniGramTokenizer(bool is_umt5 = false) { explicit T5UniGramTokenizer(bool is_umt5 = false) {
if (is_umt5) { if (is_umt5) {
InitializePieces(ModelLoader::load_umt5_tokenizer_json()); InitializePieces(load_umt5_tokenizer_json());
} else { } else {
InitializePieces(ModelLoader::load_t5_tokenizer_json()); InitializePieces(load_t5_tokenizer_json());
} }
min_score_ = FLT_MAX; min_score_ = FLT_MAX;

View File

View File

@ -1,4 +1,4 @@
static unsigned char merges_utf8_c_str[] = { static const unsigned char clip_merges_utf8_c_str[] = {
0x23, 0x23,
0x76, 0x76,
0x65, 0x65,
@ -524620,7 +524620,7 @@ static unsigned char merges_utf8_c_str[] = {
0x0a, 0x0a,
}; };
static unsigned char t5_tokenizer_json_str[] = { static const unsigned char t5_tokenizer_json_str[] = {
0x7b, 0x7b,
0x0a, 0x0a,
0x20, 0x20,

View File

@ -1,4 +1,4 @@
unsigned char mistral_merges_utf8_c_str[] = { static const unsigned char mistral_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65,
0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
@ -260614,7 +260614,7 @@ unsigned char mistral_merges_utf8_c_str[] = {
0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4, 0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4,
0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a, 0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a,
}; };
unsigned char mistral_vocab_json_utf8_c_str[] = { static const unsigned char mistral_vocab_json_utf8_c_str[] = {
0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c, 0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c,
0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22, 0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22,
0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b, 0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b,

View File

@ -1,4 +1,4 @@
unsigned char qwen2_merges_utf8_c_str[] = { static const unsigned char qwen2_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4,
0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74,
0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,

View File

@ -1,4 +1,4 @@
unsigned char umt5_tokenizer_json_str[] = { static const unsigned char umt5_tokenizer_json_str[] = {
0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20,
0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e, 0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e,
0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c,

35
src/vocab/vocab.cpp Normal file
View File

@ -0,0 +1,35 @@
#include "vocab.h"
#include "clip_t5.hpp"
#include "mistral.hpp"
#include "qwen.hpp"
#include "umt5.hpp"
std::string load_clip_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(clip_merges_utf8_c_str), sizeof(clip_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}

13
src/vocab/vocab.h Normal file
View File

@ -0,0 +1,13 @@
#ifndef __VOCAB_H__
#define __VOCAB_H__
#include <string>
std::string load_clip_merges();
std::string load_qwen2_merges();
std::string load_mistral_merges();
std::string load_mistral_vocab_json();
std::string load_t5_tokenizer_json();
std::string load_umt5_tokenizer_json();
#endif // __VOCAB_H__