#!/usr/bin/env python3 """Convert a SeFi-Image diffusers checkpoint into a single sd.cpp-compatible safetensors. Operates on raw safetensors bytes so any dtype (BF16, F32, ...) is preserved exactly. No numpy or torch dependency required. Usage: python3 script/convert_sefi.py """ import argparse import json import os import re import struct import sys _LINEAR_TO_LIN = re.compile(r"\.linear\.") _SHARED_MOD_PREFIXES = ( "double_stream_modulation_img", "double_stream_modulation_txt", "single_stream_modulation", ) def rewrite_transformer_key(key: str) -> str: if key.startswith("backbone."): key = key[len("backbone."):] elif key.startswith("dual_time_embed."): return key if any(key.startswith(prefix + ".") for prefix in _SHARED_MOD_PREFIXES): key = _LINEAR_TO_LIN.sub(".lin.", key, count=1) if key == "context_embedder.weight": return "txt_in.weight" if key == "context_embedder.bias": return "txt_in.bias" if key == "x_embedder.weight": return "img_in.weight" if key == "x_embedder.bias": return "img_in.bias" if key == "proj_out.weight": return "final_layer.linear.weight" if key == "proj_out.bias": return "final_layer.linear.bias" if key == "norm_out.linear.weight": return "final_layer.adaLN_modulation.1.weight" if key == "norm_out.linear.bias": return "final_layer.adaLN_modulation.1.bias" m = re.match(r"transformer_blocks\.(\d+)\.(.*)$", key) if m: return _rewrite_double_stream(m.group(1), m.group(2)) m = re.match(r"single_transformer_blocks\.(\d+)\.(.*)$", key) if m: return _rewrite_single_stream(m.group(1), m.group(2)) return key def _rewrite_double_stream(idx: str, tail: str) -> str: dst = f"double_blocks.{idx}." mapping = { "norm1.linear.weight": "img_mod.lin.weight", "norm1_context.linear.weight": "txt_mod.lin.weight", "attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", "attn.to_out.0.weight": "img_attn.proj.weight", "attn.to_add_out.weight": "txt_attn.proj.weight", "ff.net.0.proj.weight": "img_mlp.0.weight", "ff.net.2.weight": "img_mlp.2.weight", "ff_context.net.0.proj.weight": "txt_mlp.0.weight", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff.linear_in.weight": "img_mlp.0.weight", "ff.linear_out.weight": "img_mlp.2.weight", "ff_context.linear_in.weight": "txt_mlp.0.weight", "ff_context.linear_out.weight": "txt_mlp.2.weight", } return dst + mapping.get(tail, tail) # QKV triplets to fuse on output: source tails -> target fused tail. # Each tuple is (q_tail, k_tail, v_tail, fused_target_tail). QKV_DOUBLE_TRIPLETS = [ ("attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "img_attn.qkv.weight"), ("attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight", "txt_attn.qkv.weight"), ] def _rewrite_single_stream(idx: str, tail: str) -> str: dst = f"single_blocks.{idx}." mapping = { "norm.linear.weight": "modulation.lin.weight", "attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale", "attn.to_qkv_mlp_proj.weight": "linear1.weight", "attn.to_out.weight": "linear2.weight", } return dst + mapping.get(tail, tail) def read_safetensors_header(path: str): """Return (header dict, data start byte offset).""" with open(path, "rb") as f: hdr_len = struct.unpack(" {tail: (key, shard_path, data_off, info)} raw_others = [] for shard_path in collect_shard_paths(section_dir, "diffusion_pytorch_model"): hdr, data_off = read_safetensors_header(shard_path) for key, info in hdr.items(): if key == "__metadata__": continue m = re.match(r"backbone\.transformer_blocks\.(\d+)\.(.*)$", key) if m and any(m.group(2) in trip[:3] for trip in QKV_DOUBLE_TRIPLETS): idx = m.group(1) raw_by_block.setdefault(idx, {})[m.group(2)] = (key, shard_path, data_off, info) else: raw_others.append((key, shard_path, data_off, info)) for key, shard_path, data_off, info in raw_others: new_key = rewrite_fn(key) # Swap the (scale, shift) halves to (shift, scale) at conversion time so # the on-disk weight matches BFL flux ordering and the runtime stays # version-agnostic. norm_out.linear weight shape is [2*dim, dim] and bias # is [2*dim]; both split along axis 0 (outermost == row-major outer). if new_key in ("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias"): info = dict(info) info["_chunk_swap_halves"] = True entries.append((new_key, shard_path, data_off, info)) for block_idx, tails in raw_by_block.items(): for q_tail, k_tail, v_tail, fused_tail in QKV_DOUBLE_TRIPLETS: if q_tail in tails and k_tail in tails and v_tail in tails: q = tails[q_tail]; k = tails[k_tail]; v = tails[v_tail] # Validate shapes match. q_shape = q[3]["shape"]; k_shape = k[3]["shape"]; v_shape = v[3]["shape"] if q_shape != k_shape or q_shape != v_shape: raise ValueError(f"qkv shape mismatch at block {block_idx} {q_tail}: q={q_shape} k={k_shape} v={v_shape}") fused_shape = [q_shape[0] * 3] + list(q_shape[1:]) fused_info = { "dtype": q[3]["dtype"], "shape": fused_shape, "_qkv_sources": [q, k, v], # pseudo field consumed by writer } entries.append((f"double_blocks.{block_idx}.{fused_tail}", None, None, fused_info)) del tails[q_tail]; del tails[k_tail]; del tails[v_tail] # Anything left in tails was an unmatched single - pass through. for tail, payload in tails.items(): entries.append((rewrite_fn(payload[0]),) + payload[1:]) return entries _DTYPE_BYTES = { "BF16": 2, "F16": 2, "F32": 4, "F64": 8, "U8": 1, "I8": 1, "I16": 2, "I32": 4, "I64": 8, "BOOL": 1, } def _total_bytes(info: dict) -> int: if "_qkv_sources" in info: elems = 1 for d in info["shape"]: elems *= d return elems * _DTYPE_BYTES[info["dtype"]] start, end = info["data_offsets"] return end - start def write_consolidated(out_path: str, entries): """Write a single safetensors file by streaming raw bytes from each shard. For qkv-fused entries, q/k/v are concatenated along axis 0 (row-major), so a simple byte-level concatenation produces the correct fused layout for any standard dtype. """ entries = sorted(entries, key=lambda e: e[0]) new_header = {} cur_offset = 0 for new_key, shard_path, data_off, info in entries: size = _total_bytes(info) new_header[new_key] = { "dtype": info["dtype"], "shape": info["shape"], "data_offsets": [cur_offset, cur_offset + size], } cur_offset += size header_json = json.dumps(new_header, separators=(",", ":")).encode("utf-8") pad = (-len(header_json)) % 8 header_json = header_json + (b" " * pad) def copy_range(src_path, src_data_off, src_info, out, byte_range=None): start, end = src_info["data_offsets"] if byte_range is not None: sub_start, sub_end = byte_range start, end = start + sub_start, start + sub_end with open(src_path, "rb") as src: src.seek(src_data_off + start) remaining = end - start while remaining > 0: chunk = src.read(min(8 * 1024 * 1024, remaining)) if not chunk: raise IOError(f"Truncated tensor in {src_path}") out.write(chunk) remaining -= len(chunk) with open(out_path, "wb") as out: out.write(struct.pack(" {args.output}") os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) write_consolidated(args.output, transformer_entries) print(f"Done. Output size: {os.path.getsize(args.output) / 1e9:.2f} GB") if __name__ == "__main__": main()