#!/usr/bin/env python import argparse import json import math import os import struct from collections import Counter from pathlib import Path import torch from safetensors import safe_open FLOAT_DTYPES = { "BF16", "F16", "F32", "F64", "F8_E4M3", "F8_E4M3FN", "F8_E5M2", } FP8_DTYPES = { "F8_E4M3", "F8_E4M3FN", "F8_E5M2", } DTYPE_SIZES = { "BOOL": 1, "U8": 1, "I8": 1, "F8_E4M3": 1, "F8_E4M3FN": 1, "F8_E5M2": 1, "U16": 2, "I16": 2, "F16": 2, "BF16": 2, "U32": 4, "I32": 4, "F32": 4, "U64": 8, "I64": 8, "F64": 8, } def read_safetensors_header(path: Path): with path.open("rb") as f: header_len = struct.unpack(" 0 and scale.ndim == 1: if first_dim_end is not None and scale.shape[0] >= first_dim_end: scale = scale[first_dim_start:first_dim_end] if scale.shape[0] == chunk.shape[0]: return scale.reshape((scale.shape[0],) + (1,) * (chunk.ndim - 1)) return scale def write_scaled_fp8_weight(out, weight, scale, chunk_rows): if weight.ndim == 0: result = weight.to(torch.float32) * scale_view_for_chunk(scale, weight) write_tensor_bytes(out, result.to(torch.bfloat16)) return rows = weight.shape[0] for start in range(0, rows, chunk_rows): end = min(start + chunk_rows, rows) chunk = weight[start:end].to(torch.float32) scale_view = scale_view_for_chunk(scale, chunk, start, end) result = chunk * scale_view write_tensor_bytes(out, result.to(torch.bfloat16)) def write_float_as_bf16(out, tensor, chunk_rows): if tensor.dtype == torch.bfloat16: write_tensor_bytes(out, tensor) return if tensor.ndim == 0: write_tensor_bytes(out, tensor.to(torch.bfloat16)) return rows = tensor.shape[0] for start in range(0, rows, chunk_rows): end = min(start + chunk_rows, rows) write_tensor_bytes(out, tensor[start:end].to(torch.bfloat16)) def convert(input_path: Path, output_path: Path, chunk_rows: int, dry_run: bool): header = read_safetensors_header(input_path) plan, output_header, data_size = build_output_plan(header) source_counts = Counter(item["source_dtype"] for item in plan) output_counts = Counter(item["output_dtype"] for item in plan) scaled_count = sum(item["mode"] == "fp8_scaled_weight" for item in plan) dropped_scales = sum(item["mode"] == "fp8_scaled_weight" for item in plan) header_bytes = json.dumps(output_header, separators=(",", ":")).encode("utf-8") expected_size = 8 + len(header_bytes) + data_size print(f"input: {input_path}") print(f"output: {output_path}") print(f"tensors written: {len(plan)}") print(f"scaled fp8 weights dequantized: {scaled_count}") print(f"weight_scale tensors dropped: {dropped_scales}") print(f"source dtypes: {dict(sorted(source_counts.items()))}") print(f"output dtypes: {dict(sorted(output_counts.items()))}") print(f"expected output size: {expected_size / (1024 ** 3):.2f} GiB") if dry_run: return if output_path.exists(): raise FileExistsError(f"{output_path} already exists; pass --overwrite to replace it") tmp_path = output_path.with_suffix(output_path.suffix + ".tmp") if tmp_path.exists(): raise FileExistsError(f"{tmp_path} already exists; remove it or choose another output") with safe_open(str(input_path), framework="pt", device="cpu") as sf, tmp_path.open("wb") as out: out.write(struct.pack(" {item['output_dtype']}") tensor = sf.get_tensor(name) if item["mode"] == "fp8_scaled_weight": scale = sf.get_tensor(item["scale_key"]) write_scaled_fp8_weight(out, tensor, scale, chunk_rows) elif item["mode"] == "float_to_bf16": write_float_as_bf16(out, tensor, chunk_rows) else: write_tensor_bytes(out, tensor) actual_size = out.tell() if actual_size != expected_size: tmp_path.unlink(missing_ok=True) raise RuntimeError(f"wrote {actual_size} bytes, expected {expected_size} bytes") tmp_path.replace(output_path) print("done") def main(): parser = argparse.ArgumentParser( description="Convert an fp8 safetensors checkpoint with weight_scale tensors to bf16." ) parser.add_argument("--input", default="ideogram4_fp8.safetensors", type=Path) parser.add_argument("--output", default="ideogram4_bf16.safetensors", type=Path) parser.add_argument("--chunk-rows", default=1024, type=int) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--overwrite", action="store_true") args = parser.parse_args() input_path = args.input.resolve() output_path = args.output.resolve() if args.chunk_rows < 1: raise ValueError("--chunk-rows must be >= 1") if not input_path.exists(): raise FileNotFoundError(input_path) if args.overwrite and output_path.exists(): output_path.unlink() convert(input_path, output_path, args.chunk_rows, args.dry_run) if __name__ == "__main__": main()