save more memory

This commit is contained in:
leejet 2025-11-12 23:10:51 +08:00
parent 7195efa9a2
commit 4008102a93

View File

@ -525,19 +525,30 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* lora_mid = nullptr;
ggml_tensor* lora_down = nullptr;
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
auto iter = lora_tensors.find(lora_up_name);
if (iter != lora_tensors.end()) {
lora_up = iter->second;
if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
lora_up = ggml_cast(ctx, lora_up, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lora_mid_name);
if (iter != lora_tensors.end()) {
lora_mid = iter->second;
if (is_conv2d && lora_mid->type != GGML_TYPE_F16) {
lora_mid = ggml_cast(ctx, lora_mid, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lora_down_name);
if (iter != lora_tensors.end()) {
lora_down = iter->second;
if (is_conv2d && lora_down->type != GGML_TYPE_F16) {
lora_down = ggml_cast(ctx, lora_down, GGML_TYPE_F16);
}
}
if (lora_up == nullptr || lora_down == nullptr) {
@ -570,7 +581,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier;
ggml_tensor* lx;
if (forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_LINEAR) {
if (!is_conv2d) {
lx = ggml_ext_linear(ctx, x, lora_down, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
if (lora_mid) {
lx = ggml_ext_linear(ctx, lx, lora_mid, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);