From 0ce03b48a04a7766f8694b1de8a88073542dcc20 Mon Sep 17 00:00:00 2001 From: "John R. Lenton" Date: Sun, 12 Jan 2014 15:17:42 +0000 Subject: [PATCH] make sets iterable --- py/objset.c | 39 ++++++++++++++++++++++++++++++++++ tests/basics/tests/set_iter.py | 5 +++++ 2 files changed, 44 insertions(+) create mode 100644 tests/basics/tests/set_iter.py diff --git a/py/objset.c b/py/objset.c index 67dab11df..5606c4751 100644 --- a/py/objset.c +++ b/py/objset.c @@ -15,6 +15,14 @@ typedef struct _mp_obj_set_t { mp_set_t set; } mp_obj_set_t; +typedef struct _mp_obj_set_it_t { + mp_obj_base_t base; + mp_obj_set_t *set; + machine_uint_t cur; +} mp_obj_set_it_t; + +static mp_obj_t set_it_iternext(mp_obj_t self_in); + void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in) { mp_obj_set_t *self = self_in; bool first = true; @@ -54,11 +62,42 @@ static mp_obj_t set_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) } } +const mp_obj_type_t set_it_type = { + { &mp_const_type }, + "set_iterator", + .iternext = set_it_iternext, +}; + +static mp_obj_t set_it_iternext(mp_obj_t self_in) { + assert(MP_OBJ_IS_TYPE(self_in, &set_it_type)); + mp_obj_set_it_t *self = self_in; + machine_uint_t max = self->set->set.alloc; + mp_obj_t *table = self->set->set.table; + + for (machine_uint_t i = self->cur; i < max; i++) { + if (table[i] != NULL) { + self->cur = i + 1; + return table[i]; + } + } + + return mp_const_stop_iteration; +} + +static mp_obj_t set_getiter(mp_obj_t set_in) { + mp_obj_set_it_t *o = m_new_obj(mp_obj_set_it_t); + o->base.type = &set_it_type; + o->set = (mp_obj_set_t *)set_in; + o->cur = 0; + return o; +} + const mp_obj_type_t set_type = { { &mp_const_type }, "set", .print = set_print, .make_new = set_make_new, + .getiter = set_getiter, }; mp_obj_t mp_obj_new_set(int n_args, mp_obj_t *items) { diff --git a/tests/basics/tests/set_iter.py b/tests/basics/tests/set_iter.py new file mode 100644 index 000000000..296017730 --- /dev/null +++ b/tests/basics/tests/set_iter.py @@ -0,0 +1,5 @@ +s = {1, 2, 3, 4} +l = list(s) +l.sort() +print(l) +