diff --git a/include/net/netns/xfrm.h b/include/net/netns/xfrm.h index 24cd3949a9a4..27bb9633c69d 100644 --- a/include/net/netns/xfrm.h +++ b/include/net/netns/xfrm.h @@ -11,7 +11,7 @@ struct ctl_table_header; struct xfrm_policy_hash { - struct hlist_head *table; + struct hlist_head __rcu *table; unsigned int hmask; u8 dbits4; u8 sbits4; @@ -38,14 +38,12 @@ struct netns_xfrm { * mode. Also, it can be used by ah/esp icmp error handler to find * offending SA. */ - struct hlist_head *state_bydst; - struct hlist_head *state_bysrc; - struct hlist_head *state_byspi; + struct hlist_head __rcu *state_bydst; + struct hlist_head __rcu *state_bysrc; + struct hlist_head __rcu *state_byspi; unsigned int state_hmask; unsigned int state_num; struct work_struct state_hash_work; - struct hlist_head state_gc_list; - struct work_struct state_gc_work; struct list_head policy_all; struct hlist_head *policy_byidx; @@ -73,7 +71,7 @@ struct netns_xfrm { struct dst_ops xfrm6_dst_ops; #endif spinlock_t xfrm_state_lock; - rwlock_t xfrm_policy_lock; + spinlock_t xfrm_policy_lock; struct mutex xfrm_cfg_mutex; /* flow cache part */ diff --git a/include/net/xfrm.h b/include/net/xfrm.h index adfebd6f243c..d2fdd6d70959 100644 --- a/include/net/xfrm.h +++ b/include/net/xfrm.h @@ -187,7 +187,7 @@ struct xfrm_state { struct xfrm_replay_state_esn *preplay_esn; /* The functions for replay detection. */ - struct xfrm_replay *repl; + const struct xfrm_replay *repl; /* internal flag that only holds state for delayed aevent at the * moment diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index b5e665b3cfb0..f7ce6265961a 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -49,6 +49,7 @@ static struct xfrm_policy_afinfo __rcu *xfrm_policy_afinfo[NPROTO] __read_mostly; static struct kmem_cache *xfrm_dst_cache __read_mostly; +static __read_mostly seqcount_t xfrm_policy_hash_generation; static void xfrm_init_pmtu(struct dst_entry *dst); static int stale_bundle(struct dst_entry *dst); @@ -59,6 +60,11 @@ static void __xfrm_policy_link(struct xfrm_policy *pol, int dir); static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol, int dir); +static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy) +{ + return atomic_inc_not_zero(&policy->refcnt); +} + static inline bool __xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl) { @@ -385,9 +391,11 @@ static struct hlist_head *policy_hash_bysel(struct net *net, __get_hash_thresh(net, family, dir, &dbits, &sbits); hash = __sel_hash(sel, family, hmask, dbits, sbits); - return (hash == hmask + 1 ? - &net->xfrm.policy_inexact[dir] : - net->xfrm.policy_bydst[dir].table + hash); + if (hash == hmask + 1) + return &net->xfrm.policy_inexact[dir]; + + return rcu_dereference_check(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash; } static struct hlist_head *policy_hash_direct(struct net *net, @@ -403,7 +411,8 @@ static struct hlist_head *policy_hash_direct(struct net *net, __get_hash_thresh(net, family, dir, &dbits, &sbits); hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits); - return net->xfrm.policy_bydst[dir].table + hash; + return rcu_dereference_check(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash; } static void xfrm_dst_hash_transfer(struct net *net, @@ -426,14 +435,14 @@ redo: h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr, pol->family, nhashmask, dbits, sbits); if (!entry0) { - hlist_del(&pol->bydst); - hlist_add_head(&pol->bydst, ndsttable+h); + hlist_del_rcu(&pol->bydst); + hlist_add_head_rcu(&pol->bydst, ndsttable + h); h0 = h; } else { if (h != h0) continue; - hlist_del(&pol->bydst); - hlist_add_behind(&pol->bydst, entry0); + hlist_del_rcu(&pol->bydst); + hlist_add_behind_rcu(&pol->bydst, entry0); } entry0 = &pol->bydst; } @@ -468,22 +477,32 @@ static void xfrm_bydst_resize(struct net *net, int dir) unsigned int hmask = net->xfrm.policy_bydst[dir].hmask; unsigned int nhashmask = xfrm_new_hash_mask(hmask); unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head); - struct hlist_head *odst = net->xfrm.policy_bydst[dir].table; struct hlist_head *ndst = xfrm_hash_alloc(nsize); + struct hlist_head *odst; int i; if (!ndst) return; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + write_seqcount_begin(&xfrm_policy_hash_generation); + + odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)); + + odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)); for (i = hmask; i >= 0; i--) xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir); - net->xfrm.policy_bydst[dir].table = ndst; + rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst); net->xfrm.policy_bydst[dir].hmask = nhashmask; - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + write_seqcount_end(&xfrm_policy_hash_generation); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + synchronize_rcu(); xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head)); } @@ -500,7 +519,7 @@ static void xfrm_byidx_resize(struct net *net, int total) if (!nidx) return; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); for (i = hmask; i >= 0; i--) xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask); @@ -508,7 +527,7 @@ static void xfrm_byidx_resize(struct net *net, int total) net->xfrm.policy_byidx = nidx; net->xfrm.policy_idx_hmask = nhashmask; - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head)); } @@ -541,7 +560,6 @@ static inline int xfrm_byidx_should_resize(struct net *net, int total) void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si) { - read_lock_bh(&net->xfrm.xfrm_policy_lock); si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN]; si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT]; si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD]; @@ -550,7 +568,6 @@ void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si) si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX]; si->spdhcnt = net->xfrm.policy_idx_hmask; si->spdhmcnt = xfrm_policy_hashmax; - read_unlock_bh(&net->xfrm.xfrm_policy_lock); } EXPORT_SYMBOL(xfrm_spd_getinfo); @@ -600,7 +617,7 @@ static void xfrm_hash_rebuild(struct work_struct *work) rbits6 = net->xfrm.policy_hthresh.rbits6; } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq)); - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); /* reset the bydst and inexact table in all directions */ for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { @@ -642,7 +659,7 @@ static void xfrm_hash_rebuild(struct work_struct *work) hlist_add_head(&policy->bydst, chain); } - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); mutex_unlock(&hash_resize_mutex); } @@ -753,7 +770,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) struct hlist_head *chain; struct hlist_node *newpos; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); chain = policy_hash_bysel(net, &policy->selector, policy->family, dir); delpol = NULL; newpos = NULL; @@ -764,7 +781,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) xfrm_sec_ctx_match(pol->security, policy->security) && !WARN_ON(delpol)) { if (excl) { - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return -EEXIST; } delpol = pol; @@ -800,7 +817,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) policy->curlft.use_time = 0; if (!mod_timer(&policy->timer, jiffies + HZ)) xfrm_pol_hold(policy); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (delpol) xfrm_policy_kill(delpol); @@ -820,7 +837,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type, struct hlist_head *chain; *err = 0; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); chain = policy_hash_bysel(net, sel, sel->family, dir); ret = NULL; hlist_for_each_entry(pol, chain, bydst) { @@ -833,7 +850,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type, *err = security_xfrm_policy_delete( pol->security); if (*err) { - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return pol; } __xfrm_policy_unlink(pol, dir); @@ -842,7 +859,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type, break; } } - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (ret && delete) xfrm_policy_kill(ret); @@ -861,7 +878,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type, return NULL; *err = 0; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); chain = net->xfrm.policy_byidx + idx_hash(net, id); ret = NULL; hlist_for_each_entry(pol, chain, byidx) { @@ -872,7 +889,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type, *err = security_xfrm_policy_delete( pol->security); if (*err) { - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return pol; } __xfrm_policy_unlink(pol, dir); @@ -881,7 +898,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type, break; } } - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (ret && delete) xfrm_policy_kill(ret); @@ -939,7 +956,7 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) { int dir, err = 0, cnt = 0; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); err = xfrm_policy_flush_secctx_check(net, type, task_valid); if (err) @@ -955,14 +972,14 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) if (pol->type != type) continue; __xfrm_policy_unlink(pol, dir); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); cnt++; xfrm_audit_policy_delete(pol, 1, task_valid); xfrm_policy_kill(pol); - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); goto again1; } @@ -974,13 +991,13 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) if (pol->type != type) continue; __xfrm_policy_unlink(pol, dir); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); cnt++; xfrm_audit_policy_delete(pol, 1, task_valid); xfrm_policy_kill(pol); - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); goto again2; } } @@ -989,7 +1006,7 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) if (!cnt) err = -ESRCH; out: - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return err; } EXPORT_SYMBOL(xfrm_policy_flush); @@ -1009,7 +1026,7 @@ int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk, if (list_empty(&walk->walk.all) && walk->seq != 0) return 0; - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); if (list_empty(&walk->walk.all)) x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all); else @@ -1037,7 +1054,7 @@ int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk, } list_del_init(&walk->walk.all); out: - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return error; } EXPORT_SYMBOL(xfrm_policy_walk); @@ -1056,9 +1073,9 @@ void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net) if (list_empty(&walk->walk.all)) return; - write_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */ + spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */ list_del(&walk->walk.all); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); } EXPORT_SYMBOL(xfrm_policy_walk_done); @@ -1096,17 +1113,24 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type, struct xfrm_policy *pol, *ret; const xfrm_address_t *daddr, *saddr; struct hlist_head *chain; - u32 priority = ~0U; + unsigned int sequence; + u32 priority; daddr = xfrm_flowi_daddr(fl, family); saddr = xfrm_flowi_saddr(fl, family); if (unlikely(!daddr || !saddr)) return NULL; - read_lock_bh(&net->xfrm.xfrm_policy_lock); - chain = policy_hash_direct(net, daddr, saddr, family, dir); + rcu_read_lock(); + retry: + do { + sequence = read_seqcount_begin(&xfrm_policy_hash_generation); + chain = policy_hash_direct(net, daddr, saddr, family, dir); + } while (read_seqcount_retry(&xfrm_policy_hash_generation, sequence)); + + priority = ~0U; ret = NULL; - hlist_for_each_entry(pol, chain, bydst) { + hlist_for_each_entry_rcu(pol, chain, bydst) { err = xfrm_policy_match(pol, fl, type, family, dir); if (err) { if (err == -ESRCH) @@ -1122,7 +1146,7 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type, } } chain = &net->xfrm.policy_inexact[dir]; - hlist_for_each_entry(pol, chain, bydst) { + hlist_for_each_entry_rcu(pol, chain, bydst) { if ((pol->priority >= priority) && ret) break; @@ -1140,9 +1164,13 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type, } } - xfrm_pol_hold(ret); + if (read_seqcount_retry(&xfrm_policy_hash_generation, sequence)) + goto retry; + + if (ret && !xfrm_pol_hold_rcu(ret)) + goto retry; fail: - read_unlock_bh(&net->xfrm.xfrm_policy_lock); + rcu_read_unlock(); return ret; } @@ -1219,10 +1247,9 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir, const struct flowi *fl) { struct xfrm_policy *pol; - struct net *net = sock_net(sk); rcu_read_lock(); - read_lock_bh(&net->xfrm.xfrm_policy_lock); + again: pol = rcu_dereference(sk->sk_policy[dir]); if (pol != NULL) { bool match = xfrm_selector_match(&pol->selector, fl, @@ -1237,8 +1264,8 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir, err = security_xfrm_policy_lookup(pol->security, fl->flowi_secid, policy_to_flow_dir(dir)); - if (!err) - xfrm_pol_hold(pol); + if (!err && !xfrm_pol_hold_rcu(pol)) + goto again; else if (err == -ESRCH) pol = NULL; else @@ -1247,7 +1274,6 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir, pol = NULL; } out: - read_unlock_bh(&net->xfrm.xfrm_policy_lock); rcu_read_unlock(); return pol; } @@ -1271,7 +1297,7 @@ static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol, /* Socket policies are not hashed. */ if (!hlist_unhashed(&pol->bydst)) { - hlist_del(&pol->bydst); + hlist_del_rcu(&pol->bydst); hlist_del(&pol->byidx); } @@ -1295,9 +1321,9 @@ int xfrm_policy_delete(struct xfrm_policy *pol, int dir) { struct net *net = xp_net(pol); - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); pol = __xfrm_policy_unlink(pol, dir); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (pol) { xfrm_policy_kill(pol); return 0; @@ -1316,7 +1342,7 @@ int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol) return -EINVAL; #endif - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); old_pol = rcu_dereference_protected(sk->sk_policy[dir], lockdep_is_held(&net->xfrm.xfrm_policy_lock)); if (pol) { @@ -1334,7 +1360,7 @@ int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol) */ xfrm_sk_policy_unlink(old_pol, dir); } - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (old_pol) { xfrm_policy_kill(old_pol); @@ -1364,9 +1390,9 @@ static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir) newp->type = old->type; memcpy(newp->xfrm_vec, old->xfrm_vec, newp->xfrm_nr*sizeof(struct xfrm_tmpl)); - write_lock_bh(&net->xfrm.xfrm_policy_lock); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); xfrm_sk_policy_link(newp, dir); - write_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); xfrm_pol_put(newp); } return newp; @@ -3048,7 +3074,7 @@ static int __net_init xfrm_net_init(struct net *net) /* Initialize the per-net locks here */ spin_lock_init(&net->xfrm.xfrm_state_lock); - rwlock_init(&net->xfrm.xfrm_policy_lock); + spin_lock_init(&net->xfrm.xfrm_policy_lock); mutex_init(&net->xfrm.xfrm_cfg_mutex); return 0; @@ -3082,6 +3108,7 @@ static struct pernet_operations __net_initdata xfrm_net_ops = { void __init xfrm_init(void) { register_pernet_subsys(&xfrm_net_ops); + seqcount_init(&xfrm_policy_hash_generation); xfrm_input_init(); } @@ -3179,7 +3206,7 @@ static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector * struct hlist_head *chain; u32 priority = ~0U; - read_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME*/ + spin_lock_bh(&net->xfrm.xfrm_policy_lock); chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir); hlist_for_each_entry(pol, chain, bydst) { if (xfrm_migrate_selector_match(sel, &pol->selector) && @@ -3203,7 +3230,7 @@ static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector * xfrm_pol_hold(ret); - read_unlock_bh(&net->xfrm.xfrm_policy_lock); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); return ret; } diff --git a/net/xfrm/xfrm_replay.c b/net/xfrm/xfrm_replay.c index 4fd725a0c500..cdc2e2e71bff 100644 --- a/net/xfrm/xfrm_replay.c +++ b/net/xfrm/xfrm_replay.c @@ -558,7 +558,7 @@ static void xfrm_replay_advance_esn(struct xfrm_state *x, __be32 net_seq) x->repl->notify(x, XFRM_REPLAY_UPDATE); } -static struct xfrm_replay xfrm_replay_legacy = { +static const struct xfrm_replay xfrm_replay_legacy = { .advance = xfrm_replay_advance, .check = xfrm_replay_check, .recheck = xfrm_replay_check, @@ -566,7 +566,7 @@ static struct xfrm_replay xfrm_replay_legacy = { .overflow = xfrm_replay_overflow, }; -static struct xfrm_replay xfrm_replay_bmp = { +static const struct xfrm_replay xfrm_replay_bmp = { .advance = xfrm_replay_advance_bmp, .check = xfrm_replay_check_bmp, .recheck = xfrm_replay_check_bmp, @@ -574,7 +574,7 @@ static struct xfrm_replay xfrm_replay_bmp = { .overflow = xfrm_replay_overflow_bmp, }; -static struct xfrm_replay xfrm_replay_esn = { +static const struct xfrm_replay xfrm_replay_esn = { .advance = xfrm_replay_advance_esn, .check = xfrm_replay_check_esn, .recheck = xfrm_replay_recheck_esn, diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 9895a8c56d8c..ba8bf518ba14 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -28,6 +28,11 @@ #include "xfrm_hash.h" +#define xfrm_state_deref_prot(table, net) \ + rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock)) + +static void xfrm_state_gc_task(struct work_struct *work); + /* Each xfrm_state may be linked to two tables: 1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl) @@ -36,6 +41,15 @@ */ static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024; +static __read_mostly seqcount_t xfrm_state_hash_generation = SEQCNT_ZERO(xfrm_state_hash_generation); + +static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task); +static HLIST_HEAD(xfrm_state_gc_list); + +static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x) +{ + return atomic_inc_not_zero(&x->refcnt); +} static inline unsigned int xfrm_dst_hash(struct net *net, const xfrm_address_t *daddr, @@ -76,18 +90,18 @@ static void xfrm_hash_transfer(struct hlist_head *list, h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr, x->props.reqid, x->props.family, nhashmask); - hlist_add_head(&x->bydst, ndsttable+h); + hlist_add_head_rcu(&x->bydst, ndsttable + h); h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr, x->props.family, nhashmask); - hlist_add_head(&x->bysrc, nsrctable+h); + hlist_add_head_rcu(&x->bysrc, nsrctable + h); if (x->id.spi) { h = __xfrm_spi_hash(&x->id.daddr, x->id.spi, x->id.proto, x->props.family, nhashmask); - hlist_add_head(&x->byspi, nspitable+h); + hlist_add_head_rcu(&x->byspi, nspitable + h); } } } @@ -122,25 +136,29 @@ static void xfrm_hash_resize(struct work_struct *work) } spin_lock_bh(&net->xfrm.xfrm_state_lock); + write_seqcount_begin(&xfrm_state_hash_generation); nhashmask = (nsize / sizeof(struct hlist_head)) - 1U; + odst = xfrm_state_deref_prot(net->xfrm.state_bydst, net); for (i = net->xfrm.state_hmask; i >= 0; i--) - xfrm_hash_transfer(net->xfrm.state_bydst+i, ndst, nsrc, nspi, - nhashmask); + xfrm_hash_transfer(odst + i, ndst, nsrc, nspi, nhashmask); - odst = net->xfrm.state_bydst; - osrc = net->xfrm.state_bysrc; - ospi = net->xfrm.state_byspi; + osrc = xfrm_state_deref_prot(net->xfrm.state_bysrc, net); + ospi = xfrm_state_deref_prot(net->xfrm.state_byspi, net); ohashmask = net->xfrm.state_hmask; - net->xfrm.state_bydst = ndst; - net->xfrm.state_bysrc = nsrc; - net->xfrm.state_byspi = nspi; + rcu_assign_pointer(net->xfrm.state_bydst, ndst); + rcu_assign_pointer(net->xfrm.state_bysrc, nsrc); + rcu_assign_pointer(net->xfrm.state_byspi, nspi); net->xfrm.state_hmask = nhashmask; + write_seqcount_end(&xfrm_state_hash_generation); spin_unlock_bh(&net->xfrm.xfrm_state_lock); osize = (ohashmask + 1) * sizeof(struct hlist_head); + + synchronize_rcu(); + xfrm_hash_free(odst, osize); xfrm_hash_free(osrc, osize); xfrm_hash_free(ospi, osize); @@ -355,15 +373,16 @@ static void xfrm_state_gc_destroy(struct xfrm_state *x) static void xfrm_state_gc_task(struct work_struct *work) { - struct net *net = container_of(work, struct net, xfrm.state_gc_work); struct xfrm_state *x; struct hlist_node *tmp; struct hlist_head gc_list; spin_lock_bh(&xfrm_state_gc_lock); - hlist_move_list(&net->xfrm.state_gc_list, &gc_list); + hlist_move_list(&xfrm_state_gc_list, &gc_list); spin_unlock_bh(&xfrm_state_gc_lock); + synchronize_rcu(); + hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) xfrm_state_gc_destroy(x); } @@ -500,14 +519,12 @@ EXPORT_SYMBOL(xfrm_state_alloc); void __xfrm_state_destroy(struct xfrm_state *x) { - struct net *net = xs_net(x); - WARN_ON(x->km.state != XFRM_STATE_DEAD); spin_lock_bh(&xfrm_state_gc_lock); - hlist_add_head(&x->gclist, &net->xfrm.state_gc_list); + hlist_add_head(&x->gclist, &xfrm_state_gc_list); spin_unlock_bh(&xfrm_state_gc_lock); - schedule_work(&net->xfrm.state_gc_work); + schedule_work(&xfrm_state_gc_work); } EXPORT_SYMBOL(__xfrm_state_destroy); @@ -520,10 +537,10 @@ int __xfrm_state_delete(struct xfrm_state *x) x->km.state = XFRM_STATE_DEAD; spin_lock(&net->xfrm.xfrm_state_lock); list_del(&x->km.all); - hlist_del(&x->bydst); - hlist_del(&x->bysrc); + hlist_del_rcu(&x->bydst); + hlist_del_rcu(&x->bysrc); if (x->id.spi) - hlist_del(&x->byspi); + hlist_del_rcu(&x->byspi); net->xfrm.state_num--; spin_unlock(&net->xfrm.xfrm_state_lock); @@ -659,7 +676,7 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family); struct xfrm_state *x; - hlist_for_each_entry(x, net->xfrm.state_byspi+h, byspi) { + hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) { if (x->props.family != family || x->id.spi != spi || x->id.proto != proto || @@ -668,7 +685,8 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, if ((mark & x->mark.m) != x->mark.v) continue; - xfrm_state_hold(x); + if (!xfrm_state_hold_rcu(x)) + continue; return x; } @@ -683,7 +701,7 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, unsigned int h = xfrm_src_hash(net, daddr, saddr, family); struct xfrm_state *x; - hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) { + hlist_for_each_entry_rcu(x, net->xfrm.state_bysrc + h, bysrc) { if (x->props.family != family || x->id.proto != proto || !xfrm_addr_equal(&x->id.daddr, daddr, family) || @@ -692,7 +710,8 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, if ((mark & x->mark.m) != x->mark.v) continue; - xfrm_state_hold(x); + if (!xfrm_state_hold_rcu(x)) + continue; return x; } @@ -775,13 +794,16 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, struct xfrm_state *best = NULL; u32 mark = pol->mark.v & pol->mark.m; unsigned short encap_family = tmpl->encap_family; + unsigned int sequence; struct km_event c; to_put = NULL; - spin_lock_bh(&net->xfrm.xfrm_state_lock); + sequence = read_seqcount_begin(&xfrm_state_hash_generation); + + rcu_read_lock(); h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); - hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) { if (x->props.family == encap_family && x->props.reqid == tmpl->reqid && (mark & x->mark.m) == x->mark.v && @@ -797,7 +819,7 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, goto found; h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family); - hlist_for_each_entry(x, net->xfrm.state_bydst+h_wildcard, bydst) { + hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) { if (x->props.family == encap_family && x->props.reqid == tmpl->reqid && (mark & x->mark.m) == x->mark.v && @@ -850,19 +872,21 @@ found: } if (km_query(x, tmpl, pol) == 0) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); x->km.state = XFRM_STATE_ACQ; list_add(&x->km.all, &net->xfrm.state_all); - hlist_add_head(&x->bydst, net->xfrm.state_bydst+h); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); h = xfrm_src_hash(net, daddr, saddr, encap_family); - hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family); - hlist_add_head(&x->byspi, net->xfrm.state_byspi+h); + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); } x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL); net->xfrm.state_num++; xfrm_hash_grow_check(net, x->bydst.next != NULL); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); } else { x->km.state = XFRM_STATE_DEAD; to_put = x; @@ -871,13 +895,26 @@ found: } } out: - if (x) - xfrm_state_hold(x); - else + if (x) { + if (!xfrm_state_hold_rcu(x)) { + *err = -EAGAIN; + x = NULL; + } + } else { *err = acquire_in_progress ? -EAGAIN : error; - spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } + rcu_read_unlock(); if (to_put) xfrm_state_put(to_put); + + if (read_seqcount_retry(&xfrm_state_hash_generation, sequence)) { + *err = -EAGAIN; + if (x) { + xfrm_state_put(x); + x = NULL; + } + } + return x; } @@ -945,16 +982,16 @@ static void __xfrm_state_insert(struct xfrm_state *x) h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr, x->props.reqid, x->props.family); - hlist_add_head(&x->bydst, net->xfrm.state_bydst+h); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family); - hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); - hlist_add_head(&x->byspi, net->xfrm.state_byspi+h); + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); } tasklet_hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL); @@ -1063,9 +1100,9 @@ static struct xfrm_state *__find_acq_core(struct net *net, xfrm_state_hold(x); tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL); list_add(&x->km.all, &net->xfrm.state_all); - hlist_add_head(&x->bydst, net->xfrm.state_bydst+h); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); h = xfrm_src_hash(net, daddr, saddr, family); - hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); net->xfrm.state_num++; @@ -1581,7 +1618,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) if (x->id.spi) { spin_lock_bh(&net->xfrm.xfrm_state_lock); h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); - hlist_add_head(&x->byspi, net->xfrm.state_byspi+h); + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); spin_unlock_bh(&net->xfrm.xfrm_state_lock); err = 0; @@ -2099,8 +2136,6 @@ int __net_init xfrm_state_init(struct net *net) net->xfrm.state_num = 0; INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize); - INIT_HLIST_HEAD(&net->xfrm.state_gc_list); - INIT_WORK(&net->xfrm.state_gc_work, xfrm_state_gc_task); spin_lock_init(&net->xfrm.xfrm_state_lock); return 0; @@ -2118,7 +2153,7 @@ void xfrm_state_fini(struct net *net) flush_work(&net->xfrm.state_hash_work); xfrm_state_flush(net, IPSEC_PROTO_ANY, false); - flush_work(&net->xfrm.state_gc_work); + flush_work(&xfrm_state_gc_work); WARN_ON(!list_empty(&net->xfrm.state_all));