mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add break pseudo token support (#422)
--------- Co-authored-by: Urs Ganse <urs.ganse@helsinki.fi>
This commit is contained in:
parent
347710f68f
commit
6448430dbb
@ -278,13 +278,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
|
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
|
||||||
|
int32_t clean_index = 0;
|
||||||
|
if (curr_text == "BREAK" && curr_weight == -1.0f) {
|
||||||
|
// Pad token array up to chunk size at this point.
|
||||||
|
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
|
||||||
|
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
|
||||||
|
int padding_size = 75 - (tokens_acc % 75);
|
||||||
|
for (int j = 0; j < padding_size; j++) {
|
||||||
|
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
|
||||||
|
clean_index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// After padding, continue to the next iteration to process the following text as a new segment
|
||||||
|
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
|
||||||
|
weights.insert(weights.end(), padding_size, curr_weight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular token, process normally
|
||||||
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
|
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
|
||||||
int32_t clean_index = 0;
|
|
||||||
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
|
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
|
||||||
int token_id = curr_tokens[i];
|
int token_id = curr_tokens[i];
|
||||||
if (token_id == image_token)
|
if (token_id == image_token) {
|
||||||
class_token_index.push_back(clean_index - 1);
|
class_token_index.push_back(clean_index - 1);
|
||||||
else {
|
} else {
|
||||||
clean_input_ids.push_back(token_id);
|
clean_input_ids.push_back(token_id);
|
||||||
clean_index++;
|
clean_index++;
|
||||||
}
|
}
|
||||||
@ -387,6 +404,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
for (const auto& item : parsed_attention) {
|
for (const auto& item : parsed_attention) {
|
||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
|
|
||||||
|
if (curr_text == "BREAK" && curr_weight == -1.0f) {
|
||||||
|
// Pad token array up to chunk size at this point.
|
||||||
|
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
|
||||||
|
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
|
||||||
|
size_t current_size = tokens.size();
|
||||||
|
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding
|
||||||
|
|
||||||
|
if (padding_size > 0) {
|
||||||
|
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
|
||||||
|
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
|
||||||
|
weights.insert(weights.end(), padding_size, 1.0f);
|
||||||
|
}
|
||||||
|
continue; // Skip to the next item after handling BREAK
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
|
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
|
||||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
|
|||||||
11
util.cpp
11
util.cpp
@ -5,6 +5,7 @@
|
|||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
|
#include <regex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
@ -547,6 +548,8 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
|
|||||||
// (abc) - increases attention to abc by a multiplier of 1.1
|
// (abc) - increases attention to abc by a multiplier of 1.1
|
||||||
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
|
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||||
// [abc] - decreases attention to abc by a multiplier of 1.1
|
// [abc] - decreases attention to abc by a multiplier of 1.1
|
||||||
|
// BREAK - separates the prompt into conceptually distinct parts for sequential processing
|
||||||
|
// B - internal helper pattern; prevents 'B' in 'BREAK' from being consumed as normal text
|
||||||
// \( - literal character '('
|
// \( - literal character '('
|
||||||
// \[ - literal character '['
|
// \[ - literal character '['
|
||||||
// \) - literal character ')'
|
// \) - literal character ')'
|
||||||
@ -582,7 +585,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
|
|||||||
float round_bracket_multiplier = 1.1f;
|
float round_bracket_multiplier = 1.1f;
|
||||||
float square_bracket_multiplier = 1 / 1.1f;
|
float square_bracket_multiplier = 1 / 1.1f;
|
||||||
|
|
||||||
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
|
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:B]+|:|\bB)");
|
||||||
std::regex re_break(R"(\s*\bBREAK\b\s*)");
|
std::regex re_break(R"(\s*\bBREAK\b\s*)");
|
||||||
|
|
||||||
auto multiply_range = [&](int start_position, float multiplier) {
|
auto multiply_range = [&](int start_position, float multiplier) {
|
||||||
@ -591,7 +594,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::smatch m;
|
std::smatch m, m2;
|
||||||
std::string remaining_text = text;
|
std::string remaining_text = text;
|
||||||
|
|
||||||
while (std::regex_search(remaining_text, m, re_attention)) {
|
while (std::regex_search(remaining_text, m, re_attention)) {
|
||||||
@ -615,6 +618,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
|
|||||||
square_brackets.pop_back();
|
square_brackets.pop_back();
|
||||||
} else if (text == "\\(") {
|
} else if (text == "\\(") {
|
||||||
res.push_back({text.substr(1), 1.0f});
|
res.push_back({text.substr(1), 1.0f});
|
||||||
|
} else if (std::regex_search(text, m2, re_break)) {
|
||||||
|
res.push_back({"BREAK", -1.0f});
|
||||||
} else {
|
} else {
|
||||||
res.push_back({text, 1.0f});
|
res.push_back({text, 1.0f});
|
||||||
}
|
}
|
||||||
@ -645,4 +650,4 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
|
|||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user