D7net
Home
Console
Upload
information
Create File
Create Folder
About
Tools
:
/
opt
/
imunify360
/
venv
/
lib
/
python3.11
/
site-packages
/
im360
/
internals
/
core
/
Filename :
__init__.py
back
Copy
"""Core module for rules and sets managing.""" import logging import math import time from typing import Dict, Iterable, List, Optional, Set, Tuple from defence360agent.utils import await_for, retry_on, timeit from im360.contracts.config import NetworkInterface, UnifiedAccessLogger from im360.internals.core.ipset.port_deny import ( InputPortBlockingDenyModeIPSet, OutputPortBlockingDenyModeIPSet, ) from im360.utils.validate import IPVersion from . import ip_versions from .firewall import ( FirewallRules, RuleDef, firewall_logging_enabled, is_nat_available, ) from .ipset import IP_SET_PREFIX, libipset from .ipset.country import IPSetCountry from .ipset.ip import IPSet from .ipset.libipset import IPSetCmdBuilder, IPSetRestoreCmd from .ipset.port import IPSetIgnoredByPort, IPSetPort from .ipset.redirect import ( IPSetNoRedirectPort, IPSetWebshieldPort, ) from .ipset.sync import IPSetSyncIPListPurpose, IPSetSyncIPListRecords logger = logging.getLogger(__name__) class RuleSet: """Managing iptables rules and ipsets.""" _CHAINS = [ FirewallRules.COUNTRY_WHITELIST_CHAIN, FirewallRules.COUNTRY_BLACKLIST_CHAIN, FirewallRules.BP_INPUT_CHAIN, FirewallRules.LOG_BLACKLIST_CHAIN, FirewallRules.LOG_GRAYLIST_CHAIN, FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN, FirewallRules.WEBSHIELD_PORTS_INPUT_CHAIN, FirewallRules.LOG_BLOCK_PORT_CHAIN, ] # Since DB and ipset are updated at different times, # check relative value instead of compare absolute values. # Use a large enough relative number to avoid false positives, # 20% difference looks reasonable for this. _IPSET_COUNT_TO_RECREATE_THRESHOLD = 0.2 def __init__(self): self.entities = ( InputPortBlockingDenyModeIPSet(), OutputPortBlockingDenyModeIPSet(), IPSetPort(), IPSet(), # Order is important here, # Ensure IPSetSyncIPListRecords is created before IPSetSyncIPListPurpose IPSetSyncIPListRecords(), IPSetSyncIPListPurpose(), IPSetCountry(), IPSetIgnoredByPort(), IPSetNoRedirectPort(), IPSetWebshieldPort(), ) @staticmethod def targets(ip_version: IPVersion) -> List[Tuple]: """ Returns tables & chains that Imunify360 will use in firewall management :param ip_version: IPv4 or IPv6 :return: List[Tuple]: """ return [ (FirewallRules.FILTER, "INPUT"), ( (FirewallRules.NAT, "PREROUTING") if is_nat_available(ip_version) else (FirewallRules.MANGLE, "PREROUTING") ), ] @staticmethod def _apply_ignored_interfaces(action, interface_conf, *args, **kwargs): """ :param interface_conf: interface configuration :param Callable action: action to perform with interface """ for interface in interface_conf[NetworkInterface.DEVICE_SKIP]: yield action( FirewallRules.compose_rule( FirewallRules.interface(interface), action=FirewallRules.compose_action(FirewallRules.ACCEPT), ), chain=FirewallRules.IMUNIFY_INPUT_CHAIN, priority=0, # max priority for firewalld *args, **kwargs, ) @staticmethod def _compose_rule(ip_version: IPVersion, interface_conf: dict) -> RuleDef: """Compose rule based on NetworkInterface config""" target_interface = interface_conf[ip_version] action = FirewallRules.compose_action( FirewallRules.IMUNIFY_INPUT_CHAIN ) if target_interface: rule = FirewallRules.compose_rule( FirewallRules.interface(target_interface), action=action ) else: rule = action return rule async def ipset_create_commands(self, ip_version: IPVersion) -> List[str]: names = [] # type: List[str] for entity in self.entities: names.extend(entity.gen_ipset_create_ops(ip_version)) return names async def ipset_flush_commands( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ) -> Iterable[IPSetRestoreCmd]: """Generate ipset restore commands to destroy *existing* ipsets.""" if existing is None: existing = await self.existing_ipsets(ip_version) # get entity specific flush commands cmds = [] needed_entities = [ entity for entity in self.entities if hasattr(entity, "gen_ipset_flush_cmds") ] for entity in needed_entities: cmds += entity.gen_ipset_flush_cmds(ip_version, existing) return cmds async def ipset_destroy_commands( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ) -> Iterable[IPSetRestoreCmd]: """Generate ipset restore commands to destroy *existing* ipsets.""" if existing is None: existing = await self.existing_ipsets(ip_version) # get entity specific destroy commands cmds = {} # type: Dict[str, IPSetRestoreCmd] for entity in self.entities: entity_cmds = entity.gen_ipset_destroy_ops(ip_version, existing) cmds.update(entity_cmds) # generic destroy for ipset_name in existing: if ipset_name not in cmds: # ipset is not special, remove using a generic destroy command cmds[ipset_name] = IPSetCmdBuilder.get_destroy_cmd(ipset_name) return cmds.values() async def create_commands( self, firewall, interface_conf: dict, ip_version: IPVersion ) -> list: """Return a list of firewall commands to create all required rules.""" actions = [] # input chains for table, chain in self.targets(ip_version): actions.extend( [ firewall.create_chain( table=table, chain=FirewallRules.IMUNIFY_INPUT_CHAIN ), firewall.insert_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ), *self._apply_ignored_interfaces( firewall.insert_rule, interface_conf, table=table ), ] ) actions.extend( [ firewall.create_chain(table=FirewallRules.FILTER, chain=chain) for chain in self._CHAINS ] ) actions.extend(self._log_block_rules(firewall.append_rule, ip_version)) # output chains actions.extend( [ firewall.create_chain( table=FirewallRules.FILTER, chain=FirewallRules.IMUNIFY_OUTPUT_CHAIN, ), firewall.insert_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), chain="OUTPUT", ), ] ) actions.extend( [ firewall.create_chain(table=FirewallRules.FILTER, chain=chain) for chain in [FirewallRules.BP_OUTPUT_CHAIN] ] ) actions.extend( [ firewall.append_rule(**rule) for rule in await self._collect_ipset_rules(ip_version) ] ) # Add connection tracking rule. actions.append( firewall.insert_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING" ) ) return actions def destroy_commands( self, firewall, interface_conf: dict, ip_version: IPVersion ) -> Iterable[list]: """Returns an iterable over list of commands to destroy firewall rules. Each list should be executed as a separate firewall commit operation.""" # input chains for table, chain in self.targets(ip_version): yield [ firewall.delete_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ) ] yield [ firewall.flush_chain( FirewallRules.IMUNIFY_INPUT_CHAIN, table=table ), firewall.delete_chain( FirewallRules.IMUNIFY_INPUT_CHAIN, table=table ), ] for chain in self._CHAINS: yield [ firewall.flush_chain(chain, table=FirewallRules.FILTER), firewall.delete_chain(chain, table=FirewallRules.FILTER), ] # output chains yield [ firewall.delete_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), chain="OUTPUT", ) ] yield [ firewall.flush_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN), firewall.delete_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN), ] for chain in [FirewallRules.BP_OUTPUT_CHAIN]: yield [firewall.flush_chain(chain), firewall.delete_chain(chain)] # Delete connection tracking rule. yield [ firewall.delete_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING", ) ] def required_ipsets(self, ip_version: IPVersion) -> Set[str]: names = set() # type: Set[str] for entity in self.entities: names.update(entity.get_all_ipsets(ip_version)) return names async def check_commands( self, firewall, interface_conf, ip_version: IPVersion ) -> list: """Returns a list of firewall commands to check for firewall rules.""" actions = [] for table, chain in self.targets(ip_version): actions.extend( [ firewall.has_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ), *self._apply_ignored_interfaces( firewall.has_rule, interface_conf, table=table ), ] ) actions.extend(self._log_block_rules(firewall.has_rule, ip_version)) actions.extend( [ firewall.has_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), table=FirewallRules.FILTER, chain="OUTPUT", ), ] ) actions.extend( [ firewall.has_rule(**rule) for rule in await self._collect_ipset_rules(ip_version) ] ) actions.append( firewall.has_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING" ) ) return actions def _log_block_rules(self, predicate, ip_version: IPVersion): rules = [] for chain, prefix, action in ( ( FirewallRules.LOG_BLACKLIST_CHAIN, UnifiedAccessLogger.BLACKLIST, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_GRAYLIST_CHAIN, UnifiedAccessLogger.GRAYLIST, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN, UnifiedAccessLogger.BLACKLIST_COUNTRY, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_BLOCK_PORT_CHAIN, UnifiedAccessLogger.BLOCKED_BY_PORT, FirewallRules.compose_action(FirewallRules.REJECT), ), ): rules.extend( predicate(rule, table=FirewallRules.FILTER, chain=chain) for rule in self._log_drop_rules(ip_version, prefix, action) ) return rules async def _collect_ipset_rules(self, ip_version: IPVersion) -> List[dict]: rules = [] # type: List[dict] for entity in self.entities: rules.extend(entity.get_rules(ip_version)) rules.sort(key=lambda r: (r["chain"], r["priority"])) return rules async def fill_ipsets( self, ip_version: IPVersion, missing: Set[str] ) -> None: """Fills all ipsets with required elements.""" create_and_restore_cmds = [] for entity in self.entities: for ip_set in entity.get_all_ipset_instances(ip_version): if ip_set.gen_ipset_name_for_ip_version(ip_version) in missing: create_and_restore_cmds.extend( ip_set.gen_ipset_create_ops(ip_version) ) create_and_restore_cmds.extend( await ip_set.gen_ipset_restore_ops(ip_version) ) await libipset.restore(create_and_restore_cmds) logger.info("IP sets content restored from database") @staticmethod async def existing_ipsets(ip_version: IPVersion) -> Set[str]: prefix = ".".join([IP_SET_PREFIX, ip_version]) return set( s for s in await libipset.list_set() if s.startswith(prefix) ) async def destroy_ipsets( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ) -> None: """Destroys ipsets with given names.""" if existing is None: to_destroy = await self.existing_ipsets(ip_version) else: to_destroy = existing.copy() max_tries = 3 attempt = 0 while to_destroy or attempt > max_tries: # remove absent ipsets to_destroy &= await self.existing_ipsets(ip_version) try: await libipset.restore( await self.ipset_flush_commands(ip_version, to_destroy) ) await libipset.restore( await self.ipset_destroy_commands(ip_version, to_destroy) ) return except ( libipset.IPSetNotFoundError, libipset.IPSetCannotBeDestroyedError, ): attempt += 1 if to_destroy or attempt > max_tries: logger.error("Failed to destroy ipsets: %s", ", ".join(to_destroy)) async def _recreate_ipsets( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ): """Reset all ipsets, create them again and fill with IPs for given ip version.""" for entity in self.entities: await entity.reset(ip_version, existing) async def recreate_ipsets( self, ip_version: IPVersion = None, existing: Optional[Set[str]] = None ): """Recreate existing ipsets (or given). If *ip_version* is None, recreate ipsets for all enabled ip versions. """ if ip_version: await self._recreate_ipsets(ip_version, existing) else: for ip_version in ip_versions.enabled(): await self._recreate_ipsets(ip_version, existing) @staticmethod def _log_drop_rules(ip_version: IPVersion, prefix, action): rules = [] if firewall_logging_enabled(): rules.append( FirewallRules.compose_rule( action=FirewallRules.nflog_action( group=FirewallRules.nflog_group(ip_version), prefix=prefix, ) ) ) rules.append(action) return rules async def get_outdated_ipsets(self, ip_version: IPVersion) -> list: """ Return list of ipsets the contents of which do not match the database """ outdated: list = [] for entity in self.entities: all_ipsets = await entity.get_ipsets_count(ip_version) outdated.extend( ipset for ipset in all_ipsets if not math.isclose( ipset.ipset_count, ipset.db_count, rel_tol=self._IPSET_COUNT_TO_RECREATE_THRESHOLD, ) ) return outdated