mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-25 15:46:40 +00:00
1065 lines
43 KiB
C++
1065 lines
43 KiB
C++
#include "pickle_io.h"
|
|
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "binary_io.h"
|
|
#include "util.h"
|
|
|
|
// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
|
|
// 0: \x80 PROTO 2
|
|
// 2: } EMPTY_DICT
|
|
// 3: q BINPUT 0
|
|
// 5: ( MARK
|
|
// 6: X BINUNICODE 'epoch'
|
|
// 16: q BINPUT 1
|
|
// 18: K BININT1 6
|
|
// 20: X BINUNICODE 'global_step'
|
|
// 36: q BINPUT 2
|
|
// 38: J BININT 470000
|
|
// 43: X BINUNICODE 'pytorch-lightning_version'
|
|
// 73: q BINPUT 3
|
|
// 75: X BINUNICODE '1.4.2'
|
|
// 85: q BINPUT 4
|
|
// 87: X BINUNICODE 'state_dict'
|
|
// 102: q BINPUT 5
|
|
// 104: } EMPTY_DICT
|
|
// 105: q BINPUT 6
|
|
// 107: ( MARK
|
|
// 108: X BINUNICODE 'betas'
|
|
// 118: q BINPUT 7
|
|
// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2'
|
|
// 153: q BINPUT 8
|
|
// 155: ( MARK
|
|
// 156: ( MARK
|
|
// 157: X BINUNICODE 'storage'
|
|
// 169: q BINPUT 9
|
|
// 171: c GLOBAL 'torch FloatStorage'
|
|
// 191: q BINPUT 10
|
|
// 193: X BINUNICODE '0'
|
|
// 199: q BINPUT 11
|
|
// 201: X BINUNICODE 'cpu'
|
|
// 209: q BINPUT 12
|
|
// 211: M BININT2 1000
|
|
// 214: t TUPLE (MARK at 156)
|
|
// 215: q BINPUT 13
|
|
// 217: Q BINPERSID
|
|
// 218: K BININT1 0
|
|
// 220: M BININT2 1000
|
|
// ...............................
|
|
// 3201: q BINPUT 250
|
|
// 3203: R REDUCE
|
|
// 3204: q BINPUT 251
|
|
// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
|
|
// 3264: q BINPUT 252
|
|
// 3266: h BINGET 8
|
|
// 3268: ( MARK
|
|
// 3269: ( MARK
|
|
// 3270: h BINGET 9
|
|
// 3272: h BINGET 10
|
|
// 3274: X BINUNICODE '30'
|
|
// 3281: q BINPUT 253
|
|
// 3283: h BINGET 12
|
|
// 3285: J BININT 102400
|
|
// 3290: t TUPLE (MARK at 3269)
|
|
// 3291: q BINPUT 254
|
|
// 3293: Q BINPERSID
|
|
// 3294: K BININT1 0
|
|
// 3296: ( MARK
|
|
// 3297: M BININT2 320
|
|
// 3300: M BININT2 320
|
|
// 3303: K BININT1 1
|
|
// 3305: K BININT1 1
|
|
// 3307: t TUPLE (MARK at 3296)
|
|
// 3308: q BINPUT 255
|
|
// 3310: ( MARK
|
|
// 3311: M BININT2 320
|
|
// 3314: K BININT1 1
|
|
// 3316: K BININT1 1
|
|
// 3318: K BININT1 1
|
|
// 3320: t TUPLE (MARK at 3310)
|
|
// 3321: r LONG_BINPUT 256
|
|
// 3326: \x89 NEWFALSE
|
|
// 3327: h BINGET 16
|
|
// 3329: ) EMPTY_TUPLE
|
|
// 3330: R REDUCE
|
|
// 3331: r LONG_BINPUT 257
|
|
// 3336: t TUPLE (MARK at 3268)
|
|
// 3337: r LONG_BINPUT 258
|
|
// 3342: R REDUCE
|
|
// 3343: r LONG_BINPUT 259
|
|
// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
|
|
// 3404: r LONG_BINPUT 260
|
|
// 3409: h BINGET 8
|
|
// 3411: ( MARK
|
|
// 3412: ( MARK
|
|
// 3413: h BINGET 9
|
|
// 3415: h BINGET 10
|
|
// 3417: X BINUNICODE '31'
|
|
// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
|
|
// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
|
|
|
|
using model_io::find_char;
|
|
using model_io::read_int;
|
|
using model_io::read_short;
|
|
using model_io::read_u64;
|
|
|
|
static void set_error(std::string* error, const std::string& message) {
|
|
if (error != nullptr) {
|
|
*error = message;
|
|
}
|
|
}
|
|
|
|
bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) {
|
|
const uint8_t* p = buffer;
|
|
const uint8_t* end = buffer + buffer_size;
|
|
|
|
while (p < end) {
|
|
uint8_t opcode = *p++;
|
|
switch (opcode) {
|
|
case '.': // STOP = b'.' # every pickle ends with STOP
|
|
*object_size = (size_t)(p - buffer);
|
|
return true;
|
|
case 0x80: // PROTO = b'\x80' # protocol version indicator
|
|
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
|
case 'h': // BINGET = b'h' # read memo index, 1-byte arg
|
|
case 'q': // BINPUT = b'q' # write memo index, 1-byte arg
|
|
case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256
|
|
case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg
|
|
p += 1;
|
|
break;
|
|
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
|
case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg
|
|
p += 2;
|
|
break;
|
|
case 'J': // BININT = b'J' # push 4-byte signed int
|
|
case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg
|
|
case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg
|
|
case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg
|
|
p += 4;
|
|
break;
|
|
case 'I': // INT = b'I' # push decimal integer line
|
|
case 'L': // LONG = b'L' # push decimal long integer line
|
|
case 'F': // FLOAT = b'F' # push decimal float line
|
|
case 'S': // STRING = b'S' # push quoted string line
|
|
case 'V': { // UNICODE = b'V' # push raw-unicode string line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
} break;
|
|
case 'G': // BINFLOAT = b'G' # push 8-byte binary float
|
|
p += 8;
|
|
break;
|
|
case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
p += 1 + p[0];
|
|
break;
|
|
case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
uint32_t n = (uint32_t)read_int(p);
|
|
p += 4 + n;
|
|
} break;
|
|
case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
uint32_t n = (uint32_t)read_int(p);
|
|
p += 4 + n;
|
|
} break;
|
|
case 'T': // BINSTRING = b'T' # push string; 4-byte length
|
|
case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
uint32_t n = (uint32_t)read_int(p);
|
|
p += 4 + n;
|
|
} break;
|
|
case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length
|
|
case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length
|
|
case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length
|
|
if (p + 8 > end) {
|
|
return false;
|
|
}
|
|
uint64_t n = read_u64(p);
|
|
p += 8;
|
|
if (n > (uint64_t)(end - p)) {
|
|
return false;
|
|
}
|
|
p += n;
|
|
} break;
|
|
case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256
|
|
case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
p += 1 + p[0];
|
|
break;
|
|
case 'P': { // PERSID = b'P' # persistent id, newline-terminated
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
} break;
|
|
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
|
|
p += 8;
|
|
break;
|
|
case 'c': { // GLOBAL = b'c' # push module/name global reference
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
} break;
|
|
case '}': // EMPTY_DICT = b'}' # push empty dict
|
|
case ']': // EMPTY_LIST = b']' # push empty list
|
|
case '(': // MARK = b'(' # push markobject
|
|
case 't': // TUPLE = b't' # build tuple from mark
|
|
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack
|
|
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack
|
|
case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack
|
|
case ')': // EMPTY_TUPLE = b')' # push empty tuple
|
|
case 'l': // LIST = b'l' # build list from mark
|
|
case 'Q': // BINPERSID = b'Q' # persistent id from stack
|
|
case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo
|
|
case 0x88: // NEWTRUE = b'\x88' # push True
|
|
case 0x89: // NEWFALSE = b'\x89' # push False
|
|
case 'R': // REDUCE = b'R' # apply callable to args
|
|
case 'u': // SETITEMS = b'u' # add mark-delimited items to dict
|
|
case 's': // SETITEM = b's' # add key/value to dict
|
|
case 'e': // APPENDS = b'e' # extend list with mark-delimited items
|
|
case 'a': // APPEND = b'a' # append item to list
|
|
case 'b': // BUILD = b'b' # build object state
|
|
case 0x81: // NEWOBJ = b'\x81' # build object via __new__
|
|
case 0x8F: // EMPTY_SET = b'\x8f' # push empty set
|
|
case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set
|
|
case 0x91: // FROZENSET = b'\x91' # build frozenset from mark
|
|
case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs
|
|
case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings
|
|
case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker
|
|
case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly
|
|
case 'N': // NONE = b'N' # push None
|
|
case '0': // POP = b'0' # discard top stack item
|
|
case '1': // POP_MARK = b'1' # discard stack through topmost mark
|
|
case '2': // DUP = b'2' # duplicate top stack item
|
|
case 'o': // OBJ = b'o' # build class instance from mark
|
|
break;
|
|
case 'i': { // INST = b'i' # build class instance from module/name
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
} break;
|
|
default:
|
|
return false;
|
|
}
|
|
if (p > end) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) {
|
|
static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19};
|
|
|
|
if (buffer_size < 5 || buffer[0] != 0x80) {
|
|
return false;
|
|
}
|
|
|
|
size_t pos = 2;
|
|
if (pos >= buffer_size) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t opcode = buffer[pos++];
|
|
if (opcode != 0x8A || pos >= buffer_size) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t len = buffer[pos++];
|
|
if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) {
|
|
return false;
|
|
}
|
|
|
|
if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) {
|
|
return false;
|
|
}
|
|
pos += len;
|
|
|
|
return pos < buffer_size && buffer[pos] == '.';
|
|
}
|
|
|
|
bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) {
|
|
if (buffer_size < 4 || buffer[0] != 0x80) {
|
|
return false;
|
|
}
|
|
|
|
size_t pos = 2;
|
|
if (pos >= buffer_size) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t opcode = buffer[pos++];
|
|
switch (opcode) {
|
|
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
|
if (pos + 1 >= buffer_size) {
|
|
return false;
|
|
}
|
|
*value = buffer[pos];
|
|
pos += 1;
|
|
break;
|
|
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
|
if (pos + 2 >= buffer_size) {
|
|
return false;
|
|
}
|
|
*value = read_short(buffer + pos);
|
|
pos += 2;
|
|
break;
|
|
case 'J': // BININT = b'J' # push 4-byte signed int
|
|
if (pos + 4 >= buffer_size) {
|
|
return false;
|
|
}
|
|
*value = (uint32_t)read_int(buffer + pos);
|
|
pos += 4;
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
|
|
return pos < buffer_size && buffer[pos] == '.';
|
|
}
|
|
|
|
struct PickleStorageInfo {
|
|
std::string key;
|
|
ggml_type type = GGML_TYPE_COUNT;
|
|
bool is_f64 = false;
|
|
bool is_i64 = false;
|
|
uint64_t raw_element_nbytes = 0;
|
|
uint64_t nbytes = 0;
|
|
};
|
|
|
|
struct PickleTensorInfo {
|
|
TensorStorage tensor_storage;
|
|
int stride_n_dims = 0;
|
|
int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1};
|
|
};
|
|
|
|
struct PickleValue {
|
|
enum Kind {
|
|
MARK,
|
|
NONE,
|
|
BOOL,
|
|
INT,
|
|
STRING,
|
|
GLOBAL,
|
|
TUPLE,
|
|
LIST,
|
|
DICT,
|
|
ORDERED_DICT,
|
|
STORAGE,
|
|
TENSOR,
|
|
};
|
|
|
|
Kind kind = NONE;
|
|
int64_t int_value = 0;
|
|
bool bool_value = false;
|
|
std::string str_value;
|
|
std::vector<PickleValue> items;
|
|
std::vector<std::pair<PickleValue, PickleValue>> dict_items;
|
|
PickleStorageInfo storage;
|
|
PickleTensorInfo tensor;
|
|
};
|
|
|
|
static PickleValue make_mark_value() {
|
|
PickleValue value;
|
|
value.kind = PickleValue::MARK;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_none_value() {
|
|
PickleValue value;
|
|
value.kind = PickleValue::NONE;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_bool_value(bool b) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::BOOL;
|
|
value.bool_value = b;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_int_value(int64_t x) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::INT;
|
|
value.int_value = x;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_string_value(const std::string& s) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::STRING;
|
|
value.str_value = s;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_global_value(const std::string& s) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::GLOBAL;
|
|
value.str_value = s;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_tuple_value(std::vector<PickleValue> items) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::TUPLE;
|
|
value.items = std::move(items);
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_list_value() {
|
|
PickleValue value;
|
|
value.kind = PickleValue::LIST;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_dict_value(bool ordered) {
|
|
PickleValue value;
|
|
value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_storage_value(const PickleStorageInfo& storage) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::STORAGE;
|
|
value.storage = storage;
|
|
return value;
|
|
}
|
|
|
|
static PickleValue make_tensor_value(const PickleTensorInfo& tensor) {
|
|
PickleValue value;
|
|
value.kind = PickleValue::TENSOR;
|
|
value.tensor = tensor;
|
|
return value;
|
|
}
|
|
|
|
static std::string pickle_value_to_string(const PickleValue& value) {
|
|
if (value.kind == PickleValue::STRING) {
|
|
return value.str_value;
|
|
}
|
|
if (value.kind == PickleValue::INT) {
|
|
return std::to_string(value.int_value);
|
|
}
|
|
return "";
|
|
}
|
|
|
|
static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) {
|
|
if (global_name == "torch.FloatStorage") {
|
|
storage->type = GGML_TYPE_F32;
|
|
storage->raw_element_nbytes = 4;
|
|
return true;
|
|
}
|
|
if (global_name == "torch.DoubleStorage") {
|
|
storage->type = GGML_TYPE_F32;
|
|
storage->is_f64 = true;
|
|
storage->raw_element_nbytes = 8;
|
|
return true;
|
|
}
|
|
if (global_name == "torch.HalfStorage") {
|
|
storage->type = GGML_TYPE_F16;
|
|
storage->raw_element_nbytes = 2;
|
|
return true;
|
|
}
|
|
if (global_name == "torch.BFloat16Storage") {
|
|
storage->type = GGML_TYPE_BF16;
|
|
storage->raw_element_nbytes = 2;
|
|
return true;
|
|
}
|
|
if (global_name == "torch.IntStorage") {
|
|
storage->type = GGML_TYPE_I32;
|
|
storage->raw_element_nbytes = 4;
|
|
return true;
|
|
}
|
|
if (global_name == "torch.LongStorage") {
|
|
storage->type = GGML_TYPE_I32;
|
|
storage->is_i64 = true;
|
|
storage->raw_element_nbytes = 8;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static bool tensor_is_contiguous(const PickleTensorInfo& tensor) {
|
|
if (tensor.tensor_storage.nelements() == 0) {
|
|
return true;
|
|
}
|
|
if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) {
|
|
return false;
|
|
}
|
|
|
|
int64_t expected_stride = 1;
|
|
for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) {
|
|
if (tensor.stride[i] != expected_stride) {
|
|
return false;
|
|
}
|
|
expected_stride *= tensor.tensor_storage.ne[i];
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static void collect_tensors_from_pickle_value(const PickleValue& value,
|
|
std::vector<TensorStorage>& tensor_storages) {
|
|
if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) {
|
|
return;
|
|
}
|
|
|
|
for (const auto& item : value.dict_items) {
|
|
if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) {
|
|
TensorStorage tensor_storage = item.second.tensor.tensor_storage;
|
|
tensor_storage.name = item.first.str_value;
|
|
tensor_storage.reverse_ne();
|
|
tensor_storages.push_back(tensor_storage);
|
|
} else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) {
|
|
collect_tensors_from_pickle_value(item.second, tensor_storages);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool parse_torch_state_dict_pickle(const uint8_t* buffer,
|
|
size_t buffer_size,
|
|
std::vector<TensorStorage>& tensor_storages,
|
|
std::unordered_map<std::string, uint64_t>& storage_nbytes,
|
|
std::string* error) {
|
|
if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) {
|
|
set_error(error, "unsupported torch pickle protocol");
|
|
return false;
|
|
}
|
|
|
|
const uint8_t* p = buffer + 2;
|
|
const uint8_t* end = buffer + buffer_size;
|
|
std::vector<PickleValue> stack;
|
|
std::unordered_map<int32_t, PickleValue> memo;
|
|
|
|
while (p < end) {
|
|
uint8_t opcode = *p++;
|
|
switch (opcode) {
|
|
case '.': { // STOP = b'.' # every pickle ends with STOP
|
|
if (stack.empty()) {
|
|
set_error(error, "empty torch pickle stack");
|
|
return false;
|
|
}
|
|
size_t old_tensor_count = tensor_storages.size();
|
|
collect_tensors_from_pickle_value(stack.back(), tensor_storages);
|
|
if (tensor_storages.size() == old_tensor_count) {
|
|
set_error(error, "torch pickle does not contain a supported state_dict");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
case '}': // EMPTY_DICT = b'}' # push empty dict
|
|
stack.push_back(make_dict_value(false));
|
|
break;
|
|
case ']': // EMPTY_LIST = b']' # push empty list
|
|
stack.push_back(make_list_value());
|
|
break;
|
|
case 'l': { // LIST = b'l' # build list from mark
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx < 0) {
|
|
set_error(error, "torch pickle list without mark");
|
|
return false;
|
|
}
|
|
std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
PickleValue list_value = make_list_value();
|
|
list_value.items = std::move(items);
|
|
stack.push_back(std::move(list_value));
|
|
} break;
|
|
case '(': // MARK = b'(' # push markobject
|
|
stack.push_back(make_mark_value());
|
|
break;
|
|
case ')': // EMPTY_TUPLE = b')' # push empty tuple
|
|
stack.push_back(make_tuple_value({}));
|
|
break;
|
|
case 'N': // NONE = b'N' # push None
|
|
stack.push_back(make_none_value());
|
|
break;
|
|
case 0x88: // NEWTRUE = b'\x88' # push True
|
|
stack.push_back(make_bool_value(true));
|
|
break;
|
|
case 0x89: // NEWFALSE = b'\x89' # push False
|
|
stack.push_back(make_bool_value(false));
|
|
break;
|
|
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_int_value(*p++));
|
|
break;
|
|
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
|
if (p + 2 > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_int_value(read_short(p)));
|
|
p += 2;
|
|
break;
|
|
case 'J': // BININT = b'J' # push 4-byte signed int
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_int_value(read_int(p)));
|
|
p += 4;
|
|
break;
|
|
case 'I': { // INT = b'I' # push decimal integer line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
std::string s((const char*)p, len);
|
|
p += len + 1;
|
|
if (s == "01") {
|
|
stack.push_back(make_bool_value(true));
|
|
} else if (s == "00") {
|
|
stack.push_back(make_bool_value(false));
|
|
} else {
|
|
stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
|
|
}
|
|
} break;
|
|
case 'L': { // LONG = b'L' # push decimal long integer line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
std::string s((const char*)p, len);
|
|
p += len + 1;
|
|
if (!s.empty() && s.back() == 'L') {
|
|
s.pop_back();
|
|
}
|
|
stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
|
|
} break;
|
|
case 'F': { // FLOAT = b'F' # push decimal float line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
p += len + 1;
|
|
stack.push_back(make_none_value());
|
|
} break;
|
|
case 'G': // BINFLOAT = b'G' # push 8-byte binary float
|
|
if (p + 8 > end) {
|
|
return false;
|
|
}
|
|
p += 8;
|
|
stack.push_back(make_none_value());
|
|
break;
|
|
case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
uint8_t n = *p++;
|
|
if (p + n > end || n > 8) {
|
|
return false;
|
|
}
|
|
int64_t value = 0;
|
|
for (uint8_t i = 0; i < n; ++i) {
|
|
value |= (int64_t)p[i] << (i * 8);
|
|
}
|
|
p += n;
|
|
stack.push_back(make_int_value(value));
|
|
} break;
|
|
case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
uint8_t len = *p++;
|
|
if (p + len > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, len)));
|
|
p += len;
|
|
} break;
|
|
case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
int32_t len = read_int(p);
|
|
p += 4;
|
|
if (len < 0 || p + len > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, len)));
|
|
p += len;
|
|
} break;
|
|
case 'T': // BINSTRING = b'T' # push string; 4-byte length
|
|
case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
int32_t len = read_int(p);
|
|
p += 4;
|
|
if (len < 0 || p + len > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, len)));
|
|
p += len;
|
|
} break;
|
|
case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length
|
|
case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length
|
|
case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length
|
|
if (p + 8 > end) {
|
|
return false;
|
|
}
|
|
uint64_t len = read_u64(p);
|
|
p += 8;
|
|
if (len > (uint64_t)(end - p)) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, (size_t)len)));
|
|
p += len;
|
|
} break;
|
|
case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256
|
|
case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256
|
|
if (p >= end) {
|
|
return false;
|
|
}
|
|
uint8_t len = *p++;
|
|
if (p + len > end) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, len)));
|
|
p += len;
|
|
} break;
|
|
case 'S': { // STRING = b'S' # push quoted string line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
std::string s((const char*)p, len);
|
|
p += len + 1;
|
|
if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) {
|
|
s = s.substr(1, s.size() - 2);
|
|
}
|
|
stack.push_back(make_string_value(s));
|
|
} break;
|
|
case 'V': { // UNICODE = b'V' # push raw-unicode string line
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_string_value(std::string((const char*)p, len)));
|
|
p += len + 1;
|
|
} break;
|
|
case 'c': { // GLOBAL = b'c' # push module/name global reference
|
|
int len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
std::string module((const char*)p, len);
|
|
p += len + 1;
|
|
len = find_char(p, (int)(end - p), '\n');
|
|
if (len < 0) {
|
|
return false;
|
|
}
|
|
std::string name((const char*)p, len);
|
|
p += len + 1;
|
|
stack.push_back(make_global_value(module + "." + name));
|
|
} break;
|
|
case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings
|
|
if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING ||
|
|
stack.back().kind != PickleValue::STRING) {
|
|
return false;
|
|
}
|
|
std::string name = stack.back().str_value;
|
|
stack.pop_back();
|
|
std::string module = stack.back().str_value;
|
|
stack.pop_back();
|
|
stack.push_back(make_global_value(module + "." + name));
|
|
} break;
|
|
case 'h': // BINGET = b'h' # read memo index, 1-byte arg
|
|
if (p >= end || !memo.count(*p)) {
|
|
return false;
|
|
}
|
|
stack.push_back(memo[*p++]);
|
|
break;
|
|
case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg
|
|
if (p + 4 > end) {
|
|
return false;
|
|
}
|
|
int32_t memo_idx = read_int(p);
|
|
if (!memo.count(memo_idx)) {
|
|
return false;
|
|
}
|
|
stack.push_back(memo[memo_idx]);
|
|
p += 4;
|
|
} break;
|
|
case 'q': // BINPUT = b'q' # write memo index, 1-byte arg
|
|
if (p >= end || stack.empty()) {
|
|
return false;
|
|
}
|
|
memo[*p++] = stack.back();
|
|
break;
|
|
case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg
|
|
if (p + 4 > end || stack.empty()) {
|
|
return false;
|
|
}
|
|
memo[read_int(p)] = stack.back();
|
|
p += 4;
|
|
break;
|
|
case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo
|
|
if (stack.empty()) {
|
|
return false;
|
|
}
|
|
memo[(int32_t)memo.size()] = stack.back();
|
|
break;
|
|
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
|
|
if (p + 8 > end) {
|
|
return false;
|
|
}
|
|
p += 8;
|
|
break;
|
|
case '0': // POP = b'0' # discard top stack item
|
|
if (stack.empty()) {
|
|
return false;
|
|
}
|
|
stack.pop_back();
|
|
break;
|
|
case '1': { // POP_MARK = b'1' # discard stack through topmost mark
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx < 0) {
|
|
return false;
|
|
}
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
} break;
|
|
case '2': // DUP = b'2' # duplicate top stack item
|
|
if (stack.empty()) {
|
|
return false;
|
|
}
|
|
stack.push_back(stack.back());
|
|
break;
|
|
case 0x8F: // EMPTY_SET = b'\x8f' # push empty set
|
|
stack.push_back(make_list_value());
|
|
break;
|
|
case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
|
|
return false;
|
|
}
|
|
PickleValue& set_value = stack[mark_idx - 1];
|
|
set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
} break;
|
|
case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx < 0) {
|
|
return false;
|
|
}
|
|
PickleValue set_value = make_list_value();
|
|
set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
stack.push_back(std::move(set_value));
|
|
} break;
|
|
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack
|
|
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack
|
|
case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack
|
|
int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3);
|
|
if ((int)stack.size() < tuple_size) {
|
|
return false;
|
|
}
|
|
std::vector<PickleValue> items(stack.end() - tuple_size, stack.end());
|
|
stack.erase(stack.end() - tuple_size, stack.end());
|
|
stack.push_back(make_tuple_value(std::move(items)));
|
|
} break;
|
|
case 't': { // TUPLE = b't' # build tuple from mark
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx < 0) {
|
|
return false;
|
|
}
|
|
std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
stack.push_back(make_tuple_value(std::move(items)));
|
|
} break;
|
|
case 'Q': { // BINPERSID = b'Q' # persistent id from stack
|
|
if (stack.empty()) {
|
|
return false;
|
|
}
|
|
PickleValue pid = stack.back();
|
|
stack.pop_back();
|
|
if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING ||
|
|
pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT ||
|
|
pid.items[0].str_value != "storage") {
|
|
return false;
|
|
}
|
|
|
|
PickleStorageInfo storage;
|
|
storage.key = pickle_value_to_string(pid.items[2]);
|
|
if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) {
|
|
return false;
|
|
}
|
|
storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes;
|
|
storage_nbytes[storage.key] = storage.nbytes;
|
|
stack.push_back(make_storage_value(storage));
|
|
} break;
|
|
case 'R': { // REDUCE = b'R' # apply callable to args
|
|
if (stack.size() < 2) {
|
|
return false;
|
|
}
|
|
PickleValue args = stack.back();
|
|
stack.pop_back();
|
|
PickleValue callable = stack.back();
|
|
stack.pop_back();
|
|
if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) {
|
|
stack.push_back(make_none_value());
|
|
break;
|
|
}
|
|
|
|
if (callable.str_value == "collections.OrderedDict" && args.items.empty()) {
|
|
stack.push_back(make_dict_value(true));
|
|
break;
|
|
}
|
|
|
|
if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") &&
|
|
args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE &&
|
|
args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE &&
|
|
args.items[3].kind == PickleValue::TUPLE) {
|
|
PickleTensorInfo tensor;
|
|
tensor.tensor_storage.type = args.items[0].storage.type;
|
|
tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64;
|
|
tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64;
|
|
tensor.tensor_storage.storage_key = args.items[0].storage.key;
|
|
tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes;
|
|
|
|
for (const auto& item : args.items[2].items) {
|
|
if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) {
|
|
return false;
|
|
}
|
|
tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value;
|
|
}
|
|
|
|
for (const auto& item : args.items[3].items) {
|
|
if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) {
|
|
return false;
|
|
}
|
|
tensor.stride[tensor.stride_n_dims++] = item.int_value;
|
|
}
|
|
|
|
if (!tensor_is_contiguous(tensor)) {
|
|
return false;
|
|
}
|
|
stack.push_back(make_tensor_value(tensor));
|
|
break;
|
|
}
|
|
|
|
// Non-tensor checkpoint metadata can use REDUCE for arbitrary
|
|
// Python objects. Do not execute it; keep stack shape only.
|
|
stack.push_back(make_none_value());
|
|
break;
|
|
}
|
|
case 'b': // BUILD = b'b' # build object state
|
|
if (stack.size() < 2) {
|
|
return false;
|
|
}
|
|
stack.pop_back();
|
|
break;
|
|
case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx <= 0) {
|
|
return false;
|
|
}
|
|
PickleValue& dict = stack[mark_idx - 1];
|
|
if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
|
|
return false;
|
|
}
|
|
for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) {
|
|
dict.dict_items.emplace_back(stack[i], stack[i + 1]);
|
|
}
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
} break;
|
|
case 's': { // SETITEM = b's' # add key/value to dict
|
|
if (stack.size() < 3) {
|
|
return false;
|
|
}
|
|
PickleValue value = stack.back();
|
|
stack.pop_back();
|
|
PickleValue key = stack.back();
|
|
stack.pop_back();
|
|
PickleValue& dict = stack.back();
|
|
if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
|
|
return false;
|
|
}
|
|
dict.dict_items.emplace_back(key, value);
|
|
} break;
|
|
case 'e': { // APPENDS = b'e' # extend list with mark-delimited items
|
|
int mark_idx = (int)stack.size() - 1;
|
|
while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
|
|
--mark_idx;
|
|
}
|
|
if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
|
|
return false;
|
|
}
|
|
PickleValue& list_value = stack[mark_idx - 1];
|
|
list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
|
|
stack.erase(stack.begin() + mark_idx, stack.end());
|
|
} break;
|
|
case 'a': { // APPEND = b'a' # append item to list
|
|
if (stack.size() < 2) {
|
|
return false;
|
|
}
|
|
PickleValue item = stack.back();
|
|
stack.pop_back();
|
|
if (stack.back().kind != PickleValue::LIST) {
|
|
return false;
|
|
}
|
|
stack.back().items.push_back(item);
|
|
} break;
|
|
default:
|
|
set_error(error,
|
|
"unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) +
|
|
" at offset " + std::to_string((p - buffer) - 1));
|
|
return false;
|
|
}
|
|
}
|
|
|
|
set_error(error, "unterminated torch state_dict pickle");
|
|
return false;
|
|
}
|