diff --git a/extmod/modbtree.c b/extmod/modbtree.c index 6613c06c8..4c58d17d3 100644 --- a/extmod/modbtree.c +++ b/extmod/modbtree.c @@ -31,6 +31,7 @@ #include "py/nlr.h" #include "py/runtime.h" +#include "py/stream.h" #if MICROPY_PY_BTREE @@ -314,15 +315,20 @@ STATIC const mp_obj_type_t btree_type = { .locals_dict = (void*)&btree_locals_dict, }; +STATIC FILEVTABLE btree_stream_fvtable = { + mp_stream_posix_read, + mp_stream_posix_write, + mp_stream_posix_lseek, + mp_stream_posix_fsync +}; + STATIC mp_obj_t mod_btree_open(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { static const mp_arg_t allowed_args[] = { { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, }; - const char *fname = NULL; - if (pos_args[0] != mp_const_none) { - fname = mp_obj_str_get_str(pos_args[0]); - } + // Make sure we got a stream object + mp_get_stream_raise(pos_args[0], MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); struct { mp_arg_val_t server_side; @@ -330,7 +336,7 @@ STATIC mp_obj_t mod_btree_open(size_t n_args, const mp_obj_t *pos_args, mp_map_t mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t*)&args); - DB *db = __bt_open(fname, /*flags*/O_CREAT | O_RDWR, /*mode*/0770, /*openinfo*/NULL, /*dflags*/0); + DB *db = __bt_open(pos_args[0], &btree_stream_fvtable, /*openinfo*/NULL, /*dflags*/0); if (db == NULL) { nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(errno))); }