mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add inplace conversion support for f8_e4m3 (#359)
in the same way it is done for bf16 like how bf16 converts losslessly to fp32, f8_e4m3 converts losslessly to fp16
This commit is contained in:
parent
29ec31644a
commit
d8c65b4436
66
model.cpp
66
model.cpp
@ -554,6 +554,48 @@ float bf16_to_f32(uint16_t bfloat16) {
|
|||||||
return *reinterpret_cast<float*>(&val_bits);
|
return *reinterpret_cast<float*>(&val_bits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint16_t f8_e4m3_to_f16(uint8_t f8) {
|
||||||
|
// do we need to support uz?
|
||||||
|
|
||||||
|
const uint32_t exponent_bias = 7;
|
||||||
|
if (f8 == 0xff) {
|
||||||
|
return ggml_fp32_to_fp16(-NAN);
|
||||||
|
} else if (f8 == 0x7f) {
|
||||||
|
return ggml_fp32_to_fp16(NAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t sign = f8 & 0x80;
|
||||||
|
uint32_t exponent = (f8 & 0x78) >> 3;
|
||||||
|
uint32_t mantissa = f8 & 0x07;
|
||||||
|
uint32_t result = sign << 24;
|
||||||
|
if (exponent == 0) {
|
||||||
|
if (mantissa > 0) {
|
||||||
|
exponent = 0x7f - exponent_bias;
|
||||||
|
|
||||||
|
// yes, 2 times
|
||||||
|
if ((mantissa & 0x04) == 0) {
|
||||||
|
mantissa &= 0x03;
|
||||||
|
mantissa <<= 1;
|
||||||
|
exponent -= 1;
|
||||||
|
}
|
||||||
|
if ((mantissa & 0x04) == 0) {
|
||||||
|
mantissa &= 0x03;
|
||||||
|
mantissa <<= 1;
|
||||||
|
exponent -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
result |= (mantissa & 0x03) << 21;
|
||||||
|
result |= exponent << 23;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result |= mantissa << 20;
|
||||||
|
exponent += 0x7f - exponent_bias;
|
||||||
|
result |= exponent << 23;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
|
||||||
|
}
|
||||||
|
|
||||||
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
|
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
|
||||||
// support inplace op
|
// support inplace op
|
||||||
for (int64_t i = n - 1; i >= 0; i--) {
|
for (int64_t i = n - 1; i >= 0; i--) {
|
||||||
@ -561,6 +603,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
|
||||||
|
// support inplace op
|
||||||
|
for (int64_t i = n - 1; i >= 0; i--) {
|
||||||
|
dst[i] = f8_e4m3_to_f16(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void convert_tensor(void* src,
|
void convert_tensor(void* src,
|
||||||
ggml_type src_type,
|
ggml_type src_type,
|
||||||
void* dst,
|
void* dst,
|
||||||
@ -794,6 +843,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
|||||||
ttype = GGML_TYPE_F32;
|
ttype = GGML_TYPE_F32;
|
||||||
} else if (dtype == "F32") {
|
} else if (dtype == "F32") {
|
||||||
ttype = GGML_TYPE_F32;
|
ttype = GGML_TYPE_F32;
|
||||||
|
} else if (dtype == "F8_E4M3") {
|
||||||
|
ttype = GGML_TYPE_F16;
|
||||||
}
|
}
|
||||||
return ttype;
|
return ttype;
|
||||||
}
|
}
|
||||||
@ -866,7 +917,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
|
|
||||||
ggml_type type = str_to_ggml_type(dtype);
|
ggml_type type = str_to_ggml_type(dtype);
|
||||||
if (type == GGML_TYPE_COUNT) {
|
if (type == GGML_TYPE_COUNT) {
|
||||||
LOG_ERROR("unsupported dtype '%s'", dtype.c_str());
|
LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -903,6 +954,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
if (dtype == "BF16") {
|
if (dtype == "BF16") {
|
||||||
tensor_storage.is_bf16 = true;
|
tensor_storage.is_bf16 = true;
|
||||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||||
|
} else if (dtype == "F8_E4M3") {
|
||||||
|
tensor_storage.is_f8_e4m3 = true;
|
||||||
|
// f8 -> f16
|
||||||
|
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
|
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
|
||||||
}
|
}
|
||||||
@ -1537,6 +1592,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
|||||||
if (tensor_storage.is_bf16) {
|
if (tensor_storage.is_bf16) {
|
||||||
// inplace op
|
// inplace op
|
||||||
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
|
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
|
||||||
|
} else if (tensor_storage.is_f8_e4m3) {
|
||||||
|
// inplace op
|
||||||
|
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
read_buffer.resize(tensor_storage.nbytes());
|
read_buffer.resize(tensor_storage.nbytes());
|
||||||
@ -1545,6 +1603,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
|||||||
if (tensor_storage.is_bf16) {
|
if (tensor_storage.is_bf16) {
|
||||||
// inplace op
|
// inplace op
|
||||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||||
|
} else if (tensor_storage.is_f8_e4m3) {
|
||||||
|
// inplace op
|
||||||
|
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||||
}
|
}
|
||||||
|
|
||||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
|
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
|
||||||
@ -1557,6 +1618,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
|||||||
if (tensor_storage.is_bf16) {
|
if (tensor_storage.is_bf16) {
|
||||||
// inplace op
|
// inplace op
|
||||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||||
|
} else if (tensor_storage.is_f8_e4m3) {
|
||||||
|
// inplace op
|
||||||
|
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor_storage.type == dst_tensor->type) {
|
if (tensor_storage.type == dst_tensor->type) {
|
||||||
|
|||||||
7
model.h
7
model.h
@ -32,6 +32,7 @@ struct TensorStorage {
|
|||||||
std::string name;
|
std::string name;
|
||||||
ggml_type type = GGML_TYPE_F32;
|
ggml_type type = GGML_TYPE_F32;
|
||||||
bool is_bf16 = false;
|
bool is_bf16 = false;
|
||||||
|
bool is_f8_e4m3 = false;
|
||||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||||
int n_dims = 0;
|
int n_dims = 0;
|
||||||
|
|
||||||
@ -61,7 +62,7 @@ struct TensorStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t nbytes_to_read() const {
|
int64_t nbytes_to_read() const {
|
||||||
if (is_bf16) {
|
if (is_bf16 || is_f8_e4m3) {
|
||||||
return nbytes() / 2;
|
return nbytes() / 2;
|
||||||
} else {
|
} else {
|
||||||
return nbytes();
|
return nbytes();
|
||||||
@ -109,6 +110,8 @@ struct TensorStorage {
|
|||||||
const char* type_name = ggml_type_name(type);
|
const char* type_name = ggml_type_name(type);
|
||||||
if (is_bf16) {
|
if (is_bf16) {
|
||||||
type_name = "bf16";
|
type_name = "bf16";
|
||||||
|
} else if (is_f8_e4m3) {
|
||||||
|
type_name = "f8_e4m3";
|
||||||
}
|
}
|
||||||
ss << name << " | " << type_name << " | ";
|
ss << name << " | " << type_name << " | ";
|
||||||
ss << n_dims << " [";
|
ss << n_dims << " [";
|
||||||
@ -160,4 +163,6 @@ public:
|
|||||||
static std::string load_merges();
|
static std::string load_merges();
|
||||||
static std::string load_t5_tokenizer_json();
|
static std::string load_t5_tokenizer_json();
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // __MODEL_H__
|
#endif // __MODEL_H__
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user