From b2d4fc06fc95d8e96eabd6ef470f0e871275fb82 Mon Sep 17 00:00:00 2001 From: Paul Sokolovsky Date: Sun, 11 May 2014 13:17:29 +0300 Subject: [PATCH] objstr: Make *strip() accept bytes. --- py/objstr.c | 9 ++++++--- tests/basics/string_strip.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/py/objstr.c b/py/objstr.c index 247cfde6d..c44e9ebf1 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -540,7 +540,8 @@ enum { LSTRIP, RSTRIP, STRIP }; STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) { assert(1 <= n_args && n_args <= 2); - assert(MP_OBJ_IS_STR(args[0])); + assert(is_str_or_bytes(args[0])); + const mp_obj_type_t *self_type = mp_obj_get_type(args[0]); const byte *chars_to_del; uint chars_to_del_len; @@ -550,7 +551,9 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) { chars_to_del = whitespace; chars_to_del_len = sizeof(whitespace); } else { - assert(MP_OBJ_IS_STR(args[1])); + if (mp_obj_get_type(args[1]) != self_type) { + arg_type_mixup(); + } GET_STR_DATA_LEN(args[1], s, l); chars_to_del = s; chars_to_del_len = l; @@ -594,7 +597,7 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) { assert(last_good_char_pos >= first_good_char_pos); //+1 to accomodate the last character machine_uint_t stripped_len = last_good_char_pos - first_good_char_pos + 1; - return mp_obj_new_str(orig_str + first_good_char_pos, stripped_len, false); + return str_new(self_type, orig_str + first_good_char_pos, stripped_len); } STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) { diff --git a/tests/basics/string_strip.py b/tests/basics/string_strip.py index 8e03eff93..4684c2a24 100644 --- a/tests/basics/string_strip.py +++ b/tests/basics/string_strip.py @@ -10,3 +10,13 @@ print('www.example.com'.lstrip('cmowz.')) print(' spacious '.rstrip()) print('mississippi'.rstrip('ipz')) + +print(b'mississippi'.rstrip(b'ipz')) +try: + print(b'mississippi'.rstrip('ipz')) +except TypeError: + print("TypeError") +try: + print('mississippi'.rstrip(b'ipz')) +except TypeError: + print("TypeError")