diff --git a/py/binary.c b/py/binary.c index 927a42640..8b5c05ab3 100644 --- a/py/binary.c +++ b/py/binary.c @@ -32,6 +32,7 @@ #include "py/binary.h" #include "py/smallint.h" +#include "py/objint.h" // Helpers to work with binary-encoded data @@ -282,10 +283,13 @@ void mp_binary_set_val(char struct_type, char val_type, mp_obj_t val_in, byte ** } #endif default: - // we handle large ints here by calling the truncated accessor + #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE if (MP_OBJ_IS_TYPE(val_in, &mp_type_int)) { - val = mp_obj_int_get_truncated(val_in); - } else { + mp_obj_int_to_bytes_impl(val_in, struct_type == '>', size, p); + return; + } else + #endif + { val = mp_obj_get_int(val_in); } } diff --git a/py/mpz.c b/py/mpz.c index 241fa79be..3c20023bc 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -1425,6 +1425,40 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) { return true; } +// writes at most len bytes to buf (so buf should be zeroed before calling) +void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf) { + byte *b = buf; + if (big_endian) { + b += len; + } + mpz_dig_t *zdig = z->dig; + int bits = 0; + mpz_dbl_dig_t d = 0; + mpz_dbl_dig_t carry = 1; + for (mp_uint_t zlen = z->len; zlen > 0; --zlen) { + bits += DIG_SIZE; + d = (d << DIG_SIZE) | *zdig++; + for (; bits >= 8; bits -= 8, d >>= 8) { + mpz_dig_t val = d; + if (z->neg) { + d = (~d & 0xff) + carry; + carry = d >> 8; + } + if (big_endian) { + *--b = val; + if (b == buf) { + return; + } + } else { + *b++ = val; + if (b == buf + len) { + return; + } + } + } + } +} + #if MICROPY_PY_BUILTINS_FLOAT mp_float_t mpz_as_float(const mpz_t *i) { mp_float_t val = 0; diff --git a/py/mpz.h b/py/mpz.h index 71649aa7f..b00d2b655 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -125,6 +125,7 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m mp_int_t mpz_hash(const mpz_t *z); bool mpz_as_int_checked(const mpz_t *z, mp_int_t *value); bool mpz_as_uint_checked(const mpz_t *z, mp_uint_t *value); +void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf); #if MICROPY_PY_BUILTINS_FLOAT mp_float_t mpz_as_float(const mpz_t *z); #endif diff --git a/py/objint.c b/py/objint.c index 64faed636..7c527d4ae 100644 --- a/py/objint.c +++ b/py/objint.c @@ -35,6 +35,7 @@ #include "py/objstr.h" #include "py/runtime0.h" #include "py/runtime.h" +#include "py/binary.h" #if MICROPY_PY_BUILTINS_FLOAT #include @@ -398,12 +399,10 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(int_from_bytes_fun_obj, 2, 3, int_fro STATIC MP_DEFINE_CONST_CLASSMETHOD_OBJ(int_from_bytes_obj, (const mp_obj_t)&int_from_bytes_fun_obj); STATIC mp_obj_t int_to_bytes(mp_uint_t n_args, const mp_obj_t *args) { - // TODO: Support long ints // TODO: Support byteorder param (assumes 'little') // TODO: Support signed param (assumes signed=False) (void)n_args; - mp_int_t val = mp_obj_int_get_checked(args[0]); mp_uint_t len = MP_OBJ_SMALL_INT_VALUE(args[1]); vstr_t vstr; @@ -411,13 +410,14 @@ STATIC mp_obj_t int_to_bytes(mp_uint_t n_args, const mp_obj_t *args) { byte *data = (byte*)vstr.buf; memset(data, 0, len); - if (MP_ENDIANNESS_LITTLE) { - memcpy(data, &val, len < sizeof(mp_int_t) ? len : sizeof(mp_int_t)); - } else { - while (len--) { - *data++ = val; - val >>= 8; - } + #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE + if (!MP_OBJ_IS_SMALL_INT(args[0])) { + mp_obj_int_to_bytes_impl(args[0], false, len, data); + } else + #endif + { + mp_int_t val = MP_OBJ_SMALL_INT_VALUE(args[0]); + mp_binary_set_int(MIN((size_t)len, sizeof(val)), false, data, val); } return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); diff --git a/py/objint.h b/py/objint.h index daeb3c499..09cf7c86d 100644 --- a/py/objint.h +++ b/py/objint.h @@ -56,6 +56,7 @@ char *mp_obj_int_formatted(char **buf, mp_uint_t *buf_size, mp_uint_t *fmt_size, char *mp_obj_int_formatted_impl(char **buf, mp_uint_t *buf_size, mp_uint_t *fmt_size, mp_const_obj_t self_in, int base, const char *prefix, char base_char, char comma); mp_int_t mp_obj_int_hash(mp_obj_t self_in); +void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf); bool mp_obj_int_is_positive(mp_obj_t self_in); mp_obj_t mp_obj_int_abs(mp_obj_t self_in); mp_obj_t mp_obj_int_unary_op(mp_uint_t op, mp_obj_t o_in); diff --git a/py/objint_longlong.c b/py/objint_longlong.c index 837889704..5b2c6d3f5 100644 --- a/py/objint_longlong.c +++ b/py/objint_longlong.c @@ -63,6 +63,24 @@ mp_int_t mp_obj_int_hash(mp_obj_t self_in) { return self->val; } +void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf) { + assert(MP_OBJ_IS_TYPE(self_in, &mp_type_int)); + mp_obj_int_t *self = self_in; + long long val = self->val; + if (big_endian) { + byte *b = buf + len; + while (b > buf) { + *--b = val; + val >>= 8; + } + } else { + for (; len > 0; --len) { + *buf++ = val; + val >>= 8; + } + } +} + bool mp_obj_int_is_positive(mp_obj_t self_in) { if (MP_OBJ_IS_SMALL_INT(self_in)) { return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0; diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 2746f4dff..369e5af32 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -104,6 +104,12 @@ mp_int_t mp_obj_int_hash(mp_obj_t self_in) { return mpz_hash(&self->mpz); } +void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf) { + assert(MP_OBJ_IS_TYPE(self_in, &mp_type_int)); + mp_obj_int_t *self = self_in; + mpz_as_bytes(&self->mpz, big_endian, len, buf); +} + bool mp_obj_int_is_positive(mp_obj_t self_in) { if (MP_OBJ_IS_SMALL_INT(self_in)) { return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0; diff --git a/tests/basics/int_bytes.py b/tests/basics/int_bytes.py index 45965ed46..2f468da44 100644 --- a/tests/basics/int_bytes.py +++ b/tests/basics/int_bytes.py @@ -1,6 +1,7 @@ print((10).to_bytes(1, "little")) print((111111).to_bytes(4, "little")) print((100).to_bytes(10, "little")) +print((2**64).to_bytes(9, "little")) print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little")) print(int.from_bytes(b"\x01\0\0\0\0\0\0\0", "little")) print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little")) diff --git a/tests/basics/struct1.py b/tests/basics/struct1.py index 09ecd20a6..c473fc0b0 100644 --- a/tests/basics/struct1.py +++ b/tests/basics/struct1.py @@ -30,9 +30,18 @@ print(v == (10, 100, 200, 300)) print(struct.pack("Q", 1234567890123456789)) +print(struct.pack(">q", -1234567890123456789)) +print(struct.unpack("Q", b"\x12\x34\x56\x78\x90\x12\x34\x56")) +print(struct.unpack("q", b"\xf2\x34\x56\x78\x90\x12\x34\x56")) # check maximum unpack print(struct.unpack("