diff --git a/fs/io_uring.c b/fs/io_uring.c index dd094b387cab..aa8ac557493c 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -2768,6 +2768,38 @@ out: return submit; } +struct io_wait_queue { + struct wait_queue_entry wq; + struct io_ring_ctx *ctx; + unsigned to_wait; + unsigned nr_timeouts; +}; + +static inline bool io_should_wake(struct io_wait_queue *iowq) +{ + struct io_ring_ctx *ctx = iowq->ctx; + + /* + * Wake up if we have enough events, or if a timeout occured since we + * started waiting. For timeouts, we always want to return to userspace, + * regardless of event count. + */ + return io_cqring_events(ctx->rings) >= iowq->to_wait || + atomic_read(&ctx->cq_timeouts) != iowq->nr_timeouts; +} + +static int io_wake_function(struct wait_queue_entry *curr, unsigned int mode, + int wake_flags, void *key) +{ + struct io_wait_queue *iowq = container_of(curr, struct io_wait_queue, + wq); + + if (!io_should_wake(iowq)) + return -1; + + return autoremove_wake_function(curr, mode, wake_flags, key); +} + /* * Wait until events become available, if we don't already have some. The * application must reap them itself, as they reside on the shared cq ring. @@ -2775,8 +2807,16 @@ out: static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, const sigset_t __user *sig, size_t sigsz) { + struct io_wait_queue iowq = { + .wq = { + .private = current, + .func = io_wake_function, + .entry = LIST_HEAD_INIT(iowq.wq.entry), + }, + .ctx = ctx, + .to_wait = min_events, + }; struct io_rings *rings = ctx->rings; - unsigned nr_timeouts; int ret; if (io_cqring_events(rings) >= min_events) @@ -2795,15 +2835,21 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, return ret; } - nr_timeouts = atomic_read(&ctx->cq_timeouts); - /* - * Return if we have enough events, or if a timeout occured since - * we started waiting. For timeouts, we always want to return to - * userspace. - */ - ret = wait_event_interruptible(ctx->wait, - io_cqring_events(rings) >= min_events || - atomic_read(&ctx->cq_timeouts) != nr_timeouts); + ret = 0; + iowq.nr_timeouts = atomic_read(&ctx->cq_timeouts); + do { + prepare_to_wait_exclusive(&ctx->wait, &iowq.wq, + TASK_INTERRUPTIBLE); + if (io_should_wake(&iowq)) + break; + schedule(); + if (signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + } while (1); + finish_wait(&ctx->wait, &iowq.wq); + restore_saved_sigmask_unless(ret == -ERESTARTSYS); if (ret == -ERESTARTSYS) ret = -EINTR; @@ -3455,7 +3501,7 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait) if (READ_ONCE(ctx->rings->sq.tail) - ctx->cached_sq_head != ctx->rings->sq_ring_entries) mask |= EPOLLOUT | EPOLLWRNORM; - if (READ_ONCE(ctx->rings->sq.head) != ctx->cached_cq_tail) + if (READ_ONCE(ctx->rings->cq.head) != ctx->cached_cq_tail) mask |= EPOLLIN | EPOLLRDNORM; return mask;