feat: show tensor loading progress in MB/s or GB/s (#1380)

This commit is contained in:
leejet 2026-03-31 23:06:44 +08:00 committed by GitHub
parent 4fe7a35939
commit bf0216765a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 15 deletions

View File

@ -1311,6 +1311,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
std::atomic<int64_t> memcpy_time_ms(0);
std::atomic<int64_t> copy_to_backend_time_ms(0);
std::atomic<int64_t> convert_time_ms(0);
std::atomic<uint64_t> bytes_processed(0);
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : sd_get_num_physical_cores();
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
@ -1522,6 +1523,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
t1 = ggml_time_ms();
copy_to_backend_time_ms.fetch_add(t1 - t0);
}
bytes_processed.fetch_add((uint64_t)nbytes_to_read);
}
if (zip != nullptr) {
zip_close(zip);
@ -1535,7 +1538,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
break;
}
size_t curr_num = total_tensors_processed + current_idx;
pretty_progress(static_cast<int>(curr_num), static_cast<int>(total_tensors_to_process), (ggml_time_ms() - t_start) / 1000.0f / (curr_num + 1e-6f));
float elapsed_seconds = (ggml_time_ms() - t_start) / 1000.0f;
pretty_bytes_progress(static_cast<int>(curr_num),
static_cast<int>(total_tensors_to_process),
bytes_processed.load(),
elapsed_seconds);
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
@ -1548,7 +1555,10 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
break;
}
total_tensors_processed += file_tensors.size();
pretty_progress(static_cast<int>(total_tensors_processed), static_cast<int>(total_tensors_to_process), (ggml_time_ms() - t_start) / 1000.0f / (total_tensors_processed + 1e-6f));
pretty_bytes_progress(static_cast<int>(total_tensors_processed),
static_cast<int>(total_tensors_to_process),
bytes_processed.load(),
(ggml_time_ms() - t_start) / 1000.0f);
if (total_tensors_processed < total_tensors_to_process) {
printf("\n");
}

View File

@ -337,17 +337,13 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
return result;
}
void pretty_progress(int step, int steps, float time) {
if (sd_progress_cb) {
sd_progress_cb(step, steps, time, sd_progress_cb_data);
return;
}
if (step == 0) {
return;
}
static std::string build_progress_bar(int step, int steps) {
std::string progress = " |";
int max_progress = 50;
int32_t current = (int32_t)(step * 1.f * max_progress / steps);
int32_t current = 0;
if (steps > 0) {
current = (int32_t)(step * 1.f * max_progress / steps);
}
for (int i = 0; i < 50; i++) {
if (i > current) {
progress += " ";
@ -358,16 +354,57 @@ void pretty_progress(int step, int steps, float time) {
}
}
progress += "|";
return progress;
}
static void print_progress_line(int step, int steps, const std::string& speed_text) {
if (step == 0) {
return;
}
std::string progress = build_progress_bar(step, steps);
const char* lf = (step == steps ? "\n" : "");
printf("\r%s %i/%i - %s\033[K%s", progress.c_str(), step, steps, speed_text.c_str(), lf);
fflush(stdout); // for linux
}
void pretty_progress(int step, int steps, float time) {
if (sd_progress_cb) {
sd_progress_cb(step, steps, time, sd_progress_cb_data);
return;
}
if (step == 0) {
return;
}
const char* unit = "s/it";
float speed = time;
if (speed < 1.0f && speed > 0.f) {
speed = 1.0f / speed;
unit = "it/s";
}
printf("\r%s %i/%i - %.2f%s\033[K%s", progress.c_str(), step, steps, speed, unit, lf);
fflush(stdout); // for linux
print_progress_line(step, steps, sd_format("%.2f%s", speed, unit));
}
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds) {
if (sd_progress_cb) {
float time = elapsed_seconds / (step + 1e-6f);
sd_progress_cb(step, steps, time, sd_progress_cb_data);
return;
}
if (step == 0) {
return;
}
double bytes_per_second = 0.0;
if (elapsed_seconds > 0.0f) {
bytes_per_second = bytes_processed / (double)elapsed_seconds;
}
double speed_mb = bytes_per_second / (1024.0 * 1024.0);
if (speed_mb >= 1024.0) {
print_progress_line(step, steps, sd_format("%.2fGB/s", speed_mb / 1024.0));
} else {
print_progress_line(step, steps, sd_format("%.2fMB/s", speed_mb));
}
}
std::string ltrim(const std::string& s) {

View File

@ -64,6 +64,7 @@ protected:
std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter);
void pretty_progress(int step, int steps, float time);
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);