diff --git a/fs/io-wq.c b/fs/io-wq.c index 2b4276990571..31c5a10b0825 100644 --- a/fs/io-wq.c +++ b/fs/io-wq.c @@ -57,6 +57,7 @@ struct io_worker { struct rcu_head rcu; struct mm_struct *mm; + const struct cred *creds; struct files_struct *restore_files; }; @@ -111,6 +112,7 @@ struct io_wq { struct task_struct *manager; struct user_struct *user; + struct cred *creds; struct mm_struct *mm; refcount_t refs; struct completion done; @@ -136,6 +138,11 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker) { bool dropped_lock = false; + if (worker->creds) { + revert_creds(worker->creds); + worker->creds = NULL; + } + if (current->files != worker->restore_files) { __acquire(&wqe->lock); spin_unlock_irq(&wqe->lock); @@ -442,6 +449,8 @@ next: set_fs(USER_DS); worker->mm = wq->mm; } + if (!worker->creds) + worker->creds = override_creds(wq->creds); if (test_bit(IO_WQ_BIT_CANCEL, &wq->state)) work->flags |= IO_WQ_WORK_CANCEL; if (worker->mm) @@ -995,6 +1004,7 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) /* caller must already hold a reference to this */ wq->user = data->user; + wq->creds = data->creds; i = 0; for_each_online_node(node) { diff --git a/fs/io-wq.h b/fs/io-wq.h index bb8f1c8f8e24..5cd8c7697e88 100644 --- a/fs/io-wq.h +++ b/fs/io-wq.h @@ -45,6 +45,7 @@ typedef void (put_work_fn)(struct io_wq_work *); struct io_wq_data { struct mm_struct *mm; struct user_struct *user; + struct cred *creds; get_work_fn *get_work; put_work_fn *put_work; diff --git a/fs/io_uring.c b/fs/io_uring.c index fabae84396bc..b6c6fdc12de7 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -237,6 +237,8 @@ struct io_ring_ctx { struct user_struct *user; + struct cred *creds; + /* 0 is for ctx quiesce/reinit/free, 1 is for sqo_thread started */ struct completion *completions; @@ -3267,6 +3269,7 @@ static int io_sq_thread(void *data) { struct io_ring_ctx *ctx = data; struct mm_struct *cur_mm = NULL; + const struct cred *old_cred; mm_segment_t old_fs; DEFINE_WAIT(wait); unsigned inflight; @@ -3277,6 +3280,7 @@ static int io_sq_thread(void *data) old_fs = get_fs(); set_fs(USER_DS); + old_cred = override_creds(ctx->creds); ret = timeout = inflight = 0; while (!kthread_should_park()) { @@ -3383,6 +3387,7 @@ static int io_sq_thread(void *data) unuse_mm(cur_mm); mmput(cur_mm); } + revert_creds(old_cred); kthread_parkme(); @@ -4009,6 +4014,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx, data.mm = ctx->sqo_mm; data.user = ctx->user; + data.creds = ctx->creds; data.get_work = io_get_work; data.put_work = io_put_work; @@ -4363,6 +4369,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx) io_unaccount_mem(ctx->user, ring_pages(ctx->sq_entries, ctx->cq_entries)); free_uid(ctx->user); + put_cred(ctx->creds); kfree(ctx->completions); kmem_cache_free(req_cachep, ctx->fallback_req); kfree(ctx); @@ -4715,6 +4722,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p) ctx->compat = in_compat_syscall(); ctx->account_mem = account_mem; ctx->user = user; + ctx->creds = prepare_creds(); ret = io_allocate_scq_urings(ctx, p); if (ret)