diff --git a/export.py b/export.py index 8d43156..ffcb506 100644 --- a/export.py +++ b/export.py @@ -124,13 +124,10 @@ def version1_export(model, filepath): out_file = open(filepath, 'wb') # first write out the header. the header will be 256 bytes - nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - nbytes += 4 # 2) write version, which will be int out_file.write(struct.pack('i', version)) - nbytes += 4 # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] @@ -138,12 +135,10 @@ def version1_export(model, filepath): header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) - nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) - nbytes += 1 - pad = 256 - nbytes # pad the rest with zeros + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) @@ -198,13 +193,10 @@ def version2_export(model, filepath, group_size=64): # write out_file = open(filepath, 'wb') # first write out the header. the header will be 256 bytes - nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - nbytes += 4 # 2) write version, which will be int out_file.write(struct.pack('i', version)) - nbytes += 4 # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] @@ -212,14 +204,11 @@ def version2_export(model, filepath, group_size=64): header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) - nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) - nbytes += 1 out_file.write(struct.pack('i', group_size)) # group size used for quantization - nbytes += 4 - pad = 256 - nbytes # pad the rest with zeros + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) # now that the header is done, let's write out the model