diff --git a/zephyr/modusocket.c b/zephyr/modusocket.c index 00e0bc789..bd51e6f5d 100644 --- a/zephyr/modusocket.c +++ b/zephyr/modusocket.c @@ -28,6 +28,7 @@ #ifdef MICROPY_PY_USOCKET #include "py/runtime.h" +#include "py/stream.h" #include #include @@ -300,29 +301,44 @@ STATIC mp_obj_t socket_accept(mp_obj_t self_in) { } STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_accept_obj, socket_accept); -STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) { +STATIC mp_uint_t sock_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errcode) { socket_obj_t *socket = self_in; - socket_check_closed(socket); - - mp_buffer_info_t bufinfo; - mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + if (socket->ctx == NULL) { + // already closed + *errcode = EBADF; + return MP_STREAM_ERROR; + } struct net_buf *send_buf = net_nbuf_get_tx(socket->ctx, K_FOREVER); unsigned len = net_if_get_mtu(net_context_get_iface(socket->ctx)); // Arbitrary value to account for protocol headers len -= 64; - if (len > bufinfo.len) { - len = bufinfo.len; + if (len > size) { + len = size; } - if (!net_nbuf_append(send_buf, len, bufinfo.buf, K_FOREVER)) { + if (!net_nbuf_append(send_buf, len, buf, K_FOREVER)) { len = net_buf_frags_len(send_buf); - //mp_raise_OSError(ENOSPC); } - RAISE_ERRNO(net_context_send(send_buf, /*cb*/NULL, K_FOREVER, NULL, NULL)); + int err = net_context_send(send_buf, /*cb*/NULL, K_FOREVER, NULL, NULL); + if (err < 0) { + *errcode = -err; + return MP_STREAM_ERROR; + } + return len; +} + +STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + int err = 0; + mp_uint_t len = sock_write(self_in, bufinfo.buf, bufinfo.len, &err); + if (len == MP_STREAM_ERROR) { + mp_raise_OSError(err); + } return mp_obj_new_int_from_uint(len); } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_send_obj, socket_send); @@ -436,12 +452,18 @@ STATIC const mp_map_elem_t socket_locals_dict_table[] = { }; STATIC MP_DEFINE_CONST_DICT(socket_locals_dict, socket_locals_dict_table); +STATIC const mp_stream_p_t socket_stream_p = { + //.read = sock_read, + .write = sock_write, + //.ioctl = sock_ioctl, +}; + STATIC const mp_obj_type_t socket_type = { { &mp_type_type }, .name = MP_QSTR_socket, .print = socket_print, .make_new = socket_make_new, - //.protocol = &socket_stream_p, + .protocol = &socket_stream_p, .locals_dict = (mp_obj_t)&socket_locals_dict, };