#include "pickle_io.h" #include #include #include #include #include #include #include "binary_io.h" #include "util.h" // $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 // 0: \x80 PROTO 2 // 2: } EMPTY_DICT // 3: q BINPUT 0 // 5: ( MARK // 6: X BINUNICODE 'epoch' // 16: q BINPUT 1 // 18: K BININT1 6 // 20: X BINUNICODE 'global_step' // 36: q BINPUT 2 // 38: J BININT 470000 // 43: X BINUNICODE 'pytorch-lightning_version' // 73: q BINPUT 3 // 75: X BINUNICODE '1.4.2' // 85: q BINPUT 4 // 87: X BINUNICODE 'state_dict' // 102: q BINPUT 5 // 104: } EMPTY_DICT // 105: q BINPUT 6 // 107: ( MARK // 108: X BINUNICODE 'betas' // 118: q BINPUT 7 // 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' // 153: q BINPUT 8 // 155: ( MARK // 156: ( MARK // 157: X BINUNICODE 'storage' // 169: q BINPUT 9 // 171: c GLOBAL 'torch FloatStorage' // 191: q BINPUT 10 // 193: X BINUNICODE '0' // 199: q BINPUT 11 // 201: X BINUNICODE 'cpu' // 209: q BINPUT 12 // 211: M BININT2 1000 // 214: t TUPLE (MARK at 156) // 215: q BINPUT 13 // 217: Q BINPERSID // 218: K BININT1 0 // 220: M BININT2 1000 // ............................... // 3201: q BINPUT 250 // 3203: R REDUCE // 3204: q BINPUT 251 // 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' // 3264: q BINPUT 252 // 3266: h BINGET 8 // 3268: ( MARK // 3269: ( MARK // 3270: h BINGET 9 // 3272: h BINGET 10 // 3274: X BINUNICODE '30' // 3281: q BINPUT 253 // 3283: h BINGET 12 // 3285: J BININT 102400 // 3290: t TUPLE (MARK at 3269) // 3291: q BINPUT 254 // 3293: Q BINPERSID // 3294: K BININT1 0 // 3296: ( MARK // 3297: M BININT2 320 // 3300: M BININT2 320 // 3303: K BININT1 1 // 3305: K BININT1 1 // 3307: t TUPLE (MARK at 3296) // 3308: q BINPUT 255 // 3310: ( MARK // 3311: M BININT2 320 // 3314: K BININT1 1 // 3316: K BININT1 1 // 3318: K BININT1 1 // 3320: t TUPLE (MARK at 3310) // 3321: r LONG_BINPUT 256 // 3326: \x89 NEWFALSE // 3327: h BINGET 16 // 3329: ) EMPTY_TUPLE // 3330: R REDUCE // 3331: r LONG_BINPUT 257 // 3336: t TUPLE (MARK at 3268) // 3337: r LONG_BINPUT 258 // 3342: R REDUCE // 3343: r LONG_BINPUT 259 // 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' // 3404: r LONG_BINPUT 260 // 3409: h BINGET 8 // 3411: ( MARK // 3412: ( MARK // 3413: h BINGET 9 // 3415: h BINGET 10 // 3417: X BINUNICODE '31' // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 using model_io::find_char; using model_io::read_int; using model_io::read_short; using model_io::read_u64; static void set_error(std::string* error, const std::string& message) { if (error != nullptr) { *error = message; } } bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) { const uint8_t* p = buffer; const uint8_t* end = buffer + buffer_size; while (p < end) { uint8_t opcode = *p++; switch (opcode) { case '.': // STOP = b'.' # every pickle ends with STOP *object_size = (size_t)(p - buffer); return true; case 0x80: // PROTO = b'\x80' # protocol version indicator case 'K': // BININT1 = b'K' # push 1-byte unsigned int case 'h': // BINGET = b'h' # read memo index, 1-byte arg case 'q': // BINPUT = b'q' # write memo index, 1-byte arg case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256 case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg p += 1; break; case 'M': // BININT2 = b'M' # push 2-byte unsigned int case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg p += 2; break; case 'J': // BININT = b'J' # push 4-byte signed int case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg p += 4; break; case 'I': // INT = b'I' # push decimal integer line case 'L': // LONG = b'L' # push decimal long integer line case 'F': // FLOAT = b'F' # push decimal float line case 'S': // STRING = b'S' # push quoted string line case 'V': { // UNICODE = b'V' # push raw-unicode string line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; } break; case 'G': // BINFLOAT = b'G' # push 8-byte binary float p += 8; break; case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length if (p >= end) { return false; } p += 1 + p[0]; break; case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length if (p + 4 > end) { return false; } uint32_t n = (uint32_t)read_int(p); p += 4 + n; } break; case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length if (p + 4 > end) { return false; } uint32_t n = (uint32_t)read_int(p); p += 4 + n; } break; case 'T': // BINSTRING = b'T' # push string; 4-byte length case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length if (p + 4 > end) { return false; } uint32_t n = (uint32_t)read_int(p); p += 4 + n; } break; case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length if (p + 8 > end) { return false; } uint64_t n = read_u64(p); p += 8; if (n > (uint64_t)(end - p)) { return false; } p += n; } break; case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 if (p >= end) { return false; } p += 1 + p[0]; break; case 'P': { // PERSID = b'P' # persistent id, newline-terminated int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; } break; case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame p += 8; break; case 'c': { // GLOBAL = b'c' # push module/name global reference int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; } break; case '}': // EMPTY_DICT = b'}' # push empty dict case ']': // EMPTY_LIST = b']' # push empty list case '(': // MARK = b'(' # push markobject case 't': // TUPLE = b't' # build tuple from mark case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack case ')': // EMPTY_TUPLE = b')' # push empty tuple case 'l': // LIST = b'l' # build list from mark case 'Q': // BINPERSID = b'Q' # persistent id from stack case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo case 0x88: // NEWTRUE = b'\x88' # push True case 0x89: // NEWFALSE = b'\x89' # push False case 'R': // REDUCE = b'R' # apply callable to args case 'u': // SETITEMS = b'u' # add mark-delimited items to dict case 's': // SETITEM = b's' # add key/value to dict case 'e': // APPENDS = b'e' # extend list with mark-delimited items case 'a': // APPEND = b'a' # append item to list case 'b': // BUILD = b'b' # build object state case 0x81: // NEWOBJ = b'\x81' # build object via __new__ case 0x8F: // EMPTY_SET = b'\x8f' # push empty set case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set case 0x91: // FROZENSET = b'\x91' # build frozenset from mark case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly case 'N': // NONE = b'N' # push None case '0': // POP = b'0' # discard top stack item case '1': // POP_MARK = b'1' # discard stack through topmost mark case '2': // DUP = b'2' # duplicate top stack item case 'o': // OBJ = b'o' # build class instance from mark break; case 'i': { // INST = b'i' # build class instance from module/name int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; } break; default: return false; } if (p > end) { return false; } } return false; } bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) { static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19}; if (buffer_size < 5 || buffer[0] != 0x80) { return false; } size_t pos = 2; if (pos >= buffer_size) { return false; } uint8_t opcode = buffer[pos++]; if (opcode != 0x8A || pos >= buffer_size) { return false; } uint8_t len = buffer[pos++]; if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) { return false; } if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) { return false; } pos += len; return pos < buffer_size && buffer[pos] == '.'; } bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) { if (buffer_size < 4 || buffer[0] != 0x80) { return false; } size_t pos = 2; if (pos >= buffer_size) { return false; } uint8_t opcode = buffer[pos++]; switch (opcode) { case 'K': // BININT1 = b'K' # push 1-byte unsigned int if (pos + 1 >= buffer_size) { return false; } *value = buffer[pos]; pos += 1; break; case 'M': // BININT2 = b'M' # push 2-byte unsigned int if (pos + 2 >= buffer_size) { return false; } *value = read_short(buffer + pos); pos += 2; break; case 'J': // BININT = b'J' # push 4-byte signed int if (pos + 4 >= buffer_size) { return false; } *value = (uint32_t)read_int(buffer + pos); pos += 4; break; default: return false; } return pos < buffer_size && buffer[pos] == '.'; } struct PickleStorageInfo { std::string key; ggml_type type = GGML_TYPE_COUNT; bool is_f64 = false; bool is_i64 = false; uint64_t raw_element_nbytes = 0; uint64_t nbytes = 0; }; struct PickleTensorInfo { TensorStorage tensor_storage; int stride_n_dims = 0; int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1}; }; struct PickleValue { enum Kind { MARK, NONE, BOOL, INT, STRING, GLOBAL, TUPLE, LIST, DICT, ORDERED_DICT, STORAGE, TENSOR, }; Kind kind = NONE; int64_t int_value = 0; bool bool_value = false; std::string str_value; std::vector items; std::vector> dict_items; PickleStorageInfo storage; PickleTensorInfo tensor; }; static PickleValue make_mark_value() { PickleValue value; value.kind = PickleValue::MARK; return value; } static PickleValue make_none_value() { PickleValue value; value.kind = PickleValue::NONE; return value; } static PickleValue make_bool_value(bool b) { PickleValue value; value.kind = PickleValue::BOOL; value.bool_value = b; return value; } static PickleValue make_int_value(int64_t x) { PickleValue value; value.kind = PickleValue::INT; value.int_value = x; return value; } static PickleValue make_string_value(const std::string& s) { PickleValue value; value.kind = PickleValue::STRING; value.str_value = s; return value; } static PickleValue make_global_value(const std::string& s) { PickleValue value; value.kind = PickleValue::GLOBAL; value.str_value = s; return value; } static PickleValue make_tuple_value(std::vector items) { PickleValue value; value.kind = PickleValue::TUPLE; value.items = std::move(items); return value; } static PickleValue make_list_value() { PickleValue value; value.kind = PickleValue::LIST; return value; } static PickleValue make_dict_value(bool ordered) { PickleValue value; value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT; return value; } static PickleValue make_storage_value(const PickleStorageInfo& storage) { PickleValue value; value.kind = PickleValue::STORAGE; value.storage = storage; return value; } static PickleValue make_tensor_value(const PickleTensorInfo& tensor) { PickleValue value; value.kind = PickleValue::TENSOR; value.tensor = tensor; return value; } static std::string pickle_value_to_string(const PickleValue& value) { if (value.kind == PickleValue::STRING) { return value.str_value; } if (value.kind == PickleValue::INT) { return std::to_string(value.int_value); } return ""; } static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) { if (global_name == "torch.FloatStorage") { storage->type = GGML_TYPE_F32; storage->raw_element_nbytes = 4; return true; } if (global_name == "torch.DoubleStorage") { storage->type = GGML_TYPE_F32; storage->is_f64 = true; storage->raw_element_nbytes = 8; return true; } if (global_name == "torch.HalfStorage") { storage->type = GGML_TYPE_F16; storage->raw_element_nbytes = 2; return true; } if (global_name == "torch.BFloat16Storage") { storage->type = GGML_TYPE_BF16; storage->raw_element_nbytes = 2; return true; } if (global_name == "torch.IntStorage") { storage->type = GGML_TYPE_I32; storage->raw_element_nbytes = 4; return true; } if (global_name == "torch.LongStorage") { storage->type = GGML_TYPE_I32; storage->is_i64 = true; storage->raw_element_nbytes = 8; return true; } return false; } static bool tensor_is_contiguous(const PickleTensorInfo& tensor) { if (tensor.tensor_storage.nelements() == 0) { return true; } if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) { return false; } int64_t expected_stride = 1; for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) { if (tensor.stride[i] != expected_stride) { return false; } expected_stride *= tensor.tensor_storage.ne[i]; } return true; } static void collect_tensors_from_pickle_value(const PickleValue& value, std::vector& tensor_storages) { if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) { return; } for (const auto& item : value.dict_items) { if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) { TensorStorage tensor_storage = item.second.tensor.tensor_storage; tensor_storage.name = item.first.str_value; tensor_storage.reverse_ne(); tensor_storages.push_back(tensor_storage); } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) { collect_tensors_from_pickle_value(item.second, tensor_storages); } } } bool parse_torch_state_dict_pickle(const uint8_t* buffer, size_t buffer_size, std::vector& tensor_storages, std::unordered_map& storage_nbytes, std::string* error) { if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) { set_error(error, "unsupported torch pickle protocol"); return false; } const uint8_t* p = buffer + 2; const uint8_t* end = buffer + buffer_size; std::vector stack; std::unordered_map memo; while (p < end) { uint8_t opcode = *p++; switch (opcode) { case '.': { // STOP = b'.' # every pickle ends with STOP if (stack.empty()) { set_error(error, "empty torch pickle stack"); return false; } size_t old_tensor_count = tensor_storages.size(); collect_tensors_from_pickle_value(stack.back(), tensor_storages); if (tensor_storages.size() == old_tensor_count) { set_error(error, "torch pickle does not contain a supported state_dict"); return false; } return true; } case '}': // EMPTY_DICT = b'}' # push empty dict stack.push_back(make_dict_value(false)); break; case ']': // EMPTY_LIST = b']' # push empty list stack.push_back(make_list_value()); break; case 'l': { // LIST = b'l' # build list from mark int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx < 0) { set_error(error, "torch pickle list without mark"); return false; } std::vector items(stack.begin() + mark_idx + 1, stack.end()); stack.erase(stack.begin() + mark_idx, stack.end()); PickleValue list_value = make_list_value(); list_value.items = std::move(items); stack.push_back(std::move(list_value)); } break; case '(': // MARK = b'(' # push markobject stack.push_back(make_mark_value()); break; case ')': // EMPTY_TUPLE = b')' # push empty tuple stack.push_back(make_tuple_value({})); break; case 'N': // NONE = b'N' # push None stack.push_back(make_none_value()); break; case 0x88: // NEWTRUE = b'\x88' # push True stack.push_back(make_bool_value(true)); break; case 0x89: // NEWFALSE = b'\x89' # push False stack.push_back(make_bool_value(false)); break; case 'K': // BININT1 = b'K' # push 1-byte unsigned int if (p >= end) { return false; } stack.push_back(make_int_value(*p++)); break; case 'M': // BININT2 = b'M' # push 2-byte unsigned int if (p + 2 > end) { return false; } stack.push_back(make_int_value(read_short(p))); p += 2; break; case 'J': // BININT = b'J' # push 4-byte signed int if (p + 4 > end) { return false; } stack.push_back(make_int_value(read_int(p))); p += 4; break; case 'I': { // INT = b'I' # push decimal integer line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } std::string s((const char*)p, len); p += len + 1; if (s == "01") { stack.push_back(make_bool_value(true)); } else if (s == "00") { stack.push_back(make_bool_value(false)); } else { stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); } } break; case 'L': { // LONG = b'L' # push decimal long integer line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } std::string s((const char*)p, len); p += len + 1; if (!s.empty() && s.back() == 'L') { s.pop_back(); } stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); } break; case 'F': { // FLOAT = b'F' # push decimal float line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } p += len + 1; stack.push_back(make_none_value()); } break; case 'G': // BINFLOAT = b'G' # push 8-byte binary float if (p + 8 > end) { return false; } p += 8; stack.push_back(make_none_value()); break; case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length if (p >= end) { return false; } uint8_t n = *p++; if (p + n > end || n > 8) { return false; } int64_t value = 0; for (uint8_t i = 0; i < n; ++i) { value |= (int64_t)p[i] << (i * 8); } p += n; stack.push_back(make_int_value(value)); } break; case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256 if (p >= end) { return false; } uint8_t len = *p++; if (p + len > end) { return false; } stack.push_back(make_string_value(std::string((const char*)p, len))); p += len; } break; case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length if (p + 4 > end) { return false; } int32_t len = read_int(p); p += 4; if (len < 0 || p + len > end) { return false; } stack.push_back(make_string_value(std::string((const char*)p, len))); p += len; } break; case 'T': // BINSTRING = b'T' # push string; 4-byte length case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length if (p + 4 > end) { return false; } int32_t len = read_int(p); p += 4; if (len < 0 || p + len > end) { return false; } stack.push_back(make_string_value(std::string((const char*)p, len))); p += len; } break; case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length if (p + 8 > end) { return false; } uint64_t len = read_u64(p); p += 8; if (len > (uint64_t)(end - p)) { return false; } stack.push_back(make_string_value(std::string((const char*)p, (size_t)len))); p += len; } break; case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 if (p >= end) { return false; } uint8_t len = *p++; if (p + len > end) { return false; } stack.push_back(make_string_value(std::string((const char*)p, len))); p += len; } break; case 'S': { // STRING = b'S' # push quoted string line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } std::string s((const char*)p, len); p += len + 1; if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) { s = s.substr(1, s.size() - 2); } stack.push_back(make_string_value(s)); } break; case 'V': { // UNICODE = b'V' # push raw-unicode string line int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } stack.push_back(make_string_value(std::string((const char*)p, len))); p += len + 1; } break; case 'c': { // GLOBAL = b'c' # push module/name global reference int len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } std::string module((const char*)p, len); p += len + 1; len = find_char(p, (int)(end - p), '\n'); if (len < 0) { return false; } std::string name((const char*)p, len); p += len + 1; stack.push_back(make_global_value(module + "." + name)); } break; case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING || stack.back().kind != PickleValue::STRING) { return false; } std::string name = stack.back().str_value; stack.pop_back(); std::string module = stack.back().str_value; stack.pop_back(); stack.push_back(make_global_value(module + "." + name)); } break; case 'h': // BINGET = b'h' # read memo index, 1-byte arg if (p >= end || !memo.count(*p)) { return false; } stack.push_back(memo[*p++]); break; case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg if (p + 4 > end) { return false; } int32_t memo_idx = read_int(p); if (!memo.count(memo_idx)) { return false; } stack.push_back(memo[memo_idx]); p += 4; } break; case 'q': // BINPUT = b'q' # write memo index, 1-byte arg if (p >= end || stack.empty()) { return false; } memo[*p++] = stack.back(); break; case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg if (p + 4 > end || stack.empty()) { return false; } memo[read_int(p)] = stack.back(); p += 4; break; case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo if (stack.empty()) { return false; } memo[(int32_t)memo.size()] = stack.back(); break; case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame if (p + 8 > end) { return false; } p += 8; break; case '0': // POP = b'0' # discard top stack item if (stack.empty()) { return false; } stack.pop_back(); break; case '1': { // POP_MARK = b'1' # discard stack through topmost mark int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx < 0) { return false; } stack.erase(stack.begin() + mark_idx, stack.end()); } break; case '2': // DUP = b'2' # duplicate top stack item if (stack.empty()) { return false; } stack.push_back(stack.back()); break; case 0x8F: // EMPTY_SET = b'\x8f' # push empty set stack.push_back(make_list_value()); break; case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { return false; } PickleValue& set_value = stack[mark_idx - 1]; set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); stack.erase(stack.begin() + mark_idx, stack.end()); } break; case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx < 0) { return false; } PickleValue set_value = make_list_value(); set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); stack.erase(stack.begin() + mark_idx, stack.end()); stack.push_back(std::move(set_value)); } break; case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3); if ((int)stack.size() < tuple_size) { return false; } std::vector items(stack.end() - tuple_size, stack.end()); stack.erase(stack.end() - tuple_size, stack.end()); stack.push_back(make_tuple_value(std::move(items))); } break; case 't': { // TUPLE = b't' # build tuple from mark int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx < 0) { return false; } std::vector items(stack.begin() + mark_idx + 1, stack.end()); stack.erase(stack.begin() + mark_idx, stack.end()); stack.push_back(make_tuple_value(std::move(items))); } break; case 'Q': { // BINPERSID = b'Q' # persistent id from stack if (stack.empty()) { return false; } PickleValue pid = stack.back(); stack.pop_back(); if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING || pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT || pid.items[0].str_value != "storage") { return false; } PickleStorageInfo storage; storage.key = pickle_value_to_string(pid.items[2]); if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) { return false; } storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes; storage_nbytes[storage.key] = storage.nbytes; stack.push_back(make_storage_value(storage)); } break; case 'R': { // REDUCE = b'R' # apply callable to args if (stack.size() < 2) { return false; } PickleValue args = stack.back(); stack.pop_back(); PickleValue callable = stack.back(); stack.pop_back(); if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) { stack.push_back(make_none_value()); break; } if (callable.str_value == "collections.OrderedDict" && args.items.empty()) { stack.push_back(make_dict_value(true)); break; } if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") && args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE && args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE && args.items[3].kind == PickleValue::TUPLE) { PickleTensorInfo tensor; tensor.tensor_storage.type = args.items[0].storage.type; tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64; tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64; tensor.tensor_storage.storage_key = args.items[0].storage.key; tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes; for (const auto& item : args.items[2].items) { if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) { return false; } tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value; } for (const auto& item : args.items[3].items) { if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) { return false; } tensor.stride[tensor.stride_n_dims++] = item.int_value; } if (!tensor_is_contiguous(tensor)) { return false; } stack.push_back(make_tensor_value(tensor)); break; } // Non-tensor checkpoint metadata can use REDUCE for arbitrary // Python objects. Do not execute it; keep stack shape only. stack.push_back(make_none_value()); break; } case 'b': // BUILD = b'b' # build object state if (stack.size() < 2) { return false; } stack.pop_back(); break; case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx <= 0) { return false; } PickleValue& dict = stack[mark_idx - 1]; if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { return false; } for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) { dict.dict_items.emplace_back(stack[i], stack[i + 1]); } stack.erase(stack.begin() + mark_idx, stack.end()); } break; case 's': { // SETITEM = b's' # add key/value to dict if (stack.size() < 3) { return false; } PickleValue value = stack.back(); stack.pop_back(); PickleValue key = stack.back(); stack.pop_back(); PickleValue& dict = stack.back(); if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { return false; } dict.dict_items.emplace_back(key, value); } break; case 'e': { // APPENDS = b'e' # extend list with mark-delimited items int mark_idx = (int)stack.size() - 1; while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { --mark_idx; } if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { return false; } PickleValue& list_value = stack[mark_idx - 1]; list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); stack.erase(stack.begin() + mark_idx, stack.end()); } break; case 'a': { // APPEND = b'a' # append item to list if (stack.size() < 2) { return false; } PickleValue item = stack.back(); stack.pop_back(); if (stack.back().kind != PickleValue::LIST) { return false; } stack.back().items.push_back(item); } break; default: set_error(error, "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) + " at offset " + std::to_string((p - buffer) - 1)); return false; } } set_error(error, "unterminated torch state_dict pickle"); return false; }