stable-diffusion.cpp/script/convert_fp8_scale_to_bf16.py
2026-06-06 16:34:16 +08:00

284 lines
8.7 KiB
Python

#!/usr/bin/env python
import argparse
import json
import math
import os
import struct
from collections import Counter
from pathlib import Path
import torch
from safetensors import safe_open
FLOAT_DTYPES = {
"BF16",
"F16",
"F32",
"F64",
"F8_E4M3",
"F8_E4M3FN",
"F8_E5M2",
}
FP8_DTYPES = {
"F8_E4M3",
"F8_E4M3FN",
"F8_E5M2",
}
DTYPE_SIZES = {
"BOOL": 1,
"U8": 1,
"I8": 1,
"F8_E4M3": 1,
"F8_E4M3FN": 1,
"F8_E5M2": 1,
"U16": 2,
"I16": 2,
"F16": 2,
"BF16": 2,
"U32": 4,
"I32": 4,
"F32": 4,
"U64": 8,
"I64": 8,
"F64": 8,
}
def read_safetensors_header(path: Path):
with path.open("rb") as f:
header_len = struct.unpack("<Q", f.read(8))[0]
header = f.read(header_len).decode("utf-8").rstrip()
return json.loads(header)
def numel(shape):
return math.prod(shape) if shape else 1
def scale_key_for_weight(name: str):
if name.endswith(".weight"):
return name[:-len(".weight")] + ".weight_scale"
if name.endswith("weight"):
return name + "_scale"
return None
def tensor_nbytes(dtype: str, shape):
return numel(shape) * DTYPE_SIZES[dtype]
def build_output_plan(header):
entries = {k: v for k, v in header.items() if k != "__metadata__"}
paired_scale_keys = set()
plan = []
for name, info in entries.items():
scale_key = scale_key_for_weight(name)
if info["dtype"] in FP8_DTYPES and scale_key in entries:
paired_scale_keys.add(scale_key)
for name, info in entries.items():
if name in paired_scale_keys:
continue
dtype = info["dtype"]
shape = info["shape"]
scale_key = scale_key_for_weight(name)
if dtype in FP8_DTYPES and scale_key in entries:
scale_info = entries[scale_key]
plan.append(
{
"name": name,
"source_dtype": dtype,
"output_dtype": "BF16",
"shape": shape,
"mode": "fp8_scaled_weight",
"scale_key": scale_key,
}
)
continue
if dtype in FLOAT_DTYPES:
plan.append(
{
"name": name,
"source_dtype": dtype,
"output_dtype": "BF16",
"shape": shape,
"mode": "float_to_bf16",
}
)
else:
plan.append(
{
"name": name,
"source_dtype": dtype,
"output_dtype": dtype,
"shape": shape,
"mode": "copy",
}
)
metadata = dict(header.get("__metadata__", {}) or {})
metadata["format"] = "pt"
metadata["conversion"] = "fp8_weight_scale_to_bf16"
output_header = {"__metadata__": metadata}
offset = 0
for item in plan:
size = tensor_nbytes(item["output_dtype"], item["shape"])
output_header[item["name"]] = {
"dtype": item["output_dtype"],
"shape": item["shape"],
"data_offsets": [offset, offset + size],
}
offset += size
return plan, output_header, offset
def write_tensor_bytes(out, tensor):
tensor = tensor.detach().cpu().contiguous()
if tensor.numel() == 0:
return
if tensor.dtype == torch.bfloat16:
tensor.view(torch.uint16).numpy().tofile(out)
elif tensor.dtype in (getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e5m2", None)):
tensor.view(torch.uint8).numpy().tofile(out)
else:
tensor.numpy().tofile(out)
def scale_view_for_chunk(scale, chunk, first_dim_start=0, first_dim_end=None):
scale = scale.to(torch.float32)
if scale.numel() == 1:
return scale.reshape((1,) * chunk.ndim)
if chunk.ndim > 0 and scale.ndim == 1:
if first_dim_end is not None and scale.shape[0] >= first_dim_end:
scale = scale[first_dim_start:first_dim_end]
if scale.shape[0] == chunk.shape[0]:
return scale.reshape((scale.shape[0],) + (1,) * (chunk.ndim - 1))
return scale
def write_scaled_fp8_weight(out, weight, scale, chunk_rows):
if weight.ndim == 0:
result = weight.to(torch.float32) * scale_view_for_chunk(scale, weight)
write_tensor_bytes(out, result.to(torch.bfloat16))
return
rows = weight.shape[0]
for start in range(0, rows, chunk_rows):
end = min(start + chunk_rows, rows)
chunk = weight[start:end].to(torch.float32)
scale_view = scale_view_for_chunk(scale, chunk, start, end)
result = chunk * scale_view
write_tensor_bytes(out, result.to(torch.bfloat16))
def write_float_as_bf16(out, tensor, chunk_rows):
if tensor.dtype == torch.bfloat16:
write_tensor_bytes(out, tensor)
return
if tensor.ndim == 0:
write_tensor_bytes(out, tensor.to(torch.bfloat16))
return
rows = tensor.shape[0]
for start in range(0, rows, chunk_rows):
end = min(start + chunk_rows, rows)
write_tensor_bytes(out, tensor[start:end].to(torch.bfloat16))
def convert(input_path: Path, output_path: Path, chunk_rows: int, dry_run: bool):
header = read_safetensors_header(input_path)
plan, output_header, data_size = build_output_plan(header)
source_counts = Counter(item["source_dtype"] for item in plan)
output_counts = Counter(item["output_dtype"] for item in plan)
scaled_count = sum(item["mode"] == "fp8_scaled_weight" for item in plan)
dropped_scales = sum(item["mode"] == "fp8_scaled_weight" for item in plan)
header_bytes = json.dumps(output_header, separators=(",", ":")).encode("utf-8")
expected_size = 8 + len(header_bytes) + data_size
print(f"input: {input_path}")
print(f"output: {output_path}")
print(f"tensors written: {len(plan)}")
print(f"scaled fp8 weights dequantized: {scaled_count}")
print(f"weight_scale tensors dropped: {dropped_scales}")
print(f"source dtypes: {dict(sorted(source_counts.items()))}")
print(f"output dtypes: {dict(sorted(output_counts.items()))}")
print(f"expected output size: {expected_size / (1024 ** 3):.2f} GiB")
if dry_run:
return
if output_path.exists():
raise FileExistsError(f"{output_path} already exists; pass --overwrite to replace it")
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
if tmp_path.exists():
raise FileExistsError(f"{tmp_path} already exists; remove it or choose another output")
with safe_open(str(input_path), framework="pt", device="cpu") as sf, tmp_path.open("wb") as out:
out.write(struct.pack("<Q", len(header_bytes)))
out.write(header_bytes)
for index, item in enumerate(plan, 1):
name = item["name"]
print(f"[{index:04d}/{len(plan):04d}] {name} -> {item['output_dtype']}")
tensor = sf.get_tensor(name)
if item["mode"] == "fp8_scaled_weight":
scale = sf.get_tensor(item["scale_key"])
write_scaled_fp8_weight(out, tensor, scale, chunk_rows)
elif item["mode"] == "float_to_bf16":
write_float_as_bf16(out, tensor, chunk_rows)
else:
write_tensor_bytes(out, tensor)
actual_size = out.tell()
if actual_size != expected_size:
tmp_path.unlink(missing_ok=True)
raise RuntimeError(f"wrote {actual_size} bytes, expected {expected_size} bytes")
tmp_path.replace(output_path)
print("done")
def main():
parser = argparse.ArgumentParser(
description="Convert an fp8 safetensors checkpoint with weight_scale tensors to bf16."
)
parser.add_argument("--input", default="ideogram4_fp8.safetensors", type=Path)
parser.add_argument("--output", default="ideogram4_bf16.safetensors", type=Path)
parser.add_argument("--chunk-rows", default=1024, type=int)
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
input_path = args.input.resolve()
output_path = args.output.resolve()
if args.chunk_rows < 1:
raise ValueError("--chunk-rows must be >= 1")
if not input_path.exists():
raise FileNotFoundError(input_path)
if args.overwrite and output_path.exists():
output_path.unlink()
convert(input_path, output_path, args.chunk_rows, args.dry_run)
if __name__ == "__main__":
main()