Commit e87e614e authored by Sigmund Augdal's avatar Sigmund Augdal

New strategy for handling source security groups.

The code now uses a source security group and a destination security group and one iptables rule that matches both for each such case.
To get atomic updates all ipsets are recreated for each run with a generation number encoded in the name. Old ipsets are removed onces the iptables rules are updated
parent f2b4dc0f
......@@ -23,14 +23,28 @@ class Generator(object):
handler = logging.handlers.RotatingFileHandler(logfile, maxBytes=10*1024**3,
backupCount=5)
logging.getLogger("").addHandler(handler)
# maps security group id from etcd to the name of a ipset containing its members
self.group_members_groups = {}
# maps security group id from etcd to a ipset containing all rules with this group as source
self.by_source_groups = {}
self.generation = 0
self.serial = 0
def output(self, line):
self.output_file.write(line)
self.output_file.write("\n")
def create_ipset(self, name, set_type):
def set_name(self, prefix, family):
return "{}_{}_{}".format(prefix, family, self.generation)
def create_ipset(self, name, set_type, with_serial=False):
if with_serial:
name = "{}{}".format(name, self.serial)
self.serial += 1
for family in ("inet", "inet6"):
self.output("create {}_{} {} family {}".format(name, family, set_type, family))
setname = self.set_name(name, family)
self.output("create {} {} family {}".format(setname, set_type, family))
return name
def add_ipset_member(self, name, member, protocol="tcp", port=None, net=None, source=None):
suffix = ""
......@@ -40,28 +54,35 @@ class Generator(object):
suffix += ",{}".format(net)
member = member.lower()
if member in self.addresses_v4 and (net is None or "." in net):
setname = self.set_name(name, "inet")
if source is not None:
if source in self.addresses_v4:
suffix4 = "{},{}".format(suffix, self.addresses_v4[source])
self.output("add {}_inet {}{}".format(name, self.addresses_v4[member],
suffix4))
self.output("add {} {}{}".format(setname, self.addresses_v4[member],
suffix4))
else:
self.output("add {}_inet {}{}".format(name, self.addresses_v4[member],
suffix))
self.output("add {} {}{}".format(setname, self.addresses_v4[member],
suffix))
if member in self.addresses_v6 and (net is None or ":" in net):
setname = self.set_name(name, "inet6")
if source is not None:
if source in self.addresses_v6:
suffix6 = "{},{}".format(suffix, self.addresses_v6[source])
self.output("add {}_inet6 {}{}".format(name, self.addresses_v6[member],
suffix6))
self.output("add {} {}{}".format(setname, self.addresses_v6[member],
suffix6))
else:
self.output("add {}_inet6 {}{}".format(name, self.addresses_v6[member],
suffix))
self.output("add {} {}{}".format(setname, self.addresses_v6[member],
suffix))
def process_security_group(self, group_id, name):
rules = security_groups.get_group_rules(self.etcd_client, group_id)
_, members = security_groups.get_group_members(self.etcd_client, group_id)
members_name = self.create_ipset("source", "hash:ip", True)
for member in members:
self.add_ipset_member(members_name, member)
self.group_members_groups[group_id] = members_name
for rule in rules:
for member in members:
if rule["source_type"] == "any":
......@@ -73,9 +94,11 @@ class Generator(object):
rule["source_cidr"])
elif rule["source_type"] == "security_group":
source_group = rule["source_security_group"]
for source in security_groups.get_group_members(self.etcd_client, source_group)[1]:
self.add_ipset_member("rules_from_sg", member, rule["protocol"],
rule["destination_port"], source=source)
if not source_group in self.by_source_groups:
self.by_source_groups[source_group] = self.create_ipset("rules_by_source", "hash:ip,port", True)
group_name = self.by_source_groups[source_group]
self.add_ipset_member(group_name, member, rule["protocol"],
rule["destination_port"])
else:
logging.warning("Unhandled source type: %s", rule["source_type"])
......@@ -93,19 +116,56 @@ class Generator(object):
raise ex
return addresses
def import_iptables(self, filename, output):
infile = open(filename)
for line in infile.readlines():
if line == "COMMIT\n":
return
output.write(line)
def generate_all(self):
index = None
self.addresses_v4 = self.get_addresses("ipv4")
self.addresses_v6 = self.get_addresses("ipv6_public")
self.output_file = tempfile.TemporaryFile()
self.output("flush")
self.serial = 0
old_groups = subprocess.check_output(["ipset", "list", "-name"])
old_groups = old_groups.split("\n")
old_groups = [group for group in old_groups if group]
if len(old_groups) >= 6:
self.generation = max((int(group.split("_")[-1]) for group in old_groups)) + 1
else:
self.generation = 0
logging.debug("Building ipsets with generation %d", self.generation)
self.create_ipset("rules_from_any", "hash:ip,port")
self.create_ipset("rules_from_cidr", "hash:ip,port,net")
groups = security_groups.get_security_groups(self.etcd_client)
for group_id, name in groups.items():
self.process_security_group(group_id, name)
self.output_file.seek(0)
subprocess.call("ipset restore", stdin=self.output_file, shell=True)
# subprocess.call("cat", stdin=self.output_file, shell=True)
iptables_output = tempfile.TemporaryFile()
ip6tables_output = tempfile.TemporaryFile()
self.import_iptables("/etc/iptables.save", iptables_output)
self.import_iptables("/etc/ip6tables.save", ip6tables_output)
for family, output in (("inet", iptables_output), ("inet6", ip6tables_output)):
output.write("-A FORWARD -m set --match-set {} dst,dst -j ACCEPT\n".format(self.set_name("rules_from_any", family)))
output.write("-A FORWARD -m set --match-set {} dst,dst,src -j ACCEPT\n".format(self.set_name("rules_from_cidr", family)))
for group, ipset in self.by_source_groups.items():
output.write("-A FORWARD -m set --match-set {} src -m set --match-set {} dst,dst -j ACCEPT\n".format(self.set_name(self.group_members_groups[group], family),
self.set_name(ipset, family)))
output.write("COMMIT\n")
output.seek(0)
# subprocess.call("cat", stdin=iptables_output, shell=True)
# subprocess.call("cat", stdin=ip6tables_output, shell=True)
subprocess.call("iptables-restore", stdin=iptables_output, shell=True)
subprocess.call("ip6tables-restore", stdin=ip6tables_output, shell=True)
for ipset in old_groups:
subprocess.call("ipset destroy {}".format(ipset), shell=True)
return index
def main(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment