feat: add detailed tensor loading time stat (#793)

This commit is contained in:
leejet 2025-09-07 22:51:44 +08:00 committed by GitHub
parent c587a43c99
commit c648001030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 6 deletions

View File

@ -1966,6 +1966,16 @@ std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& v
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
int64_t process_time_ms = 0;
int64_t read_time_ms = 0;
int64_t memcpy_time_ms = 0;
int64_t copy_to_backend_time_ms = 0;
int64_t convert_time_ms = 0;
int64_t prev_time_ms = 0;
int64_t curr_time_ms = 0;
int64_t start_time = ggml_time_ms();
prev_time_ms = start_time;
std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
@ -1978,6 +1988,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
}
std::vector<TensorStorage> dedup = remove_duplicates(processed_tensor_storages);
processed_tensor_storages = dedup;
curr_time_ms = ggml_time_ms();
process_time_ms = curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
@ -2019,15 +2032,27 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
size_t entry_size = zip_entry_size(zip);
if (entry_size != n) {
read_buffer.resize(entry_size);
prev_time_ms = ggml_time_ms();
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
curr_time_ms = ggml_time_ms();
memcpy_time_ms += curr_time_ms - prev_time_ms;
} else {
prev_time_ms = ggml_time_ms();
zip_entry_noallocread(zip, (void*)buf, n);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
}
zip_entry_close(zip);
} else {
prev_time_ms = ggml_time_ms();
file.seekg(tensor_storage.offset);
file.read(buf, n);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
if (!file) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
return false;
@ -2072,6 +2097,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
}
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
@ -2086,10 +2112,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
} else if (tensor_storage.is_i64) {
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements());
}
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
@ -2109,11 +2138,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
}
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
@ -2133,14 +2165,24 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
if (tensor_storage.type == dst_tensor->type) {
// copy to device memory
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
curr_time_ms = ggml_time_ms();
copy_to_backend_time_ms += curr_time_ms - prev_time_ms;
} else {
// convert first, then copy to device memory
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
(void*)convert_buffer.data(), dst_tensor->type,
(int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
curr_time_ms = ggml_time_ms();
copy_to_backend_time_ms += curr_time_ms - prev_time_ms;
}
}
++tensor_count;
@ -2170,6 +2212,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
break;
}
}
int64_t end_time = ggml_time_ms();
LOG_INFO("loading tensors completed, taking %.2fs (process: %.2fs, read: %.2fs, memcpy: %.2fs, convert: %.2fs, copy_to_backend: %.2fs)",
(end_time - start_time) / 1000.f,
process_time_ms / 1000.f,
read_time_ms / 1000.f,
memcpy_time_ms / 1000.f,
convert_time_ms / 1000.f,
copy_to_backend_time_ms / 1000.f);
return success;
}

View File

@ -557,8 +557,6 @@ public:
// load weights
LOG_DEBUG("loading weights");
int64_t t0 = ggml_time_ms();
std::set<std::string> ignore_tensors;
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
if (use_tiny_autoencoder) {
@ -656,11 +654,7 @@ public:
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
}
int64_t t1 = ggml_time_ms();
LOG_INFO("loading model from '%s' completed, taking %.2fs", SAFE_STR(sd_ctx_params->model_path), (t1 - t0) * 1.0f / 1000);
// check is_using_v_parameterization_for_sd2
if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;