From bd611ac59f7c4db885a2f8631ef0bcdcd1901ca0 Mon Sep 17 00:00:00 2001 From: Johan Lundberg Date: Thu, 2 Apr 2015 10:43:33 +0200 Subject: Init --- .gitignore | 46 + COPYING | 202 ++ MANIFEST.in | 5 + README | 10 + aclcheck_cmdline.py | 58 + aclgen.py | 218 ++ def/NETWORK.net | 119 + def/SERVICES.svc | 62 + definate.py | 293 ++ definate/COPYING | 202 ++ definate/README | 90 + definate/__init__.py | 26 + definate/definate.yaml | 36 + definate/definition_filter.py | 164 ++ definate/dns_generator.py | 88 + definate/file_filter.py | 121 + definate/filter_factory.py | 104 + definate/generator.py | 56 + definate/generator_factory.py | 57 + definate/global_filter.py | 68 + definate/yaml_validator.py | 87 + doc/README.txt | 9 + doc/naming_definitions.txt | 40 + doc/policy_format.txt | 298 ++ doc/quick_start.txt | 48 + filters/.save | 0 filters/sample_srx.srx | 73 + filters/sample_tug_wlc_fw.acl | 49 + filters/sample_tug_wlc_fw.asa | 27 + filters/sample_tug_wlc_fw.demo | 55 + filters/sample_tug_wlc_fw.html | 55 + filters/sample_tug_wlc_fw.ipt | 28 + filters/sample_tug_wlc_fw.jcl | 62 + filters/sample_tug_wlc_fw.srx | 96 + lib/COPYING | 202 ++ lib/PKG-INFO | 18 + lib/README | 10 + lib/__init__.py | 31 + lib/aclcheck.py | 302 ++ lib/aclgenerator.py | 418 +++ lib/cisco.py | 744 +++++ lib/ciscoasa.py | 454 +++ lib/demo.py | 241 ++ lib/html.py | 233 ++ lib/ipset.py | 200 ++ lib/iptables.py | 789 +++++ lib/juniper.py | 727 +++++ lib/junipersrx.py | 448 +++ lib/nacaddr.py | 250 ++ lib/naming.py | 502 ++++ lib/packetfilter.py | 348 +++ lib/policy.py | 1821 ++++++++++++ lib/policyreader.py | 245 ++ lib/port.py | 55 + lib/setup.py | 39 + lib/speedway.py | 50 + make_dist.sh | 19 + policies/includes/untrusted-networks-blocking.inc | 18 + policies/sample_srx.pol | 26 + policies/sample_tug_wlc_fw.pol | 36 + setup.py | 43 + third_party/__init__.py | 0 third_party/ipaddr.py | 1951 ++++++++++++ third_party/ply/__init__.py | 4 + third_party/ply/lex.py | 1058 +++++++ third_party/ply/yacc.py | 3276 +++++++++++++++++++++ tools/cgrep.py | 80 + tools/get-country-zones.pl | 64 + 68 files changed, 17654 insertions(+) create mode 100644 .gitignore create mode 100644 COPYING create mode 100644 MANIFEST.in create mode 100644 README create mode 100755 aclcheck_cmdline.py create mode 100755 aclgen.py create mode 100644 def/NETWORK.net create mode 100644 def/SERVICES.svc create mode 100755 definate.py create mode 100644 definate/COPYING create mode 100644 definate/README create mode 100644 definate/__init__.py create mode 100644 definate/definate.yaml create mode 100755 definate/definition_filter.py create mode 100755 definate/dns_generator.py create mode 100755 definate/file_filter.py create mode 100755 definate/filter_factory.py create mode 100755 definate/generator.py create mode 100755 definate/generator_factory.py create mode 100755 definate/global_filter.py create mode 100755 definate/yaml_validator.py create mode 100644 doc/README.txt create mode 100644 doc/naming_definitions.txt create mode 100644 doc/policy_format.txt create mode 100644 doc/quick_start.txt create mode 100644 filters/.save create mode 100644 filters/sample_srx.srx create mode 100644 filters/sample_tug_wlc_fw.acl create mode 100644 filters/sample_tug_wlc_fw.asa create mode 100644 filters/sample_tug_wlc_fw.demo create mode 100644 filters/sample_tug_wlc_fw.html create mode 100644 filters/sample_tug_wlc_fw.ipt create mode 100644 filters/sample_tug_wlc_fw.jcl create mode 100644 filters/sample_tug_wlc_fw.srx create mode 100644 lib/COPYING create mode 100644 lib/PKG-INFO create mode 100644 lib/README create mode 100644 lib/__init__.py create mode 100755 lib/aclcheck.py create mode 100755 lib/aclgenerator.py create mode 100644 lib/cisco.py create mode 100644 lib/ciscoasa.py create mode 100755 lib/demo.py create mode 100755 lib/html.py create mode 100644 lib/ipset.py create mode 100644 lib/iptables.py create mode 100644 lib/juniper.py create mode 100644 lib/junipersrx.py create mode 100644 lib/nacaddr.py create mode 100644 lib/naming.py create mode 100644 lib/packetfilter.py create mode 100644 lib/policy.py create mode 100644 lib/policyreader.py create mode 100755 lib/port.py create mode 100644 lib/setup.py create mode 100755 lib/speedway.py create mode 100755 make_dist.sh create mode 100644 policies/includes/untrusted-networks-blocking.inc create mode 100644 policies/sample_srx.pol create mode 100644 policies/sample_tug_wlc_fw.pol create mode 100755 setup.py create mode 100644 third_party/__init__.py create mode 100644 third_party/ipaddr.py create mode 100644 third_party/ply/__init__.py create mode 100644 third_party/ply/lex.py create mode 100644 third_party/ply/yacc.py create mode 100755 tools/cgrep.py create mode 100755 tools/get-country-zones.pl diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d23d1b5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +# C extensions +*.so +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +# Translations +*.mo +*.pot +# Django stuff: +*.log +# Sphinx documentation +docs/_build/ +# PyBuilder +target/ diff --git a/COPYING b/COPYING new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/COPYING @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..462b604 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include COPYING +include filters/.save +include policies/sample.pol +include def/NETWORK.net +include def/SERVICES.svc diff --git a/README b/README new file mode 100644 index 0000000..6442579 --- /dev/null +++ b/README @@ -0,0 +1,10 @@ +Capirca is a system to develop and manage access control lists +for a variety of platforms. +It was developed by Google for internal use, and is now open source. + +Project home page: http://code.google.com/p/capirca/ + +Please send contributions to capirca-dev@googlegroups.com. + +Code should include unit tests and follow the Google Python style guide: +http://code.google.com/p/soc/wiki/PythonStyleGuide diff --git a/aclcheck_cmdline.py b/aclcheck_cmdline.py new file mode 100755 index 0000000..fca8bbf --- /dev/null +++ b/aclcheck_cmdline.py @@ -0,0 +1,58 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Command line interface to aclcheck library.""" + +__author__ = 'watson@google.com (Tony Watson)' + +from optparse import OptionParser +import sys +from lib import aclcheck +from lib import policy +from lib import naming + + +def main(): + usage = "usage: %prog [options] arg" + _parser = OptionParser(usage) + _parser.add_option('--definitions-directory', dest='definitions', + help='definitions directory', default='./def') + _parser.add_option('-p', '--policy-file', dest='pol', + help='policy file', default='./policies/sample.pol') + _parser.add_option('-d', '--destination', dest='dst', + help='destination IP', default='200.1.1.1') + _parser.add_option('-s' ,'--source', dest='src', + help='source IP', default='any') + _parser.add_option('--proto', '--protocol', dest='proto', + help='Protocol (tcp, udp, icmp, etc.)', default='tcp') + _parser.add_option('--dport', '--destination-port', dest='dport', + help='destination port', default='80') + _parser.add_option('--sport', '--source-port', dest='sport', + help='source port', default='1025') + (FLAGS, args) = _parser.parse_args() + #if FLAGS.help: + # print _parser.format_help() + + defs = naming.Naming(FLAGS.definitions) + policy_obj = policy.ParsePolicy(open(FLAGS.pol).read(), defs) + check = aclcheck.AclCheck(policy_obj, src=FLAGS.src, dst=FLAGS.dst, + sport=FLAGS.sport, dport=FLAGS.dport, + proto=FLAGS.proto) + print str(check) + +if __name__ == '__main__': + main() diff --git a/aclgen.py b/aclgen.py new file mode 100755 index 0000000..9743d9b --- /dev/null +++ b/aclgen.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This is an sample tool which will render policy +# files into usable iptables tables, cisco access lists or +# juniper firewall filters. + + +__author__ = 'watson@google.com (Tony Watson)' + +# system imports +import copy +import dircache +import datetime +from optparse import OptionParser +import os +import stat +import logging + +# compiler imports +from lib import naming +from lib import policy + +# renderers +from lib import cisco +from lib import ciscoasa +from lib import iptables +from lib import ipset +from lib import speedway +from lib import juniper +from lib import junipersrx +from lib import packetfilter +from lib import demo +from lib import html + +_parser = OptionParser() +_parser.add_option('-d', '--def', dest='definitions', + help='definitions directory', default='./def') +_parser.add_option('-o', dest='output_directory', help='output directory', + default='./filters') +_parser.add_option('', '--poldir', dest='policy_directory', + help='policy directory (incompatible with -p)', + default='./policies') +_parser.add_option('-p', '--pol', help='policy file (incompatible with poldir)', + dest='policy') +_parser.add_option('--debug', help='enable debug-level logging', dest='debug') +_parser.add_option('-s', '--shade_checking', help='Enable shade checking', + action="store_true", dest="shade_check", default=False) +_parser.add_option('-e', '--exp_info', type='int', action='store', + dest='exp_info', default=2, + help='Weeks in advance to notify that a term will expire') + +(FLAGS, args) = _parser.parse_args() + + +def load_and_render(base_dir, defs, shade_check, exp_info): + rendered = 0 + for dirfile in dircache.listdir(base_dir): + fname = os.path.join(base_dir, dirfile) + #logging.debug('load_and_render working with fname %s', fname) + if os.path.isdir(fname): + rendered += load_and_render(fname, defs, shade_check, exp_info) + elif fname.endswith('.pol'): + #logging.debug('attempting to render_filters on fname %s', fname) + rendered += render_filters(fname, defs, shade_check, exp_info) + return rendered + +def filter_name(source, suffix): + source = source.lstrip('./') + o_dir = '/'.join([FLAGS.output_directory] + source.split('/')[1:-1]) + fname = '%s%s' % (".".join(os.path.basename(source).split('.')[0:-1]), + suffix) + return os.path.join(o_dir, fname) + +def do_output_filter(filter_text, filter_file): + if not os.path.isdir(os.path.dirname(filter_file)): + os.makedirs(os.path.dirname(filter_file)) + output = open(filter_file, 'w') + if output: + filter_text = revision_tag_handler(filter_file, filter_text) + print 'writing %s' % filter_file + output.write(filter_text) + +def revision_tag_handler(fname, text): + # replace $Id:$ and $Date:$ tags with filename and date + timestamp = datetime.datetime.now().strftime('%Y/%m/%d') + new_text = [] + for line in text.split('\n'): + if '$Id:$' in line: + line = line.replace('$Id:$', '$Id: %s $' % fname) + if '$Date:$' in line: + line = line.replace('$Date:$', '$Date: %s $' % timestamp) + new_text.append(line) + return '\n'.join(new_text) + +def render_filters(source_file, definitions_obj, shade_check, exp_info): + count = 0 + [(jcl, acl, asa, ipt, ips, pf, spd, spk, srx, dem, htm)] = [ + (False, False, False, False, False, False, False, False, False, False, False)] + + pol = policy.ParsePolicy(open(source_file).read(), definitions_obj, + shade_check=shade_check) + + for header in pol.headers: + if 'juniper' in header.platforms: + jcl = copy.deepcopy(pol) + if 'cisco' in header.platforms: + acl = copy.deepcopy(pol) + if 'ciscoasa' in header.platforms: + asa = copy.deepcopy(pol) + if 'iptables' in header.platforms: + ipt = copy.deepcopy(pol) + if 'ipset' in header.platforms: + ips = copy.deepcopy(pol) + if 'packetfilter' in header.platforms: + pf = copy.deepcopy(pol) + if 'speedway' in header.platforms: + spd = copy.deepcopy(pol) + # SRX needs to be un-optimized for correct building of the address book + # entries. + if 'srx' in header.platforms: + unoptimized_pol = policy.ParsePolicy(open(source_file).read(), + definitions_obj, optimize=False) + srx = copy.deepcopy(unoptimized_pol) + if 'demo' in header.platforms: + dem = copy.deepcopy(pol) + if 'html' in header.platforms: + htm = copy.deepcopy(pol) + if jcl: + fw = juniper.Juniper(jcl, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if acl: + fw = cisco.Cisco(acl, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if asa: + fw = ciscoasa.CiscoASA(asa, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if ipt: + fw = iptables.Iptables(ipt, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if ips: + fw = ipset.Ipset(ips, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if pf: + fw = packetfilter.PacketFilter(pf, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if spd: + fw = speedway.Speedway(spd, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if srx: + fw = junipersrx.JuniperSRX(srx, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if dem: + fw = demo.Demo(dem, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + if htm: + fw = html.HTML(htm, exp_info) + do_output_filter(str(fw), filter_name(source_file, fw._SUFFIX)) + count += 1 + + return count + +def main(): + if not FLAGS.definitions: + _parser.error('no definitions supplied') + defs = naming.Naming(FLAGS.definitions) + if not defs: + print 'problem loading definitions' + return + + count = 0 + if FLAGS.policy_directory: + count = load_and_render(FLAGS.policy_directory, defs, FLAGS.shade_check, + FLAGS.exp_info) + + elif FLAGS.policy: + count = render_filters(FLAGS.policy, defs, FLAGS.shade_check, + FLAGS.exp_info) + + print '%d filters rendered' % count + + +if __name__ == '__main__': + # some sanity checking + if FLAGS.policy_directory and FLAGS.policy: + # When parsing a single file, ignore default path of policy_directory + FLAGS.policy_directory = False + if not (FLAGS.policy_directory or FLAGS.policy): + raise ValueError('must provide policy or policy_directive') + + # enable debugging + if FLAGS.debug: + logging.basicConfig(level=logging.DEBUG) + + # run run run run run away + main() diff --git a/def/NETWORK.net b/def/NETWORK.net new file mode 100644 index 0000000..726fa72 --- /dev/null +++ b/def/NETWORK.net @@ -0,0 +1,119 @@ +# +# Sample naming defintions for network objects +# +RFC1918 = 10.0.0.0/8 # non-public + 172.16.0.0/12 # non-public + 192.168.0.0/16 # non-public + +INTERNAL = RFC1918 + +LOOPBACK = 127.0.0.0/8 # loopback + ::1/128 # ipv6 loopback + +RFC_3330 = 169.254.0.0/16 # special use IPv4 addresses - netdeploy + +LINKLOCAL = FE80::/10 # IPv6 link-local + +SITELOCAL = FEC0::/10 # Ipv6 Site-local + +MULTICAST = 224.0.0.0/4 # IP multicast + FF00::/8 # IPv6 multicast + +CLASS-E = 240.0.0.0/4 + +RESERVED = 0.0.0.0/8 # reserved + RFC1918 + LOOPBACK + RFC_3330 + MULTICAST + CLASS-E + 0000::/8 # reserved by IETF + 0100::/8 # reserved by IETF + 0200::/7 # reserved by IETF + 0400::/6 # reserved by IETF + 0800::/5 # reserved by IETF + 1000::/4 # reserved by IETF + 4000::/3 # reserved by IETF + 6000::/3 # reserved by IETF + 8000::/3 # reserved by IETF + A000::/3 # reserved by IETF + C000::/3 # reserved by IETF + E000::/4 # reserved by IETF + F000::/5 # reserved by IETF + F800::/6 # reserved by IETF + FC00::/7 # unique local unicast + FE00::/9 # reserved by IETF + LINKLOCAL # link local unicast + SITELOCAL # IPv6 site-local + +# http://www.team-cymru.org/Services/Bogons/bogon-bn-agg.txt +# 22-Apr-2011 +BOGON = 0.0.0.0/8 + 192.0.0.0/24 + 192.0.2.0/24 + 198.18.0.0/15 + 198.51.100.0/24 + 203.0.113.0/24 + MULTICAST + CLASS-E + 3FFE::/16 # 6bone + 5F00::/8 # 6bone + 2001:DB8::/32 # IPv6 documentation prefix + +GOOGLE_PUBLIC_DNS_ANYCAST = 8.8.4.4/32 # IPv4 Anycast + 8.8.8.8/32 # IPv4 Anycast + 2001:4860:4860::8844/128 # IPv6 Anycast + 2001:4860:4860::8888/128 # IPv6 Anycast +GOOGLE_DNS = GOOGLE_PUBLIC_DNS_ANYCAST + + +# The following are sample entires intended for us in the included +# sample policy file. These should be removed. + +DNS_SERVERS = 109.105.96.141/32 # resolver1.nordu.net + 109.105.96.142/32 # resolver2.nordu.net + +NTP_SERVERS = 109.105.96.132/32 # ntp1.nordu.net + 109.105.96.133/32 # ntp2.nordu.net + +SYSLOG_SERVERS = 109.105.113.13/32 # syslog1.nordu.net + 109.105.113.86/32 # syslog2.nordu.net + 2001:948:4:2::13/128 # syslog1.nordu.net + 2001:948:4:3::86/128 # syslog2.nordu.net + +TACACS_SERVERS = 109.105.113.42/32 # statler.nordu.net + 109.105.113.85/32 # waldorf.nordu.net + 2001:948:4:2::42/128 # statler.nordu.net + 2001:948:4:3::85/128 # waldorf.nordu.net + +RADIUS_SERVERS = 109.105.111.40/32 # radius1.nordu.net + 109.105.111.40/32 # radius1.nordu.net + 2001:948:4:6::40/128 # radius1.nordu.net + 2001:948:4:a::40/128 # radius2.nordu.net + +KERBEROS_SERVERS = 109.105.113.8/32 # kdc1.nordu.net + 109.105.113.10/32 # kdc2.nordu.net + 109.105.113.87/32 # kdc3.nordu.net + 2001:948:4:2::8/128 # kdc1.nordu.net + 2001:948:4:2::10/128 # kdc2.nordu.net + 2001:948:4:2::87/128 # kdc3.nordu.net + +NORDUNET_AGGREGATE = 109.105.96.0/19 + 193.10.252.0/24 + 193.10.254.0/24 + 193.11.3.0/24 + 194.68.13.0/24 + +NDN_TUG_WLC_NET = 109.105.104.16/28 # Wireless Controller net + +NDN_TUG_NET = 109.105.104.0/24 # Office net TUG + +NDN_KAS_WLC_NET = 109.105.106.16/28 # Wireless Controller net + +NDN_KAS_NET = 109.105.106.0/24 # Office net KAS + +SUNET_PILSNET = 192.36.125.0/24 # Pilsnet TUG + +SUNET_AP_STATICS = 130.242.82.30/32 # AP FRE POP + 130.242.121.137/32 # AP LULE POP + diff --git a/def/SERVICES.svc b/def/SERVICES.svc new file mode 100644 index 0000000..ce6d614 --- /dev/null +++ b/def/SERVICES.svc @@ -0,0 +1,62 @@ +# +# Sample naming service definitions +# +WHOIS = 43/udp +SSH = 22/tcp +TELNET = 23/tcp +SMTP = 25/tcp +MAIL_SERVICES = SMTP + ESMTP + SMTP_SSL +TIME = 37/tcp 37/udp +TACACS = 49/tcp +DNS = 53/tcp 53/udp +BOOTPS = 67/udp # BOOTP server +BOOTPC = 68/udp # BOOTP client +DHCP = BOOTPS + BOOTPC +TFTP = 69/tcp 69/udp +HTTP = 80/tcp +WEB_SERVICES = HTTP HTTPS +POP3 = 110/tcp +RPC = 111/udp +IDENT = 113/tcp 113/udp +NNTP = 119/tcp +NTP = 123/tcp 123/udp +MS_RPC_EPMAP = 135/udp 135/tcp +MS_137 = 137/udp +MS_138 = 138/udp +MS_139 = 139/tcp +IMAP = 143/tcp +SNMP = 161/udp +SNMP_TRAP = 162/udp +BGP = 179/tcp +IMAP3 = 220/tcp +LDAP = 389/tcp +LDAP_SERVICE = LDAP + LDAPS +HTTPS = 443/tcp +MS_445 = 445/tcp +SMTP_SSL = 465/tcp +IKE = 500/udp +SYSLOG = 514/udp +RTSP = 554/tcp +ESMTP = 587/tcp +LDAPS = 636/tcp +IMAPS = 993/tcp +POP_SSL = 995/tcp +HIGH_PORTS = 1024-65535/tcp 1024-65535/udp +MSSQL = 1433/tcp +MSSQL_MONITOR = 1434/tcp +RADIUS = 1812/tcp 1812/udp +HSRP = 1985/udp +NFSD = 2049/tcp 2049/udp +NETFLOW = 2056/udp +SQUID_PROXY = 3128/tcp +MYSQL = 3306/tcp +RDP = 3389/tcp +IPSEC = 4500/udp +POSTGRESQL = 5432/tcp +TRACEROUTE = 33434-33534/udp + + diff --git a/definate.py b/definate.py new file mode 100755 index 0000000..abea8f6 --- /dev/null +++ b/definate.py @@ -0,0 +1,293 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates network definitions for use with Capirca. + +Definate is a framework to generate definitions for the automatic network policy +generation framework. The definitions are generated based on a configuration +file. +""" + +__author__ = 'msu@google.com (Martin Suess)' + +import logging +from optparse import OptionParser +import os +import yaml + +from definate import definition_filter +from definate import file_filter +from definate import filter_factory +from definate import generator_factory +from definate import global_filter +from definate import yaml_validator + + +_parser = OptionParser() +_parser.add_option('-c', '--config', dest='configuration', + help='configuration file', + default='./definate/definate.yaml') +_parser.add_option('-d', '--def', dest='definitions', + help='definitions directory', default='./def') +_parser.add_option('--debug', help='enable debug-level logging', dest='debug') +(FLAGS, args) = _parser.parse_args() + + +class Error(Exception): + """Base error class.""" + + +class DefinateGenerationError(Error): + """Exception to use when Definate fails generating output.""" + + +class Definate(object): + """Generates the network definition files.""" + + def __init__(self): + """Initializer.""" + self._generator_factory = generator_factory.GeneratorFactory() + self._filter_factory = filter_factory.FilterFactory() + self._yaml_validator = yaml_validator.YamlValidator() + + def _ReadConfiguration(self, conf_path): + """Reads the configuration from a YAML file. + + Args: + conf_path: String representing the path to the configuration file. + + Raises: + DefinateConfigError: The configuration cannot be read. + + Returns: + YAML generated configuration structure (lists and dictionaries containing + configuration values). + """ + try: + config_file = file(conf_path, 'r') + except IOError as e: + raise yaml_validator.DefinateConfigError('Unable to open config: %s' % e) + config = yaml.safe_load(config_file) + config_file.close() + return config + + def GenerateDefinitions(self, config_path, def_path): + """Generate all network definition files based on the config passed in. + + Args: + config_path: Full path to the YAML configuration file as a string. + See YAML configuration file for reference: README + def_path: Full path to the directory where the network definitions are + stored as string. + + Raises: + DefinateConfigError: The configuration that has been passed in is not + sane. + """ + yaml_structure = { + 'global': {}, + 'files': [{ + 'path': 'str', + 'generators': [{ + 'name': 'str', + 'definitions': [{ + 'name': 'str', + 'networks': [], + }], + }], + }], + } + + config = self._ReadConfiguration(config_path) + self._yaml_validator.CheckConfiguration(config, yaml_structure) + logging.info('Configuration check: Done. Global config appears to be sane.') + + additional_args = {'def_path': def_path} + + global_config = config['global'] + + global_container = global_filter.Container() + # TODO(msu): Maybe add sanity check filter which is always run? + global_container = self._RunFilter( + filter_factory.GLOBAL_FILTER, filter_factory.PRE_FILTERS, + global_config.get('pre_filters', []), + global_container, filterargs=additional_args) + + for file_definition in config['files']: + relative_path = file_definition['path'] + file_path = os.path.join(def_path, relative_path) + logging.info('Generating file: %s', file_path) + + file_header = file_definition.get('file_header', []) + if file_header: + file_header = ['# %s' % line for line in file_header] + file_header.append('\n') + file_container = file_filter.Container( + lines=file_header, absolute_path=file_path, + relative_path=relative_path) + + file_container = self._RunFilter( + filter_factory.FILE_FILTER, filter_factory.PRE_FILTERS, + global_config.get('per_file_pre_filters', []), + file_container, filterargs=additional_args) + file_container = self._RunFilter( + filter_factory.FILE_FILTER, filter_factory.PRE_FILTERS, + file_definition.get('pre_filters', []), + file_container, filterargs=additional_args) + + file_container = self._GenerateFile( + file_definition['generators'], global_config, file_container) + + global_container.absolute_paths.append(file_path) + global_container.relative_paths.append(relative_path) + + # TODO(msu): Maybe add some sanity check filter which is always run? + file_container = self._RunFilter( + filter_factory.FILE_FILTER, filter_factory.POST_FILTERS, + file_definition.get('post_filters', []), + file_container, filterargs=additional_args) + file_container = self._RunFilter( + filter_factory.FILE_FILTER, filter_factory.POST_FILTERS, + global_config.get('per_file_post_filters', []), + file_container, filterargs=additional_args) + + # TODO(msu): Maybe add some sanity check filter which is always run? + global_container = self._RunFilter( + filter_factory.GLOBAL_FILTER, filter_factory.POST_FILTERS, + global_config.get('post_filters', []), + global_container, filterargs=additional_args) + + def _GenerateFile(self, generators, global_config, file_container): + """Generate one network definition file. + + Args: + generators: Configuration based on which the file is generated. + global_config: Global section of the configuration. + file_container: Dictionary representing the container used to hold all + information for one definition file. + + Returns: + Container dictionary as defined in file_filter module. + + Raises: + DefinateGenerationError: In case one of the generated definition does not + contain any nodes. + DefinateConfigError: In case the configuration is not well formed. + """ + for generator_config in generators: + generator = self._generator_factory.GetGenerator( + generator_config['name']) + logging.info('Running generator \"%s\" now.', generator_config['name']) + + for definition in generator_config['definitions']: + def_header = definition.get('header', []) + if def_header: + def_header = ['# %s' % line for line in def_header] + def_container = definition_filter.Container( + header=def_header, name=definition.get('name')) + logging.info('Generating definition: %s', definition.get('name')) + + def_container = self._RunFilter( + filter_factory.DEFINITION_FILTER, filter_factory.PRE_FILTERS, + global_config.get('per_definition_pre_filters', []), + def_container) + def_container = self._RunFilter( + filter_factory.DEFINITION_FILTER, filter_factory.PRE_FILTERS, + definition.get('pre_filters', []), + def_container) + + def_container.entries_and_comments = generator.GenerateDefinition( + definition.get('networks', []), global_config) + + if not def_container.entries_and_comments: + raise DefinateGenerationError( + 'Generator returned no nodes for this definition: %s' % ( + definition.get('name'))) + + # TODO(msu): Maybe add sanity check filter which is always run? + def_container = self._RunFilter( + filter_factory.DEFINITION_FILTER, filter_factory.POST_FILTERS, + definition.get('post_filters', []), + def_container) + def_container = self._RunFilter( + filter_factory.DEFINITION_FILTER, filter_factory.POST_FILTERS, + global_config.get('per_definition_post_filters', []), + def_container) + + if not def_container.string_representation: + # TODO(msu): Define what should happen if no/wrong filters have been + # applied and no output is generated. Discard? Warn? Write warning to + # file? + pass + else: + file_container.lines.extend(def_container.header) + file_container.lines.append(def_container.string_representation) + file_container.lines.append('') + + return file_container + + def _RunFilter(self, filter_type, sequence, filter_config, container, + filterargs=None): + """Checks filter config and runs filters specified depending on type. + + Args: + filter_type: Integer defining the filter type to use. Valid values are + specified in the filter_factory module. + sequence: String identifier for when the filter is run. Valid values are + specified in the filter_factory module. + filter_config: Configuration structure as defined in the YAML + configuration. + container: Container dictionary as a bucket for all necessary information + to pass from filter to filter. + filterargs: Optional argument dictionary that is passed to a filter. Note + that these args update (and potentially overwrite) previously configured + arguments from the YAML configuration. + + Returns: + Container dictionary that has been passed in. + + Raises: + DefinateConfigError: In case the configuration is not well formed. + """ + if not filter_config: + logging.debug('Filter config has not been specified.') + return container + + if not filterargs: + filterargs = {} + + for filter_def in filter_config: + self._yaml_validator.CheckConfigurationItem(filter_def, 'name') + filter_name = filter_def['name'] + filter_args = filter_def.get('args', {}) + filter_args.update(filterargs) + fltr = self._filter_factory.GetFilter( + filter_type, filter_name, sequence) + logging.debug('Running filter \"%s\".', filter_name) + container = fltr.Filter(container, filter_args) + + return container + + +def main(): + if FLAGS.debug: + logging.basicConfig(level=logging.DEBUG) + definate = Definate() + definate.GenerateDefinitions(FLAGS.configuration, + FLAGS.definitions) + +if __name__ == '__main__': + main() diff --git a/definate/COPYING b/definate/COPYING new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/definate/COPYING @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/definate/README b/definate/README new file mode 100644 index 0000000..5c0823f --- /dev/null +++ b/definate/README @@ -0,0 +1,90 @@ +Definate is part of Capirca and is a system to develop and manage network +definitions that can be used in Capirca policies. +It was developed by Google for internal use, and is now open source. + +Project home page: http://code.google.com/p/capirca/ + +Please send contributions to capirca-dev@googlegroups.com. + +Code should include unit tests and follow the Google Python style guide: +http://code.google.com/p/soc/wiki/PythonStyleGuide + +================================================================================ + +Definate Configuration File + +Global section + This section contains directives global in scope. + + [pre|post]_filters: Optional list of global filters to run before + (pre_filters) or after (post_filters) the definition file generation. + Each filter may define a set of filter specific arguments in the 'args' + attribute. + 'name': Name of the filter used to lookup the filter class to use. + Valid values: There are currently no implementations. + per_file_[pre|post]_filters: Optional list of file filters run before + (pre_filters) or after (post_filters) the generation of each file. + The pre_filters here are run BEFORE the individual filters specified in + the files section and the post_filters are run AFTER the individual + filters specified in the files section. + For a list of possible filters and arguments, see the "Files section". + per_definition_[pre|post]_filters: Optional list of definition filters run + before (pre_filters) or after (post_filters) the generation of each + definition. + The pre_filters here are run BEFORE the individual filters specified in + the definitions section and the post_filters are run AFTER the individual + filters specified in the definitions section. + For a list of possible filters and arguments, see the "Definitions + section". + +Files section + This section contains a list of settings and configurations for each + file that gets generated. + + path: Path to the definitions file to be generated, relative to the + def path defined in the global section. + [pre|post]_filters: Optional list of file level filters to run before + (pre_filters) or after (post_filters) the file has been generated. + Each filter may define a set of filter specific arguments in the + 'args' attribute. + 'name': Name of the filter used to lookup the filter class to use. + Valid values: + 'PrintFilter': Does not modify input. Just prints it. + 'WriteFileFilter': Writes files out locally. + file_header: List of strings that get printed in the beginning of the file. + generators: List of generator blocks. + +Generators section + name: The generator defines the source of the information. + Valid values: + 'DnsGenerator': For definitions generated based on hostnames with a simple + DNS resolver. Note that the resolver might not return all addresses used + for one hostname depending on the implemented DNS load balancing. + definitions: List of definition blocks. + +Definitions section + name: Name of the definition that gets generated. This name is used in the + definitions file and can be used in policies to reference the definition. + header: Optional list of header strings that will be printed before the + definition. + [pre|post]_filters: Optional list of definition level filters to run + before (pre_filters) or after (post_filters) the definition has been + generated. Each filter may define a set of filter specific arguments + in the 'args' attribute. + 'name': Name of the filter used to lookup the filter class to use. + Valid values: + 'SortFilter': Sort the input list and return one list containing + IPv4 sorted, IPv6 sorted. + 'AlignFilter': Take the input list and definition name and create + nicely formated output. + networks: Contains a list of descriptions about how to get a complete set + of networks/IPs for one definition. This section contains + generator-specific configuration directives". + +Network directives for DnsGenerator + +Networks section + names: List of hostnames that should be resolved. + types: List of types the output should be filtered for. Valid values: + 'A': Filter for IPv4 addresses. + 'AAAA': Filter for IPv6 addresses. diff --git a/definate/__init__.py b/definate/__init__.py new file mode 100644 index 0000000..e6069ec --- /dev/null +++ b/definate/__init__.py @@ -0,0 +1,26 @@ +# +# Network definition generator libraries +# +# definate/__init__.py +# +# This package is intended to provide functionality to generate lists of network +# definitions that can be used within other network definitions and policies of +# Capirca. +# +# from definate import generator +# from definate import generator_factory +# from definate import dns_generator +# from definate import filter_factory +# from definate import global_filter +# from definate import file_filter +# from definate import definition_filter +# from definate import yaml_validator +# + +__version__ = '1.0.0' + +__all__ = ['generator', 'generator_factory', 'dns_generator', + 'filter_factory', 'global_filter', 'file_filter', + 'definition_filter', 'yaml_validator'] + +__author__ = 'Martin Suess (msu@google.com)' diff --git a/definate/definate.yaml b/definate/definate.yaml new file mode 100644 index 0000000..9e9690e --- /dev/null +++ b/definate/definate.yaml @@ -0,0 +1,36 @@ +# Definate configuration +# For usage information, see README file. +global: + per_file_post_filters: + - name: 'WriteFileFilter' + per_definition_post_filters: + - name: 'SortFilter' + - name: 'AlignFilter' +files: + - path: 'AUTOGEN.net' + file_header: + - 'This file is autogenerated. Please do not edit it manually.' + - 'Instead run Definate: ./definate.py' + generators: + - name: 'DnsGenerator' + definitions: + - name: 'WWW_AUTOGEN' + header: + - 'WWW Clusters' + - 'Generated from DNS names (best effort)' + networks: + - names: + - 'www.google.com' + - 'www.gmail.com' + types: ['A', 'AAAA'] + - name: 'NS_AUTOGEN' + header: + - 'NS Clusters' + - 'Generated from DNS names (best effort)' + networks: + - names: + - 'ns1.google.com' + - 'ns2.google.com' + - 'ns3.google.com' + - 'ns4.google.com' + types: ['A', 'AAAA'] diff --git a/definate/definition_filter.py b/definate/definition_filter.py new file mode 100755 index 0000000..492afc7 --- /dev/null +++ b/definate/definition_filter.py @@ -0,0 +1,164 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module that holds all definition-level filter classes of Definate.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import logging + + +class Container(object): + """Container class to hold all information to be passed between filters.""" + + def __init__(self, header=None, name='', entries_and_comments=None, + string_representation=''): + """Initializer. + + Args: + header: Optional list of strings to be added as headers. + name: Optional string representing the name of the definition. + entries_and_comments: Optional list of tuples (entries, comments) which + hold all entries for one definition as well as comments. + string_representation: Optional string holding the string representation + of the definition (typically used as output e.g. in a file in the end). + """ + self.header = header if header else [] + self.name = name + self.entries_and_comments = ( + entries_and_comments if entries_and_comments else []) + self.string_representation = string_representation + + +class DefinitionFilter(object): + """Abstract class defining the interface for the filter chain objects.""" + + def Filter(self, container, args): + """Interface to filter or modify data passed into it. + + Args: + container: Container object which holds all information for one + definition. See Container class for details. + args: Dictionary of arguments depending on the actual filter in use. + + Raises: + NotImplementedError: In any case since this is not implemented an needs + to be defined by subclasses. + """ + raise NotImplementedError( + 'This is an interface only. Implemented by subclasses.') + + +class SortFilter(DefinitionFilter): + """DefinitionFilter implementation which sorts all entries for nice output.""" + + def Filter(self, container, unused_args): + """Filter method that sorts all entries in a definition for nice output. + + The filter sorts all entries in ascending order: + - IPv4 networks + - IPv6 networks + + Args: + container: Container object which holds all information for one + definition. See Container class for details. + unused_args: No extra arguments required by this filter implementation. + + Returns: + Container object that has been passed in. + """ + ipv4_nodes = [] + ipv6_nodes = [] + + for node, comment in container.entries_and_comments: + if node.version == 4: + ipv4_nodes.append((node, comment)) + elif node.version == 6: + ipv6_nodes.append((node, comment)) + else: + logging.warn('Unsupported address version detected: %s', node.version) + + ipv4_nodes = self._RemoveDuplicateNetworks(ipv4_nodes) + ipv6_nodes = self._RemoveDuplicateNetworks(ipv6_nodes) + + ipv4_nodes.sort() + ipv6_nodes.sort() + + container.entries_and_comments = ipv4_nodes + ipv6_nodes + return container + + def _RemoveDuplicateNetworks(self, network_list): + """Method to remove duplicate networks from the network list. + + Args: + network_list: List of node/comment tuples where node is an IPNetwork + object and comment is a string. + + Returns: + The same list of networks and comments minus duplicate entries. + """ + result_list = [] + result_dict = {} + for node, comment in network_list: + result_dict[str(node)] = (node, comment) + for node in result_dict: + result_list.append(result_dict[node]) + return result_list + + +class AlignFilter(DefinitionFilter): + """DefinitionFilter implementation which generates nicely aligned output.""" + + def Filter(self, container, unused_args): + """Filter method that aligns the entries in the output nicely. + + This code formats the entries_and_comments by figuring out the + left-justification from the definition name ('name'), and padding the + left justification of the comments to 3 spaces after the longest entry + length. + + In order to do this succinctly, without adding strings together, we use a + format string that we replace twice. Once for the (left|right) + justification bounds, and again with the final values. + + Args: + container: Container object which holds all information for one + definition. See Container class for details. + unused_args: No extra arguments required by this filter implementation. + + Returns: + Container object that has been passed in. + """ + first_format_string = '%%s = %%%is# %%s' + format_string = '%%%is%%%is# %%s' + + max_len = max(len(str(e)) for e, _ in container.entries_and_comments) + value_justification = -1 * (max_len + 3) + column_justification = len(container.name) + 3 # 3 for ' = ' + + first_format_string %= value_justification + format_string %= (column_justification, value_justification) + + entry, comment = container.entries_and_comments[0] + first_string = first_format_string % (container.name, entry, comment) + output = [first_string] + + for entry, comment in container.entries_and_comments[1:]: + output.append(format_string % ('', entry, comment)) + + container.string_representation = '\n'.join(output) + return container diff --git a/definate/dns_generator.py b/definate/dns_generator.py new file mode 100755 index 0000000..bb17c71 --- /dev/null +++ b/definate/dns_generator.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generator for DNS based network definitions.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import logging +import socket + +from third_party import ipaddr +import generator + + +class DnsGeneratorError(Exception): + """Exception to use when DnsGenerator fails.""" + + +class DnsGenerator(generator.Generator): + """Generator implementation for network definitions based on DNS.""" + + SUPPORTED_TYPES = ['A', 'AAAA'] + + def GenerateDefinition(self, config, unused_global_config): + """Generates a list of all nodes in a network definition. + + This method basically processes all the configuration which is + hierarchically below "networks" in the "definitions" section in the + configuration file to generate a list of all nodes in that definition. + + Args: + config: YAML configuration structure (dictionaries, lists and strings) + representing the "networks" section in "definitions" of the + configuration file. + unused_global_config: YAML configuration structure (dictionaries, lists + and strings) representing the "global" section of the configuration + file. + + Returns: + Tuples of IPNetwork objects and string comments representing all the nodes + in one definition. + + Raises: + DefinateConfigError: The configuration is not well formed. + DnsGeneratorError: There is a problem generating the output. + """ + nodes = [] + yaml_structure = { + 'names': ['str'], + 'types': ['str'], + } + for network in config: + self._yaml_validator.CheckConfiguration(network, yaml_structure) + for typ in network['types']: + if typ not in self.SUPPORTED_TYPES: + raise DnsGeneratorError('Unsupported DNS type found: %s' % typ) + for name in network['names']: + try: + addr_list = socket.getaddrinfo(name, None) + except socket.gaierror: + raise DnsGeneratorError('Hostname not found: %s' % name) + for family, _, _, _, sockaddr in addr_list: + ip_addr = None + if family == socket.AF_INET and 'A' in network['types']: + # sockaddr = (address, port) + ip_addr = ipaddr.IPv4Network(sockaddr[0]) + elif family == socket.AF_INET6 and 'AAAA' in network['types']: + # sockaddr = (address, port, flow info, scope id) + ip_addr = ipaddr.IPv6Network(sockaddr[0]) + else: + logging.debug('Skipping unknown AF \'%d\' for: %s', family, name) + if ip_addr: + nodes.append((ip_addr, name)) + return nodes diff --git a/definate/file_filter.py b/definate/file_filter.py new file mode 100755 index 0000000..021cc09 --- /dev/null +++ b/definate/file_filter.py @@ -0,0 +1,121 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module that holds all file-level filter classes of Definate.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import logging + + +class Error(Exception): + """Base error class.""" + + +class FileError(Error): + """Exception to use when file handling files.""" + + +class Container(object): + """Container class to hold all information to be passed between filters.""" + + def __init__(self, lines=None, relative_path='', absolute_path=''): + """Initializer. + + Args: + lines: Optional list of strings which will be added as lines. + E.g. the file header can be added here directly. + relative_path: Optional string to specify the path of the file relative to + the location of the definition directory (e.g. 'AUTOGEN.net'). + absolute_path: Optional string to specify the absolute path of the local + file to be written (e.g. '/tmp/AUTOGEN.net') or if a SCM software is + used it can refer to the full path there + (e.g. '//depot/def/AUTOGEN.net'). + """ + self.lines = lines if lines else [] + self.absolute_path = absolute_path + self.relative_path = relative_path + + +class FileFilter(object): + """Abstract class defining the interface for the filter chain objects.""" + + def Filter(self, container, args): + """Interface to filter or modify data passed into it. + + Args: + container: Container object which holds all information for one definition + file. See Container class for details. + args: Dictionary of arguments depending on the actual filter in use. + + Raises: + NotImplementedError: In any case since this is not implemented an needs + to be defined by subclasses. + """ + raise NotImplementedError( + 'This is an interface only. Implemented by subclasses.') + + +class PrintFilter(FileFilter): + """FileFilter implementation which simply logs the file content.""" + + def Filter(self, container, unused_args): + """Filter method that prints the content of the file to stdout. + + Args: + container: Container object which holds all information for one definition + file. See Container class for details. + unused_args: No extra arguments required by this filter implementation. + + Returns: + Container object that has been passed in. + """ + print '# File "%s"' % container.absolute_path + print '\n'.join(container.lines) + return container + + +class WriteFileFilter(FileFilter): + """FileFilter implementation which writes the content into a file.""" + + def Filter(self, container, unused_args): + """Filter method that writes the content of the file into a file. + + Args: + container: Container object which holds all information for one definition + file. See Container class for details. + unused_args: No extra arguments required by this filter implementation. + + Returns: + Container object that has been passed in. + """ + try: + f = file(container.absolute_path, 'w') + except IOError as e: + raise FileError('File "%s" could not be opened: %s' % ( + container.absolute_path, e)) + + try: + f.write('\n'.join(container.lines)) + except IOError as e: + raise FileError('File "%s" could not be written: %s' % ( + container.absolute_path, e)) + else: + f.close() + + logging.info('Wrote file: %s', container.absolute_path) + return container diff --git a/definate/filter_factory.py b/definate/filter_factory.py new file mode 100755 index 0000000..1ad1d3f --- /dev/null +++ b/definate/filter_factory.py @@ -0,0 +1,104 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality to allow easily retrieving certain filter objects.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import definition_filter +import file_filter + +DEFINITION_FILTER = 1 +FILE_FILTER = 2 +GLOBAL_FILTER = 3 + +PRE_FILTERS = 'PreFilters' +POST_FILTERS = 'PostFilters' + + +class Error(Exception): + """Base error class.""" + + +class FilterIdentificationError(Error): + """Exception to use when FilterFactory fails to identify the Filter.""" + + +class FilterFactory(object): + """Functionality to get a filter object easily based on its name. + + This class can be initialized and the GetFilter() method allows retrieving a + specific filter based on the name of the filter and the scope (global, file + and definition). + """ + + def __init__(self): + """Initializer.""" + self._filters = { + DEFINITION_FILTER: { + 'PostFilters': { + 'SortFilter': definition_filter.SortFilter, + 'AlignFilter': definition_filter.AlignFilter, + }, + }, + FILE_FILTER: { + 'PostFilters': { + 'PrintFilter': file_filter.PrintFilter, + 'WriteFileFilter': file_filter.WriteFileFilter, + }, + }, + GLOBAL_FILTER: { + 'PreFilters': { + }, + 'PostFilters': { + }, + }, + } + + def GetFilter(self, scope, identifier, sequence): + """Returns a specific filter instance based on the identifier. + + Args: + scope: Type of filter to be returned. Valid types are listed as globals + in the beginning of this module. + identifier: String identifier for the filter to get. + sequence: String identifier for the sequence information to determine + when the filter should be applied. Valid values: + - 'PreFilters': Filters that are applied before processing the data + (e.g. before the definition is created). + - 'PostFilters': Filters that are applied after processing the data + (e.g. after the definition has been created). + + Raises: + FilterIdentificationError: If the filter cannot be identified. + + Returns: + Filter instance based on the identifier passed in. + """ + if scope not in self._filters: + raise FilterIdentificationError( + 'Filter scope \'%d\' could not be found in filters.' % scope) + if sequence not in self._filters[scope]: + raise FilterIdentificationError( + 'Filter sequence \'%s\' is not applicable to scope \'%d\'.' % ( + sequence, scope)) + filters = self._filters[scope][sequence] + if identifier not in filters: + raise FilterIdentificationError( + 'Filter \'%s\' could not be identified. Wrong scope (%d) or sequence' + ' (%s)?' % (identifier, scope, sequence)) + return filters[identifier]() diff --git a/definate/generator.py b/definate/generator.py new file mode 100755 index 0000000..22cad08 --- /dev/null +++ b/definate/generator.py @@ -0,0 +1,56 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module holding the abstract definition generator class.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import yaml_validator + + +class Error(Exception): + """Base error class.""" + + +class GeneratorError(Error): + """Base Generator error class to inherit from in specific generators.""" + + +class Generator(object): + """Abstract class defining the interface for the definition generation.""" + + def __init__(self): + """Initializer.""" + self._yaml_validator = yaml_validator.YamlValidator() + + def GenerateDefinition(self, config, global_config): + """Interface to generate definitions based on a configuration passed in. + + Classes inheriting from Generator should implement this interface by parsing + the configuration and generating a network definition based on it. + For reference, have a look at the already implemented classes. + + Args: + config: Configuration necessary to generate one full definition. + global_config: Global configuration section. + + Raises: + NotImplementedError: In any case since this is not implemented and needs + to be defined by sublcasses. + """ + raise NotImplementedError( + 'This is an interface only. Implemented by subclasses.') diff --git a/definate/generator_factory.py b/definate/generator_factory.py new file mode 100755 index 0000000..c887cc1 --- /dev/null +++ b/definate/generator_factory.py @@ -0,0 +1,57 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality to allow easily retrieving the right definition generator.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import dns_generator + + +class Error(Exception): + """Base error class.""" + + +class GeneratorIdentificationError(Error): + """Exception to use when GeneratorFactory fails to identify the Generator.""" + + +class GeneratorFactory(object): + """Functionality to get a definition generator easily based on its name.""" + + def __init__(self): + """Initializer.""" + self._generators = { + 'DnsGenerator': dns_generator.DnsGenerator, + } + + def GetGenerator(self, identifier): + """Returns a specific generator instance based on the identifier. + + Args: + identifier: String identifier for the generator to get. + + Raises: + GeneratorIdentificationError: If the generator cannot be identified. + + Returns: + Generator instance based on the identifier passed in. + """ + if identifier not in self._generators: + raise GeneratorIdentificationError( + 'Generator \'%s\' could not be identified.' % identifier) + return self._generators[identifier]() diff --git a/definate/global_filter.py b/definate/global_filter.py new file mode 100755 index 0000000..5969443 --- /dev/null +++ b/definate/global_filter.py @@ -0,0 +1,68 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module that holds all global-level filter classes of Definate.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +import yaml_validator + + +class Error(Exception): + """Base error class.""" + + +class Container(object): + """Container class to hold all information to be passed between filters.""" + + def __init__(self, absolute_paths=None, relative_paths=None): + """Initializer. + + Args: + absolute_paths: Optional list of strings to specify the full path of the + generated files + (e.g. ['//depot/def/AUTOGEN1.net', '/tmp/AUTOGEN2.net']). + relative_paths: Optional list of strings to specify the paths of the + generated files relative to the location of the definition directory + (e.g. ['AUTOGEN1.net']). + """ + self.absolute_paths = absolute_paths if absolute_paths else [] + self.relative_paths = relative_paths if relative_paths else [] + self.changelist = '' + + +class GlobalFilter(object): + """Abstract class defining the interface for the filter chain objects.""" + + def __init__(self): + """Initializer.""" + self._yaml_validator = yaml_validator.YamlValidator() + + def Filter(self, container, args): + """Interface to filter or modify data passed into it. + + Args: + container: Container object which holds all global information. + See Container class for details. + args: Dictionary of arguments depending on the actual filter in use. + + Raises: + NotImplementedError: In any case since this is not implemented an needs + to be defined by subclasses. + """ + raise NotImplementedError( + 'This is an interface only. Implemented by subclasses.') diff --git a/definate/yaml_validator.py b/definate/yaml_validator.py new file mode 100755 index 0000000..5aae74a --- /dev/null +++ b/definate/yaml_validator.py @@ -0,0 +1,87 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools to allow the verification of the YAML configuration for Definate.""" + +__author__ = 'msu@google.com (Martin Suess)' + + +class Error(Exception): + """Base error class.""" + + +class DefinateConfigError(Error): + """Exception to use when Definate fails reading the configuration.""" + + +class YamlValidator(object): + """Class to verify the sanity of a YAML configuration.""" + + def CheckConfigurationItem(self, dictionary, item, typ=None): + """Checks for the presence of an item in a dictionary. + + Args: + dictionary: Configuration part that should be checked. + item: Name of the key to check as string. + typ: Type of the value to check. Default is to not check the type. + + Raises: + DefinateConfigError: The configuration is not sane. + """ + if not dictionary or item not in dictionary: + raise DefinateConfigError('"%s" is not defined in config: %s' + % (item, dictionary)) + if typ and type(dictionary[item]) is not typ: + raise DefinateConfigError('Type of "%s" is %s, expected %s.' % + (item, str(type(dictionary[item])), str(typ))) + + def CheckConfiguration(self, config, structure, max_recursion_depth=30): + """Recursively checks the sanity and structure of the configuration. + + This method checks the sanity of the configuration structure for Definate + and raises a DefinateConfigError if the configuration is not sane. + + Args: + config: Dictionary generated from the YAML configuration file which should + be checked. + structure: Structure of the configuration against which should be checked. + max_recursion_depth: Defines the maximum amount of recursion cycles before + checking is aborted. Default is 30. + + Raises: + DefinateConfigError: The configuration is not sane. + """ + max_depth = max_recursion_depth - 1 + if max_depth <= 0: + raise DefinateConfigError('Maximum recursion depth reached. Please check ' + 'configuration manually.') + if type(structure) in [dict, list]: + for item in structure: + value = item + if type(structure) is dict: + value = structure[item] + + self.CheckConfigurationItem(config, item, typ=type(value)) + if type(value) is dict: + self.CheckConfiguration(config[item], value, max_depth) + elif type(value) is list: + for (i, list_value) in enumerate(value): + self.CheckConfiguration(config[item][i], list_value, max_depth) + elif type(structure) is type(config): + return + else: + raise DefinateConfigError('Type of "%s" is %s, expected %s.' % ( + config, str(type(config)), str(structure))) diff --git a/doc/README.txt b/doc/README.txt new file mode 100644 index 0000000..e69f8b9 --- /dev/null +++ b/doc/README.txt @@ -0,0 +1,9 @@ +The authoritative documentation is maintained at: +http://code.google.com/p/capirca/wiki/ + +The documentation in this directory is intended to provide text copies of the +wiki documentation, but may be updates less frequently and as a result may not +be as current as the wiki. + +-- +Paul (Tony) Watson diff --git a/doc/naming_definitions.txt b/doc/naming_definitions.txt new file mode 100644 index 0000000..f1c19a7 --- /dev/null +++ b/doc/naming_definitions.txt @@ -0,0 +1,40 @@ +Introduction +The naming definitions provide the network and service "address books" that are used in the creation of policy files. Naming definitions are usually stored in a single directory, and consist of two or more files. Each file must end in either a '.net' or '.svc' extension, specifying a network or services definitions files. + +Multiple network and service definitions files may be created. The use of multiple files may be done to facilitate grouping of related definitions, or to utilize filesystem permissions to restrict or permit the editing of files by specific groups. + +The use of a revision control system, such as perforce or subversion, is a recommended way to ensure historical change control and tracking of contributor changes. + +Format of Files +Each network or service definition file has a very simple structure. A token is defined, followed by an equal sign, then followed by a definition and optional description field. + +For example, here is an example of a service definition: + +DNS = 53/tcp # transfers + 53/udp # queries +Likewise, here is an example of a network definition: + +INTERNAL = 192.168.0.0/16 # company DMZ networks + 172.16.0.0/12 # company remote offices + 10.0.0.0/8 # company production networks +Nesting of tokens is also permitted. Below are examples of nested service and network definitions: + +HTTP = 80/tcp # common web +HTTPS = 443/tcp # SSL web +HTTP_8080 = 8080/tcp # web on non-standard port +WEB_SERVICES = HTTP HTTP_8080 HTTPS # all our web services +DB_SERVICES = 3306/tcp # allow db access + HTTPS # and SSL access +NYC_NETWORK = 200.1.1.0/24 # New York office +ATL_NETWORK = 200.2.1.0/24 # Atlanta office +DEN_NETWORK = 200.5.1.0/24 # Denver office +REMOTE_OFFICES = NYC_NETWORK + ATL_NETWORK + DEN_NETWORK +Network definitions can also contain a mix of both IPv4 and IPv6 addresses: + +LOOPBACK = 127.0.0.1/32 # loopback in IPv4 + ::1/128 # loopback in IPv6 +LINKLOCAL = FE80::/10 # IPv6 link local address +NYC_NETWORK = 172.16.1.0/24 # NYC IPv4 + 2620:0:10A1::/48 # NYC IPv6 diff --git a/doc/policy_format.txt b/doc/policy_format.txt new file mode 100644 index 0000000..fd3f973 --- /dev/null +++ b/doc/policy_format.txt @@ -0,0 +1,298 @@ +Introduction +The access control policy describes the desired network security policy through the use of a high-level language that uses keywords and tokens. Tokens are derived from the naming libraries import of definition files. + +Basic Policy File Format +A policy file consists of one or more filters, with each filter containing one or more terms. Each term specifies basic network filter information, such as addresses, ports, protocols and actions. + +A policy file consists of one or more header sections, with each header section being followed by one or more terms. + +A header section is typically used to specify a filter for a given direction, such as an INPUT filter on Iptables. A second header section will typically be included in the policy to specify the OUTPUT filter. + +In addition, the policy language support "include files" which inject the text from the included file into the policy at the specified location. For more details, see the Includes section. + +Header Section +Each filter is identified with a header section. The header section is used to define the type of filter, a descriptor or name, direction (if applicable) and format (ipv4/ipv6). + +For example, the following simple header defines a filter that can generate output for 'juniper', 'cisco' and 'iptables' formats. + +header { + comment:: "Example header for juniper and iptables filter." + target:: juniper edge-filter + target:: speedway INPUT + target:: iptables INPUT + target:: cisco edge-filter +} +Notice that the first target has 2 arguments: "juniper" and "edge_filter". The first argument specifies that the filter can be rendered for Juniper JCLs, and that the output filter should be called "edge_filter". + +The second target also has 2 arguments: "speedway" and "INPUT". Since Speedway/Iptables has specific inherent filters, such as INPUT, OUTPUT and FORWARD, the target specification for iptables usually points to one of these filters although a custom chain can be specified (usually for combining with other filters rules through the use of a jump from one of the default filters) + +Likewise, the 4th target, "cisco" simply specifies the name of the access control list to be generated. + +Each target platform may have different possible arguments, which are detailed in the following subsections. + +Juniper +The juniper header designation has the following format: + +target:: juniper [filter name] {inet|inet6|bridge} +filter name: defines the name of the juniper filter. +inet: specifies the output should be for IPv4 only filters. This is the default format. +inet6: specifies the output be for IPv6 only filters. +bridge: specifies the output should render a Juniper bridge filter. +When inet4 or inet6 is specified, naming tokens with both IPv4 and IPv6 filters will be rendered using only the specified addresses. + +The default format is inet4, and is implied if not other argument is given. + +Cisco +The cisco header designation has the following format: + +target:: cisco [filter name] {extended|standard|object-group|inet6|mixed} +filter name: defines the name or number of the cisco filter. +extended: specifies that the output should be an extended access list, and the filter name should be non-numeric. This is the default option. +standard: specifies that the output should be a standard access list, and the filter name should be numeric and in the range of 1-99. +object-group: specifies this is a cisco extended access list, and that object-groups should be used for ports and addresses. +inet6: specifies the output be for IPv6 only filters. +mixed: specifies output will include both IPv6 and IPv4 filters. +When inet4 or inet6 is specified, naming tokens with both IPv4 and IPv6 filters will be rendered using only the specified addresses. + +The default format is inet4, and is implied if not other argument is given. + +Iptables +NOTE: Iptables produces output that must be passed, line by line, to the 'iptables/ip6tables' command line. For 'iptables-restore' compatible output, please use the Speedway generator. + +The Iptables header designation has the following format: + +target:: iptables [INPUT|OUTPUT|FORWARD|custom] {ACCEPT|DROP} {truncatenames} {nostate} {inet|inet6} +INPUT: apply the terms to the input filter. +OUTPUT: apply the terms to the output filter. +FORWARD: apply the terms to the forwarding filter. +custom: create the terms under a custom filter name, which must then be linked/jumped to from one of the default filters (e.g. iptables -A input -j custom) +ACCEPT: specifies that the default policy on the filter should be 'accept'. +DROP: specifies that the default policy on the filter should be to 'drop'. +inet: specifies that the resulting filter should only render IPv4 addresses. +inet6: specifies that the resulting filter should only render IPv6 addresses. +truncatenames: specifies to abbreviate term names if necessary (see lib/iptables.py:CheckTerMLength for abbreviation table) +nostate: specifies to produce 'stateless' filter output (e.g. no connection tracking) +Speedway +NOTE: Speedway produces Iptables filtering output that is suitable for passing to the 'iptables-restore' command. + +The Speedway header designation has the following format: + +target:: speedway [INPUT|OUTPUT|FORWARD|custom] {ACCEPT|DROP} {truncatenames} {nostate} {inet|inet6} +INPUT: apply the terms to the input filter. +OUTPUT: apply the terms to the output filter. +FORWARD: apply the terms to the forwarding filter. +custom: create the terms under a custom filter name, which must then be linked/jumped to from one of the default filters (e.g. iptables -A input -j custom) +ACCEPT: specifies that the default policy on the filter should be 'accept'. +DROP: specifies that the default policy on the filter should be to 'drop'. +inet: specifies that the resulting filter should only render IPv4 addresses. +inet6: specifies that the resulting filter should only render IPv6 addresses. +truncatenames: specifies to abbreviate term names if necessary (see lib/iptables.py: CheckTermLength? for abbreviation table) +nostate: specifies to produce 'stateless' filter output (e.g. no connection tracking) +Terms Section +Terms defines access control rules within a filter. Once the filter is defined in the header sections, it is followed by one or more terms. Terms are enclosed in brackets and use keywords to specify the functionality of a specific access control. + +A term section begins with the keyword term, followed by a term name. Opening and closing brackets follow, which include the keywords and tokens to define the matching and action of the access control term. + +The keywords fall into two categories, those are are required to be supported by all output generators, and those that are optionally supported by each generator. Optional keywords are intended to provide additional flexibility when developing policies on a single target platform. + +NOTE: Some generators may silently ignore optional keyword tokens which they do not support. + +WARNING: When developing filters that are intended to be rendered across multiple generators (cisco, iptables & juniper for example) it is strongly recommended to only use required keyword tokens in the policy terms. This will help ensure each platform's rendered filter will contain compatible security policies. + +Keywords +The following are a list of keywords that must be supported by all output generators: + +action:: the action to take when matched. [accept|deny|reject|next|reject-with-tcp-rst] +comment:: a text comment enclosed in double-quotes. The comment can extend over multiple lines if desired, until a closing quote is encountered. +destination-address:: one or more destination address tokens +destination-exclude:: exclude one or more address tokens from the specified destination-address +destination-port:: one or more service definition tokens +icmp-type:: specify icmp-type code to match, see section ICMP TYPES for list of valid arguments +option:: [established|tcp-established|sample|intial|rst|first-fragment] +established - only permit established connections, implements tcp-established if protocol is tcp only, otherwise adds 1024-65535 to required destination-ports. +tcp-established - only permit established tcp connections, usually checked based on TCP flag settings. If protocol UDP is included in term, only adds 1024-65535 to required destination-ports. +sample - not supported by all generators. Samples traffic for netflow. +initial - currently only supported by juniper generator. Appends tcp-initial to the term. +rst - currently only supported by juniper generator. Appends "tcp-flags rst" to the term. +first-fragment - currently only supported by juniper generator. Appends 'first-fragment' to the term. +protocol:: the network protocols this term will match, such as tcp, udp, icmp, or a numeric value. +protocol-except:: network protocols that should be excluded from the protocol specification. This is rarely used. +source-address:: one or more source address tokens +source-exclude:: exclude one or more address tokens from the specified source-address +source-port:: one or more service definition tokens +verbatim:: this specifies that the text enclosed within quotes should be rendered into the output without interpretation or modification. This is sometimes used as a temporary workaround while new required features are being added. +Optionally Supported Keywords +The following are keywords that can be optionally supported by output generators. It is important to note that these may or may not function properly on all generators. + +address:: one or more network address tokens +counter:: juniper only, update a counter for matching packets +destination-prefix:: juniper only, specify destination-prefix matching (e.g. source-prefix:: configured-neighbors-only) +ether-type:: juniper only, specify matching ether-type(e.g. ether-type:: arp) +fragement-offset:: juniper only, specify a fragment offset of a fragmented packet +logging:: supported juniper and iptables/speedway, specify that this packet should be logged via syslog +loss-priority:: juniper only, specify loss priority +packet-length:: juniper only, specify packet length +policer:: juniper only, specify which policer to apply to matching packets +precedence:: juniper only, specify precendence +qos:: juniper only, apply quality of service classification to matching packets (e.g. qos:: af4) +routing-instance:: juniper only, specify routing instance for matching packets +source-interface:: iptables and speedway only, specify specific interface a term should apply to (e.g. source-interface:: eth3) +source-prefix:: juniper only, specify source-prefix matching (e.g. source-prefix:: configured-neighbors-only) +traffic-type:: juniper only, specify traffic-type +Term Examples +The following are examples of how to construct a term, and assumes that naming definition tokens used have been defined in the definitions files. + +Block incoming bogons and spoofed traffic + +term block-bogons { + source-address:: BOGONS RFC1918 + source-address:: COMPANY_INTERNAL + action:: deny +Permit Public to Web Servers + +term permit-to-web-servers { + destination-address:: WEB_SERVERS + destination-port:: HTTP + protocol:: tcp + action:: accept +} +Permit Replies to DNS Servers From Primaries + +term permit-dns-tcp-replies { + source-address:: DNS_PRIMARIES + destination-address:: DNS_SECONDARIES + source-address:: DNS + protocol:: tcp + option:: tcp-established + action:: accept +} +Permit All Corporate Networks, Except New York, to FTP Server + +This will "subtract" the CORP_NYC_NETBLOCK from the CORP_NETBLOCKS token. For example, assume CORP_NETBLOCKS includes 200.0.0.0/20, and CORP_NYC_NETBLOCK is defined as 200.2.0.0/24. The source-exclude will remove the NYC netblock from the permitted source addresses. If the excluded address is not contained with the source address, nothing is changed. + +term allow-inbound-ftp-from-corp { + source-address:: CORP_NETBLOCKS + source-exclude:: CORP_NYC_NETBLOCK + destination-port:: FTP + protocol:: tcp + action:: accept +} +Includes +The policy language supports the use of #include statements. An include can be used to avoid duplication of commonly used text, such as a group of terms that permit or block specific types of traffic. + +An include directive will result in the contents of the included file being injected into the current policy at the exact location of the include directive. + +The include directive has the following format: + +... +#include 'policies/includes/untrusted-networks-blocking.inc' +... +The .inc file extension and "include" directory path are not required, but typically used to help differentiate from typical policy files. + +Example Policy File +Below is an example policy file for a Juniper target platform. It contains two filters, each with a handful of terms. This examples assumes that the network and service naming definition tokens have been defined. + +header { + comment:: "edge input filter for sample network." + target:: juniper edge-inbound +} +term discard-spoofs { + source-address:: RFC1918 + action:: deny +} +term permit-ipsec-access { + source-address:: REMOTE_OFFICES + destination-address:: VPN_HUB + protocol:: 50 + action:: accept +} +term permit-ike-access { + source-address:: REMOTE_OFFICES + destination-address:: VPN_HUB + protocol:: udp + destination-port:: IKE + action:: accept +} +term permit-public-web-access { + destination-address:: WEB_SERVERS + destination-port:: HTTP HTTPS HTTP_8080 + protocol:: tcp + action:: accept +} +term permit-tcp-replies { + option:: tcp-established + action:: accept +} +term default-deny { + action:: deny +} + +header { + comment:: "edge output filter for sample network." + target:: juniper edge-outbound +} +term drop-internal-sourced-outbound { + destination-address:: INTERNAL + destination-address:: RESERVED + action:: deny +} +term reject-internal { + source-address:: INTERNAL + action:: reject +} +term default-accept { + action:: accept +} +ICMP TYPES +The following are the list of icmp-type specifications which can be used with the 'icmp-type::' policy token. + +IPv4 +echo-reply +unreachable +source-quench +redirect +alternate-address +echo-request +router-advertisement +router-solicitation +time-exceeded +parameter-problem +timestamp-request +timestamp-reply +information-request +information-reply +mask-request +mask-reply +conversion-error +mobile-redirect +IPv6 +destination-unreachable +packet-too-big +time-exceeded +parameter-problem +echo-request +echo-reply +multicast-listener-query +multicast-listener-report +multicast-listener-done +router-solicit +router-advertisement +neighbor-solicit +neighbor-advertisement +redirect-message +router-renumbering +icmp-node-information-query +icmp-node-information-response +inverse-neighbor-discovery-solicitation +inverse-neighbor-discovery-advertisement +version-2-multicast-listener-report +home-agent-address-discovery-request +home-agent-address-discovery-reply +mobile-prefix-solicitation +mobile-prefix-advertisement +certification-path-solicitation +certification-path-advertisement +multicast-router-advertisement +multicast-router-solicitation +multicast-router-termination + diff --git a/doc/quick_start.txt b/doc/quick_start.txt new file mode 100644 index 0000000..374d492 --- /dev/null +++ b/doc/quick_start.txt @@ -0,0 +1,48 @@ +--------------- +Introduction +--------------- +This page is intended to provide the necessary information needed to install the libraries and files needed to begin using capirca. + +This page is changing rapidly as the code is migrated from its Google roots and to open source, and while the new structure of the code and usage is finalized. Unfortunately, this page may be frequently out of date for short periods, but we will strive to keep it current. + +In its current form, this page is intended to provide a quick-start guide. See the other wiki pages for more details. + +--------------- +Details +--------------- +Quick Start +In the install directory, simply run: + +python aclgen.py + +This should generate sample output filters for cisco, juniper and iptables from the provided sample.pol policy file and the predefined network and service definitions. + +Optionally, you can provide arguments to the aclgen.py script the specifies a non-default location for naming definition, policy files and filter output directory. + +python aclgen.py --help + +--------------- +Manually Generating Naming, Policy, and Platform Generator Output +--------------- +The following commands can be run from the parent installation directory to manually create a naming definitions object, policy objection, and render generator filter output. + +Import naming library and create naming object from definitions files + + from lib import naming + defs = naming.Naming(‘./def’) + +Import policy library, read in the policy data, and create a policy object + + from lib import policy + conf = open(‘./policies/sample.pol’).read() + pol = policy.ParsePolicy(conf, defs, optimize=True) + +Import a generator library (juniper in this case) and output a policy in the desired format + + from lib import juniper + for header in pol.headers: + if ‘juniper’ in header.platforms: + jcl = True … + if jcl: + output = juniper.Juniper(pol) + print output diff --git a/filters/.save b/filters/.save new file mode 100644 index 0000000..e69de29 diff --git a/filters/sample_srx.srx b/filters/sample_srx.srx new file mode 100644 index 0000000..3c3beaf --- /dev/null +++ b/filters/sample_srx.srx @@ -0,0 +1,73 @@ +security { + zones { + security-zone DMZ { + replace: address-book { + address RFC1918_0 10.0.0.0/8; + address RFC1918_1 172.16.0.0/12; + address RFC1918_2 192.168.0.0/16; + address-set RFC1918 { + address RFC1918_0; + address RFC1918_1; + address RFC1918_2; + } + } + } + } + replace: policies { + /* + $Id: ./filters/sample_srx.srx $ + $Date: 2015/03/26 $ + */ + from-zone Untrust to-zone DMZ { + policy test-tcp { + match { + source-address any; + destination-address [ RFC1918 ]; + application test-tcp-app; + } + then { + permit; + log { + session-init; + } + } + } + policy test-icmp { + match { + source-address any; + destination-address [ RFC1918 ]; + application test-icmp-app; + } + then { + permit; + } + } + policy default-deny { + match { + source-address any; + destination-address any; + application any; + } + then { + deny; + } + } + } + } +} +replace: applications { + application-set test-tcp-app { + application test-tcp-app1; + application test-tcp-app2; + } + application test-tcp-app1 { + term t1 protocol tcp; + } + application test-tcp-app2 { + term t2 protocol udp; + } + application test-icmp-app { + term t1 protocol icmp icmp-type 0 inactivity-timeout 60; + term t2 protocol icmp icmp-type 8 inactivity-timeout 60; + } +} \ No newline at end of file diff --git a/filters/sample_tug_wlc_fw.acl b/filters/sample_tug_wlc_fw.acl new file mode 100644 index 0000000..a8a8905 --- /dev/null +++ b/filters/sample_tug_wlc_fw.acl @@ -0,0 +1,49 @@ +! $Id: ./filters/sample_tug_wlc_fw.acl $ +! $Date: 2015/03/26 $ +no ip access-list extended fw_tug_wlc_protect +ip access-list extended fw_tug_wlc_protect +remark $Id: ./filters/sample_tug_wlc_fw.acl $ +remark $Date: 2015/03/26 $ +remark this is a sample output filter that generates +remark multiplatform for tug wlc protection + + +remark permit-icmp + permit 1 any 109.105.104.16 0.0.0.15 + + +remark permit-traceroute + permit 17 any 109.105.104.16 0.0.0.15 range 33434 33534 + + +remark permit-NORDUnet + permit ip 109.105.96.0 0.0.31.255 109.105.104.16 0.0.0.15 + permit ip host 130.242.82.30 109.105.104.16 0.0.0.15 + permit ip host 130.242.121.137 109.105.104.16 0.0.0.15 + permit ip 193.10.252.0 0.0.0.255 109.105.104.16 0.0.0.15 + permit ip 193.10.254.0 0.0.0.255 109.105.104.16 0.0.0.15 + permit ip 193.11.3.0 0.0.0.255 109.105.104.16 0.0.0.15 + permit ip 194.68.13.0 0.0.0.255 109.105.104.16 0.0.0.15 + + +remark default-deny + deny ip any any + + +no ipv6 access-list fw_tug_wlc_protect +ipv6 access-list fw_tug_wlc_protect +remark $Id: ./filters/sample_tug_wlc_fw.acl $ +remark $Date: 2015/03/26 $ +remark this is a sample output filter that generates +remark multiplatform for tug wlc protection + + +remark Term permit-icmp +remark not rendered due to protocol/AF mismatch. + + +remark default-deny + deny ipv6 any any + + +end diff --git a/filters/sample_tug_wlc_fw.asa b/filters/sample_tug_wlc_fw.asa new file mode 100644 index 0000000..ece52c0 --- /dev/null +++ b/filters/sample_tug_wlc_fw.asa @@ -0,0 +1,27 @@ +clear configure access-list asa_in +access-list asa_in remark $Id: ./filters/sample_tug_wlc_fw.asa $ +access-list asa_in remark $Date: 2015/03/26 $ +access-list asa_in remark this is a sample output filter that generates +access-list asa_in remark multiplatform for tug wlc protection + + +access-list asa_in remark permit-icmp +access-list asa_in extended permit icmp any 109.105.104.16 255.255.255.240 + + +access-list asa_in remark permit-traceroute +access-list asa_in extended permit udp any 109.105.104.16 255.255.255.240 range 33434 33534 + + +access-list asa_in remark permit-NORDUnet +access-list asa_in extended permit ip 109.105.96.0 255.255.224.0 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip host 130.242.82.30 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip host 130.242.121.137 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip 193.10.252.0 255.255.255.0 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip 193.10.254.0 255.255.255.0 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip 193.11.3.0 255.255.255.0 109.105.104.16 255.255.255.240 +access-list asa_in extended permit ip 194.68.13.0 255.255.255.0 109.105.104.16 255.255.255.240 + + +access-list asa_in remark default-deny +access-list asa_in extended deny ip any any \ No newline at end of file diff --git a/filters/sample_tug_wlc_fw.demo b/filters/sample_tug_wlc_fw.demo new file mode 100644 index 0000000..c7a2d52 --- /dev/null +++ b/filters/sample_tug_wlc_fw.demo @@ -0,0 +1,55 @@ +Header { + Name: MUPP { + Type: inet + Comment: this is a sample output filter that generates + Comment: multiplatform for tug wlc protection + Family type: none + } + Term: permit-icmp{ + + Destination IP's + 109.105.104.16/28 + + Protocol + icmp + + Action: allow all traffic + } + + Term: permit-traceroute{ + + Destination IP's + 109.105.104.16/28 + + Destination Ports + 33434-33534 + + Protocol + udp + + Action: allow all traffic + } + + Term: permit-NORDUnet{ + + Source IP's + 109.105.96.0/19 + 130.242.82.30/32 + 130.242.121.137/32 + 193.10.252.0/24 + 193.10.254.0/24 + 193.11.3.0/24 + 194.68.13.0/24 + + Destination IP's + 109.105.104.16/28 + + Action: allow all traffic + } + + Term: default-deny{ + + Action: discard all traffic + } + +} \ No newline at end of file diff --git a/filters/sample_tug_wlc_fw.html b/filters/sample_tug_wlc_fw.html new file mode 100644 index 0000000..c7a2d52 --- /dev/null +++ b/filters/sample_tug_wlc_fw.html @@ -0,0 +1,55 @@ +Header { + Name: MUPP { + Type: inet + Comment: this is a sample output filter that generates + Comment: multiplatform for tug wlc protection + Family type: none + } + Term: permit-icmp{ + + Destination IP's + 109.105.104.16/28 + + Protocol + icmp + + Action: allow all traffic + } + + Term: permit-traceroute{ + + Destination IP's + 109.105.104.16/28 + + Destination Ports + 33434-33534 + + Protocol + udp + + Action: allow all traffic + } + + Term: permit-NORDUnet{ + + Source IP's + 109.105.96.0/19 + 130.242.82.30/32 + 130.242.121.137/32 + 193.10.252.0/24 + 193.10.254.0/24 + 193.11.3.0/24 + 194.68.13.0/24 + + Destination IP's + 109.105.104.16/28 + + Action: allow all traffic + } + + Term: default-deny{ + + Action: discard all traffic + } + +} \ No newline at end of file diff --git a/filters/sample_tug_wlc_fw.ipt b/filters/sample_tug_wlc_fw.ipt new file mode 100644 index 0000000..54bf251 --- /dev/null +++ b/filters/sample_tug_wlc_fw.ipt @@ -0,0 +1,28 @@ +*filter +# Speedway INPUT Policy +# this is a sample output filter that generates +# multiplatform for tug wlc protection +# +# $Id: ./filters/sample_tug_wlc_fw.ipt $ +# $Date: 2015/03/26 $ +# inet +:INPUT DROP +-N I_permit-icmp +-A I_permit-icmp -p icmp -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A INPUT -j I_permit-icmp +-N I_permit-traceroute +-A I_permit-traceroute -p udp --dport 33434:33534 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A INPUT -j I_permit-traceroute +-N I_permit-NORDUnet +-A I_permit-NORDUnet -p all -s 109.105.96.0/19 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 130.242.82.30/32 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 130.242.121.137/32 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 193.10.252.0/24 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 193.10.254.0/24 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 193.11.3.0/24 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A I_permit-NORDUnet -p all -s 194.68.13.0/24 -d 109.105.104.16/28 -m state --state NEW,ESTABLISHED,RELATED -j ACCEPT +-A INPUT -j I_permit-NORDUnet +-N I_default-deny +-A I_default-deny -p all -j DROP +-A INPUT -j I_default-deny +COMMIT diff --git a/filters/sample_tug_wlc_fw.jcl b/filters/sample_tug_wlc_fw.jcl new file mode 100644 index 0000000..0c1e129 --- /dev/null +++ b/filters/sample_tug_wlc_fw.jcl @@ -0,0 +1,62 @@ +firewall { + family inet { + replace: + /* + ** $Id: ./filters/sample_tug_wlc_fw.jcl $ + ** $Date: 2015/03/26 $ + ** + ** this is a sample output filter that generates + ** multiplatform for tug wlc protection + */ + filter fw_tug_wlc_protect { + interface-specific; + term permit-icmp { + from { + destination-address { + 109.105.104.16/28; /* Wireless Controller net */ + } + protocol icmp; + } + then { + accept; + } + } + term permit-traceroute { + from { + destination-address { + 109.105.104.16/28; /* Wireless Controller net */ + } + protocol udp; + destination-port 33434-33534; + } + then { + accept; + } + } + term permit-NORDUnet { + from { + source-address { + 109.105.96.0/19; + 130.242.82.30/32; /* AP FRE POP */ + 130.242.121.137/32; /* AP LULE POP */ + 193.10.252.0/24; + 193.10.254.0/24; + 193.11.3.0/24; + 194.68.13.0/24; + } + destination-address { + 109.105.104.16/28; /* Wireless Controller net */ + } + } + then { + accept; + } + } + term default-deny { + then { + discard; + } + } + } + } +} diff --git a/filters/sample_tug_wlc_fw.srx b/filters/sample_tug_wlc_fw.srx new file mode 100644 index 0000000..f86998c --- /dev/null +++ b/filters/sample_tug_wlc_fw.srx @@ -0,0 +1,96 @@ +security { + zones { + security-zone WLC_net { + replace: address-book { + address NDN_TUG_WLC_NET_0 109.105.104.16/28; + address-set NDN_TUG_WLC_NET { + address NDN_TUG_WLC_NET_0; + } + } + } + security-zone NORDUnet_nets { + replace: address-book { + address NORDUNET_AGGREGATE_0 109.105.96.0/19; + address NORDUNET_AGGREGATE_1 193.10.252.0/24; + address NORDUNET_AGGREGATE_2 193.10.254.0/24; + address NORDUNET_AGGREGATE_3 193.11.3.0/24; + address NORDUNET_AGGREGATE_4 194.68.13.0/24; + address SUNET_AP_STATICS_0 130.242.82.30/32; + address SUNET_AP_STATICS_1 130.242.121.137/32; + address-set NORDUNET_AGGREGATE { + address NORDUNET_AGGREGATE_0; + address NORDUNET_AGGREGATE_1; + address NORDUNET_AGGREGATE_2; + address NORDUNET_AGGREGATE_3; + address NORDUNET_AGGREGATE_4; + } + address-set SUNET_AP_STATICS { + address SUNET_AP_STATICS_0; + address SUNET_AP_STATICS_1; + } + } + } + } + replace: policies { + /* + $Id: ./filters/sample_tug_wlc_fw.srx $ + $Date: 2015/03/26 $ + */ + from-zone NORDUnet_nets to-zone WLC_net { + policy permit-icmp { + match { + source-address any; + destination-address [ NDN_TUG_WLC_NET ]; + application permit-icmp-app; + } + then { + permit; + } + } + policy permit-traceroute { + match { + source-address any; + destination-address [ NDN_TUG_WLC_NET ]; + application permit-traceroute-app; + } + then { + permit; + } + } + policy permit-NORDUnet { + match { + source-address [ NORDUNET_AGGREGATE SUNET_AP_STATICS ]; + destination-address [ NDN_TUG_WLC_NET ]; + application any; + } + then { + permit; + } + } + policy default-deny { + match { + source-address any; + destination-address any; + application any; + } + then { + deny; + } + } + } + } +} +replace: applications { + application-set permit-icmp-app { + application permit-icmp-app1; + } + application permit-icmp-app1 { + term t1 protocol icmp; + } + application-set permit-traceroute-app { + application permit-traceroute-app1; + } + application permit-traceroute-app1 { + term t1 protocol udp destination-port 33434-33534; + } +} \ No newline at end of file diff --git a/lib/COPYING b/lib/COPYING new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/lib/COPYING @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lib/PKG-INFO b/lib/PKG-INFO new file mode 100644 index 0000000..9a074f5 --- /dev/null +++ b/lib/PKG-INFO @@ -0,0 +1,18 @@ +Metadata-Version: 1.0 +Name: capirca +Version: 1.0.0 +Summary: UNKNOWN +Home-page: http://code.google.com/p/capirca/ +Author: Google +Author-email: watson@gmail.com +License: Apache License, Version 2.0 +Description: UNKNOWN +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Topic :: Internet +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: System :: Networking +Classifier: Topic :: Security diff --git a/lib/README b/lib/README new file mode 100644 index 0000000..6442579 --- /dev/null +++ b/lib/README @@ -0,0 +1,10 @@ +Capirca is a system to develop and manage access control lists +for a variety of platforms. +It was developed by Google for internal use, and is now open source. + +Project home page: http://code.google.com/p/capirca/ + +Please send contributions to capirca-dev@googlegroups.com. + +Code should include unit tests and follow the Google Python style guide: +http://code.google.com/p/soc/wiki/PythonStyleGuide diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..4d6ecb9 --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1,31 @@ +# +# Network access control library and utilities +# +# capirca/__init__.py +# +# This package is intended to simplify the process of developing +# and working with large numbers of network access control lists +# for various platforms that share common network and service +# definitions. +# +# from capirca import naming +# from capirca import policy +# from capirca import cisco +# from capirca import juniper +# from capirca import iptables +# from capirca import policyreader +# from capirca import aclcheck +# from capirca import aclgenerator +# from capirca import nacaddr +# from capirca import packetfilter +# from capirca import port +# from capirca import speedway +# + +__version__ = '1.0.0' + +__all__ = ['naming', 'policy', 'cisco', 'juniper', 'iptables', + 'policyreader', 'aclcheck', 'aclgenerator', 'nacaddr', + 'packetfilter', 'port', 'speedway'] + +__author__ = 'Paul (Tony) Watson (watson@gmail.com / watson@google.com)' diff --git a/lib/aclcheck.py b/lib/aclcheck.py new file mode 100755 index 0000000..3e36a99 --- /dev/null +++ b/lib/aclcheck.py @@ -0,0 +1,302 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Check where hosts, ports and protocols are matched in a capirca policy.""" + +__author__ = 'watson@google.com (Tony Watson)' + +import logging +import sys +import nacaddr +import policy +import port + + +class Error(Exception): + """Base error class.""" + + +class AddressError(Error): + """Incorrect IP address or format.""" + + +class BadPolicy(Error): + """Item is not a valid policy object.""" + + +class NoTargetError(Error): + """Specified target platform not available in specified policy.""" + + +class AclCheck(object): + """Check where hosts, ports and protocols match in a NAC policy. + + Args: + pol: + policy.Policy object + src: + string, the source address + dst: + string: the destination address. + sport: + string, the source port. + dport: + string, the destination port. + proto: + string, the protocol. + + Returns: + An AclCheck Object + + Raises: + port.BarPortValue: An invalid source port is used + port.BadPortRange: A port is outside of the acceptable range 0-65535 + AddressError: Incorrect ip address or format + + """ + + def __init__(self, + pol, + src='any', + dst='any', + sport='any', + dport='any', + proto='any', + ): + + self.pol_obj = pol + self.proto = proto + + # validate source port + if sport == 'any': + self.sport = sport + else: + self.sport = port.Port(sport) + + # validate destination port + if dport == 'any': + self.dport = dport + else: + self.dport = port.Port(dport) + + # validate source address + if src == 'any': + self.src = src + else: + try: + self.src = nacaddr.IP(src) + except ValueError: + raise AddressError('bad source address: %s\n' % src) + + # validate destination address + if dst == 'any': + self.dst = dst + else: + try: + self.dst = nacaddr.IP(dst) + except ValueError: + raise AddressError('bad destination address: %s\n' % dst) + + if type(self.pol_obj) is not policy.Policy: + raise BadPolicy('Policy object is not valid.') + + self.matches = [] + self.exact_matches = [] + for header, terms in self.pol_obj.filters: + filtername = header.target[0].options[0] + for term in terms: + possible = [] + logging.debug('checking term: %s', term.name) + if not self._AddrInside(self.src, term.source_address): + logging.debug('srcaddr does not match') + continue + logging.debug('srcaddr matches: %s', self.src) + if not self._AddrInside(self.dst, term.destination_address): + logging.debug('dstaddr does not match') + continue + logging.debug('dstaddr matches: %s', self.dst) + if (self.sport != 'any' and term.source_port and not + self._PortInside(self.sport, term.source_port)): + logging.debug('sport does not match') + continue + logging.debug('sport matches: %s', self.sport) + if (self.dport != 'any' and term.destination_port and not + self._PortInside(self.dport, term.destination_port)): + logging.debug('dport does not match') + continue + logging.debug('dport matches: %s', self.dport) + if (self.proto != 'any' and term.protocol and + self.proto not in term.protocol): + logging.debug('proto does not match') + continue + logging.debug('proto matches: %s', self.proto) + if term.protocol_except and self.proto in term.protocol_except: + logging.debug('protocol excepted by term, no match.') + continue + logging.debug('proto not excepted: %s', self.proto) + if not term.action: # avoid any verbatim + logging.debug('term had no action (verbatim?), no match.') + continue + logging.debug('term has an action') + possible = self._PossibleMatch(term) + self.matches.append(Match(filtername, term.name, possible, term.action, + term.qos)) + if possible: + logging.debug('term has options: %s, not treating as exact match', + possible) + continue + + # if we get here then we have a match, and if the action isn't next and + # there are no possibles, then this is a "definite" match and we needn't + # look for any further matches (i.e. later terms may match, but since + # we'll never get there we shouldn't report them) + if 'next' not in term.action: + self.exact_matches.append(Match(filtername, term.name, [], + term.action, term.qos)) + break + + def Matches(self): + """Return list of matched terms.""" + return self.matches + + def ExactMatches(self): + """Return matched terms, but not terms with possibles or action next.""" + return self.exact_matches + + def ActionMatch(self, action='any'): + """Return list of matched terms with specified actions.""" + match_list = [] + for next in self.matches: + if next.action: + if not next.possibles: + if action is 'any' or action in next.action: + match_list.append(next) + return match_list + + def DescribeMatches(self): + """Provide sentence descriptions of matches. + + Returns: + ret_str: text sentences describing matches + """ + ret_str = [] + for next in self.matches: + text = str(next) + ret_str.append(text) + return '\n'.join(ret_str) + + def __str__(self): + text = [] + last_filter = '' + for next in self.matches: + if next.filter != last_filter: + last_filter = next.filter + text.append(' filter: ' + next.filter) + if next.possibles: + text.append(' ' * 10 + 'term: ' + next.term + ' (possible match)') + else: + text.append(' ' * 10 + 'term: ' + next.term) + if next.possibles: + text.append(' ' * 16 + next.action + ' if ' + str(next.possibles)) + else: + text.append(' ' * 16 + next.action) + return '\n'.join(text) + + def _PossibleMatch(self, term): + """Ignore some options and keywords that are edge cases. + + Args: + term: term object to examine for edge-cases + + Returns: + ret_str: a list of reasons this term may possible match + """ + ret_str = [] + if 'first-fragment' in term.option: + ret_str.append('first-frag') + if term.fragment_offset: + ret_str.append('frag-offset') + if term.packet_length: + ret_str.append('packet-length') + if 'established' in term.option: + ret_str.append('est') + if 'tcp-established' in term.option and 'tcp' in term.protocol: + ret_str.append('tcp-est') + return ret_str + + def _AddrInside(self, addr, addresses): + """Check if address is matched in another address or group of addresses. + + Args: + addr: An ipaddr network or host address or text 'any' + addresses: A list of ipaddr network or host addresses + + Returns: + bool: True of false + """ + if addr is 'any': return True # always true if we match for any addr + if not addresses: return True # always true if term has nothing to match + for next in addresses: + # ipaddr can incorrectly report ipv4 as contained with ipv6 addrs + if type(addr) is type(next): + if addr in next: + return True + return False + + def _PortInside(self, myport, port_list): + """Check if port matches in a port or group of ports. + + Args: + myport: port number + port_list: list of ports + + Returns: + bool: True of false + """ + if myport == 'any': return True + if [x for x in port_list if x[0] <= myport <= x[1]]: + return True + return False + + +class Match(object): + """A matching term and its associate values.""" + + def __init__(self, filtername, term, possibles, action, qos=None): + self.filter = filtername + self.term = term + self.possibles = possibles + self.action = action[0] + self.qos = qos + + def __str__(self): + text = '' + if self.possibles: + text += 'possible ' + self.action + else: + text += self.action + text += ' in term ' + self.term + ' of filter ' + self.filter + if self.possibles: + text += ' with factors: ' + str(', '.join(self.possibles)) + return text + + +def main(): + pass + +if __name__ == '__main__': + main() diff --git a/lib/aclgenerator.py b/lib/aclgenerator.py new file mode 100755 index 0000000..c5be343 --- /dev/null +++ b/lib/aclgenerator.py @@ -0,0 +1,418 @@ +#!/usr/bin/python2.4 +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ACL Generator base class.""" + +import copy +import re +from string import Template + +import policy + + +# generic error class +class Error(Exception): + """Base error class.""" + pass + + +class NoPlatformPolicyError(Error): + """Raised when a policy is received that doesn't support this platform.""" + pass + + +class UnsupportedFilter(Error): + """Raised when we see an inappropriate filter.""" + pass + + +class UnknownIcmpTypeError(Error): + """Raised when we see an unknown icmp-type.""" + pass + + +class MismatchIcmpInetError(Error): + """Raised when mistmatch between icmp/icmpv6 and inet/inet6.""" + pass + + +class EstablishedError(Error): + """Raised when a term has established option with inappropriate protocol.""" + pass + + +class UnsupportedAF(Error): + """Raised when provided an unsupported address family.""" + pass + + +class DuplicateTermError(Error): + """Raised when duplication of term names are detected.""" + pass + + +class UnsupportedFilterError(Error): + """Raised when we see an inappropriate filter.""" + pass + + +class TermNameTooLongError(Error): + """Raised when term named can not be abbreviated.""" + pass + + +class Term(object): + """Generic framework for a generator Term.""" + ICMP_TYPE = policy.Term.ICMP_TYPE + PROTO_MAP = {'ip': 0, + 'icmp': 1, + 'igmp': 2, + 'ggp': 3, + 'ipencap': 4, + 'tcp': 6, + 'egp': 8, + 'igp': 9, + 'udp': 17, + 'rdp': 27, + 'ipv6': 41, + 'ipv6-route': 43, + 'ipv6-frag': 44, + 'rsvp': 46, + 'gre': 47, + 'esp': 50, + 'ah': 51, + 'icmpv6': 58, + 'ipv6-nonxt': 59, + 'ipv6-opts': 60, + 'ospf': 89, + 'ipip': 94, + 'pim': 103, + 'vrrp': 112, + 'l2tp': 115, + 'sctp': 132, + } + AF_MAP = {'inet': 4, + 'inet6': 6, + 'bridge': 4 # if this doesn't exist, output includes v4 & v6 + } + # provide flipped key/value dicts + PROTO_MAP_BY_NUMBER = dict([(v, k) for (k, v) in PROTO_MAP.iteritems()]) + AF_MAP_BY_NUMBER = dict([(v, k) for (k, v) in AF_MAP.iteritems()]) + + NO_AF_LOG_FORMAT = Template('Term $term will not be rendered, as it has' + ' $direction address match specified but no' + ' $direction addresses of $af address family' + ' are present.') + + def NormalizeAddressFamily(self, af): + """Convert (if necessary) address family name to numeric value. + + Args: + af: Address family, can be either numeric or string (e.g. 4 or 'inet') + + Returns: + af: Numeric address family value + + Raises: + UnsupportedAF: Address family not in keys or values of our AF_MAP. + """ + # ensure address family (af) is valid + if af in self.AF_MAP_BY_NUMBER: + return af + elif af in self.AF_MAP: + # convert AF name to number (e.g. 'inet' becomes 4, 'inet6' becomes 6) + af = self.AF_MAP[af] + else: + raise UnsupportedAF('Address family %s is not supported, term %s.' % ( + af, self.term.name)) + return af + + def NormalizeIcmpTypes(self, icmp_types, protocols, af): + """Return verified list of appropriate icmp-types. + + Args: + icmp_types: list of icmp_types + protocols: list of protocols + af: address family of this term, either numeric or text (see self.AF_MAP) + + Returns: + sorted list of numeric icmp-type codes. + + Raises: + UnsupportedFilterError: icmp-types specified with non-icmp protocol. + MismatchIcmpInetError: mismatch between icmp protocol and address family. + UnknownIcmpTypeError: unknown icmp-type specified + """ + if not icmp_types: + return [''] + # only protocols icmp or icmpv6 can be used with icmp-types + if protocols != ['icmp'] and protocols != ['icmpv6']: + raise UnsupportedFilterError('%s %s' % ( + 'icmp-types specified for non-icmp protocols in term: ', + self.term.name)) + # make sure we have a numeric address family (4 or 6) + af = self.NormalizeAddressFamily(af) + # check that addr family and protocl are appropriate + if ((af != 4 and protocols == ['icmp']) or + (af != 6 and protocols == ['icmpv6'])): + raise MismatchIcmpInetError('%s %s' % ( + 'ICMP/ICMPv6 mismatch with address family IPv4/IPv6 in term', + self.term.name)) + # ensure all icmp types are valid + for icmptype in icmp_types: + if icmptype not in self.ICMP_TYPE[af]: + raise UnknownIcmpTypeError('%s %s %s %s' % ( + '\nUnrecognized ICMP-type (', icmptype, + ') specified in term ', self.term.name)) + rval = [] + rval.extend([self.ICMP_TYPE[af][x] for x in icmp_types]) + rval.sort() + return rval + + +class ACLGenerator(object): + """Generates platform specific filters and terms from a policy object. + + This class takes a policy object and renders the output into a syntax which + is understood by a specific platform (eg. iptables, cisco, etc). + """ + + _PLATFORM = None + # Default protocol to apply when no protocol is specified. + _DEFAULT_PROTOCOL = 'ip' + # Unsupported protocols by address family. + _SUPPORTED_AF = set(('inet', 'inet6')) + # Commonly misspelled protocols that the generator should reject. + _FILTER_BLACKLIST = {} + + # Set of required keywords that every generator must support. + _REQUIRED_KEYWORDS = set(['action', + 'comment', + 'destination_address', + 'destination_address_exclude', + 'destination_port', + 'icmp_type', + 'name', # obj attribute, not keyword + 'option', + 'protocol', + 'platform', + 'platform_exclude', + 'source_address', + 'source_address_exclude', + 'source_port', + 'translated', # obj attribute, not keyword + 'verbatim', + ]) + # Generators should redefine this in subclass as optional support is added + _OPTIONAL_SUPPORTED_KEYWORDS = set([]) + + # Abbreviation table used to automatically abbreviate terms that exceed + # specified limit. We use uppercase for abbreviations to distinguish + # from lowercase names. This is order list - we try the ones in the + # top of the list before the ones later in the list. Prefer clear + # or very-space-saving abbreviations by putting them early in the + # list. Abbreviations may be regular expressions or fixed terms; + # prefer fixed terms unless there's a clear benefit to regular + # expressions. + _ABBREVIATION_TABLE = [ + ('bogons', 'BGN'), + ('bogon', 'BGN'), + ('reserved', 'RSV'), + ('rfc1918', 'PRV'), + ('rfc-1918', 'PRV'), + ('internet', 'EXT'), + ('global', 'GBL'), + ('internal', 'INT'), + ('customer', 'CUST'), + ('google', 'GOOG'), + ('ballmer', 'ASS'), + ('microsoft', 'LOL'), + ('china', 'BAN'), + ('border', 'BDR'), + ('service', 'SVC'), + ('router', 'RTR'), + ('transit', 'TRNS'), + ('experiment', 'EXP'), + ('established', 'EST'), + ('unreachable', 'UNR'), + ('fragment', 'FRG'), + ('accept', 'OK'), + ('discard', 'DSC'), + ('reject', 'REJ'), + ('replies', 'ACK'), + ('request', 'REQ'), + ] + # Maximum term length. Can be overriden by generator to enforce + # platform specific restrictions. + _TERM_MAX_LENGTH = 62 + + def __init__(self, pol, exp_info): + """Initialise an ACLGenerator. Store policy structure for processing.""" + object.__init__(self) + + # The default list of valid keyword tokens for generators + self._VALID_KEYWORDS = self._REQUIRED_KEYWORDS.union( + self._OPTIONAL_SUPPORTED_KEYWORDS) + + self.policy = pol + + for header, terms in pol.filters: + if self._PLATFORM in header.platforms: + # Verify valid keywords + # error on unsupported optional keywords that could result + # in dangerous or unexpected results + for term in terms: + # Only verify optional keywords if the term is active on the platform. + err = [] + if term.platform: + if self._PLATFORM not in term.platform: + continue + if term.platform_exclude: + if self._PLATFORM in term.platform_exclude: + continue + for el, val in term.__dict__.items(): + # Private attributes do not need to be valid keywords. + if (val and el not in self._VALID_KEYWORDS + and not el.startswith('flatten')): + err.append(el) + if err: + raise UnsupportedFilterError('%s %s %s %s %s %s' % ('\n', term.name, + 'unsupported optional keywords for target', self._PLATFORM, + 'in policy:', ' '.join(err))) + continue + + self._TranslatePolicy(pol, exp_info) + + def _TranslatePolicy(self, pol, exp_info): + """Translate policy contents to platform specific data structures.""" + raise Error('%s does not implement _TranslatePolicies()' % self._PLATFORM) + + def FixHighPorts(self, term, af='inet', all_protocols_stateful=False): + """Evaluate protocol and ports of term, return sane version of term.""" + mod = term + + # Determine which protocols this term applies to. + if term.protocol: + protocols = set(term.protocol) + else: + protocols = set((self._DEFAULT_PROTOCOL,)) + + # Check that the address family matches the protocols. + if not af in self._SUPPORTED_AF: + raise UnsupportedAF('\nAddress family %s, found in %s, ' + 'unsupported by %s' % (af, term.name, self._PLATFORM)) + if af in self._FILTER_BLACKLIST: + unsupported_protocols = self._FILTER_BLACKLIST[af].intersection(protocols) + if unsupported_protocols: + raise UnsupportedFilter('\n%s targets do not support protocol(s) %s ' + 'with address family %s (in %s)' % + (self._PLATFORM, unsupported_protocols, + af, term.name)) + + # Many renders expect high ports for terms with the established option. + for opt in [str(x) for x in term.option]: + if opt.find('established') == 0: + unstateful_protocols = protocols.difference(set(('tcp', 'udp'))) + if not unstateful_protocols: + # TCP/UDP: add in high ports then collapse to eliminate overlaps. + mod = copy.deepcopy(term) + mod.destination_port.append((1024, 65535)) + mod.destination_port = mod.CollapsePortList(mod.destination_port) + elif not all_protocols_stateful: + errmsg = 'Established option supplied with inappropriate protocol(s)' + raise EstablishedError('%s %s %s %s' % + (errmsg, unstateful_protocols, + 'in term', term.name)) + break + + return mod + + def FixTermLength(self, term_name, abbreviate=False, truncate=False): + """Return a term name which is equal or shorter than _TERM_MAX_LENGTH. + + New term is obtained in two steps. First, if allowed, automatic + abbreviation is performed using hardcoded abbreviation table. Second, + if allowed, term name is truncated to specified limit. + + Args: + term_name: Name to abbreviate if necessary. + abbreviate: Whether to allow abbreviations to shorten the length. + truncate: Whether to allow truncation to shorten the length. + Returns: + A string based on term_name, that is equal or shorter than + _TERM_MAX_LENGTH abbreviated and truncated as necessary. + Raises: + TermNameTooLongError: term_name cannot be abbreviated + to be shorter than _TERM_MAX_LENGTH, or truncation is disabled. + """ + new_term = term_name + if abbreviate: + for word, abbrev in self._ABBREVIATION_TABLE: + if len(new_term) <= self._TERM_MAX_LENGTH: + return new_term + new_term = re.sub(word, abbrev, new_term) + if truncate: + new_term = new_term[:self._TERM_MAX_LENGTH] + if len(new_term) <= self._TERM_MAX_LENGTH: + return new_term + raise TermNameTooLongError('Term %s (originally %s) is ' + 'too long. Limit is %d characters (vs. %d) ' + 'and no abbreviations remain or abbreviations ' + 'disabled.' % + (new_term, term_name, + self._TERM_MAX_LENGTH, + len(new_term))) + + +def AddRepositoryTags(prefix=''): + """Add repository tagging into the output. + + Args: + prefix: comment delimiter, if needed, to appear before tags + Returns: + list of text lines containing revision data + """ + tags = [] + p4_id = '%sId:%s' % ('$', '$') + p4_date = '%sDate:%s' % ('$', '$') + tags.append('%s%s' % (prefix, p4_id)) + tags.append('%s%s' % (prefix, p4_date)) + return tags + + +def WrapWords(textlist, size, joiner='\n'): + """Insert breaks into the listed strings at specified width. + + Args: + textlist: a list of text strings + size: width of reformated strings + joiner: text to insert at break. eg. '\n ' to add an indent. + Returns: + list of strings + """ + # \S*? is a non greedy match to collect words of len > size + # .{1,%d} collects words and spaces up to size in length. + # (?:\s|\Z) ensures that we break on spaces or at end of string. + rval = [] + linelength_re = re.compile(r'(\S*?.{1,%d}(?:\s|\Z))' % size) + for index in range(len(textlist)): + if len(textlist[index]) > size: + # insert joiner into the string at appropriate places. + textlist[index] = joiner.join(linelength_re.findall(textlist[index])) + # avoid empty comment lines + rval.extend(x.strip() for x in textlist[index].strip().split(joiner) if x) + return rval diff --git a/lib/cisco.py b/lib/cisco.py new file mode 100644 index 0000000..0156ce9 --- /dev/null +++ b/lib/cisco.py @@ -0,0 +1,744 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Cisco generator.""" + +__author__ = 'pmoody@google.com (Peter Moody)' +__author__ = 'watson@google.com (Tony Watson)' + +import datetime +import logging +import re + +from third_party import ipaddr +import aclgenerator +import nacaddr + + +_ACTION_TABLE = { + 'accept': 'permit', + 'deny': 'deny', + 'reject': 'deny', + 'next': '! next', + 'reject-with-tcp-rst': 'deny', # tcp rst not supported +} + + +# generic error class +class Error(Exception): + """Generic error class.""" + pass + + +class UnsupportedCiscoAccessListError(Error): + """Raised when we're give a non named access list.""" + pass + + +class StandardAclTermError(Error): + """Raised when there is a problem in a standard access list.""" + pass + + +class TermStandard(object): + """A single standard ACL Term.""" + + def __init__(self, term, filter_name): + self.term = term + self.filter_name = filter_name + self.options = [] + self.logstring = '' + # sanity checking for standard acls + if self.term.protocol: + raise StandardAclTermError( + 'Standard ACLs cannot specify protocols') + if self.term.icmp_type: + raise StandardAclTermError( + 'ICMP Type specifications are not permissible in standard ACLs') + if (self.term.source_address + or self.term.source_address_exclude + or self.term.destination_address + or self.term.destination_address_exclude): + raise StandardAclTermError( + 'Standard ACLs cannot use source or destination addresses') + if self.term.option: + raise StandardAclTermError( + 'Standard ACLs prohibit use of options') + if self.term.source_port or self.term.destination_port: + raise StandardAclTermError( + 'Standard ACLs prohibit use of port numbers') + if self.term.counter: + raise StandardAclTermError( + 'Counters are not implemented in standard ACLs') + if self.term.logging: + logging.warn( + 'WARNING: Standard ACL logging is set in filter %s, term %s and ' + 'may not implemented on all IOS versions', self.filter_name, + self.term.name) + self.logstring = ' log' + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'cisco' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'cisco' in self.term.platform_exclude: + return '' + + ret_str = [] + + # Term verbatim output - this will skip over normal term creation + # code by returning early. Warnings provided in policy.py. + if self.term.verbatim: + for next_verbatim in self.term.verbatim: + if next_verbatim.value[0] == 'cisco': + ret_str.append(str(next_verbatim.value[1])) + return '\n'.join(ret_str) + + v4_addresses = [x for x in self.term.address if type(x) != nacaddr.IPv6] + if self.filter_name.isdigit(): + ret_str.append('access-list %s remark %s' % (self.filter_name, + self.term.name)) + + comment_max_width = 70 + comments = aclgenerator.WrapWords(self.term.comment, comment_max_width) + if comments and comments[0]: + for comment in comments: + ret_str.append('access-list %s remark %s' % (self.filter_name, + comment)) + + action = _ACTION_TABLE.get(str(self.term.action[0])) + if v4_addresses: + for addr in v4_addresses: + if addr.prefixlen == 32: + ret_str.append('access-list %s %s %s%s' % (self.filter_name, + action, + addr.ip, + self.logstring)) + else: + ret_str.append('access-list %s %s %s %s%s' % (self.filter_name, + action, + addr.network, + addr.hostmask, + self.logstring)) + else: + ret_str.append('access-list %s %s %s%s' % (self.filter_name, action, + 'any', self.logstring)) + + else: + ret_str.append('remark ' + self.term.name) + comment_max_width = 70 + comments = aclgenerator.WrapWords(self.term.comment, comment_max_width) + if comments and comments[0]: + for comment in comments: + ret_str.append('remark ' + str(comment)) + + action = _ACTION_TABLE.get(str(self.term.action[0])) + if v4_addresses: + for addr in v4_addresses: + if addr.prefixlen == 32: + ret_str.append(' %s %s%s' % (action, addr.ip, self.logstring)) + else: + ret_str.append(' %s %s %s%s' % (action, addr.network, + addr.hostmask, self.logstring)) + else: + ret_str.append(' %s %s%s' % (action, 'any', self.logstring)) + + return '\n'.join(ret_str) + + +class ObjectGroup(object): + """Used for printing out the object group definitions. + + since the ports don't store the token name information, we have + to fudge their names. ports will be written out like + + object-group ip port - + range + exit + + where as the addressess can be written as + + object-group ip address first-term-source-address + 172.16.0.0 + 172.20.0.0 255.255.0.0 + 172.22.0.0 255.128.0.0 + 172.24.0.0 + 172.28.0.0 + exit + """ + + def __init__(self): + self.filter_name = '' + self.terms = [] + + @property + def valid(self): + # pylint: disable-msg=C6411 + return len(self.terms) > 0 + # pylint: enable-msg=C6411 + + def AddTerm(self, term): + self.terms.append(term) + + def AddName(self, filter_name): + self.filter_name = filter_name + + def __str__(self): + ret_str = ['\n'] + addresses = {} + ports = {} + + for term in self.terms: + # I don't have an easy way get the token name used in the pol file + # w/o reading the pol file twice (with some other library) or doing + # some other ugly hackery. Instead, the entire block of source and dest + # addresses for a given term is given a unique, computable name which + # is not related to the NETWORK.net token name. that's what you get + # for using cisco, which has decided to implement its own meta language. + + # source address + saddrs = term.GetAddressOfVersion('source_address', 4) + # check to see if we've already seen this address. + if saddrs and saddrs[0].parent_token not in addresses: + addresses[saddrs[0].parent_token] = True + ret_str.append('object-group ip address %s' % saddrs[0].parent_token) + for addr in saddrs: + ret_str.append(' %s %s' % (addr.ip, addr.netmask)) + ret_str.append('exit\n') + + # destination address + daddrs = term.GetAddressOfVersion('destination_address', 4) + # check to see if we've already seen this address + if daddrs and daddrs[0].parent_token not in addresses: + addresses[daddrs[0].parent_token] = True + ret_str.append('object-group ip address %s' % daddrs[0].parent_token) + for addr in term.GetAddressOfVersion('destination_address', 4): + ret_str.append(' %s %s' % (addr.ip, addr.netmask)) + ret_str.append('exit\n') + + # source port + for port in term.source_port + term.destination_port: + if not port: + continue + port_key = '%s-%s' % (port[0], port[1]) + if port_key not in ports.keys(): + ports[port_key] = True + ret_str.append('object-group ip port %s' % port_key) + if port[0] != port[1]: + ret_str.append(' range %d %d' % (port[0], port[1])) + else: + ret_str.append(' eq %d' % port[0]) + ret_str.append('exit\n') + + return '\n'.join(ret_str) + + +class ObjectGroupTerm(aclgenerator.Term): + """An individual term of an object-group'd acl. + + Object Group acls are very similar to extended acls in their + syntax except they use a meta language with address/service + definitions. + + eg: + + permit tcp first-term-source-address 179-179 ANY + + where first-term-source-address, ANY and 179-179 are defined elsewhere + in the acl. + """ + + def __init__(self, term, filter_name): + self.term = term + self.filter_name = filter_name + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'cisco' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'cisco' in self.term.platform_exclude: + return '' + + source_address_dict = {} + destination_address_dict = {} + + ret_str = ['\n'] + ret_str.append('remark %s' % self.term.name) + comment_max_width = 70 + comments = aclgenerator.WrapWords(self.term.comment, comment_max_width) + if comments and comments[0]: + for comment in comments: + ret_str.append('remark %s' % str(comment)) + + # Term verbatim output - this will skip over normal term creation + # code by returning early. Warnings provided in policy.py. + if self.term.verbatim: + for next_verbatim in self.term.verbatim: + if next_verbatim.value[0] == 'cisco': + ret_str.append(str(next_verbatim.value[1])) + return '\n'.join(ret_str) + + # protocol + if not self.term.protocol: + protocol = ['ip'] + else: + # pylint: disable-msg=C6402 + protocol = map(self.PROTO_MAP.get, self.term.protocol, self.term.protocol) + # pylint: enable-msg=C6402 + + # addresses + source_address = self.term.source_address + if not self.term.source_address: + source_address = [nacaddr.IPv4('0.0.0.0/0', token='ANY')] + source_address_dict[source_address[0].parent_token] = True + + destination_address = self.term.destination_address + if not self.term.destination_address: + destination_address = [nacaddr.IPv4('0.0.0.0/0', token='ANY')] + destination_address_dict[destination_address[0].parent_token] = True + # ports + source_port = [()] + destination_port = [()] + if self.term.source_port: + source_port = self.term.source_port + if self.term.destination_port: + destination_port = self.term.destination_port + + for saddr in source_address: + for daddr in destination_address: + for sport in source_port: + for dport in destination_port: + for proto in protocol: + ret_str.append( + self._TermletToStr(_ACTION_TABLE.get(str( + self.term.action[0])), proto, saddr, sport, daddr, dport)) + + return '\n'.join(ret_str) + + def _TermletToStr(self, action, proto, saddr, sport, daddr, dport): + """Output a portion of a cisco term/filter only, based on the 5-tuple.""" + # fix addreses + if saddr: + saddr = 'addrgroup %s' % saddr + if daddr: + daddr = 'addrgroup %s' % daddr + # fix ports + if sport: + sport = 'portgroup %d-%d' % (sport[0], sport[1]) + else: + sport = '' + if dport: + dport = 'portgroup %d-%d' % (dport[0], dport[1]) + else: + dport = '' + + return ' %s %s %s %s %s %s' % ( + action, proto, saddr, sport, daddr, dport) + + +class Term(aclgenerator.Term): + """A single ACL Term.""" + + def __init__(self, term, af=4): + self.term = term + self.options = [] + # Our caller should have already verified the address family. + assert af in (4, 6) + self.af = af + self.text_af = self.AF_MAP_BY_NUMBER[self.af] + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'cisco' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'cisco' in self.term.platform_exclude: + return '' + + ret_str = ['\n'] + + # Don't render icmpv6 protocol terms under inet, or icmp under inet6 + if ((self.af == 6 and 'icmp' in self.term.protocol) or + (self.af == 4 and 'icmpv6' in self.term.protocol)): + ret_str.append('remark Term %s' % self.term.name) + ret_str.append('remark not rendered due to protocol/AF mismatch.') + return '\n'.join(ret_str) + + ret_str.append('remark ' + self.term.name) + if self.term.owner: + self.term.comment.append('Owner: %s' % self.term.owner) + for comment in self.term.comment: + for line in comment.split('\n'): + ret_str.append('remark ' + str(line)[:100]) + + # Term verbatim output - this will skip over normal term creation + # code by returning early. Warnings provided in policy.py. + if self.term.verbatim: + for next_verbatim in self.term.verbatim: + if next_verbatim.value[0] == 'cisco': + ret_str.append(str(next_verbatim.value[1])) + return '\n'.join(ret_str) + + # protocol + if not self.term.protocol: + if self.af == 6: + protocol = ['ipv6'] + else: + protocol = ['ip'] + else: + # pylint: disable-msg=C6402 + protocol = map(self.PROTO_MAP.get, self.term.protocol, self.term.protocol) + # pylint: disable-msg=C6402 + + # source address + if self.term.source_address: + source_address = self.term.GetAddressOfVersion('source_address', self.af) + source_address_exclude = self.term.GetAddressOfVersion( + 'source_address_exclude', self.af) + if source_address_exclude: + source_address = nacaddr.ExcludeAddrs( + source_address, + source_address_exclude) + if not source_address: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='source', + af=self.text_af)) + return '' + else: + # source address not set + source_address = ['any'] + + # destination address + if self.term.destination_address: + destination_address = self.term.GetAddressOfVersion( + 'destination_address', self.af) + destination_address_exclude = self.term.GetAddressOfVersion( + 'destination_address_exclude', self.af) + if destination_address_exclude: + destination_address = nacaddr.ExcludeAddrs( + destination_address, + destination_address_exclude) + if not destination_address: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='destination', + af=self.text_af)) + return '' + else: + # destination address not set + destination_address = ['any'] + + # options + opts = [str(x) for x in self.term.option] + if self.PROTO_MAP['tcp'] in protocol and ('tcp-established' in opts or + 'established' in opts): + self.options.extend(['established']) + + # ports + source_port = [()] + destination_port = [()] + if self.term.source_port: + source_port = self.term.source_port + if self.term.destination_port: + destination_port = self.term.destination_port + + # logging + if self.term.logging: + self.options.append('log') + + # icmp-types + icmp_types = [''] + if self.term.icmp_type: + icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type, + self.term.protocol, self.af) + + for saddr in source_address: + for daddr in destination_address: + for sport in source_port: + for dport in destination_port: + for proto in protocol: + for icmp_type in icmp_types: + ret_str.extend(self._TermletToStr( + _ACTION_TABLE.get(str(self.term.action[0])), + proto, + saddr, + sport, + daddr, + dport, + icmp_type, + self.options)) + + return '\n'.join(ret_str) + + def _TermletToStr(self, action, proto, saddr, sport, daddr, dport, + icmp_type, option): + """Take the various compenents and turn them into a cisco acl line. + + Args: + action: str, action + proto: str, protocl + saddr: str or ipaddr, source address + sport: str list or none, the source port + daddr: str or ipaddr, the destination address + dport: str list or none, the destination port + icmp_type: icmp-type numeric specification (if any) + option: list or none, optional, eg. 'logging' tokens. + + Returns: + string of the cisco acl line, suitable for printing. + + Raises: + UnsupportedCiscoAccessListError: When unknown icmp-types specified + """ + # inet4 + if type(saddr) is nacaddr.IPv4 or type(saddr) is ipaddr.IPv4Network: + if saddr.numhosts > 1: + saddr = '%s %s' % (saddr.ip, saddr.hostmask) + else: + saddr = 'host %s' % (saddr.ip) + if type(daddr) is nacaddr.IPv4 or type(daddr) is ipaddr.IPv4Network: + if daddr.numhosts > 1: + daddr = '%s %s' % (daddr.ip, daddr.hostmask) + else: + daddr = 'host %s' % (daddr.ip) + # inet6 + if type(saddr) is nacaddr.IPv6 or type(saddr) is ipaddr.IPv6Network: + if saddr.numhosts > 1: + saddr = '%s' % (saddr.with_prefixlen) + else: + saddr = 'host %s' % (saddr.ip) + if type(daddr) is nacaddr.IPv6 or type(daddr) is ipaddr.IPv6Network: + if daddr.numhosts > 1: + daddr = '%s' % (daddr.with_prefixlen) + else: + daddr = 'host %s' % (daddr.ip) + + # fix ports + if not sport: + sport = '' + elif sport[0] != sport[1]: + sport = 'range %d %d' % (sport[0], sport[1]) + else: + sport = 'eq %d' % (sport[0]) + + if not dport: + dport = '' + elif dport[0] != dport[1]: + dport = 'range %d %d' % (dport[0], dport[1]) + else: + dport = 'eq %d' % (dport[0]) + + if not option: + option = [''] + + # Prevent UDP from appending 'established' to ACL line + sane_options = list(option) + if proto == self.PROTO_MAP['udp'] and 'established' in sane_options: + sane_options.remove('established') + ret_lines = [] + + # str(icmp_type) is needed to ensure 0 maps to '0' instead of FALSE + icmp_type = str(icmp_type) + if icmp_type: + ret_lines.append(' %s %s %s %s %s %s %s %s' % (action, proto, saddr, + sport, daddr, dport, + icmp_type, + ' '.join(sane_options) + )) + else: + ret_lines.append(' %s %s %s %s %s %s %s' % (action, proto, saddr, + sport, daddr, dport, + ' '.join(sane_options) + )) + + # remove any trailing spaces and replace multiple spaces with singles + stripped_ret_lines = [re.sub(r'\s+', ' ', x).rstrip() for x in ret_lines] + return stripped_ret_lines + + +class Cisco(aclgenerator.ACLGenerator): + """A cisco policy object.""" + + _PLATFORM = 'cisco' + _DEFAULT_PROTOCOL = 'ip' + _SUFFIX = '.acl' + + _OPTIONAL_SUPPORTED_KEYWORDS = set(['address', + 'counter', + 'expiration', + 'logging', + 'loss_priority', + 'owner', + 'policer', + 'port', + 'qos', + 'routing_instance', + ]) + + def _TranslatePolicy(self, pol, exp_info): + self.cisco_policies = [] + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + # a mixed filter outputs both ipv4 and ipv6 acls in the same output file + good_filters = ['extended', 'standard', 'object-group', 'inet6', + 'mixed'] + + for header, terms in pol.filters: + if self._PLATFORM not in header.platforms: + continue + + obj_target = ObjectGroup() + + filter_options = header.FilterOptions(self._PLATFORM) + filter_name = header.FilterName(self._PLATFORM) + + # extended is the most common filter type. + filter_type = 'extended' + if len(filter_options) > 1: + filter_type = filter_options[1] + + # check if filter type is renderable + if filter_type not in good_filters: + raise UnsupportedCiscoAccessListError( + 'access list type %s not supported by %s (good types: %s)' % ( + filter_type, self._PLATFORM, str(good_filters))) + + filter_list = [filter_type] + if filter_type == 'mixed': + # Loop through filter and generate output for inet and inet6 in sequence + filter_list = ['extended', 'inet6'] + + for next_filter in filter_list: + if next_filter == 'extended': + try: + if int(filter_name) in range(1, 100) + range(1300, 2000): + raise UnsupportedCiscoAccessListError( + 'Access lists between 1-99 and 1300-1999 are reserved for ' + 'standard ACLs') + except ValueError: + # Extended access list names do not have to be numbers. + pass + if next_filter == 'standard': + try: + if int(filter_name) not in range(1, 100) + range(1300, 2000): + raise UnsupportedCiscoAccessListError( + 'Standard access lists must be numeric in the range of 1-99' + ' or 1300-1999.') + except ValueError: + # Standard access list names do not have to be numbers either. + pass + + new_terms = [] + for term in terms: + term.name = self.FixTermLength(term.name) + af = 'inet' + if next_filter == 'inet6': + af = 'inet6' + term = self.FixHighPorts(term, af=af) + if not term: + continue + + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + + # render terms based on filter type + if next_filter == 'standard': + # keep track of sequence numbers across terms + new_terms.append(TermStandard(term, filter_name)) + elif next_filter == 'extended': + new_terms.append(Term(term)) + elif next_filter == 'object-group': + obj_target.AddTerm(term) + new_terms.append(ObjectGroupTerm(term, filter_name)) + elif next_filter == 'inet6': + new_terms.append(Term(term, 6)) + + self.cisco_policies.append((header, filter_name, [next_filter], + new_terms, obj_target)) + + def __str__(self): + target_header = [] + target = [] + + # add the p4 tags + target.extend(aclgenerator.AddRepositoryTags('! ')) + + for (header, filter_name, filter_list, terms, obj_target + ) in self.cisco_policies: + for filter_type in filter_list: + if filter_type == 'standard': + if filter_name.isdigit(): + target.append('no access-list %s' % filter_name) + else: + target.append('no ip access-list standard %s' % filter_name) + target.append('ip access-list standard %s' % filter_name) + elif filter_type == 'extended': + target.append('no ip access-list extended %s' % filter_name) + target.append('ip access-list extended %s' % filter_name) + elif filter_type == 'object-group': + obj_target.AddName(filter_name) + target.append('no ip access-list extended %s' % filter_name) + target.append('ip access-list extended %s' % filter_name) + elif filter_type == 'inet6': + target.append('no ipv6 access-list %s' % filter_name) + target.append('ipv6 access-list %s' % filter_name) + else: + raise UnsupportedCiscoAccessListError( + 'access list type %s not supported by %s' % ( + filter_type, self._PLATFORM)) + + # Add the Perforce Id/Date tags, these must come after + # remove/re-create of the filter, otherwise config mode doesn't + # know where to place these remarks in the configuration. + if filter_name.isdigit(): + target.extend(aclgenerator.AddRepositoryTags('access-list %s remark ' + % filter_name)) + else: + target.extend(aclgenerator.AddRepositoryTags('remark ')) + + # add a header comment if one exists + for comment in header.comment: + for line in comment.split('\n'): + target.append('remark %s' % line) + + # now add the terms + for term in terms: + term_str = str(term) + if term_str: + target.append(term_str) + target.append('\n') + + if obj_target.valid: + target = [str(obj_target)] + target + # ensure that the header is always first + target = target_header + target + target += ['end', ''] + return '\n'.join(target) diff --git a/lib/ciscoasa.py b/lib/ciscoasa.py new file mode 100644 index 0000000..f3f92b5 --- /dev/null +++ b/lib/ciscoasa.py @@ -0,0 +1,454 @@ +#!/usr/bin/python + + + +"""Cisco ASA renderer.""" + +__author__ = 'antony@slac.stanford.edu (Antonio Ceseracciu)' + +import datetime +import socket +import logging +import re + +from third_party import ipaddr +import aclgenerator +import nacaddr + + +_ACTION_TABLE = { + 'accept': 'permit', + 'deny': 'deny', + 'reject': 'deny', + 'next': '! next', + 'reject-with-tcp-rst': 'deny', # tcp rst not supported + } + + +# generic error class +class Error(Exception): + """Generic error class.""" + pass + + +class UnsupportedCiscoAccessListError(Error): + """Raised when we're give a non named access list.""" + pass + + +class StandardAclTermError(Error): + """Raised when there is a problem in a standard access list.""" + pass + + +class NoCiscoPolicyError(Error): + """Raised when a policy is errantly passed to this module for rendering.""" + pass + + +class Term(aclgenerator.Term): + """A single ACL Term.""" + + + def __init__(self, term, filter_name, af=4): + self.term = term + self.filter_name = filter_name + self.options = [] + assert af in (4, 6) + self.af = af + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'ciscoasa' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'ciscoasa' in self.term.platform_exclude: + return '' + + ret_str = ['\n'] + + # Don't render icmpv6 protocol terms under inet, or icmp under inet6 + if ((self.af == 6 and 'icmp' in self.term.protocol) or + (self.af == 4 and 'icmpv6' in self.term.protocol)): + ret_str.append('remark Term %s' % self.term.name) + ret_str.append('remark not rendered due to protocol/AF mismatch.') + return '\n'.join(ret_str) + + ret_str.append('access-list %s remark %s' % (self.filter_name, + self.term.name)) + if self.term.owner: + self.term.comment.append('Owner: %s' % self.term.owner) + for comment in self.term.comment: + for line in comment.split('\n'): + ret_str.append('access-list %s remark %s' % (self.filter_name, + str(line)[:100])) + + # Term verbatim output - this will skip over normal term creation + # code by returning early. Warnings provided in policy.py. + if self.term.verbatim: + for next in self.term.verbatim: + if next.value[0] == 'ciscoasa': + ret_str.append(str(next.value[1])) + return '\n'.join(ret_str) + + # protocol + if not self.term.protocol: + protocol = ['ip'] + else: + # fix the protocol + protocol = self.term.protocol + + # source address + if self.term.source_address: + source_address = self.term.GetAddressOfVersion('source_address', self.af) + source_address_exclude = self.term.GetAddressOfVersion( + 'source_address_exclude', self.af) + if source_address_exclude: + source_address = nacaddr.ExcludeAddrs( + source_address, + source_address_exclude) + else: + # source address not set + source_address = ['any'] + + # destination address + if self.term.destination_address: + destination_address = self.term.GetAddressOfVersion( + 'destination_address', self.af) + destination_address_exclude = self.term.GetAddressOfVersion( + 'destination_address_exclude', self.af) + if destination_address_exclude: + destination_address = nacaddr.ExcludeAddrs( + destination_address, + destination_address_exclude) + else: + # destination address not set + destination_address = ['any'] + + # options + extra_options = [] + for opt in [str(x) for x in self.term.option]: + if opt.find('tcp-established') == 0 and 6 in protocol: + extra_options.append('established') + elif opt.find('established') == 0 and 6 in protocol: + # only needed for TCP, for other protocols policy.py handles high-ports + extra_options.append('established') + self.options.extend(extra_options) + + # ports + source_port = [()] + destination_port = [()] + if self.term.source_port: + source_port = self.term.source_port + if self.term.destination_port: + destination_port = self.term.destination_port + + # logging + if self.term.logging: + self.options.append('log') + if 'disable' in [x.value for x in self.term.logging]: + self.options.append('disable') + + # icmp-types + icmp_types = [''] + if self.term.icmp_type: + icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type, + self.term.protocol, self.af) + + for saddr in source_address: + for daddr in destination_address: + for sport in source_port: + for dport in destination_port: + for proto in protocol: + for icmp_type in icmp_types: + # only output address family appropriate IP addresses + do_output = False + if self.af == 4: + if (((type(saddr) is nacaddr.IPv4) or (saddr == 'any')) and + ((type(daddr) is nacaddr.IPv4) or (daddr == 'any'))): + do_output = True + if self.af == 6: + if (((type(saddr) is nacaddr.IPv6) or (saddr == 'any')) and + ((type(daddr) is nacaddr.IPv6) or (daddr == 'any'))): + do_output = True + if do_output: + ret_str.extend(self._TermletToStr( + self.filter_name, + _ACTION_TABLE.get(str(self.term.action[0])), + proto, + saddr, + sport, + daddr, + dport, + icmp_type, + self.options)) + + return '\n'.join(ret_str) + + def _TermPortToProtocol (self,portNumber,proto): + + _ASA_PORTS_TCP = { +5190: "aol", +179: "bgp", +19: "chargen", +1494: "citrix-ica", +514: "cmd", +2748: "ctiqbe", +13: "daytime", +9: "discard", +53: "domain", +7: "echo", +512: "exec", +79: "finger", +21: "ftp", +20: "ftp-data", +70: "gopher", +443: "https", +1720: "h323", +101: "hostname", +113: "ident", +143: "imap4", +194: "irc", +750: "kerberos", +543: "klogin", +544: "kshell", +389: "ldap", +636: "ldaps", +515: "lpd", +513: "login", +1352: "lotusnotes", +139: "netbios-ssn", +119: "nntp", +5631: "pcanywhere-data", +496: "pim-auto-rp", +109: "pop2", +110: "pop3", +1723: "pptp", +25: "smtp", +1521: "sqlnet", +22: "ssh", +111: "sunrpc", +49: "tacacs", +517: "talk", +23: "telnet", +540: "uucp", +43: "whois", +80: "www", +2049: "nfs" + } + _ASA_PORTS_UDP = { +512: "biff", +68: "bootpc", +67: "bootps", +9: "discard", +53: "domain", +195: "dnsix", +7: "echo", +500: "isakmp", +750: "kerberos", +434: "mobile-ip", +42: "nameserver", +137: "netbios-ns", +138: "netbios-dgm", +123: "ntp", +5632: "pcanywhere-status", +496: "pim-auto-rp", +1645: "radius", +1646: "radius-acct", +520: "rip", +5510: "secureid-udp", +161: "snmp", +162: "snmptrap", +111: "sunrpc", +514: "syslog", +49: "tacacs", +517: "talk", +69: "tftp", +37: "time", +513: "who", +177: "xdmcp", +2049: "nfs" + } + + _ASA_TYPES_ICMP = { +6: "alternate-address", +31: "conversion-error", +8: "echo", +0: "echo-reply", +16: "information-reply", +15: "information-request", +18: "mask-reply", +17: "mask-request", +32: "mobile-redirect", +12: "parameter-problem", +5: "redirect", +9: "router-advertisement", +10: "router-solicitation", +4: "source-quench", +11: "time-exceeded", +14: "timestamp-reply", +13: "timestamp-request", +30: "traceroute", +3: "unreachable" + } + + + if proto == "tcp": + if portNumber in _ASA_PORTS_TCP: + return _ASA_PORTS_TCP[portNumber] + elif proto == "udp": + if portNumber in _ASA_PORTS_UDP: + return _ASA_PORTS_UDP[portNumber] + elif proto == "icmp": + if portNumber in _ASA_TYPES_ICMP: + return _ASA_TYPES_ICMP[portNumber] + return portNumber + + def _TermletToStr(self, filter_name, action, proto, saddr, sport, daddr, dport, + icmp_type, option): + """Take the various compenents and turn them into a cisco acl line. + + Args: + action: str, action + proto: str, protocl + saddr: str or ipaddr, source address + sport: str list or none, the source port + daddr: str or ipaddr, the destination address + dport: str list or none, the destination port + icmp_type: icmp-type numeric specification (if any) + option: list or none, optional, eg. 'logging' tokens. + + Returns: + string of the cisco acl line, suitable for printing. + """ + + + # inet4 + if type(saddr) is nacaddr.IPv4 or type(saddr) is ipaddr.IPv4Network: + if saddr.numhosts > 1: + saddr = '%s %s' % (saddr.ip, saddr.netmask) + else: + saddr = 'host %s' % (saddr.ip) + if type(daddr) is nacaddr.IPv4 or type(daddr) is ipaddr.IPv4Network: + if daddr.numhosts > 1: + daddr = '%s %s' % (daddr.ip, daddr.netmask) + else: + daddr = 'host %s' % (daddr.ip) + # inet6 + if type(saddr) is nacaddr.IPv6 or type(saddr) is ipaddr.IPv6Network: + if saddr.numhosts > 1: + saddr = '%s/%s' % (saddr.ip, saddr.prefixlen) + else: + saddr = 'host %s' % (saddr.ip) + if type(daddr) is nacaddr.IPv6 or type(daddr) is ipaddr.IPv6Network: + if daddr.numhosts > 1: + daddr = '%s/%s' % (daddr.ip, daddr.prefixlen) + else: + daddr = 'host %s' % (daddr.ip) + + # fix ports + if not sport: + sport = '' + elif sport[0] != sport[1]: + sport = ' range %s %s' % (self._TermPortToProtocol(sport[0],proto), self._TermPortToProtocol(sport[1],proto)) + else: + sport = ' eq %s' % (self._TermPortToProtocol(sport[0],proto)) + + if not dport: + dport = '' + elif dport[0] != dport[1]: + dport = ' range %s %s' % (self._TermPortToProtocol(dport[0],proto), self._TermPortToProtocol(dport[1],proto)) + else: + dport = ' eq %s' % (self._TermPortToProtocol(dport[0],proto)) + + if not option: + option = [''] + + # Prevent UDP from appending 'established' to ACL line + sane_options = list(option) + if proto == 'udp' and 'established' in sane_options: + sane_options.remove('established') + + ret_lines = [] + + # str(icmp_type) is needed to ensure 0 maps to '0' instead of FALSE + icmp_type = str(self._TermPortToProtocol(icmp_type,"icmp")) + + ret_lines.append('access-list %s extended %s %s %s %s %s %s %s %s' % + (filter_name, action, proto, saddr, + sport, daddr, dport, + icmp_type, + ' '.join(sane_options) + )) + + # remove any trailing spaces and replace multiple spaces with singles + stripped_ret_lines = [re.sub('\s+', ' ', x).rstrip() for x in ret_lines] + return stripped_ret_lines + +# return 'access-list %s extended %s %s %s%s %s%s %s' % ( +# filter_name, action, proto, saddr, sport, daddr, dport, ' '.join(option)) + + +class CiscoASA(aclgenerator.ACLGenerator): + """A cisco ASA policy object.""" + + _PLATFORM = 'ciscoasa' + _DEFAULT_PROTOCOL = 'ip' + _SUFFIX = '.asa' + + _OPTIONAL_SUPPORTED_KEYWORDS = set(['expiration', + 'logging', + 'owner', + ]) + + def _TranslatePolicy(self, pol, exp_info): + self.ciscoasa_policies = [] + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + for header, terms in self.policy.filters: + filter_options = header.FilterOptions('ciscoasa') + filter_name = header.FilterName('ciscoasa') + + new_terms = [] + # now add the terms + for term in terms: + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + + new_terms.append(str(Term(term,filter_name))) + + self.ciscoasa_policies.append((header, filter_name, new_terms)) + + def __str__(self): + target_header = [] + target = [] + + for (header, filter_name, terms) in self.ciscoasa_policies: + + target.append('clear configure access-list %s' % filter_name) + + # add the p4 tags + target.extend(aclgenerator.AddRepositoryTags('access-list %s remark ' + % filter_name)) + + # add a header comment if one exists + for comment in header.comment: + for line in comment.split('\n'): + target.append('access-list %s remark %s' % (filter_name,line)) + + # now add the terms + for term in terms: + target.append(str(term)) + + # end for header, filter_name, filter_type... + return '\n'.join(target) + diff --git a/lib/demo.py b/lib/demo.py new file mode 100755 index 0000000..9f35b72 --- /dev/null +++ b/lib/demo.py @@ -0,0 +1,241 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Demo generator for capirca.""" + +__author__ = 'robankeny@google.com (Robert Ankeny)' + + +import datetime +from lib import aclgenerator + + +class Term(aclgenerator.Term): + """Used to create an individual term. + + The __str__ method must be implemented. + + Args: term policy.Term object + + This is created to be a demo. + """ + _ACTIONS = {'accept': 'allow', + 'deny': 'discard', + 'reject': 'say go away to', + 'next': 'pass it onto the next term', + 'reject-with-tcp-rst': 'reset' + } + + def __init__ (self, term, term_type): + self.term = term + self.term_type = term_type + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'demo' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'demo' in self.term.platform_exclude: + return '' + + ret_str = [] + + #NAME + ret_str.append(' ' * 4 + 'Term: '+self.term.name+'{') + + #COMMENTS + if self.term.comment: + ret_str.append(' ') + ret_str.append(' ' * 8 + '#COMMENTS') + for comment in self.term.comment: + for line in comment.split('\n'): + ret_str.append(' ' * 8 + '#'+line) + + #SOURCE ADDRESS + source_address = self.term.GetAddressOfVersion( + 'source_address', self.AF_MAP.get(self.term_type)) + source_address_exclude = self.term.GetAddressOfVersion( + 'source_address_exclude', self.AF_MAP.get(self.term_type)) + if source_address: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Source IP\'s') + for saddr in source_address: + ret_str.append(' ' * 8 + str(saddr)) + + #SOURCE ADDRESS EXCLUDE + if source_address_exclude: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Excluded Source IP\'s') + for ex in source_address: + ret_str.append(' ' * 8 + str(ex)) + + #SOURCE PORT + if self.term.source_port: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Source ports') + ret_str.append(' ' * 8 + self._Group(self.term.source_port)) + + #DESTINATION + destination_address = self.term.GetAddressOfVersion( + 'destination_address', self.AF_MAP.get(self.term_type)) + destination_address_exclude = self.term.GetAddressOfVersion( + 'destination_address_exclude', self.AF_MAP.get(self.term_type)) + if destination_address: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Destination IP\'s') + for daddr in destination_address: + ret_str.append(' ' * 8 + str(daddr)) + + #DESINATION ADDRESS EXCLUDE + if destination_address_exclude: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Excluded Destination IP\'s') + for ex in destination_address_exclude: + ret_str.append(' ' * 8 + str(ex)) + + #DESTINATION PORT + if self.term.destination_port: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Destination Ports') + ret_str.append(' ' * 8 + self._Group(self.term.destination_port)) + + #PROTOCOL + if self.term.protocol: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Protocol') + ret_str.append(' ' * 8 + self._Group(self.term.protocol)) + + #OPTION + if self.term.option: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Options') + for option in self.term.option: + ret_str.append(' ' * 8 + option) + + #ACTION + for action in self.term.action: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Action: ' + + self._ACTIONS.get(str(action))+' all traffic') + return '\n '.join(ret_str) + + def _Group(self, group): + def _FormattedGroup(el): + if isinstance(el, str): + return el.lower() + elif isinstance(el, int): + return str(el) + elif el[0] == el[1]: + return '%d' % el[0] + else: + return '%d-%d' % (el[0], el[1]) + if len(group) > 1: + rval = '' + for item in group: + rval = rval + str(item[0])+' ' + else: + rval = _FormattedGroup(group[0]) + return rval + + +class Demo(aclgenerator.ACLGenerator): + """Demo rendering class. + + This class takes a policy object and renders output into + a syntax which is not useable by routers. This class should + only be used for testing and understanding how to create a + generator of your own. + + Args: + pol: policy.Policy object + Steps to implement this library + 1) Import library in aclgen.py + 2) Create a 3 letter entry in the table in the render_filters + function for the demo library and set it to False + 3) In the for header in policy.headers: use the previous entry + to add an if statement to create a deep copy of the + policy object + 4) Create an if statement that will be used if that specific + policy object is present will pass the policy file + onto the demo Class. + 5) The returned object can be then printed to a file using the + do_output_filter function + 6) Create a policy file with a target set to use demo + """ + _PLATFORM = 'demo' + _SUFFIX = '.demo' + + _OPTIONAL_SUPPORTED_KEYWORDS = set(['expiration',]) + + def _TranslatePolicy(self, pol, exp_info): + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + self.demo_policies = [] + for header, terms in pol.filters: + if not self._PLATFORM in header.platforms: + continue + filter_options = header.FilterOptions('demo') + filter_name = filter_options[0] + if len(filter_options) > 1: + interface_specific = filter_options[1] + else: + interface_specific = 'none' + filter_type = 'inet' + term_names = set() + new_terms = [] + for term in terms: + if term.name in term_names: + raise DemoFilterError('Duplicate term name') + term_names.add(term.name) + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + new_terms.append(Term(term, filter_type)) + self.demo_policies.append((header, filter_name, filter_type, + interface_specific, new_terms)) + + def __str__(self): + target = [] + for (header, filter_name, filter_type, + interface_specific, terms) in self.demo_policies: + target.append('Header {') + target.append(' ' * 4 + 'Name: %s {' % filter_name) + target.append(' ' * 8 + 'Type: %s ' % filter_type) + for comment in header.comment: + for line in comment.split('\n'): + target.append(' ' * 8 + 'Comment: %s'%line) + target.append(' ' * 8 + 'Family type: %s'%interface_specific) + target.append(' ' * 4 +'}') + for term in terms: + target.append(str(term)) + target.append(' ' * 4 +'}') + target.append(' ') + target.append('}') + return '\n'.join(target) + + +class Error(Exception): + pass + +class DemoFilterError(Error): + pass diff --git a/lib/html.py b/lib/html.py new file mode 100755 index 0000000..5fb0bb1 --- /dev/null +++ b/lib/html.py @@ -0,0 +1,233 @@ +#!/usr/bin/python +# +# Copyright 2015 NORDUnet A/S All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""HTML generator for capirca.""" + +__author__ = 'lundberg@nordu.net (Johan Lundberg)' + + +import datetime +import logging +from lib import aclgenerator + + +class Term(aclgenerator.Term): + """Used to create an individual term. + + The __str__ method must be implemented. + + Args: term policy.Term object + + """ + _ACTIONS = {'accept': 'allow', + 'deny': 'discard', + 'reject': 'say go away to', + 'next': 'pass it onto the next term', + 'reject-with-tcp-rst': 'reset' + } + + def __init__ (self, term, term_type): + self.term = term + self.term_type = term_type + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'html' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'html' in self.term.platform_exclude: + return '' + + ret_str = [] + + #NAME + ret_str.append(' ' * 4 + 'Term: '+self.term.name+'{') + + #COMMENTS + if self.term.comment: + ret_str.append(' ') + ret_str.append(' ' * 8 + '#COMMENTS') + for comment in self.term.comment: + for line in comment.split('\n'): + ret_str.append(' ' * 8 + '#'+line) + + #SOURCE ADDRESS + source_address = self.term.GetAddressOfVersion( + 'source_address', self.AF_MAP.get(self.term_type)) + source_address_exclude = self.term.GetAddressOfVersion( + 'source_address_exclude', self.AF_MAP.get(self.term_type)) + if source_address: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Source IP\'s') + for saddr in source_address: + ret_str.append(' ' * 8 + str(saddr)) + + #SOURCE ADDRESS EXCLUDE + if source_address_exclude: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Excluded Source IP\'s') + for ex in source_address: + ret_str.append(' ' * 8 + str(ex)) + + #SOURCE PORT + if self.term.source_port: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Source ports') + ret_str.append(' ' * 8 + self._Group(self.term.source_port)) + + #DESTINATION + destination_address = self.term.GetAddressOfVersion( + 'destination_address', self.AF_MAP.get(self.term_type)) + destination_address_exclude = self.term.GetAddressOfVersion( + 'destination_address_exclude', self.AF_MAP.get(self.term_type)) + if destination_address: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Destination IP\'s') + for daddr in destination_address: + ret_str.append(' ' * 8 + str(daddr)) + + #DESINATION ADDRESS EXCLUDE + if destination_address_exclude: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Excluded Destination IP\'s') + for ex in destination_address_exclude: + ret_str.append(' ' * 8 + str(ex)) + + #DESTINATION PORT + if self.term.destination_port: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Destination Ports') + ret_str.append(' ' * 8 + self._Group(self.term.destination_port)) + + #PROTOCOL + if self.term.protocol: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Protocol') + ret_str.append(' ' * 8 + self._Group(self.term.protocol)) + + #OPTION + if self.term.option: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Options') + for option in self.term.option: + ret_str.append(' ' * 8 + option) + + #ACTION + for action in self.term.action: + ret_str.append(' ') + ret_str.append(' ' * 8 + 'Action: ' + + self._ACTIONS.get(str(action))+' all traffic') + return '\n '.join(ret_str) + + def _Group(self, group): + def _FormattedGroup(el): + if isinstance(el, str): + return el.lower() + elif isinstance(el, int): + return str(el) + elif el[0] == el[1]: + return '%d' % el[0] + else: + return '%d-%d' % (el[0], el[1]) + if len(group) > 1: + rval = '' + for item in group: + rval = rval + str(item[0])+' ' + else: + rval = _FormattedGroup(group[0]) + return rval + + +class HTML(aclgenerator.ACLGenerator): + """HTML rendering class. + + This class takes a policy object and renders output into + a syntax which is not useable by routers. This class should + only be used for visualizing or documenting policies. + + Args: + pol: policy.Policy object + Steps to implement this library + 1) Import library in aclgen.py + 2) Create a 3 letter entry in the table in the render_filters + function for the HTML library and set it to False + 3) In the for header in policy.headers: use the previous entry + to add an if statement to create a deep copy of the + policy object + 4) Create an if statement that will be used if that specific + policy object is present will pass the policy file + onto the HTML Class. + 5) The returned object can be then printed to a file using the + do_output_filter function + 6) Create a policy file with a target set to use HTML + """ + _PLATFORM = 'html' + _SUFFIX = '.html' + + _OPTIONAL_SUPPORTED_KEYWORDS = set(['expiration',]) + + def _TranslatePolicy(self, pol, exp_info): + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + self.html_policies = [] + for header, terms in pol.filters: + if not self._PLATFORM in header.platforms: + continue + filter_options = header.FilterOptions('html') + filter_name = filter_options[0] + if len(filter_options) > 1: + interface_specific = filter_options[1] + else: + interface_specific = 'none' + filter_type = 'inet' + term_names = set() + new_terms = [] + for term in terms: + if term.name in term_names: + raise HTMLFilterError('Duplicate term name') + term_names.add(term.name) + + new_terms.append(Term(term, filter_type)) + self.html_policies.append((header, filter_name, filter_type, + interface_specific, new_terms)) + + def __str__(self): + target = [] + for (header, filter_name, filter_type, + interface_specific, terms) in self.html_policies: + target.append('Header {') + target.append(' ' * 4 + 'Name: %s {' % filter_name) + target.append(' ' * 8 + 'Type: %s ' % filter_type) + for comment in header.comment: + for line in comment.split('\n'): + target.append(' ' * 8 + 'Comment: %s'%line) + target.append(' ' * 8 + 'Family type: %s'%interface_specific) + target.append(' ' * 4 +'}') + for term in terms: + target.append(str(term)) + target.append(' ' * 4 +'}') + target.append(' ') + target.append('}') + return '\n'.join(target) + + +class Error(Exception): + pass + +class HTMLFilterError(Error): + pass diff --git a/lib/ipset.py b/lib/ipset.py new file mode 100644 index 0000000..2ff4fbb --- /dev/null +++ b/lib/ipset.py @@ -0,0 +1,200 @@ +#!/usr/bin/python +# +# Copyright 2013 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Ipset iptables generator. This is a subclass of Iptables generator. + +ipset is a system inside the Linux kernel, which can very efficiently store +and match IPv4 and IPv6 addresses. This can be used to dramatically increase +performace of iptables firewall. + +""" + +__author__ = 'vklimovs@google.com (Vjaceslavs Klimovs)' + +from string import Template + +import iptables +import nacaddr + + +class Error(Exception): + pass + + +class Term(iptables.Term): + """Single Ipset term representation.""" + + _PLATFORM = 'ipset' + _SET_MAX_LENGTH = 31 + _POSTJUMP_FORMAT = None + _PREJUMP_FORMAT = None + _TERM_FORMAT = None + _COMMENT_FORMAT = Template('-A $filter -m comment --comment "$comment"') + _FILTER_TOP_FORMAT = Template('-A $filter') + + def __init__(self, *args, **kwargs): + super(Term, self).__init__(*args, **kwargs) + # This stores tuples of set name and set contents, keyed by direction. + # For example: + # { 'src': ('term_name', [ipaddr object, ipaddr object]), + # 'dst': ('term_name', [ipaddr object, ipaddr object]) } + self.addr_sets = dict() + + def _CalculateAddresses(self, src_addr_list, src_ex_addr_list, + dst_addr_list, dst_ex_addr_list): + """Calculate source and destination address list for a term. + + Since ipset is very efficient at matching large number of + addresses, we never return eny exclude addresses. Instead + least positive match is calculated for both source and destination + addresses. + + For source and destination address list, three cases are possible. + First case is when there is no addresses. In that case we return + _all_ips. + Second case is when there is strictly one address. In that case, + we optimize by not generating a set, and it's then the only + element of returned set. + Third case case is when there is more than one address in a set. + In that case we generate a set and also return _all_ips. Note the + difference to the first case where no set is actually generated. + + Args: + src_addr_list: source address list of the term. + src_ex_addr_list: source address exclude list of the term. + dst_addr_list: destination address list of the term. + dst_ex_addr_list: destination address exclude list of the term. + + Returns: + tuple containing source address list, source exclude address list, + destination address list, destination exclude address list in + that order. + + """ + if not src_addr_list: + src_addr_list = [self._all_ips] + src_addr_list = [src_addr for src_addr in src_addr_list if + src_addr.version == self.AF_MAP[self.af]] + if src_ex_addr_list: + src_ex_addr_list = [src_ex_addr for src_ex_addr in src_ex_addr_list if + src_ex_addr.version == self.AF_MAP[self.af]] + src_addr_list = nacaddr.ExcludeAddrs(src_addr_list, src_ex_addr_list) + if len(src_addr_list) > 1: + set_name = self._GenerateSetName(self.term.name, 'src') + self.addr_sets['src'] = (set_name, src_addr_list) + src_addr_list = [self._all_ips] + + if not dst_addr_list: + dst_addr_list = [self._all_ips] + dst_addr_list = [dst_addr for dst_addr in dst_addr_list if + dst_addr.version == self.AF_MAP[self.af]] + if dst_ex_addr_list: + dst_ex_addr_list = [dst_ex_addr for dst_ex_addr in dst_ex_addr_list if + dst_ex_addr.version == self.AF_MAP[self.af]] + dst_addr_list = nacaddr.ExcludeAddrs(dst_addr_list, dst_ex_addr_list) + if len(dst_addr_list) > 1: + set_name = self._GenerateSetName(self.term.name, 'dst') + self.addr_sets['dst'] = (set_name, dst_addr_list) + dst_addr_list = [self._all_ips] + return (src_addr_list, [], dst_addr_list, []) + + def _GenerateAddressStatement(self, src_addr, dst_addr): + """Return the address section of an individual iptables rule. + + See _CalculateAddresses documentation. Three cases are possible here, + and they map directly to cases in _CalculateAddresses. + First, there can be no addresses for a direction (value is _all_ips then) + In that case we return empty string. + Second there can be stricly one address. In that case we return single + address match (-s or -d). + Third case, is when the value is _all_ips but also the set for particular + direction is present. That's when we return a set match. + + Args: + src_addr: source address of the rule. + dst_addr: destination address of the rule. + + Returns: + tuple containing source and destination address statement, in + that order. + + """ + src_addr_stmt = '' + dst_addr_stmt = '' + if src_addr and dst_addr: + if src_addr == self._all_ips: + if 'src' in self.addr_sets: + src_addr_stmt = ('-m set --set %s src' % self.addr_sets['src'][0]) + else: + src_addr_stmt = '-s %s/%d' % (src_addr.ip, src_addr.prefixlen) + if dst_addr == self._all_ips: + if 'dst' in self.addr_sets: + dst_addr_stmt = ('-m set --set %s dst' % self.addr_sets['dst'][0]) + else: + dst_addr_stmt = '-d %s/%d' % (dst_addr.ip, dst_addr.prefixlen) + return (src_addr_stmt, dst_addr_stmt) + + def _GenerateSetName(self, term_name, suffix): + if self.af == 'inet6': + suffix += '-v6' + if len(term_name) + len(suffix) + 1 > self._SET_MAX_LENGTH: + term_name = term_name[:self._SET_MAX_LENGTH - + (len(term_name) + len(suffix) + 1)] + return term_name + '-' + suffix + + +class Ipset(iptables.Iptables): + """Ipset generator.""" + _PLATFORM = 'ipset' + _SET_TYPE = 'hash:net' + _SUFFIX = '.ips' + _TERM = Term + + def __str__(self): + # Actual rendering happens in __str__, so it has to be called + # before we do set specific part. + iptables_output = iptables.Iptables.__str__(self) + output = [] + for (_, _, _, _, terms) in self.iptables_policies: + for term in terms: + output.extend(self._GenerateSetConfig(term)) + output.append(iptables_output) + return '\n'.join(output) + + def _GenerateSetConfig(self, term): + """Generate set configuration for supplied term. + + Args: + term: input term. + + Returns: + string that is configuration of supplied term. + + """ + output = [] + for direction in sorted(term.addr_sets, reverse=True): + set_hashsize = 2 ** len(term.addr_sets[direction][1]).bit_length() + set_maxelem = 2 ** len(term.addr_sets[direction][1]).bit_length() + output.append('create %s %s family %s hashsize %i maxelem %i' % + (term.addr_sets[direction][0], + self._SET_TYPE, + term.af, + set_hashsize, + set_maxelem)) + for address in term.addr_sets[direction][1]: + output.append('add %s %s' % (term.addr_sets[direction][0], address)) + return output diff --git a/lib/iptables.py b/lib/iptables.py new file mode 100644 index 0000000..d465c74 --- /dev/null +++ b/lib/iptables.py @@ -0,0 +1,789 @@ +#!/usr/bin/python +# +# Copyright 2010 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Iptables generator.""" + +__author__ = 'watson@google.com (Tony Watson)' + +import datetime +import logging +import nacaddr +import re +from string import Template + +import aclgenerator + + +class Term(aclgenerator.Term): + """Generate Iptables policy terms.""" + + # Validate that term does not contain any fields we do not + # support. This prevents us from thinking that our output is + # correct in cases where we've omitted fields from term. + _PLATFORM = 'iptables' + _POSTJUMP_FORMAT = None + _PREJUMP_FORMAT = Template('-A $filter -j $term') + _TERM_FORMAT = Template('-N $term') + _COMMENT_FORMAT = Template('-A $term -m comment --comment "$comment"') + _FILTER_TOP_FORMAT = Template('-A $term') + _ACTION_TABLE = { + 'accept': '-j ACCEPT', + 'deny': '-j DROP', + 'reject': '-j REJECT --reject-with icmp-host-prohibited', + 'reject-with-tcp-rst': '-j REJECT --reject-with tcp-reset', + 'next': '-j RETURN' + } + _PROTO_TABLE = { + 'icmpv6': '-p icmpv6', + 'icmp': '-p icmp', + 'tcp': '-p tcp', + 'udp': '-p udp', + 'all': '-p all', + 'esp': '-p esp', + 'ah': '-p ah', + 'gre': '-p gre', + } + _TCP_FLAGS_TABLE = { + 'syn': 'SYN', + 'ack': 'ACK', + 'fin': 'FIN', + 'rst': 'RST', + 'urg': 'URG', + 'psh': 'PSH', + 'all': 'ALL', + 'none': 'NONE', + } + _KNOWN_OPTIONS_MATCHERS = { + # '! -f' also matches non-fragmented packets. + 'first-fragment': '-m u32 --u32 4&0x3FFF=0x2000', + 'initial': '--syn', + 'tcp-initial': '--syn', + 'sample': '', + } + + def __init__(self, term, filter_name, trackstate, filter_action, af='inet'): + """Setup a new term. + + Args: + term: A policy.Term object to represent in iptables. + filter_name: The name of the filter chan to attach the term to. + trackstate: Specifies if conntrack should be used for new connections + filter_action: The default action of the filter. + af: Which address family ('inet' or 'inet6') to apply the term to. + + Raises: + UnsupportedFilterError: Filter is not supported. + """ + self.trackstate = trackstate + self.term = term # term object + self.filter = filter_name # actual name of filter + self.default_action = filter_action + self.options = [] + self.af = af + + if af == 'inet6': + self._all_ips = nacaddr.IPv6('::/0') + self._ACTION_TABLE['reject'] = '-j REJECT --reject-with adm-prohibited' + else: + self._all_ips = nacaddr.IPv4('0.0.0.0/0') + self._ACTION_TABLE['reject'] = ('-j REJECT --reject-with ' + 'icmp-host-prohibited') + + self.term_name = '%s_%s' % (self.filter[:1], self.term.name) + + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if self._PLATFORM not in self.term.platform: + return '' + if self.term.platform_exclude: + if self._PLATFORM in self.term.platform_exclude: + return '' + + ret_str = [] + + # Don't render icmpv6 protocol terms under inet, or icmp under inet6 + if ((self.af == 'inet6' and 'icmp' in self.term.protocol) or + (self.af == 'inet' and 'icmpv6' in self.term.protocol)): + ret_str.append('# Term %s' % self.term.name) + ret_str.append('# not rendered due to protocol/AF mismatch.') + return '\n'.join(ret_str) + + # Term verbatim output - this will skip over most normal term + # creation code by returning early. Warnings provided in policy.py + if self.term.verbatim: + for next_verbatim in self.term.verbatim: + if next_verbatim.value[0] == self._PLATFORM: + ret_str.append(str(next_verbatim.value[1])) + return '\n'.join(ret_str) + + # We don't support these keywords for filtering, so unless users + # put in a "verbatim:: iptables" statement, any output we emitted + # would misleadingly suggest that we applied their filters. + # Instead, we fail loudly. + if self.term.ether_type: + raise UnsupportedFilterError('\n%s %s %s %s' % ( + 'ether_type unsupported by', self._PLATFORM, + '\nError in term', self.term.name)) + if self.term.address: + raise UnsupportedFilterError('\n%s %s %s %s %s' % ( + 'address unsupported by', self._PLATFORM, + '- specify source or dest', '\nError in term:', self.term.name)) + if self.term.port: + raise UnsupportedFilterError('\n%s %s %s %s %s' % ( + 'port unsupported by', self._PLATFORM, + '- specify source or dest', '\nError in term:', self.term.name)) + + # Create a new term + if self._TERM_FORMAT: + ret_str.append(self._TERM_FORMAT.substitute(term=self.term_name)) + + if self._PREJUMP_FORMAT: + ret_str.append(self._PREJUMP_FORMAT.substitute(filter=self.filter, + term=self.term_name)) + + if self.term.owner: + self.term.comment.append('Owner: %s' % self.term.owner) + # reformat long comments, if needed + # + # iptables allows individual comments up to 256 chars. + # But our generator will limit a single comment line to < 120, using: + # max = 119 - 27 (static chars in comment command) - [length of term name] + comment_max_width = 92 - len(self.term_name) + if comment_max_width < 40: + comment_max_width = 40 + comments = aclgenerator.WrapWords(self.term.comment, comment_max_width) + # append comments to output + if comments and comments[0]: + for line in comments: + if not line: + continue # iptables-restore does not like 0-length comments. + # term comments + ret_str.append(self._COMMENT_FORMAT.substitute(filter=self.filter, + term=self.term_name, + comment=str(line))) + + # if terms does not specify action, use filter default action + if not self.term.action: + self.term.action[0].value = self.default_action + + # Unsupported configuration; in the case of 'accept' or 'next', we + # skip the rule. In other cases, we blow up (raise an exception) + # to ensure that this is not considered valid configuration. + if self.term.source_prefix or self.term.destination_prefix: + if str(self.term.action[0]) not in set(['accept', 'next']): + raise UnsupportedFilterError('%s %s %s %s %s %s %s %s' % ( + '\nTerm', self.term.name, 'has action', str(self.term.action[0]), + 'with source_prefix or destination_prefix,', + ' which is unsupported in', self._PLATFORM, 'iptables output.')) + return ('# skipped %s due to source or destination prefix rule' % + self.term.name) + + # protocol + if self.term.protocol: + protocol = self.term.protocol + else: + protocol = ['all'] + if self.term.protocol_except: + raise UnsupportedFilterError('%s %s %s' % ( + '\n', self.term.name, + 'protocol_except logic not currently supported.')) + + (term_saddr, exclude_saddr, + term_daddr, exclude_daddr) = self._CalculateAddresses( + self.term.source_address, self.term.source_address_exclude, + self.term.destination_address, self.term.destination_address_exclude) + if not term_saddr: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='source', + af=self.af)) + return '' + if not term_daddr: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='destination', + af=self.af)) + return '' + + # ports + source_port = [] + destination_port = [] + if self.term.source_port: + source_port = self.term.source_port + if self.term.destination_port: + destination_port = self.term.destination_port + + # icmp-types + icmp_types = [''] + if self.term.icmp_type: + icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type, protocol, + self.af) + + source_interface = '' + if self.term.source_interface: + source_interface = self.term.source_interface + + destination_interface = '' + if self.term.destination_interface: + destination_interface = self.term.destination_interface + + log_hits = False + if self.term.logging: + # Iptables sends logs to hosts configured syslog + log_hits = True + + # options + tcp_flags = [] + tcp_track_options = [] + for next_opt in [str(x) for x in self.term.option]: + # + # Sanity checking and high-ports are added as appropriate in + # pre-processing that is done in __str__ within class Iptables. + # Option established will add destination port high-ports if protocol + # contains only tcp, udp or both. This is done earlier in class Iptables. + # + if ((next_opt.find('established') == 0 or + next_opt.find('tcp-established') == 0) + and 'ESTABLISHED' not in [x.strip() for x in self.options]): + if next_opt.find('tcp-established') == 0 and protocol != ['tcp']: + raise TcpEstablishedError('%s %s %s' % ( + '\noption tcp-established can only be applied for proto tcp.', + '\nError in term:', self.term.name)) + + if self.trackstate: + # Use nf_conntrack to track state -- works with any proto + self.options.append('-m state --state ESTABLISHED,RELATED') + elif protocol == ['tcp']: + # Simple established-only rule for TCP: Must have ACK field + # (SYN/ACK or subsequent ACK), or RST and no other flags. + tcp_track_options = [(['ACK'], ['ACK']), + (['SYN', 'FIN', 'ACK', 'RST'], ['RST'])] + + # Iterate through flags table, and create list of tcp-flags to append + for next_flag in self._TCP_FLAGS_TABLE: + if next_opt.find(next_flag) == 0: + tcp_flags.append(self._TCP_FLAGS_TABLE.get(next_flag)) + if next_opt in self._KNOWN_OPTIONS_MATCHERS: + self.options.append(self._KNOWN_OPTIONS_MATCHERS[next_opt]) + if self.term.packet_length: + # Policy format is "#-#", but iptables format is "#:#" + self.options.append('-m length --length %s' % + self.term.packet_length.replace('-', ':')) + if self.term.fragment_offset: + self.options.append('-m u32 --u32 4&0x1FFF=%s' % + self.term.fragment_offset.replace('-', ':')) + + for saddr in exclude_saddr: + ret_str.extend(self._FormatPart( + '', saddr, '', '', '', '', '', '', '', '', '', '', + self._ACTION_TABLE.get('next'))) + for daddr in exclude_daddr: + ret_str.extend(self._FormatPart( + '', '', '', daddr, '', '', '', '', '', '', '', '', + self._ACTION_TABLE.get('next'))) + + for saddr in term_saddr: + for daddr in term_daddr: + for icmp in icmp_types: + for proto in protocol: + for tcp_matcher in tcp_track_options or (([], []),): + ret_str.extend(self._FormatPart( + str(proto), + saddr, + source_port, + daddr, + destination_port, + self.options, + tcp_flags, + icmp, + tcp_matcher, + source_interface, + destination_interface, + log_hits, + self._ACTION_TABLE.get(str(self.term.action[0])) + )) + + if self._POSTJUMP_FORMAT: + ret_str.append(self._POSTJUMP_FORMAT.substitute(filter=self.filter, + term=self.term_name)) + + return '\n'.join(str(v) for v in ret_str if v is not '') + + def _CalculateAddresses(self, term_saddr, exclude_saddr, + term_daddr, exclude_daddr): + """Calculate source and destination address list for a term. + + Args: + term_saddr: source address list of the term + exclude_saddr: source address exclude list of the term + term_daddr: destination address list of the term + exclude_daddr: destination address exclude list of the term + + Returns: + tuple containing source address list, source exclude address list, + destination address list, destination exclude address list in + that order + + """ + # source address + term_saddr_excluded = [] + if not term_saddr: + term_saddr = [self._all_ips] + if exclude_saddr: + term_saddr_excluded.extend(nacaddr.ExcludeAddrs(term_saddr, + exclude_saddr)) + + # destination address + term_daddr_excluded = [] + if not term_daddr: + term_daddr = [self._all_ips] + if exclude_daddr: + term_daddr_excluded.extend(nacaddr.ExcludeAddrs(term_daddr, + exclude_daddr)) + + # Just to be safe, always have a result of at least 1 to avoid * by zero + # returning incorrect results (10src*10dst=100, but 10src*0dst=0, not 10) + bailout_count = len(exclude_saddr) + len(exclude_daddr) + ( + (len(self.term.source_address) or 1) * + (len(self.term.destination_address) or 1)) + exclude_count = ((len(term_saddr_excluded) or 1) * + (len(term_daddr_excluded) or 1)) + + # Use bailout jumps for excluded addresses if it results in fewer output + # lines than nacaddr.ExcludeAddrs() method. + if exclude_count < bailout_count: + exclude_saddr = [] + exclude_daddr = [] + if term_saddr_excluded: + term_saddr = term_saddr_excluded + if term_daddr_excluded: + term_daddr = term_daddr_excluded + + # With many sources and destinations, iptables needs to generate the + # cartesian product of sources and destinations. If there are no + # exclude rules, this can instead be written as exclude [0/0 - + # srcs], exclude [0/0 - dsts]. + v4_src_count = len([x for x in term_saddr if x.version == 4]) + v4_dst_count = len([x for x in term_daddr if x.version == 4]) + v6_src_count = len([x for x in term_saddr if x.version == 6]) + v6_dst_count = len([x for x in term_daddr if x.version == 6]) + num_pairs = v4_src_count * v4_dst_count + v6_src_count * v6_dst_count + if num_pairs > 100: + new_exclude_source = nacaddr.ExcludeAddrs([self._all_ips], term_saddr) + new_exclude_dest = nacaddr.ExcludeAddrs([self._all_ips], term_daddr) + # Invert the shortest list that does not already have exclude addresses + if len(new_exclude_source) < len(new_exclude_dest) and not exclude_saddr: + if len(new_exclude_source) + len(term_daddr) < num_pairs: + exclude_saddr = new_exclude_source + term_saddr = [self._all_ips] + elif not exclude_daddr: + if len(new_exclude_dest) + len(term_saddr) < num_pairs: + exclude_daddr = new_exclude_dest + term_daddr = [self._all_ips] + term_saddr = [x for x in term_saddr + if x.version == self.AF_MAP[self.af]] + exclude_saddr = [x for x in exclude_saddr + if x.version == self.AF_MAP[self.af]] + term_daddr = [x for x in term_daddr + if x.version == self.AF_MAP[self.af]] + exclude_daddr = [x for x in exclude_daddr + if x.version == self.AF_MAP[self.af]] + return (term_saddr, exclude_saddr, term_daddr, exclude_daddr) + + def _FormatPart(self, protocol, saddr, sport, daddr, dport, options, + tcp_flags, icmp_type, track_flags, sint, dint, log_hits, + action): + """Compose one iteration of the term parts into a string. + + Args: + protocol: The network protocol + saddr: Source IP address + sport: Source port numbers + daddr: Destination IP address + dport: Destination port numbers + options: Optional arguments to append to our rule + tcp_flags: Which tcp_flag arguments, if any, should be appended + icmp_type: What icmp protocol to allow, if any + track_flags: A tuple of ([check-flags], [set-flags]) arguments to tcp-flag + sint: Optional source interface + dint: Optional destination interface + log_hits: Boolean, to log matches or not + action: What should happen if this rule matches + Returns: + rval: A single iptables argument line + """ + src, dst = self._GenerateAddressStatement(saddr, daddr) + + filter_top = self._FILTER_TOP_FORMAT.substitute(filter=self.filter, + term=self.term_name) + + source_int = '' + if sint: + source_int = '-i %s' % sint + + destination_int = '' + if dint: + destination_int = '-o %s' % dint + + log_jump = '' + if log_hits: + log_jump = '-j LOG --log-prefix %s ' % self.term.name + + if not options: + options = [] + + proto = self._PROTO_TABLE.get(str(protocol)) + # Don't drop protocol if we don't recognize it + if protocol and not proto: + proto = '-p %s' % str(protocol) + + # set conntrack state to NEW, unless policy requested "nostate" + if self.trackstate: + already_stateful = False + # we will add new stateful arguments only if none already exist, such + # as from "option:: established" + for option in options: + if 'state' in option: + already_stateful = True + if not already_stateful: + if 'ACCEPT' in action: + # We have to permit established/related since a policy may not + # have an existing blank permit for established/related, which + # may be more efficient, but slightly less secure. + options.append('-m state --state NEW,ESTABLISHED,RELATED') + + if tcp_flags or (track_flags and track_flags[0]): + check_fields = ','.join(sorted(set(tcp_flags + track_flags[0]))) + set_fields = ','.join(sorted(set(tcp_flags + track_flags[1]))) + flags = '--tcp-flags %s %s' % (check_fields, set_fields) + else: + flags = '' + + icmp_type = str(icmp_type) + if not icmp_type: + icmp = '' + elif str(protocol) == 'icmpv6': + icmp = '--icmpv6-type %s' % icmp_type + else: + icmp = '--icmp-type %s' % icmp_type + + # format tcp and udp ports + sports = dports = [''] + if sport: + sports = self._GeneratePortStatement(sport, source=True) + if dport: + dports = self._GeneratePortStatement(dport, dest=True) + + ret_lines = [] + for sport in sports: + for dport in dports: + rval = [filter_top] + if re.search('multiport', sport) and not re.search('multiport', dport): + # Due to bug in iptables, use of multiport module before a single + # port specification will result in multiport trying to consume it. + # this is a little hack to ensure single ports are listed before + # any multiport specification. + dport, sport = sport, dport + for value in (proto, flags, sport, dport, icmp, src, dst, + ' '.join(options), source_int, destination_int): + if value: + rval.append(str(value)) + if log_jump: + # -j LOG + ret_lines.append(' '.join(rval+[log_jump])) + # -j ACTION + ret_lines.append(' '.join(rval+[action])) + return ret_lines + + def _GenerateAddressStatement(self, saddr, daddr): + """Return the address section of an individual iptables rule. + + Args: + saddr: source address of the rule + daddr: destination address of the rule + + Returns: + tuple containing source and destination address statement, in + that order + + """ + src = '' + dst = '' + if not saddr or saddr == self._all_ips: + src = '' + else: + src = '-s %s/%d' % (saddr.ip, saddr.prefixlen) + if not daddr or daddr == self._all_ips: + dst = '' + else: + dst = '-d %s/%d' % (daddr.ip, daddr.prefixlen) + return (src, dst) + + def _GeneratePortStatement(self, ports, source=False, dest=False): + """Return the 'port' section of an individual iptables rule. + + Args: + ports: list of ports or port ranges (pairs) + source: (bool) generate a source port rule + dest: (bool) generate a dest port rule + + Returns: + list holding the 'port' sections of an iptables rule. + + Raises: + BadPortsError: if too many ports are passed in, or if both 'source' + and 'dest' are true. + NotImplementedError: if both 'source' and 'dest' are true. + """ + if not ports: + return '' + + direction = '' # default: no direction / '--port'. As yet, unused. + if source and dest: + raise BadPortsError('_GeneratePortStatement called ambiguously.') + elif source: + direction = 's' # source port / '--sport' + elif dest: + direction = 'd' # dest port / '--dport' + else: + raise NotImplementedError('--port support not yet implemented.') + + # Normalize ports and get accurate port count. + # iptables multiport module limits to 15, but we use 14 to ensure a range + # doesn't tip us over the limit + max_ports = 14 + norm_ports = [] + portstrings = [] + count = 0 + for port in ports: + if port[0] == port[1]: + norm_ports.append(str(port[0])) + count += 1 + else: + norm_ports.append('%d:%d' % (port[0], port[1])) + count += 2 + if count >= max_ports: + count = 0 + portstrings.append('-m multiport --%sports %s' % (direction, + ','.join(norm_ports))) + norm_ports = [] + if len(norm_ports) == 1: + portstrings.append('--%sport %s' % (direction, norm_ports[0])) + else: + portstrings.append('-m multiport --%sports %s' % (direction, + ','.join(norm_ports))) + return portstrings + + +class Iptables(aclgenerator.ACLGenerator): + """Generates filters and terms from provided policy object.""" + + _PLATFORM = 'iptables' + _DEFAULT_PROTOCOL = 'all' + _SUFFIX = '' + _RENDER_PREFIX = None + _RENDER_SUFFIX = None + _DEFAULTACTION_FORMAT = '-P %s %s' + _DEFAULT_ACTION = 'DROP' + _TERM = Term + _TERM_MAX_LENGTH = 24 + _OPTIONAL_SUPPORTED_KEYWORDS = set(['counter', + 'destination_interface', + 'destination_prefix', # skips these terms + 'expiration', + 'fragment_offset', + 'logging', + 'owner', + 'packet_length', + 'policer', # safely ignored + 'qos', + 'routing_instance', # safe to skip + 'source_interface', + 'source_prefix', # skips these terms + ]) + + def _TranslatePolicy(self, pol, exp_info): + """Translate a policy from objects into strings.""" + self.iptables_policies = [] + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + default_action = None + good_default_actions = ['ACCEPT', 'DROP'] + good_filters = ['INPUT', 'OUTPUT', 'FORWARD'] + good_afs = ['inet', 'inet6'] + good_options = ['nostate', 'abbreviateterms', 'truncateterms'] + all_protocols_stateful = True + + for header, terms in pol.filters: + filter_type = None + if self._PLATFORM not in header.platforms: + continue + + filter_options = header.FilterOptions(self._PLATFORM)[1:] + filter_name = header.FilterName(self._PLATFORM) + + if filter_name not in good_filters: + logging.warn('Filter is generating a non-standard chain that will not ' + 'apply to traffic unless linked from INPUT, OUTPUT or ' + 'FORWARD filters. New chain name is: %s', filter_name) + + # ensure all options after the filter name are expected + for opt in filter_options: + if opt not in good_default_actions + good_afs + good_options: + raise UnsupportedTargetOption('%s %s %s %s' % ( + '\nUnsupported option found in', self._PLATFORM, + 'target definition:', opt)) + + # disable stateful? + if 'nostate' in filter_options: + all_protocols_stateful = False + + # Check for matching af + for address_family in good_afs: + if address_family in filter_options: + # should not specify more than one AF in options + if filter_type is not None: + raise UnsupportedFilterError('%s %s %s %s' % ( + '\nMay only specify one of', good_afs, 'in filter options:', + filter_options)) + filter_type = address_family + if filter_type is None: + filter_type = 'inet' + + if self._PLATFORM == 'iptables' and filter_name == 'FORWARD': + default_action = 'DROP' + + # does this policy override the default filter actions? + for next_target in header.target: + if next_target.platform == self._PLATFORM: + if len(next_target.options) > 1: + for arg in next_target.options: + if arg in good_default_actions: + default_action = arg + if default_action and default_action not in good_default_actions: + raise UnsupportedDefaultAction('%s %s %s %s %s' % ( + '\nOnly', ', '.join(good_default_actions), + 'default filter action allowed;', default_action, 'used.')) + + # add the terms + new_terms = [] + term_names = set() + for term in terms: + term.name = self.FixTermLength(term.name, + 'abbreviateterms' in filter_options, + 'truncateterms' in filter_options) + if term.name in term_names: + raise aclgenerator.DuplicateTermError( + 'You have a duplicate term: %s' % term.name) + term_names.add(term.name) + + term = self.FixHighPorts(term, af=filter_type, + all_protocols_stateful=all_protocols_stateful) + if not term: + continue + + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + + new_terms.append(self._TERM(term, filter_name, all_protocols_stateful, + default_action, filter_type)) + + self.iptables_policies.append((header, filter_name, filter_type, + default_action, new_terms)) + + def __str__(self): + target = [] + pretty_platform = '%s%s' % (self._PLATFORM[0].upper(), self._PLATFORM[1:]) + + if self._RENDER_PREFIX: + target.append(self._RENDER_PREFIX) + + for (header, filter_name, filter_type, default_action, terms + ) in self.iptables_policies: + # Add comments for this filter + target.append('# %s %s Policy' % (pretty_platform, + header.FilterName(self._PLATFORM))) + + # reformat long text comments, if needed + comments = aclgenerator.WrapWords(header.comment, 70) + if comments and comments[0]: + for line in comments: + target.append('# %s' % line) + target.append('#') + # add the p4 tags + target.extend(aclgenerator.AddRepositoryTags('# ')) + target.append('# ' + filter_type) + + # always specify the default filter states for speedway, + # if default action policy not specified for iptables, do nothing. + if self._PLATFORM == 'speedway': + if not default_action: + target.append(self._DEFAULTACTION_FORMAT % (filter_name, + self._DEFAULT_ACTION)) + if default_action: + target.append(self._DEFAULTACTION_FORMAT % (filter_name, + default_action)) + # add the terms + for term in terms: + term_str = str(term) + if term_str: + target.append(term_str) + + if self._RENDER_SUFFIX: + target.append(self._RENDER_SUFFIX) + + target.append('') + return '\n'.join(target) + + +class Error(Exception): + """Base error class.""" + + +class BadPortsError(Error): + """Too many ports for a single iptables statement.""" + + +class UnsupportedFilterError(Error): + """Raised when we see an inappropriate filter.""" + + +class NoIptablesPolicyError(Error): + """Raised when a policy is received that doesn't support iptables.""" + + +class TcpEstablishedError(Error): + """Raised when a term has tcp-established option but not proto tcp only.""" + + +class EstablishedError(Error): + """Raised when a term has established option with inappropriate protocol.""" + + +class UnsupportedDefaultAction(Error): + """Raised when a filter has an impermissible default action specified.""" + + +class UnsupportedTargetOption(Error): + """Raised when a filter has an impermissible default action specified.""" diff --git a/lib/juniper.py b/lib/juniper.py new file mode 100644 index 0000000..f793f34 --- /dev/null +++ b/lib/juniper.py @@ -0,0 +1,727 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__author__ = ['pmoody@google.com (Peter Moody)', + 'watson@google.com (Tony Watson)'] + + +import datetime +import logging + +import aclgenerator +import nacaddr + + +# generic error class +class Error(Exception): + pass + + +class JuniperTermPortProtocolError(Error): + pass + + +class TcpEstablishedWithNonTcp(Error): + pass + + +class JuniperDuplicateTermError(Error): + pass + + +class UnsupportedFilterError(Error): + pass + + +class PrecedenceError(Error): + pass + + +class JuniperIndentationError(Error): + pass + + +class Config(object): + """Config allows a configuration to be assembled easily. + + Configurations are automatically indented following Juniper's style. + A textual representation of the config can be extracted with str(). + + Attributes: + indent: The number of leading spaces on the current line. + tabstop: The number of spaces to indent for a new level. + """ + + def __init__(self, indent=0, tabstop=4): + self.indent = indent + self._initial_indent = indent + self.tabstop = tabstop + self.lines = [] + + def __str__(self): + if self.indent != self._initial_indent: + raise JuniperIndentationError( + 'Expected indent %d but got %d' % (self._initial_indent, self.indent)) + return '\n'.join(self.lines) + + def Append(self, line, verbatim=False): + """Append one line to the configuration. + + Args: + line: The string to append to the config. + verbatim: append line without adjusting indentation. Default False. + Raises: + JuniperIndentationError: If the indentation would be further left + than the initial indent. e.g. too many close braces. + """ + if verbatim: + self.lines.append(line) + return + + if line.endswith('}'): + self.indent -= self.tabstop + if self.indent < self._initial_indent: + raise JuniperIndentationError('Too many close braces.') + spaces = ' ' * self.indent + self.lines.append(spaces + line.strip()) + if line.endswith(' {'): + self.indent += self.tabstop + + +class Term(aclgenerator.Term): + """Representation of an individual Juniper term. + + This is mostly useful for the __str__() method. + + Args: + term: policy.Term object + term_type: the address family for the term, one of "inet", "inet6", + or "bridge" + """ + _DEFAULT_INDENT = 12 + _ACTIONS = {'accept': 'accept', + 'deny': 'discard', + 'reject': 'reject', + 'next': 'next term', + 'reject-with-tcp-rst': 'reject tcp-reset'} + + # the following lookup table is used to map between the various types of + # filters the juniper generator can render. As new differences are + # encountered, they should be added to this table. Accessing members + # of this table looks like: + # self._TERM_TYPE('inet').get('saddr') -> 'source-address' + # + # it's critical that the members of each filter type be the same, that is + # to say that if _TERM_TYPE.get('inet').get('foo') returns something, + # _TERM_TYPE.get('inet6').get('foo') must return the inet6 equivalent. + _TERM_TYPE = {'inet': {'addr': 'address', + 'saddr': 'source-address', + 'daddr': 'destination-address', + 'protocol': 'protocol', + 'protocol-except': 'protocol-except', + 'tcp-est': 'tcp-established'}, + 'inet6': {'addr': 'address', + 'saddr': 'source-address', + 'daddr': 'destination-address', + 'protocol': 'next-header', + 'protocol-except': 'next-header-except', + 'tcp-est': 'tcp-established'}, + 'bridge': {'addr': 'ip-address', + 'saddr': 'ip-source-address', + 'daddr': 'ip-destination-address', + 'protocol': 'ip-protocol', + 'protocol-except': 'ip-protocol-except', + 'tcp-est': 'tcp-flags "(ack|rst)"'}} + + def __init__(self, term, term_type): + self.term = term + self.term_type = term_type + + if term_type not in self._TERM_TYPE: + raise ValueError('Unknown Filter Type: %s' % term_type) + + # some options need to modify the actions + self.extra_actions = [] + + # TODO(pmoody): get rid of all of the default string concatenation here. + # eg, indent(8) + 'foo;' -> '%s%s;' % (indent(8), 'foo'). pyglint likes this + # more. + def __str__(self): + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'juniper' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'juniper' in self.term.platform_exclude: + return '' + + config = Config(indent=self._DEFAULT_INDENT) + from_str = [] + + # Don't render icmpv6 protocol terms under inet, or icmp under inet6 + if ((self.term_type == 'inet6' and 'icmp' in self.term.protocol) or + (self.term_type == 'inet' and 'icmpv6' in self.term.protocol)): + config.Append('/* Term %s' % self.term.name) + config.Append('** not rendered due to protocol/AF mismatch.') + config.Append('*/') + return str(config) + + # comment + # this deals just fine with multi line comments, but we could probably + # output them a little cleaner; do things like make sure the + # len(output) < 80, etc. + if self.term.owner: + self.term.comment.append('Owner: %s' % self.term.owner) + if self.term.comment: + config.Append('/*') + for comment in self.term.comment: + for line in comment.split('\n'): + config.Append('** ' + line) + config.Append('*/') + + # Term verbatim output - this will skip over normal term creation + # code. Warning generated from policy.py if appropriate. + if self.term.verbatim: + for next_term in self.term.verbatim: + if next_term.value[0] == 'juniper': + config.Append(str(next_term.value[1]), verbatim=True) + return str(config) + + # Helper for per-address-family keywords. + family_keywords = self._TERM_TYPE.get(self.term_type) + + # option + # this is going to be a little ugly b/c there are a few little messed + # up options we can deal with. + if self.term.option: + for opt in [str(x) for x in self.term.option]: + # there should be a better way to search the array of protocols + if opt.startswith('sample'): + self.extra_actions.append('sample') + + # only append tcp-established for option established when + # tcp is the only protocol, otherwise other protos break on juniper + elif opt.startswith('established'): + if self.term.protocol == ['tcp']: + if 'tcp-established;' not in from_str: + from_str.append(family_keywords['tcp-est'] + ';') + + # if tcp-established specified, but more than just tcp is included + # in the protocols, raise an error + elif opt.startswith('tcp-established'): + flag = family_keywords['tcp-est'] + ';' + if self.term.protocol == ['tcp']: + if flag not in from_str: + from_str.append(flag) + else: + raise TcpEstablishedWithNonTcp( + 'tcp-established can only be used with tcp protocol in term %s' + % self.term.name) + elif opt.startswith('rst'): + from_str.append('tcp-flags "rst";') + elif opt.startswith('initial') and 'tcp' in self.term.protocol: + from_str.append('tcp-initial;') + elif opt.startswith('first-fragment'): + from_str.append('first-fragment;') + + # we don't have a special way of dealing with this, so we output it and + # hope the user knows what they're doing. + else: + from_str.append('%s;' % opt) + + # term name + config.Append('term %s {' % self.term.name) + + # a default action term doesn't have any from { clause + has_match_criteria = (self.term.address or + self.term.destination_address or + self.term.destination_prefix or + self.term.destination_port or + self.term.precedence or + self.term.protocol or + self.term.protocol_except or + self.term.port or + self.term.source_address or + self.term.source_prefix or + self.term.source_port or + self.term.ether_type or + self.term.traffic_type) + + if has_match_criteria: + config.Append('from {') + + term_af = self.AF_MAP.get(self.term_type) + + # address + address = self.term.GetAddressOfVersion('address', term_af) + if address: + config.Append('%s {' % family_keywords['addr']) + for addr in address: + config.Append('%s;%s' % (addr, self._Comment(addr))) + config.Append('}') + elif self.term.address: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + af=self.term_type)) + return '' + + # source address + source_address, source_address_exclude = self._MinimizePrefixes( + self.term.GetAddressOfVersion('source_address', term_af), + self.term.GetAddressOfVersion('source_address_exclude', term_af)) + + if source_address: + config.Append('%s {' % family_keywords['saddr']) + for addr in source_address: + config.Append('%s;%s' % (addr, self._Comment(addr))) + for addr in source_address_exclude: + config.Append('%s except;%s' % ( + addr, self._Comment(addr, exclude=True))) + config.Append('}') + elif self.term.source_address: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='source', + af=self.term_type)) + return '' + + # destination address + destination_address, destination_address_exclude = self._MinimizePrefixes( + self.term.GetAddressOfVersion('destination_address', term_af), + self.term.GetAddressOfVersion('destination_address_exclude', term_af)) + + if destination_address: + config.Append('%s {' % family_keywords['daddr']) + for addr in destination_address: + config.Append('%s;%s' % (addr, self._Comment(addr))) + for addr in destination_address_exclude: + config.Append('%s except;%s' % ( + addr, self._Comment(addr, exclude=True))) + config.Append('}') + elif self.term.destination_address: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='destination', + af=self.term_type)) + return '' + + # source prefix list + if self.term.source_prefix: + config.Append('source-prefix-list {') + for pfx in self.term.source_prefix: + config.Append(pfx + ';') + config.Append('}') + + # destination prefix list + if self.term.destination_prefix: + config.Append('destination-prefix-list {') + for pfx in self.term.destination_prefix: + config.Append(pfx + ';') + config.Append('}') + + # protocol + if self.term.protocol: + config.Append(family_keywords['protocol'] + + ' ' + self._Group(self.term.protocol)) + + # protocol + if self.term.protocol_except: + config.Append(family_keywords['protocol-except'] + ' ' + + self._Group(self.term.protocol_except)) + + # port + if self.term.port: + config.Append('port %s' % self._Group(self.term.port)) + + # source port + if self.term.source_port: + config.Append('source-port %s' % self._Group(self.term.source_port)) + + # destination port + if self.term.destination_port: + config.Append('destination-port %s' % + self._Group(self.term.destination_port)) + + # append any options beloging in the from {} section + for next_str in from_str: + config.Append(next_str) + + # packet length + if self.term.packet_length: + config.Append('packet-length %s;' % self.term.packet_length) + + # fragment offset + if self.term.fragment_offset: + config.Append('fragment-offset %s;' % self.term.fragment_offset) + + # icmp-types + icmp_types = [''] + if self.term.icmp_type: + icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type, + self.term.protocol, self.term_type) + if icmp_types != ['']: + config.Append('icmp-type %s' % self._Group(icmp_types)) + + if self.term.ether_type: + config.Append('ether-type %s' % + self._Group(self.term.ether_type)) + + if self.term.traffic_type: + config.Append('traffic-type %s' % + self._Group(self.term.traffic_type)) + + if self.term.precedence: + # precedence may be a single integer, or a space separated list + policy_precedences = set() + # precedence values may only be 0 through 7 + for precedence in self.term.precedence: + if int(precedence) in range(0, 8): + policy_precedences.add(precedence) + else: + raise PrecedenceError('Precedence value %s is out of bounds in %s' % + (precedence, self.term.name)) + config.Append('precedence %s' % self._Group(sorted(policy_precedences))) + + config.Append('}') # end from { ... } + + #### + # ACTIONS go below here + #### + config.Append('then {') + # logging + if self.term.logging: + for log_target in self.term.logging: + if str(log_target) == 'local': + config.Append('log;') + else: + config.Append('syslog;') + + if self.term.routing_instance: + config.Append('routing-instance %s;' % self.term.routing_instance) + + if self.term.counter: + config.Append('count %s;' % self.term.counter) + + if self.term.policer: + config.Append('policer %s;' % self.term.policer) + + if self.term.qos: + config.Append('forwarding-class %s;' % self.term.qos) + + if self.term.loss_priority: + config.Append('loss-priority %s;' % self.term.loss_priority) + + for action in self.extra_actions: + config.Append(action + ';') + + # If there is a routing-instance defined, skip reject/accept/etc actions. + if not self.term.routing_instance: + for action in self.term.action: + config.Append(self._ACTIONS.get(action) + ';') + + config.Append('}') # end then{...} + config.Append('}') # end term accept-foo-to-bar { ... } + + return str(config) + + def _MinimizePrefixes(self, include, exclude): + """Calculate a minimal set of prefixes for Juniper match conditions. + + Args: + include: Iterable of nacaddr objects, prefixes to match. + exclude: Iterable of nacaddr objects, prefixes to exclude. + Returns: + A tuple (I,E) where I and E are lists containing the minimized + versions of include and exclude, respectively. The order + of each input list is preserved. + """ + # Remove any included prefixes that have EXACT matches in the + # excluded list. Excluded prefixes take precedence on the router + # regardless of the order in which the include/exclude are applied. + exclude_set = set(exclude) + include_result = [ip for ip in include if ip not in exclude_set] + + # Every address match condition on a Juniper firewall filter + # contains an implicit "0/0 except" or "0::0/0 except". If an + # excluded prefix is not contained within any less-specific prefix + # in the included set, we can elide it. In other words, if the + # next-less-specific prefix is the implicit "default except", + # there is no need to configure the more specific "except". + # + # TODO(kbrint): this could be made more efficient with a Patricia trie. + exclude_result = [] + for exclude_prefix in exclude: + for include_prefix in include_result: + if exclude_prefix in include_prefix: + exclude_result.append(exclude_prefix) + break + + return include_result, exclude_result + + def _Comment(self, addr, exclude=False, line_length=132): + """Returns address comment field if it exists. + + Args: + addr: nacaddr.IPv4 object (?) + exclude: bool - address excludes have different indentations + line_length: integer - this is the length to which a comment will be + truncated, no matter what. ie, a 1000 character comment will be + truncated to line_length, and then split. if 0, the whole comment + is kept. the current default of 132 is somewhat arbitrary. + + Returns: + string + + Notes: + This method tries to intelligently split long comments up. if we've + managed to summarize 4 /32's into a /30, each with a nacaddr text field + of something like 'foobar N', normal concatination would make the + resulting rendered comment look in mondrian like + + source-address { + ... + 1.1.1.0/30; /* foobar1, foobar2, foobar3, foo + bar4 */ + + b/c of the line splitting at 80 chars. this method will split the + comments at word breaks and make the previous example look like + + source-address { + .... + 1.1.1.0/30; /* foobar1, foobar2, foobar3, + ** foobar4 */ + much cleaner. + """ + rval = [] + # indentation, for multi-line comments, ensures that subsquent lines + # are correctly alligned with the first line of the comment. + indentation = 0 + if exclude: + # len('1.1.1.1/32 except;') == 21 + indentation = 21 + self._DEFAULT_INDENT + len(str(addr)) + else: + # len('1.1.1.1/32;') == 14 + indentation = 14 + self._DEFAULT_INDENT + len(str(addr)) + + # length_eol is the width of the line; b/c of the addition of the space + # and the /* characters, it needs to be a little less than the actual width + # to keep from wrapping + length_eol = 77 - indentation + + if isinstance(addr, (nacaddr.IPv4, nacaddr.IPv6)): + if addr.text: + + if line_length == 0: + # line_length of 0 means that we don't want to truncate the comment. + line_length = len(addr.text) + + # There should never be a /* or */, but be safe and ignore those + # comments + if addr.text.find('/*') >= 0 or addr.text.find('*/') >= 0: + logging.debug('Malformed comment [%s] ignoring', addr.text) + else: + + text = addr.text[:line_length] + + comment = ' /*' + while text: + # split the line + if len(text) > length_eol: + new_length_eol = text[:length_eol].rfind(' ') + if new_length_eol <= 0: + new_length_eol = length_eol + else: + new_length_eol = length_eol + + # what line am I gunna output? + line = comment + ' ' + text[:new_length_eol].strip() + # truncate what's left + text = text[new_length_eol:] + # setup the comment and indentation for the next go-round + comment = ' ' * indentation + '**' + + rval.append(line) + + rval[-1] += ' */' + else: + # should we be paying attention to any other addr type? + logging.debug('Ignoring non IPv4 or IPv6 address: %s', addr) + return '\n'.join(rval) + + def _Group(self, group): + """If 1 item return it, else return [ item1 item2 ]. + + Args: + group: a list. could be a list of strings (protocols) or a list of + tuples (ports) + + Returns: + rval: a string surrounded by '[' and '];' if len(group) > 1 + or with just ';' appended if len(group) == 1 + """ + + def _FormattedGroup(el): + """Return the actual formatting of an individual element. + + Args: + el: either a string (protocol) or a tuple (ports) + + Returns: + string: either the lower()'ed string or the ports, hyphenated + if they're a range, or by itself if it's not. + """ + if isinstance(el, str): + return el.lower() + elif isinstance(el, int): + return str(el) + # type is a tuple below here + elif el[0] == el[1]: + return '%d' % el[0] + else: + return '%d-%d' % (el[0], el[1]) + + if len(group) > 1: + rval = '[ ' + ' '.join([_FormattedGroup(x) for x in group]) + ' ];' + else: + rval = _FormattedGroup(group[0]) + ';' + return rval + + +class Juniper(aclgenerator.ACLGenerator): + """JCL rendering class. + + This class takes a policy object and renders the output into a syntax + which is understood by juniper routers. + + Args: + pol: policy.Policy object + """ + + _PLATFORM = 'juniper' + _DEFAULT_PROTOCOL = 'ip' + _SUPPORTED_AF = set(('inet', 'inet6', 'bridge')) + _SUFFIX = '.jcl' + + _OPTIONAL_SUPPORTED_KEYWORDS = set(['address', + 'counter', + 'destination_prefix', + 'ether_type', + 'expiration', + 'fragment_offset', + 'logging', + 'loss_priority', + 'owner', + 'packet_length', + 'policer', + 'port', + 'precedence', + 'protocol_except', + 'qos', + 'routing_instance', + 'source_prefix', + 'traffic_type', + ]) + + def _TranslatePolicy(self, pol, exp_info): + self.juniper_policies = [] + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + for header, terms in pol.filters: + if self._PLATFORM not in header.platforms: + continue + + filter_options = header.FilterOptions(self._PLATFORM) + filter_name = header.FilterName(self._PLATFORM) + + # Checks if the non-interface-specific option was specified. + # I'm assuming that it will be specified as maximum one time, and + # don't check for more appearances of the word in the options. + interface_specific = 'not-interface-specific' not in filter_options[1:] + + # Remove the option so that it is not confused with a filter type + if not interface_specific: + filter_options.remove('not-interface-specific') + + # default to inet4 filters + filter_type = 'inet' + if len(filter_options) > 1: + filter_type = filter_options[1] + + term_names = set() + new_terms = [] + for term in terms: + term.name = self.FixTermLength(term.name) + if term.name in term_names: + raise JuniperDuplicateTermError('You have multiple terms named: %s' % + term.name) + term_names.add(term.name) + + term = self.FixHighPorts(term, af=filter_type) + if not term: + continue + + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + + new_terms.append(Term(term, filter_type)) + + self.juniper_policies.append((header, filter_name, filter_type, + interface_specific, new_terms)) + + def __str__(self): + config = Config() + + for (header, filter_name, filter_type, interface_specific, terms + ) in self.juniper_policies: + # add the header information + config.Append('firewall {') + config.Append('family %s {' % filter_type) + config.Append('replace:') + config.Append('/*') + + # we want the acl to contain id and date tags, but p4 will expand + # the tags here when we submit the generator, so we have to trick + # p4 into not knowing these words. like taking c-a-n-d-y from a + # baby. + for line in aclgenerator.AddRepositoryTags('** '): + config.Append(line) + config.Append('**') + + for comment in header.comment: + for line in comment.split('\n'): + config.Append('** ' + line) + config.Append('*/') + + config.Append('filter %s {' % filter_name) + if interface_specific: + config.Append('interface-specific;') + + for term in terms: + term_str = str(term) + if term_str: + config.Append(term_str, verbatim=True) + + config.Append('}') # filter { ... } + config.Append('}') # family inet { ... } + config.Append('}') # firewall { ... } + + return str(config) + '\n' diff --git a/lib/junipersrx.py b/lib/junipersrx.py new file mode 100644 index 0000000..c2e0676 --- /dev/null +++ b/lib/junipersrx.py @@ -0,0 +1,448 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""SRX generator.""" +# pylint: disable-msg=W0231 + +__author__ = 'robankeny@google.com (Robert Ankeny)' + +import collections +import datetime +import logging + +import aclgenerator +import nacaddr + + +class Error(Exception): + """generic error class.""" + + +class UnsupportedFilterError(Error): + pass + + +class UnsupportedHeader(Error): + pass + + +class SRXDuplicateTermError(Error): + pass + + +class SRXVerbatimError(Error): + pass + + +class SRXOptionError(Error): + pass + + +class Term(aclgenerator.Term): + """Representation of an individual SRX term. + + This is mostly useful for the __str__() method. + + Args: + obj: a policy.Term object + term_type: type of filter to generate, e.g. inet or inet6 + filter_options: list of remaining target options (zones) + """ + + _ACTIONS = {'accept': 'permit', + 'deny': 'deny', + 'reject': 'reject', + 'count': 'count', + 'log': 'log'} + + def __init__(self, term, term_type, zones): + self.term = term + self.term_type = term_type + self.from_zone = zones[1] + self.to_zone = zones[3] + self.extra_actions = [] + + def __str__(self): + """Render config output from this term object.""" + # Verify platform specific terms. Skip whole term if platform does not + # match. + if self.term.platform: + if 'srx' not in self.term.platform: + return '' + if self.term.platform_exclude: + if 'srx' in self.term.platform_exclude: + return '' + ret_str = [] + + # COMMENTS + comment_max_width = 68 + if self.term.owner: + self.term.comment.append('Owner: %s' % self.term.owner) + comments = aclgenerator.WrapWords(self.term.comment, comment_max_width) + if comments and comments[0]: + ret_str.append(JuniperSRX.INDENT * 3 + '/*') + for line in comments: + ret_str.append(JuniperSRX.INDENT * 3 + line) + ret_str.append(JuniperSRX.INDENT * 3 + '*/') + + ret_str.append(JuniperSRX.INDENT * 3 + 'policy ' + self.term.name + ' {') + ret_str.append(JuniperSRX.INDENT * 4 + 'match {') + + # SOURCE-ADDRESS + if self.term.source_address: + saddr_check = set() + for saddr in self.term.source_address: + saddr_check.add(saddr.parent_token) + saddr_check = sorted(saddr_check) + source_address_string = '' + for addr in saddr_check: + source_address_string += addr + ' ' + ret_str.append(JuniperSRX.INDENT * 5 + 'source-address [ ' + + source_address_string + '];') + else: + ret_str.append(JuniperSRX.INDENT * 5 + 'source-address any;') + + # DESTINATION-ADDRESS + if self.term.destination_address: + daddr_check = [] + for daddr in self.term.destination_address: + daddr_check.append(daddr.parent_token) + daddr_check = set(daddr_check) + daddr_check = list(daddr_check) + daddr_check.sort() + destination_address_string = '' + for addr in daddr_check: + destination_address_string += addr + ' ' + ret_str.append(JuniperSRX.INDENT * 5 + 'destination-address [ ' + + destination_address_string + '];') + else: + ret_str.append(JuniperSRX.INDENT * 5 + 'destination-address any;') + + # APPLICATION + if (not self.term.source_port and not self.term.destination_port and not + self.term.icmp_type and not self.term.protocol): + ret_str.append(JuniperSRX.INDENT * 5 + 'application any;') + else: + ret_str.append(JuniperSRX.INDENT * 5 + 'application ' + self.term.name + + '-app;') + + ret_str.append(JuniperSRX.INDENT * 4 + '}') + + # ACTIONS + for action in self.term.action: + ret_str.append(JuniperSRX.INDENT * 4 + 'then {') + ret_str.append(JuniperSRX.INDENT * 5 + self._ACTIONS.get( + str(action)) + ';') + + # LOGGING + if self.term.logging: + ret_str.append(JuniperSRX.INDENT * 5 + 'log {') + ret_str.append(JuniperSRX.INDENT * 6 + 'session-init;') + ret_str.append(JuniperSRX.INDENT * 5 + '}') + ret_str.append(JuniperSRX.INDENT * 4 + '}') + + ret_str.append(JuniperSRX.INDENT * 3 + '}') + + # OPTIONS + if self.term.option: + raise SRXOptionError('Options are not implemented yet, please remove ' + + 'from term %s' % self.term.name) + + # VERBATIM + if self.term.verbatim: + raise SRXVerbatimError('Verbatim is not implemented, please remove ' + + 'the offending term %s.' % self.term.name) + return '\n'.join(ret_str) + + def _Group(self, group): + """If 1 item return it, else return [ item1 item2 ]. + + Args: + group: a list. could be a list of strings (protocols) or a list of + tuples (ports) + + Returns: + rval: a string surrounded by '[' and '];' if len(group) > 1 + or with just ';' appended if len(group) == 1 + """ + + def _FormattedGroup(el): + """Return the actual formatting of an individual element. + + Args: + el: either a string (protocol) or a tuple (ports) + + Returns: + string: either the lower()'ed string or the ports, hyphenated + if they're a range, or by itself if it's not. + """ + if isinstance(el, str): + return el.lower() + elif isinstance(el, int): + return str(el) + # type is a tuple below here + elif el[0] == el[1]: + return '%d' % el[0] + else: + return '%d-%d' % (el[0], el[1]) + + if len(group) > 1: + rval = '[ ' + ' '.join([_FormattedGroup(x) for x in group]) + ' ];' + else: + rval = _FormattedGroup(group[0]) + ';' + return rval + + +class JuniperSRX(aclgenerator.ACLGenerator): + """SRX rendering class. + + This class takes a policy object and renders the output into a syntax + which is understood by SRX firewalls. + + Args: + pol: policy.Policy object + """ + + _PLATFORM = 'srx' + _SUFFIX = '.srx' + _SUPPORTED_AF = set(('inet',)) + _OPTIONAL_SUPPORTED_KEYWORDS = set(['expiration', + 'logging', + 'owner', + 'routing_instance', # safe to skip + 'timeout' + ]) + INDENT = ' ' + + def _TranslatePolicy(self, pol, exp_info): + """Transform a policy object into a JuniperSRX object. + + Args: + pol: policy.Policy object + exp_info: print a info message when a term is set to expire + in that many weeks + + Raises: + UnsupportedFilterError: An unsupported filter was specified + UnsupportedHeader: A header option exists that is not understood/usable + SRXDuplicateTermError: Two terms were found with same name in same filter + """ + self.srx_policies = [] + self.addressbook = collections.OrderedDict() + self.applications = [] + self.ports = [] + self.from_zone = '' + self.to_zone = '' + + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + for header, terms in pol.filters: + if self._PLATFORM not in header.platforms: + continue + + filter_options = header.FilterOptions(self._PLATFORM) + + if (len(filter_options) < 4 or filter_options[0] != 'from-zone' or + filter_options[2] != 'to-zone'): + raise UnsupportedFilterError( + 'SRX filter arguments must specify from-zone and to-zone.') + self.from_zone = filter_options[1] + self.to_zone = filter_options[3] + + if len(filter_options) > 4: + filter_type = filter_options[4] + else: + filter_type = 'inet' + if filter_type not in self._SUPPORTED_AF: + raise UnsupportedHeader( + 'SRX Generator currently does not support %s as a header option' % + (filter_type)) + + term_dup_check = set() + new_terms = [] + for term in terms: + term.name = self.FixTermLength(term.name) + if term.name in term_dup_check: + raise SRXDuplicateTermError('You have a duplicate term: %s' + % term.name) + term_dup_check.add(term.name) + + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s>%s expires ' + 'in less than two weeks.', term.name, self.from_zone, + self.to_zone) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s>%s is expired.', + term.name, self.from_zone, self.to_zone) + + for i in term.source_address_exclude: + term.source_address = nacaddr.RemoveAddressFromList( + term.source_address, i) + for i in term.destination_address_exclude: + term.destination_address = nacaddr.RemoveAddressFromList( + term.destination_address, i) + + for addr in term.source_address: + self._BuildAddressBook(self.from_zone, addr) + for addr in term.destination_address: + self._BuildAddressBook(self.to_zone, addr) + + new_term = Term(term, filter_type, filter_options) + new_terms.append(new_term) + tmp_icmptype = new_term.NormalizeIcmpTypes( + term.icmp_type, term.protocol, filter_type) + # NormalizeIcmpTypes returns [''] for empty, convert to [] for eval + normalized_icmptype = tmp_icmptype if tmp_icmptype != [''] else [] + # rewrites the protocol icmpv6 to icmp6 + if 'icmpv6' in term.protocol: + protocol = list(term.protocol) + protocol[protocol.index('icmpv6')] = 'icmp6' + else: + protocol = term.protocol + self.applications.append({'sport': self._BuildPort(term.source_port), + 'dport': self._BuildPort( + term.destination_port), + 'name': term.name, + 'protocol': protocol, + 'icmp-type': normalized_icmptype, + 'timeout': term.timeout}) + self.srx_policies.append((header, new_terms, filter_options)) + + def _BuildAddressBook(self, zone, address): + """Create the address book configuration entries. + + Args: + zone: the zone these objects will reside in + address: a naming library address object + """ + if zone not in self.addressbook: + self.addressbook[zone] = collections.OrderedDict() + if address.parent_token not in self.addressbook[zone]: + self.addressbook[zone][address.parent_token] = [] + name = address.parent_token + for ip in self.addressbook[zone][name]: + if str(address) == str(ip[0]): + return + counter = len(self.addressbook[zone][address.parent_token]) + name = '%s_%s' % (name, str(counter)) + self.addressbook[zone][address.parent_token].append((address, name)) + + def _BuildPort(self, ports): + """Transform specified ports into list and ranges. + + Args: + ports: a policy terms list of ports + + Returns: + port_list: list of ports and port ranges + """ + port_list = [] + for i in ports: + if i[0] == i[1]: + port_list.append(str(i[0])) + else: + port_list.append('%s-%s' % (str(i[0]), str(i[1]))) + return port_list + + def __str__(self): + """Render the output of the JuniperSRX policy into config.""" + target = [] + target.append('security {') + target.append(self.INDENT + 'zones {') + for zone in self.addressbook: + target.append(self.INDENT * 2 + 'security-zone ' + zone + ' {') + target.append(self.INDENT * 3 + 'replace: address-book {') + for group in self.addressbook[zone]: + for address, name in self.addressbook[zone][group]: + target.append(self.INDENT * 4 + 'address ' + name + ' ' + + str(address) + ';') + for group in self.addressbook[zone]: + target.append(self.INDENT * 4 + 'address-set ' + group + ' {') + for address, name in self.addressbook[zone][group]: + target.append(self.INDENT * 5 + 'address ' + name + ';') + + target.append(self.INDENT * 4 + '}') + target.append(self.INDENT * 3 + '}') + target.append(self.INDENT * 2 + '}') + target.append(self.INDENT + '}') + + target.append(self.INDENT + 'replace: policies {') + + target.append(self.INDENT * 2 + '/*') + target.extend(aclgenerator.AddRepositoryTags(self.INDENT * 2)) + target.append(self.INDENT * 2 + '*/') + + for (_, terms, filter_options) in self.srx_policies: + target.append(self.INDENT * 2 + 'from-zone ' + filter_options[1] + + ' to-zone ' + filter_options[3] + ' {') + for term in terms: + target.append(str(term)) + target.append(self.INDENT * 2 +'}') + target.append(self.INDENT + '}') + target.append('}') + + # APPLICATIONS + target.append('replace: applications {') + done_apps = [] + for app in self.applications: + app_list = [] + if app in done_apps: + continue + if app['protocol'] or app['sport'] or app['dport'] or app['icmp-type']: + if app['icmp-type']: + target.append(self.INDENT + 'application ' + app['name'] + '-app {') + if app['timeout']: + timeout = app['timeout'] + else: + timeout = 60 + for i, code in enumerate(app['icmp-type']): + target.append( + self.INDENT * 2 + + 'term t%d protocol icmp icmp-type %s inactivity-timeout %d;' % + (i+1, str(code), int(timeout))) + else: + i = 1 + target.append(self.INDENT + + 'application-set ' + app['name'] + '-app {') + + for proto in (app['protocol'] or ['']): + for sport in (app['sport'] or ['']): + for dport in (app['dport'] or ['']): + chunks = [] + if proto: chunks.append(' protocol %s' % proto) + if sport: chunks.append(' source-port %s' % sport) + if dport: chunks.append(' destination-port %s' % dport) + if app['timeout']: + chunks.append(' inactivity-timeout %d' % int(app['timeout'])) + if chunks: + target.append(self.INDENT * 2 + + 'application ' + app['name'] + '-app%d;' % i) + app_list.append(self.INDENT + 'application ' + app['name'] + + '-app%d {' % i) + app_list.append(self.INDENT * 2 + 'term t%d' % i + + ''.join(chunks) + ';') + app_list.append(self.INDENT + '}') + i += 1 + target.append(self.INDENT + '}') + done_apps.append(app) + if app_list: + target.extend(app_list) + + target.append('}') + return '\n'.join(target) diff --git a/lib/nacaddr.py b/lib/nacaddr.py new file mode 100644 index 0000000..fc06f17 --- /dev/null +++ b/lib/nacaddr.py @@ -0,0 +1,250 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A subclass of the ipaddr library that includes comments for ipaddr objects.""" + +__author__ = 'watson@google.com (Tony Watson)' + +from third_party import ipaddr + +def IP(ipaddress, comment='', token=''): + """Take an ip string and return an object of the correct type. + + Args: + ip_string: the ip address. + comment:: option comment field + token:: option token name where this address was extracted from + + Returns: + ipaddr.IPv4 or ipaddr.IPv6 object or raises ValueError. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 address. + + Notes: + this is sort of a poor-mans factory method. + """ + a = ipaddr.IPNetwork(ipaddress) + if a.version == 4: + return IPv4(ipaddress, comment, token) + elif a.version == 6: + return IPv6(ipaddress, comment, token) + +class IPv4(ipaddr.IPv4Network): + """This subclass allows us to keep text comments related to each object.""" + + def __init__(self, ip_string, comment='', token=''): + ipaddr.IPv4Network.__init__(self, ip_string) + self.text = comment + self.token = token + self.parent_token = token + + def AddComment(self, comment=''): + """Append comment to self.text, comma seperated. + + Don't add the comment if it's the same as self.text. + + Args: comment + """ + if self.text: + if comment and comment not in self.text: + self.text += ', ' + comment + else: + self.text = comment + + def supernet(self, prefixlen_diff=1): + """Override ipaddr.IPv4 supernet so we can maintain comments. + + See ipaddr.IPv4.Supernet for complete documentation. + """ + if self.prefixlen == 0: + return self + if self.prefixlen - prefixlen_diff < 0: + raise PrefixlenDiffInvalidError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % ( + self.prefixlen, prefixlen_diff)) + ret_addr = IPv4(ipaddr.IPv4Network.supernet(self, prefixlen_diff), + comment=self.text, token=self.token) + return ret_addr + + # Backwards compatibility name from v1. + Supernet = supernet + + +class IPv6(ipaddr.IPv6Network): + """This subclass allows us to keep text comments related to each object.""" + + def __init__(self, ip_string, comment='', token=''): + ipaddr.IPv6Network.__init__(self, ip_string) + self.text = comment + self.token = token + self.parent_token = token + + def supernet(self, prefixlen_diff=1): + """Override ipaddr.IPv6Network supernet so we can maintain comments. + + See ipaddr.IPv6Network.Supernet for complete documentation. + """ + if self.prefixlen == 0: + return self + if self.prefixlen - prefixlen_diff < 0: + raise PrefixlenDiffInvalidError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % ( + self.prefixlen, prefixlen_diff)) + ret_addr = IPv6(ipaddr.IPv6Network.supernet(self, prefixlen_diff), + comment=self.text, token=self.token) + return ret_addr + + # Backwards compatibility name from v1. + Supernet = supernet + + def AddComment(self, comment=''): + """Append comment to self.text, comma seperated. + + Don't add the comment if it's the same as self.text. + + Args: comment + """ + if self.text: + if comment and comment not in self.text: + self.text += ', ' + comment + else: + self.text = comment + + +def CollapseAddrListRecursive(addresses): + """Recursively loops through the addresses, collapsing concurent netblocks. + + Example: + + ip1 = ipaddr.IPv4Network('1.1.0.0/24') + ip2 = ipaddr.IPv4Network('1.1.1.0/24') + ip3 = ipaddr.IPv4Network('1.1.2.0/24') + ip4 = ipaddr.IPv4Network('1.1.3.0/24') + ip5 = ipaddr.IPv4Network('1.1.4.0/24') + ip6 = ipaddr.IPv4Network('1.1.0.1/22') + + CollapseAddrRecursive([ip1, ip2, ip3, ip4, ip5, ip6]) -> + [IPv4Network('1.1.0.0/22'), IPv4Network('1.1.4.0/24')] + + Note, this shouldn't be called directly, but is called via + CollapseAddr([]) + + Args: + addresses: List of IPv4 or IPv6 objects + + Returns: + List of IPv4 or IPv6 objects (depending on what we were passed) + """ + ret_array = [] + optimized = False + + for cur_addr in addresses: + if not ret_array: + ret_array.append(cur_addr) + continue + if ret_array[-1].Contains(cur_addr): + # save the comment from the subsumed address + ret_array[-1].AddComment(cur_addr.text) + optimized = True + elif cur_addr == ret_array[-1].Supernet().Subnet()[1]: + ret_array.append(ret_array.pop().Supernet()) + # save the text from the subsumed address + ret_array[-1].AddComment(cur_addr.text) + optimized = True + else: + ret_array.append(cur_addr) + + if optimized: + return CollapseAddrListRecursive(ret_array) + return ret_array + + +def CollapseAddrList(addresses): + """Collapse an array of IP objects. + + Example: CollapseAddr( + [IPv4('1.1.0.0/24'), IPv4('1.1.1.0/24')]) -> [IPv4('1.1.0.0/23')] + Note: this works just as well with IPv6 addresses too. + + Args: + addresses: list of ipaddr.IPNetwork objects + + Returns: + list of ipaddr.IPNetwork objects + """ + return CollapseAddrListRecursive( + sorted(addresses, key=ipaddr._BaseNet._get_networks_key)) + + +def SortAddrList(addresses): + """Return a sorted list of nacaddr objects.""" + return sorted(addresses, key=ipaddr._BaseNet._get_networks_key) + + +def RemoveAddressFromList(superset, exclude): + """Remove a single address from a list of addresses. + + Args: + superset: a List of nacaddr IPv4 or IPv6 addresses + exclude: a single nacaddr IPv4 or IPv6 address + + Returns: + a List of nacaddr IPv4 or IPv6 addresses + """ + ret_array = [] + for addr in superset: + if exclude == addr or addr in exclude: + # this is a bug in ipaddr v1. IP('1.1.1.1').AddressExclude(IP('1.1.1.1')) + # raises an error. Not tested in v2 yet. + pass + elif exclude.version == addr.version and exclude in addr: + ret_array.extend([IP(x) for x in addr.AddressExclude(exclude)]) + else: + ret_array.append(addr) + return ret_array + + +def AddressListExclude(superset, excludes): + """Remove a list of addresses from another list of addresses. + + Args: + superset: a List of nacaddr IPv4 or IPv6 addresses + excludes: a List nacaddr IPv4 or IPv6 addresses + + Returns: + a List of nacaddr IPv4 or IPv6 addresses + """ + superset = CollapseAddrList(superset) + excludes = CollapseAddrList(excludes) + + ret_array = [] + + for ex in excludes: + superset = RemoveAddressFromList(superset, ex) + return CollapseAddrList(superset) + + +ExcludeAddrs = AddressListExclude + + +class PrefixlenDiffInvalidError(ipaddr.NetmaskValueError): + """Holdover from ipaddr v1.""" + + +if __name__ == '__main__': + pass diff --git a/lib/naming.py b/lib/naming.py new file mode 100644 index 0000000..40196bc --- /dev/null +++ b/lib/naming.py @@ -0,0 +1,502 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Parse naming definition files. + +Network access control applications use definition files which contain +information about networks and services. This naming class +will provide an easy interface into using these definitions. + +Sample usage with definition files contained in ./acl/defs: + defs = Naming('acl/defs/') + + services = defs.GetService('DNS') + returns ['53/tcp', '53/udp', ...] + + networks = defs.GetNet('INTERNAL') + returns a list of nacaddr.IPv4 object + +The definition files are contained in a single directory and +may consist of multiple files ending in .net or .svc extensions, +indicating network or service definitions respectively. The +format of the files consists of a 'token' value, followed by a +list of values and optional comments, such as: + +INTERNAL = 10.0.0.0/8 # RFC-1918 + 172.16.0.0/12 # RFC-1918 + 192.168.0.0/16 # RFC-1918 +or + +DNS = 53/tcp + 53/udp + +""" + +__author__ = 'watson@google.com (Tony Watson)' + +import glob + +import nacaddr + + +class Error(Exception): + """Create our own base error class to be inherited by other error classes.""" + + +class NamespaceCollisionError(Error): + """Used to report on duplicate symbol names found while parsing.""" + + +class BadNetmaskTypeError(Error): + """Used to report on duplicate symbol names found while parsing.""" + + +class NoDefinitionsError(Error): + """Raised if no definitions are found.""" + + +class ParseError(Error): + """Raised if an error occurs during parsing.""" + + +class UndefinedAddressError(Error): + """Raised if an address is referenced but not defined.""" + + +class UndefinedServiceError(Error): + """Raised if a service is referenced but not defined.""" + + +class UnexpectedDefinitionType(Error): + """An unexpected/unknown definition type was used.""" + + +class _ItemUnit(object): + """This class is a container for an index key and a list of associated values. + + An ItemUnit will contain the name of either a service or network group, + and a list of the associated values separated by spaces. + + Attributes: + name: A string representing a unique token value. + items: a list of strings containing values for the token. + """ + + def __init__(self, symbol): + self.name = symbol + self.items = [] + + +class Naming(object): + """Object to hold naming objects from NETWORK and SERVICES definition files. + + Attributes: + current_symbol: The current token being handled while parsing data. + services: A collection of all of the current service item tokens. + networks: A collection of all the current network item tokens. + """ + + def __init__(self, naming_dir=None, naming_file=None, naming_type=None): + """Set the default values for a new Naming object.""" + self.current_symbol = None + self.services = {} + self.networks = {} + self.unseen_services = {} + self.unseen_networks = {} + if naming_file and naming_type: + filename = os.path.sep.join([naming_dir, naming_file]) + file_handle = gfile.GFile(filename, 'r') + self._ParseFile(file_handle, naming_type) + elif naming_dir: + self._Parse(naming_dir, 'services') + self._CheckUnseen('services') + + self._Parse(naming_dir, 'networks') + self._CheckUnseen('networks') + + def _CheckUnseen(self, def_type): + if def_type == 'services': + if self.unseen_services: + raise UndefinedServiceError('%s %s' % ( + 'The following tokens were nested as a values, but not defined', + self.unseen_services)) + if def_type == 'networks': + if self.unseen_networks: + raise UndefinedAddressError('%s %s' % ( + 'The following tokens were nested as a values, but not defined', + self.unseen_networks)) + + def GetIpParents(self, query): + """Return network tokens that contain IP in query. + + Args: + query: an ip string ('10.1.1.1') or nacaddr.IP object + """ + base_parents = [] + recursive_parents = [] + # convert string to nacaddr, if arg is ipaddr then convert str() to nacaddr + if type(query) != nacaddr.IPv4 and type(query) != nacaddr.IPv6: + if query[:1].isdigit(): + query = nacaddr.IP(query) + # Get parent token for an IP + if type(query) == nacaddr.IPv4 or type(query) == nacaddr.IPv6: + for token in self.networks: + for item in self.networks[token].items: + item = item.split('#')[0].strip() + if item[:1].isdigit() and nacaddr.IP(item).Contains(query): + base_parents.append(token) + # Get parent token for another token + else: + for token in self.networks: + for item in self.networks[token].items: + item = item.split('#')[0].strip() + if item[:1].isalpha() and item == query: + base_parents.append(token) + # look for nested tokens + for bp in base_parents: + done = False + for token in self.networks: + if bp in self.networks[token].items: + # ignore IPs, only look at token values + if bp[:1].isalpha(): + if bp not in recursive_parents: + recursive_parents.append(bp) + recursive_parents.extend(self.GetIpParents(bp)) + done = True + # if no nested tokens, just append value + if not done: + if bp[:1].isalpha() and bp not in recursive_parents: + recursive_parents.append(bp) + return sorted(list(set(recursive_parents))) + + def GetServiceParents(self, query): + """Given a query token, return list of services definitions with that token. + + Args: + query: a service token name. + """ + return self._GetParents(query, self.services) + + def GetNetParents(self, query): + """Given a query token, return list of network definitions with that token. + + Args: + query: a network token name. + """ + return self._GetParents(query, self.networks) + + def _GetParents(self, query, query_group): + """Given a naming item dict, return any tokens containing the value. + + Args: + query: a service or token name, such as 53/tcp or DNS + query_group: either services or networks dict + """ + base_parents = [] + recursive_parents = [] + # collect list of tokens containing query + for token in query_group: + if query in query_group[token].items: + base_parents.append(token) + if not base_parents: + return [] + # iterate through tokens containing query, doing recursion if necessary + for bp in base_parents: + for token in query_group: + if bp in query_group[token].items and bp not in recursive_parents: + recursive_parents.append(bp) + recursive_parents.extend(self._GetParents(bp, query_group)) + if bp not in recursive_parents: + recursive_parents.append(bp) + return recursive_parents + + def GetService(self, query): + """Given a service name, return a list of associated ports and protocols. + + Args: + query: Service name symbol or token. + + Returns: + A list of service values such as ['80/tcp', '443/tcp', '161/udp', ...] + + Raises: + UndefinedServiceError: If the service name isn't defined. + """ + expandset = set() + already_done = set() + data = [] + service_name = '' + data = query.split('#') # Get the token keyword and remove any comment + service_name = data[0].split()[0] # strip and cast from list to string + if service_name not in self.services: + raise UndefinedServiceError('\nNo such service: %s' % query) + + already_done.add(service_name) + + for next_item in self.services[service_name].items: + # Remove any trailing comment. + service = next_item.split('#')[0].strip() + # Recognized token, not a value. + if not '/' in service: + # Make sure we are not descending into recursion hell. + if service not in already_done: + already_done.add(service) + try: + expandset.update(self.GetService(service)) + except UndefinedServiceError as e: + # One of the services in query is undefined, refine the error msg. + raise UndefinedServiceError('%s (in %s)' % (e, query)) + else: + expandset.add(service) + return sorted(expandset) + + def GetServiceByProto(self, query, proto): + """Given a service name, return list of ports in the service by protocol. + + Args: + query: Service name to lookup. + proto: A particular protocol to restrict results by, such as 'tcp'. + + Returns: + A list of service values of type 'proto', such as ['80', '443', ...] + + Raises: + UndefinedServiceError: If the service name isn't defined. + """ + services_set = set() + proto = proto.upper() + data = [] + servicename = '' + data = query.split('#') # Get the token keyword and remove any comment + servicename = data[0].split()[0] # strip and cast from list to string + if servicename not in self.services: + raise UndefinedServiceError('%s %s' % ('\nNo such service,', servicename)) + + for service in self.GetService(servicename): + if service and '/' in service: + parts = service.split('/') + if parts[1].upper() == proto: + services_set.add(parts[0]) + return sorted(services_set) + + def GetNetAddr(self, token): + """Given a network token, return a list of netaddr.IPv4 objects. + + Args: + token: A name of a network definition, such as 'INTERNAL' + + Returns: + A list of netaddr.IPv4 objects. + + Raises: + UndefinedAddressError: if the network name isn't defined. + """ + return self.GetNet(token) + + def GetNet(self, query): + """Expand a network token into a list of nacaddr.IPv4 objects. + + Args: + query: Network definition token which may include comment text + + Raises: + BadNetmaskTypeError: Results when an unknown netmask_type is + specified. Acceptable values are 'cidr', 'netmask', and 'hostmask'. + + Returns: + List of nacaddr.IPv4 objects + + Raises: + UndefinedAddressError: for an undefined token value + """ + returnlist = [] + data = [] + token = '' + data = query.split('#') # Get the token keyword and remove any comment + token = data[0].split()[0] # Remove whitespace and cast from list to string + if token not in self.networks: + raise UndefinedAddressError('%s %s' % ('\nUNDEFINED:', str(token))) + + for next in self.networks[token].items: + comment = '' + if next.find('#') > -1: + (net, comment) = next.split('#', 1) + else: + net = next + try: + net = net.strip() + addr = nacaddr.IP(net) + # we want to make sure that we're storing the network addresses + # ie, FOO = 192.168.1.1/24 should actually return 192.168.1.0/24 + if addr.ip != addr.network: + addr = nacaddr.IP('%s/%d' % (addr.network, addr.prefixlen)) + + addr.text = comment.lstrip() + addr.token = token + returnlist.append(addr) + except ValueError: + # if net was something like 'FOO', or the name of another token which + # needs to be dereferenced, nacaddr.IP() will return a ValueError + returnlist.extend(self.GetNet(net)) + for next in returnlist: + next.parent_token = token + return returnlist + + def _Parse(self, defdirectory, def_type): + """Parse files of a particular type for tokens and values. + + Given a directory name and the type (services|networks) to + process, grab all the appropriate files in that directory + and parse them for definitions. + + Args: + defdirectory: Path to directory containing definition files. + def_type: Type of definitions to parse + + Raises: + NoDefinitionsError: if no definitions are found. + """ + file_names = [] + get_files = {'services': lambda: glob.glob(defdirectory + '/*.svc'), + 'networks': lambda: glob.glob(defdirectory + '/*.net')} + + if def_type in get_files: + file_names = get_files[def_type]() + else: + raise NoDefinitionsError('Unknown definitions type.') + if not file_names: + raise NoDefinitionsError('No definition files for %s in %s found.' % + (def_type, defdirectory)) + + for current_file in file_names: + try: + file_handle = open(current_file, 'r').readlines() + for line in file_handle: + self._ParseLine(line, def_type) + except IOError as error_info: + raise NoDefinitionsError('%s', error_info) + + def _ParseFile(self, file_handle, def_type): + for line in file_handle: + self._ParseLine(line, def_type) + + def ParseServiceList(self, data): + """Take an array of service data and import into class. + + This method allows us to pass an array of data that contains service + definitions that are appended to any definitions read from files. + + Args: + data: array of text lines containing service definitions. + """ + for line in data: + self._ParseLine(line, 'services') + + def ParseNetworkList(self, data): + """Take an array of network data and import into class. + + This method allows us to pass an array of data that contains network + definitions that are appended to any definitions read from files. + + Args: + data: array of text lines containing net definitions. + + """ + for line in data: + self._ParseLine(line, 'networks') + + def _ParseLine(self, line, definition_type): + """Parse a single line of a service definition file. + + This routine is used to parse a single line of a service + definition file, building a list of 'self.services' objects + as each line of the file is iterated through. + + Args: + line: A single line from a service definition files. + definition_type: Either 'networks' or 'services' + + Raises: + UnexpectedDefinitionType: when called with unexpected type of defintions + NamespaceCollisionError: when overlapping tokens are found. + ParseError: If errors occur + """ + if definition_type not in ['services', 'networks']: + raise UnexpectedDefinitionType('%s %s' % ( + 'Received an unexpected defintion type:', definition_type)) + line = line.strip() + if not line or line.startswith('#'): # Skip comments and blanks. + return + comment = '' + if line.find('#') > -1: # if there is a comment, save it + (line, comment) = line.split('#', 1) + line_parts = line.split('=') # Split on var = val lines. + # the value field still has the comment at this point + # If there was '=', then do var and value + if len(line_parts) > 1: + self.current_symbol = line_parts[0].strip() # varname left of '=' + if definition_type == 'services': + if self.current_symbol in self.services: + raise NamespaceCollisionError('%s %s' % ( + '\nMultiple definitions found for service: ', + self.current_symbol)) + elif definition_type == 'networks': + if self.current_symbol in self.networks: + raise NamespaceCollisionError('%s %s' % ( + '\nMultiple definitions found for service: ', + self.current_symbol)) + + self.unit = _ItemUnit(self.current_symbol) + if definition_type == 'services': + self.services[self.current_symbol] = self.unit + # unseen_services is a list of service TOKENS found in the values + # of newly defined services, but not previously defined themselves. + # When we define a new service, we should remove it (if it exists) + # from the list of unseen_services. + if self.current_symbol in self.unseen_services: + self.unseen_services.pop(self.current_symbol) + elif definition_type == 'networks': + self.networks[self.current_symbol] = self.unit + if self.current_symbol in self.unseen_networks: + self.unseen_networks.pop(self.current_symbol) + else: + raise ParseError('Unknown definitions type.') + values = line_parts[1] + # No '=', so this is a value only line + else: + values = line_parts[0] # values for previous var are continued this line + for value_piece in values.split(): + if not value_piece: + continue + if not self.current_symbol: + break + if comment: + self.unit.items.append(value_piece + ' # ' + comment) + else: + self.unit.items.append(value_piece) + # token? + if value_piece[0].isalpha() and ':' not in value_piece: + if definition_type == 'services': + # already in top definitions list? + if value_piece not in self.services: + # already have it as an unused value? + if value_piece not in self.unseen_services: + self.unseen_services[value_piece] = True + if definition_type == 'networks': + if value_piece not in self.networks: + if value_piece not in self.unseen_networks: + self.unseen_networks[value_piece] = True diff --git a/lib/packetfilter.py b/lib/packetfilter.py new file mode 100644 index 0000000..c9742b9 --- /dev/null +++ b/lib/packetfilter.py @@ -0,0 +1,348 @@ +#!/usr/bin/python +# +# Copyright 2012 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""PacketFilter (PF) generator.""" + +__author__ = 'msu@google.com (Martin Suess)' + +import aclgenerator +import datetime +import logging + + +class Error(Exception): + """Base error class.""" + + +class UnsupportedActionError(Error): + """Raised when we see an unsupported action.""" + + +class UnsupportedTargetOption(Error): + """Raised when we see an unsupported option.""" + + +class Term(aclgenerator.Term): + """Generate PacketFilter policy terms.""" + + # Validate that term does not contain any fields we do not + # support. This prevents us from thinking that our output is + # correct in cases where we've omitted fields from term. + _PLATFORM = 'packetfilter' + _ACTION_TABLE = { + 'accept': 'pass', + 'deny': 'block drop', + 'reject': 'block return', + } + _TCP_FLAGS_TABLE = { + 'syn': 'S', + 'ack': 'A', + 'fin': 'F', + 'rst': 'R', + 'urg': 'U', + 'psh': 'P', + 'all': 'ALL', + 'none': 'NONE', + } + + def __init__(self, term, filter_name, af='inet'): + """Setup a new term. + + Args: + term: A policy.Term object to represent in packetfilter. + filter_name: The name of the filter chan to attach the term to. + af: Which address family ('inet' or 'inet6') to apply the term to. + + Raises: + aclgenerator.UnsupportedFilterError: Filter is not supported. + """ + self.term = term # term object + self.filter = filter_name # actual name of filter + self.options = [] + self.default_action = 'deny' + self.af = af + + def __str__(self): + """Render config output from this term object.""" + ret_str = [] + + # Create a new term + ret_str.append('\n# term %s' % self.term.name) + # append comments to output + for line in self.term.comment: + if not line: + continue + ret_str.append('# %s' % str(line)) + + # if terms does not specify action, use filter default action + if not self.term.action: + self.term.action[0].value = self.default_action + if str(self.term.action[0]) not in self._ACTION_TABLE: + raise aclgenerator.UnsupportedFilterError('%s %s %s %s' % ( + '\n', self.term.name, self.term.action[0], + 'action not currently supported.')) + + # protocol + if self.term.protocol: + protocol = self.term.protocol + else: + protocol = [] + if self.term.protocol_except: + raise aclgenerator.UnsupportedFilterError('%s %s %s' % ( + '\n', self.term.name, + 'protocol_except logic not currently supported.')) + + # source address + term_saddrs = self._CheckAddressAf(self.term.source_address) + if not term_saddrs: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='source', + af=self.af)) + return '' + term_saddr = self._GenerateAddrStatement( + term_saddrs, self.term.source_address_exclude) + + # destination address + term_daddrs = self._CheckAddressAf(self.term.destination_address) + if not term_daddrs: + logging.warn(self.NO_AF_LOG_FORMAT.substitute(term=self.term.name, + direction='destination', + af=self.af)) + return '' + term_daddr = self._GenerateAddrStatement( + term_daddrs, self.term.destination_address_exclude) + + # ports + source_port = [] + destination_port = [] + if self.term.source_port: + source_port = self._GeneratePortStatement(self.term.source_port) + if self.term.destination_port: + destination_port = self._GeneratePortStatement(self.term.destination_port) + + # icmp-type + icmp_types = [''] + if self.term.icmp_type: + if self.af != 'mixed': + af = self.af + elif protocol == ['icmp']: + af = 'inet' + elif protocol == ['icmp6']: + af = 'inet6' + else: + raise aclgenerator.UnsupportedFilterError('%s %s %s' % ( + '\n', self.term.name, + 'icmp protocol is not defined or not supported.')) + icmp_types = self.NormalizeIcmpTypes( + self.term.icmp_type, protocol, af) + + # options + opts = [str(x) for x in self.term.option] + tcp_flags = [] + for next_opt in opts: + # Iterate through flags table, and create list of tcp-flags to append + for next_flag in self._TCP_FLAGS_TABLE: + if next_opt.find(next_flag) == 0: + tcp_flags.append(self._TCP_FLAGS_TABLE.get(next_flag)) + + ret_str.extend(self._FormatPart( + self._ACTION_TABLE.get(str(self.term.action[0])), + self.term.logging, + self.af, + protocol, + term_saddr, + source_port, + term_daddr, + destination_port, + tcp_flags, + icmp_types, + self.options, + )) + + return '\n'.join(str(v) for v in ret_str if v is not '') + + def _CheckAddressAf(self, addrs): + """Verify that the requested address-family matches the address's family.""" + if not addrs: + return ['any'] + if self.af == 'mixed': + return addrs + af_addrs = [] + af = self.NormalizeAddressFamily(self.af) + for addr in addrs: + if addr.version == af: + af_addrs.append(addr) + return af_addrs + + def _FormatPart(self, action, log, af, proto, src_addr, src_port, + dst_addr, dst_port, tcp_flags, icmp_types, options): + """Format the string which will become a single PF entry.""" + line = ['%s' % action] + if log and 'true' in [str(l) for l in log]: + line.append('log') + + line.append('quick') + if af != 'mixed': + line.append(af) + + if proto: + line.append(self._GenerateProtoStatement(proto)) + + line.append('from %s' % src_addr) + if src_port: + line.append('port %s' % src_port) + + line.append('to %s' % dst_addr) + if dst_port: + line.append('port %s' % dst_port) + + if 'tcp' in proto and tcp_flags: + line.append('flags') + line.append('/'.join(tcp_flags)) + + if 'icmp' in proto and icmp_types: + type_strs = [str(icmp_type) for icmp_type in icmp_types] + type_strs = ', '.join(type_strs) + if type_strs: + line.append('icmp-type { %s }' % type_strs) + + if options: + line.extend(options) + + return [' '.join(line)] + + def _GenerateProtoStatement(self, protocols): + proto = '' + if protocols: + proto = 'proto { %s }' % ' '.join(protocols) + return proto + + def _GenerateAddrStatement(self, addrs, exclude_addrs): + addresses = [str(addr) for addr in addrs] + for exclude_addr in exclude_addrs: + addresses.append('!%s' % str(exclude_addr)) + return '{ %s }' % ', '.join(addresses) + + def _GeneratePortStatement(self, ports): + port_list = [] + for port_tuple in ports: + for port in port_tuple: + port_list.append(str(port)) + return '{ %s }' % ' '.join(list(set(port_list))) + + +class PacketFilter(aclgenerator.ACLGenerator): + """Generates filters and terms from provided policy object.""" + + _PLATFORM = 'packetfilter' + _DEFAULT_PROTOCOL = 'all' + _SUFFIX = '.pf' + _TERM = Term + _OPTIONAL_SUPPORTED_KEYWORDS = set(['expiration', + 'logging', + 'routing_instance', + ]) + + def _TranslatePolicy(self, pol, exp_info): + self.pf_policies = [] + current_date = datetime.date.today() + exp_info_date = current_date + datetime.timedelta(weeks=exp_info) + + good_afs = ['inet', 'inet6', 'mixed'] + good_options = [] + filter_type = None + + for header, terms in pol.filters: + if self._PLATFORM not in header.platforms: + continue + + filter_options = header.FilterOptions(self._PLATFORM)[1:] + filter_name = header.FilterName(self._PLATFORM) + + # ensure all options after the filter name are expected + for opt in filter_options: + if opt not in good_afs + good_options: + raise UnsupportedTargetOption('%s %s %s %s' % ( + '\nUnsupported option found in', self._PLATFORM, + 'target definition:', opt)) + + # Check for matching af + for address_family in good_afs: + if address_family in filter_options: + # should not specify more than one AF in options + if filter_type is not None: + raise aclgenerator.UnsupportedFilterError('%s %s %s %s' % ( + '\nMay only specify one of', good_afs, 'in filter options:', + filter_options)) + filter_type = address_family + if filter_type is None: + filter_type = 'mixed' + + # add the terms + new_terms = [] + term_names = set() + for term in terms: + term.name = self.FixTermLength(term.name) + if term.name in term_names: + raise aclgenerator.DuplicateTermError( + 'You have a duplicate term: %s' % term.name) + term_names.add(term.name) + + if not term: + continue + + if term.expiration: + if term.expiration <= exp_info_date: + logging.info('INFO: Term %s in policy %s expires ' + 'in less than two weeks.', term.name, filter_name) + if term.expiration <= current_date: + logging.warn('WARNING: Term %s in policy %s is expired and ' + 'will not be rendered.', term.name, filter_name) + continue + + new_terms.append(self._TERM(term, filter_name, filter_type)) + + self.pf_policies.append((header, filter_name, filter_type, new_terms)) + + def __str__(self): + """Render the output of the PF policy into config.""" + target = [] + pretty_platform = '%s%s' % (self._PLATFORM[0].upper(), self._PLATFORM[1:]) + + for (header, filter_name, filter_type, terms) in self.pf_policies: + # Add comments for this filter + target.append('# %s %s Policy' % (pretty_platform, + header.FilterName(self._PLATFORM))) + + # reformat long text comments, if needed + comments = aclgenerator.WrapWords(header.comment, 70) + if comments and comments[0]: + for line in comments: + target.append('# %s' % line) + target.append('#') + # add the p4 tags + target.extend(aclgenerator.AddRepositoryTags('# ')) + target.append('# ' + filter_type) + + # add the terms + for term in terms: + term_str = str(term) + if term_str: + target.append(term_str) + target.append('') + + return '\n'.join(target) diff --git a/lib/policy.py b/lib/policy.py new file mode 100644 index 0000000..a6b8ad6 --- /dev/null +++ b/lib/policy.py @@ -0,0 +1,1821 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Parses the generic policy files and return a policy object for acl rendering. +""" + +import datetime +import os +import sys + +import logging +import nacaddr +import naming + +from third_party.ply import lex +from third_party.ply import yacc + + +DEFINITIONS = None +DEFAULT_DEFINITIONS = './def' +_ACTIONS = set(('accept', 'deny', 'reject', 'next', 'reject-with-tcp-rst')) +_LOGGING = set(('true', 'True', 'syslog', 'local', 'disable')) +_OPTIMIZE = True +_SHADE_CHECK = False + + +class Error(Exception): + """Generic error class.""" + + +class FileNotFoundError(Error): + """Policy file unable to be read.""" + + +class FileReadError(Error): + """Policy file unable to be read.""" + + +class RecursionTooDeepError(Error): + """Included files exceed maximum recursion depth.""" + + +class ParseError(Error): + """ParseError in the input.""" + + +class TermAddressExclusionError(Error): + """Excluded address block is not contained in the accepted address block.""" + + +class TermObjectTypeError(Error): + """Error with an object passed to Term.""" + + +class TermPortProtocolError(Error): + """Error when a requested protocol doesn't have any of the requested ports.""" + + +class TermProtocolEtherTypeError(Error): + """Error when both ether-type & upper-layer protocol matches are requested.""" + + +class TermNoActionError(Error): + """Error when a term hasn't defined an action.""" + + +class TermInvalidIcmpType(Error): + """Error when a term has invalid icmp-types specified.""" + + +class InvalidTermActionError(Error): + """Error when an action is invalid.""" + + +class InvalidTermLoggingError(Error): + """Error when a option is set for logging.""" + + +class UndefinedAddressError(Error): + """Error when an undefined address is referenced.""" + + +class NoTermsError(Error): + """Error when no terms were found.""" + + +class ShadingError(Error): + """Error when a term is shaded by a prior term.""" + + +def TranslatePorts(ports, protocols, term_name): + """Return all ports of all protocols requested. + + Args: + ports: list of ports, eg ['SMTP', 'DNS', 'HIGH_PORTS'] + protocols: list of protocols, eg ['tcp', 'udp'] + term_name: name of current term, used for warning messages + + Returns: + ret_array: list of ports tuples such as [(25,25), (53,53), (1024,65535)] + + Note: + Duplication will be taken care of in Term.CollapsePortList + """ + ret_array = [] + for proto in protocols: + for port in ports: + service_by_proto = DEFINITIONS.GetServiceByProto(port, proto) + if not service_by_proto: + logging.warn('%s %s %s %s %s %s%s %s', 'Term', term_name, + 'has service', port, 'which is not defined with protocol', + proto, + ', but will be permitted. Unless intended, you should', + 'consider splitting the protocols into separate terms!') + + for p in [x.split('-') for x in service_by_proto]: + if len(p) == 1: + ret_array.append((int(p[0]), int(p[0]))) + else: + ret_array.append((int(p[0]), int(p[1]))) + return ret_array + + +# classes for storing the object types in the policy files. +class Policy(object): + """The policy object contains everything found in a given policy file.""" + + def __init__(self, header, terms): + """Initiator for the Policy object. + + Args: + header: __main__.Header object. contains comments which should be passed + on to the rendered acls as well as the type of acls this policy file + should render to. + + terms: list __main__.Term. an array of Term objects which must be rendered + in each of the rendered acls. + + Attributes: + filters: list of tuples containing (header, terms). + """ + self.filters = [] + self.AddFilter(header, terms) + + def AddFilter(self, header, terms): + """Add another header & filter.""" + self.filters.append((header, terms)) + self._TranslateTerms(terms) + if _SHADE_CHECK: + self._DetectShading(terms) + + def _TranslateTerms(self, terms): + """.""" + if not terms: + raise NoTermsError('no terms found') + for term in terms: + # TODO(pmoody): this probably belongs in Term.SanityCheck(), + # or at the very least, in some method under class Term() + if term.translated: + continue + if term.port: + term.port = TranslatePorts(term.port, term.protocol, term.name) + if not term.port: + raise TermPortProtocolError( + 'no ports of the correct protocol for term %s' % ( + term.name)) + if term.source_port: + term.source_port = TranslatePorts(term.source_port, term.protocol, + term.name) + if not term.source_port: + raise TermPortProtocolError( + 'no source ports of the correct protocol for term %s' % ( + term.name)) + if term.destination_port: + term.destination_port = TranslatePorts(term.destination_port, + term.protocol, term.name) + if not term.destination_port: + raise TermPortProtocolError( + 'no destination ports of the correct protocol for term %s' % ( + term.name)) + + # If argument is true, we optimize, otherwise just sort addresses + term.AddressCleanup(_OPTIMIZE) + # Reset _OPTIMIZE global to default value + globals()['_OPTIMIZE'] = True + term.SanityCheck() + term.translated = True + + @property + def headers(self): + """Returns the headers from each of the configured filters. + + Returns: + headers + """ + return [x[0] for x in self.filters] + + def _DetectShading(self, terms): + """Finds terms which are shaded (impossible to reach). + + Iterate through each term, looking at each prior term. If a prior term + contains every component of the current term then the current term would + never be hit and is thus shaded. This can be a mistake. + + Args: + terms: list of Term objects. + + Raises: + ShadingError: When a term is impossible to reach. + """ + # Reset _OPTIMIZE global to default value + globals()['_SHADE_CHECK'] = False + shading_errors = [] + for index, term in enumerate(terms): + for prior_index in xrange(index): + # Check each term that came before for shading. Terms with next as an + # action do not terminate evaluation, so cannot shade. + if (term in terms[prior_index] + and 'next' not in terms[prior_index].action): + shading_errors.append( + ' %s is shaded by %s.' % ( + term.name, terms[prior_index].name)) + if shading_errors: + raise ShadingError('\n'.join(shading_errors)) + + +class Term(object): + """The Term object is used to store each of the terms. + + Args: + obj: an object of type VarType or a list of objects of type VarType + + members: + address/source_address/destination_address/: list of + VarType.(S|D)?ADDRESS's + address_exclude/source_address_exclude/destination_address_exclude: list of + VarType.(S|D)?ADDEXCLUDE's + port/source_port/destination_port: list of VarType.(S|D)?PORT's + options: list of VarType.OPTION's. + protocol: list of VarType.PROTOCOL's. + counter: VarType.COUNTER + action: list of VarType.ACTION's + comments: VarType.COMMENT + expiration: VarType.EXPIRATION + verbatim: VarType.VERBATIM + logging: VarType.LOGGING + qos: VarType.QOS + policer: VarType.POLICER + """ + ICMP_TYPE = {4: {'echo-reply': 0, + 'unreachable': 3, + 'source-quench': 4, + 'redirect': 5, + 'alternate-address': 6, + 'echo-request': 8, + 'router-advertisement': 9, + 'router-solicitation': 10, + 'time-exceeded': 11, + 'parameter-problem': 12, + 'timestamp-request': 13, + 'timestamp-reply': 14, + 'information-request': 15, + 'information-reply': 16, + 'mask-request': 17, + 'mask-reply': 18, + 'conversion-error': 31, + 'mobile-redirect': 32, + }, + 6: {'destination-unreachable': 1, + 'packet-too-big': 2, + 'time-exceeded': 3, + 'parameter-problem': 4, + 'echo-request': 128, + 'echo-reply': 129, + 'multicast-listener-query': 130, + 'multicast-listener-report': 131, + 'multicast-listener-done': 132, + 'router-solicit': 133, + 'router-advertisement': 134, + 'neighbor-solicit': 135, + 'neighbor-advertisement': 136, + 'redirect-message': 137, + 'router-renumbering': 138, + 'icmp-node-information-query': 139, + 'icmp-node-information-response': 140, + 'inverse-neighbor-discovery-solicitation': 141, + 'inverse-neighbor-discovery-advertisement': 142, + 'version-2-multicast-listener-report': 143, + 'home-agent-address-discovery-request': 144, + 'home-agent-address-discovery-reply': 145, + 'mobile-prefix-solicitation': 146, + 'mobile-prefix-advertisement': 147, + 'certification-path-solicitation': 148, + 'certification-path-advertisement': 149, + 'multicast-router-advertisement': 151, + 'multicast-router-solicitation': 152, + 'multicast-router-termination': 153, + }, + } + + def __init__(self, obj): + self.name = None + + self.action = [] + self.address = [] + self.address_exclude = [] + self.comment = [] + self.counter = None + self.expiration = None + self.destination_address = [] + self.destination_address_exclude = [] + self.destination_port = [] + self.destination_prefix = [] + self.logging = [] + self.loss_priority = None + self.option = [] + self.owner = None + self.policer = None + self.port = [] + self.precedence = [] + self.principals = [] + self.protocol = [] + self.protocol_except = [] + self.qos = None + self.routing_instance = None + self.source_address = [] + self.source_address_exclude = [] + self.source_port = [] + self.source_prefix = [] + self.verbatim = [] + # juniper specific. + self.packet_length = None + self.fragment_offset = None + self.icmp_type = [] + self.ether_type = [] + self.traffic_type = [] + self.translated = False + # iptables specific + self.source_interface = None + self.destination_interface = None + self.platform = [] + self.platform_exclude = [] + self.timeout = None + self.AddObject(obj) + self.flattened = False + self.flattened_addr = None + self.flattened_saddr = None + self.flattened_daddr = None + + def __contains__(self, other): + """Determine if other term is contained in this term.""" + if self.verbatim or other.verbatim: + # short circuit these + if sorted(self.verbatim) != sorted(other.verbatim): + return False + + # check prototols + # either protocol or protocol-except may be used, not both at the same time. + if self.protocol: + if other.protocol: + if not self.CheckProtocolIsContained(other.protocol, self.protocol): + return False + # this term has protocol, other has protocol_except. + elif other.protocol_except: + return False + else: + # other does not have protocol or protocol_except. since we do other + # cannot be contained in self. + return False + elif self.protocol_except: + if other.protocol_except: + if self.CheckProtocolIsContained( + self.protocol_except, other.protocol_except): + return False + elif other.protocol: + for proto in other.protocol: + if proto in self.protocol_except: + return False + else: + return False + + # combine addresses with exclusions for proper contains comparisons. + if not self.flattened: + self.FlattenAll() + if not other.flattened: + other.FlattenAll() + + # flat 'address' is compared against other flat (saddr|daddr). + # if NONE of these evaluate to True other is not contained. + if not ( + self.CheckAddressIsContained( + self.flattened_addr, other.flattened_addr) + or self.CheckAddressIsContained( + self.flattened_addr, other.flattened_saddr) + or self.CheckAddressIsContained( + self.flattened_addr, other.flattened_daddr)): + return False + + # compare flat address from other to flattened self (saddr|daddr). + if not ( + # other's flat address needs both self saddr & daddr to contain in order + # for the term to be contained. We already compared the flattened_addr + # attributes of both above, which was not contained. + self.CheckAddressIsContained( + other.flattened_addr, self.flattened_saddr) + and self.CheckAddressIsContained( + other.flattened_addr, self.flattened_daddr)): + return False + + # basic saddr/daddr check. + if not ( + self.CheckAddressIsContained( + self.flattened_saddr, other.flattened_saddr)): + return False + if not ( + self.CheckAddressIsContained( + self.flattened_daddr, other.flattened_daddr)): + return False + + if not ( + self.CheckPrincipalsContained( + self.principals, other.principals)): + return False + + # check ports + # like the address directive, the port directive is special in that it can + # be either source or destination. + if self.port: + if not (self.CheckPortIsContained(self.port, other.port) or + self.CheckPortIsContained(self.port, other.sport) or + self.CheckPortIsContained(self.port, other.dport)): + return False + if not self.CheckPortIsContained(self.source_port, other.source_port): + return False + if not self.CheckPortIsContained(self.destination_port, + other.destination_port): + return False + + # prefix lists + if self.source_prefix: + if sorted(self.source_prefix) != sorted(other.source_prefix): + return False + if self.destination_prefix: + if sorted(self.destination_prefix) != sorted( + other.destination_prefix): + return False + + # check precedence + if self.precedence: + if not other.precedence: + return False + for precedence in other.precedence: + if precedence not in self.precedence: + return False + # check various options + if self.option: + if not other.option: + return False + for opt in other.option: + if opt not in self.option: + return False + if self.fragment_offset: + # fragment_offset looks like 'integer-integer' or just, 'integer' + sfo = [int(x) for x in self.fragment_offset.split('-')] + if other.fragment_offset: + ofo = [int(x) for x in other.fragment_offset.split('-')] + if sfo[0] < ofo[0] or sorted(sfo[1:]) > sorted(ofo[1:]): + return False + else: + return False + if self.packet_length: + # packet_length looks like 'integer-integer' or just, 'integer' + spl = [int(x) for x in self.packet_length.split('-')] + if other.packet_length: + opl = [int(x) for x in other.packet_length.split('-')] + if spl[0] < opl[0] or sorted(spl[1:]) > sorted(opl[1:]): + return False + else: + return False + if self.icmp_type: + if sorted(self.icmp_type) is not sorted(other.icmp_type): + return False + + # check platform + if self.platform: + if sorted(self.platform) is not sorted(other.platform): + return False + if self.platform_exclude: + if sorted(self.platform_exclude) is not sorted(other.platform_exclude): + return False + + # we have containment + return True + + def __str__(self): + ret_str = [] + ret_str.append(' name: %s' % self.name) + if self.address: + ret_str.append(' address: %s' % self.address) + if self.address_exclude: + ret_str.append(' address_exclude: %s' % self.address_exclude) + if self.source_address: + ret_str.append(' source_address: %s' % self.source_address) + if self.source_address_exclude: + ret_str.append(' source_address_exclude: %s' % + self.source_address_exclude) + if self.destination_address: + ret_str.append(' destination_address: %s' % self.destination_address) + if self.destination_address_exclude: + ret_str.append(' destination_address_exclude: %s' % + self.destination_address_exclude) + if self.source_prefix: + ret_str.append(' source_prefix: %s' % self.source_prefix) + if self.destination_prefix: + ret_str.append(' destination_prefix: %s' % self.destination_prefix) + if self.protocol: + ret_str.append(' protocol: %s' % self.protocol) + if self.protocol_except: + ret_str.append(' protocol-except: %s' % self.protocol_except) + if self.owner: + ret_str.append(' owner: %s' % self.owner) + if self.port: + ret_str.append(' port: %s' % self.port) + if self.source_port: + ret_str.append(' source_port: %s' % self.source_port) + if self.destination_port: + ret_str.append(' destination_port: %s' % self.destination_port) + if self.action: + ret_str.append(' action: %s' % self.action) + if self.option: + ret_str.append(' option: %s' % self.option) + if self.qos: + ret_str.append(' qos: %s' % self.qos) + if self.logging: + ret_str.append(' logging: %s' % self.logging) + if self.counter: + ret_str.append(' counter: %s' % self.counter) + if self.source_interface: + ret_str.append(' source_interface: %s' % self.source_interface) + if self.destination_interface: + ret_str.append(' destination_interface: %s' % self.destination_interface) + if self.expiration: + ret_str.append(' expiration: %s' % self.expiration) + if self.platform: + ret_str.append(' platform: %s' % self.platform) + if self.platform_exclude: + ret_str.append(' platform_exclude: %s' % self.platform_exclude) + if self.timeout: + ret_str.append(' timeout: %s' % self.timeout) + return '\n'.join(ret_str) + + def __eq__(self, other): + # action + if sorted(self.action) != sorted(other.action): + return False + + # addresses. + if not (sorted(self.address) == sorted(other.address) and + sorted(self.source_address) == sorted(other.source_address) and + sorted(self.source_address_exclude) == + sorted(other.source_address_exclude) and + sorted(self.destination_address) == + sorted(other.destination_address) and + sorted(self.destination_address_exclude) == + sorted(other.destination_address_exclude)): + return False + + # prefix lists + if not (sorted(self.source_prefix) == sorted(other.source_prefix) and + sorted(self.destination_prefix) == + sorted(other.destination_prefix)): + return False + + # ports + if not (sorted(self.port) == sorted(other.port) and + sorted(self.source_port) == sorted(other.source_port) and + sorted(self.destination_port) == sorted(other.destination_port)): + return False + + # protocol + if not (sorted(self.protocol) == sorted(other.protocol) and + sorted(self.protocol_except) == sorted(other.protocol_except)): + return False + + # option + if sorted(self.option) != sorted(other.option): + return False + + # qos + if self.qos != other.qos: + return False + + # verbatim + if self.verbatim != other.verbatim: + return False + + # policer + if self.policer != other.policer: + return False + + # interface + if self.source_interface != other.source_interface: + return False + + if self.destination_interface != other.destination_interface: + return False + + if sorted(self.logging) != sorted(other.logging): + return False + if self.qos != other.qos: + return False + if self.packet_length != other.packet_length: + return False + if self.fragment_offset != other.fragment_offset: + return False + if sorted(self.icmp_type) != sorted(other.icmp_type): + return False + if sorted(self.ether_type) != sorted(other.ether_type): + return False + if sorted(self.traffic_type) != sorted(other.traffic_type): + return False + + # platform + if not (sorted(self.platform) == sorted(other.platform) and + sorted(self.platform_exclude) == sorted(other.platform_exclude)): + return False + + # timeout + if self.timeout != other.timeout: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def FlattenAll(self): + """Reduce source, dest, and address fields to their post-exclude state. + + Populates the self.flattened_addr, self.flattened_saddr, + self.flattened_daddr by removing excludes from includes. + """ + # No excludes, set flattened attributes and move along. + self.flattened = True + if not (self.source_address_exclude or self.destination_address_exclude or + self.address_exclude): + self.flattened_saddr = self.source_address + self.flattened_daddr = self.destination_address + self.flattened_addr = self.address + return + + if self.source_address_exclude: + self.flattened_saddr = self._FlattenAddresses( + self.source_address, self.source_address_exclude) + if self.destination_address_exclude: + self.flattened_daddr = self._FlattenAddresses( + self.destination_address, self.destination_address_exclude) + if self.address_exclude: + self.flattened_addr = self._FlattenAddresses( + self.address, self.address_exclude) + + + @staticmethod + def _FlattenAddresses(include, exclude): + """Reduce an include and exclude list to a single include list. + + Using recursion, whittle away exclude addresses from address include + addresses which contain the exclusion. + + Args: + include: list of include addresses. + exclude: list of exclude addresses. + Returns: + a single flattened list of nacaddr objects. + """ + if not exclude: + return include + + for index, in_addr in enumerate(include): + for ex_addr in exclude: + if ex_addr in in_addr: + reduced_list = in_addr.address_exclude(ex_addr) + include.pop(index) + include.extend( + Term._FlattenAddresses(reduced_list, exclude[1:])) + return include + + def GetAddressOfVersion(self, addr_type, af=None): + """Returns addresses of the appropriate Address Family. + + Args: + addr_type: string, this will be either + 'source_address', 'source_address_exclude', + 'destination_address' or 'destination_address_exclude' + af: int or None, either Term.INET4 or Term.INET6 + + Returns: + list of addresses of the correct family. + """ + if not af: + return eval('self.' + addr_type) + + return filter(lambda x: x.version == af, eval('self.' + addr_type)) + + def AddObject(self, obj): + """Add an object of unknown type to this term. + + Args: + obj: single or list of either + [Address, Port, Option, Protocol, Counter, Action, Comment, Expiration] + + Raises: + InvalidTermActionError: if the action defined isn't an accepted action. + eg, action:: godofoobar + TermObjectTypeError: if AddObject is called with an object it doesn't + understand. + InvalidTermLoggingError: when a option is set for logging not known. + """ + if type(obj) is list: + for x in obj: + # do we have a list of addresses? + # expanded address fields consolidate naked address fields with + # saddr/daddr. + if x.var_type is VarType.SADDRESS: + saddr = DEFINITIONS.GetNetAddr(x.value) + self.source_address.extend(saddr) + elif x.var_type is VarType.DADDRESS: + daddr = DEFINITIONS.GetNetAddr(x.value) + self.destination_address.extend(daddr) + elif x.var_type is VarType.ADDRESS: + addr = DEFINITIONS.GetNetAddr(x.value) + self.address.extend(addr) + # do we have address excludes? + elif x.var_type is VarType.SADDREXCLUDE: + saddr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.source_address_exclude.extend(saddr_exclude) + elif x.var_type is VarType.DADDREXCLUDE: + daddr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.destination_address_exclude.extend(daddr_exclude) + elif x.var_type is VarType.ADDREXCLUDE: + addr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.address_exclude.extend(addr_exclude) + # do we have a list of ports? + elif x.var_type is VarType.PORT: + self.port.append(x.value) + elif x.var_type is VarType.SPORT: + self.source_port.append(x.value) + elif x.var_type is VarType.DPORT: + self.destination_port.append(x.value) + # do we have a list of protocols? + elif x.var_type is VarType.PROTOCOL: + self.protocol.append(x.value) + # do we have a list of protocol-exceptions? + elif x.var_type is VarType.PROTOCOL_EXCEPT: + self.protocol_except.append(x.value) + # do we have a list of options? + elif x.var_type is VarType.OPTION: + self.option.append(x.value) + elif x.var_type is VarType.PRINCIPALS: + self.principals.append(x.value) + elif x.var_type is VarType.SPFX: + self.source_prefix.append(x.value) + elif x.var_type is VarType.DPFX: + self.destination_prefix.append(x.value) + elif x.var_type is VarType.ETHER_TYPE: + self.ether_type.append(x.value) + elif x.var_type is VarType.TRAFFIC_TYPE: + self.traffic_type.append(x.value) + elif x.var_type is VarType.PRECEDENCE: + self.precedence.append(x.value) + elif x.var_type is VarType.PLATFORM: + self.platform.append(x.value) + elif x.var_type is VarType.PLATFORMEXCLUDE: + self.platform_exclude.append(x.value) + else: + raise TermObjectTypeError( + '%s isn\'t a type I know how to deal with (contains \'%s\')' % ( + type(x), x.value)) + else: + # stupid no switch statement in python + if obj.var_type is VarType.COMMENT: + self.comment.append(str(obj)) + elif obj.var_type is VarType.OWNER: + self.owner = obj.value + elif obj.var_type is VarType.EXPIRATION: + self.expiration = obj.value + elif obj.var_type is VarType.LOSS_PRIORITY: + self.loss_priority = obj.value + elif obj.var_type is VarType.ROUTING_INSTANCE: + self.routing_instance = obj.value + elif obj.var_type is VarType.PRECEDENCE: + self.precedence = obj.value + elif obj.var_type is VarType.VERBATIM: + self.verbatim.append(obj) + elif obj.var_type is VarType.ACTION: + if str(obj) not in _ACTIONS: + raise InvalidTermActionError('%s is not a valid action' % obj) + self.action.append(obj.value) + elif obj.var_type is VarType.COUNTER: + self.counter = obj + elif obj.var_type is VarType.ICMP_TYPE: + self.icmp_type.extend(obj.value) + elif obj.var_type is VarType.LOGGING: + if str(obj) not in _LOGGING: + raise InvalidTermLoggingError('%s is not a valid logging option' % + obj) + self.logging.append(obj) + # police man, tryin'a take you jail + elif obj.var_type is VarType.POLICER: + self.policer = obj.value + # qos? + elif obj.var_type is VarType.QOS: + self.qos = obj.value + elif obj.var_type is VarType.PACKET_LEN: + self.packet_length = obj.value + elif obj.var_type is VarType.FRAGMENT_OFFSET: + self.fragment_offset = obj.value + elif obj.var_type is VarType.SINTERFACE: + self.source_interface = obj.value + elif obj.var_type is VarType.DINTERFACE: + self.destination_interface = obj.value + elif obj.var_type is VarType.TIMEOUT: + self.timeout = obj.value + else: + raise TermObjectTypeError( + '%s isn\'t a type I know how to deal with' % (type(obj))) + + def SanityCheck(self): + """Sanity check the definition of the term. + + Raises: + ParseError: if term has both verbatim and non-verbatim tokens + TermInvalidIcmpType: if term has invalid icmp-types specified + TermNoActionError: if the term doesn't have an action defined. + TermPortProtocolError: if the term has a service/protocol definition pair + which don't match up, eg. SNMP and tcp + TermAddressExclusionError: if one of the *-exclude directives is defined, + but that address isn't contained in the non *-exclude directive. eg: + source-address::CORP_INTERNAL source-exclude:: LOCALHOST + TermProtocolEtherTypeError: if the term has both ether-type and + upper-layer protocol restrictions + InvalidTermActionError: action and routing-instance both defined + + This should be called when the term is fully formed, and + all of the options are set. + + """ + if self.verbatim: + if (self.action or self.source_port or self.destination_port or + self.port or self.protocol or self.option): + raise ParseError( + 'term "%s" has both verbatim and non-verbatim tokens.' % self.name) + else: + if not self.action and not self.routing_instance: + raise TermNoActionError('no action specified for term %s' % self.name) + # have we specified a port with a protocol that doesn't support ports? + if self.source_port or self.destination_port or self.port: + if 'tcp' not in self.protocol and 'udp' not in self.protocol: + raise TermPortProtocolError( + 'ports specified with a protocol that doesn\'t support ports. ' + 'Term: %s ' % self.name) + # TODO(pmoody): do we have mutually exclusive options? + # eg. tcp-established + tcp-initial? + + if self.ether_type and ( + self.protocol or + self.address or + self.destination_address or + self.destination_address_exclude or + self.destination_port or + self.destination_prefix or + self.source_address or + self.source_address_exclude or + self.source_port or + self.source_prefix): + raise TermProtocolEtherTypeError( + 'ether-type not supported when used with upper-layer protocol ' + 'restrictions. Term: %s' % self.name) + # validate icmp-types if specified, but addr_family will have to be checked + # in the generators as policy module doesn't know about that at this point. + if self.icmp_type: + for icmptype in self.icmp_type: + if (icmptype not in self.ICMP_TYPE[4] and icmptype not in + self.ICMP_TYPE[6]): + raise TermInvalidIcmpType('Term %s contains an invalid icmp-type:' + '%s' % (self.name, icmptype)) + + def AddressCleanup(self, optimize=True): + """Do Address and Port collapsing. + + Notes: + Collapses both the address definitions and the port definitions + to their smallest possible length. + + Args: + optimize: boolean value indicating whether to optimize addresses + """ + if optimize: + cleanup = nacaddr.CollapseAddrList + else: + cleanup = nacaddr.SortAddrList + + # address collapsing. + if self.address: + self.address = cleanup(self.address) + if self.source_address: + self.source_address = cleanup(self.source_address) + if self.source_address_exclude: + self.source_address_exclude = cleanup(self.source_address_exclude) + if self.destination_address: + self.destination_address = cleanup(self.destination_address) + if self.destination_address_exclude: + self.destination_address_exclude = cleanup( + self.destination_address_exclude) + + # port collapsing. + if self.port: + self.port = self.CollapsePortList(self.port) + if self.source_port: + self.source_port = self.CollapsePortList(self.source_port) + if self.destination_port: + self.destination_port = self.CollapsePortList(self.destination_port) + + def CollapsePortListRecursive(self, ports): + """Given a sorted list of ports, collapse to the smallest required list. + + Args: + ports: sorted list of port tuples + + Returns: + ret_ports: collapsed list of ports + """ + optimized = False + ret_ports = [] + for port in ports: + if not ret_ports: + ret_ports.append(port) + # we should be able to count on ret_ports[-1][0] <= port[0] + elif ret_ports[-1][1] >= port[1]: + # (10, 20) and (12, 13) -> (10, 20) + optimized = True + elif port[0] < ret_ports[-1][1] < port[1]: + # (10, 20) and (15, 30) -> (10, 30) + ret_ports[-1] = (ret_ports[-1][0], port[1]) + optimized = True + elif ret_ports[-1][1] + 1 == port[0]: + # (10, 20) and (21, 30) -> (10, 30) + ret_ports[-1] = (ret_ports[-1][0], port[1]) + optimized = True + else: + # (10, 20) and (22, 30) -> (10, 20), (22, 30) + ret_ports.append(port) + + if optimized: + return self.CollapsePortListRecursive(ret_ports) + return ret_ports + + def CollapsePortList(self, ports): + """Given a list of ports, Collapse to the smallest required. + + Args: + ports: a list of port strings eg: [(80,80), (53,53) (2000, 2009), + (1024,65535)] + + Returns: + ret_array: the collapsed sorted list of ports, eg: [(53,53), (80,80), + (1024,65535)] + """ + return self.CollapsePortListRecursive(sorted(ports)) + + def CheckPrincipalsContained(self, superset, subset): + """Check to if the given list of principals is wholly contained. + + Args: + superset: list of principals + subset: list of principals + + Returns: + bool: True if subset is contained in superset. false otherwise. + """ + # Skip set comparison if neither term has principals. + if not superset and not subset: + return True + + # Convert these lists to sets to use set comparison. + sup = set(superset) + sub = set(subset) + return sub.issubset(sup) + + def CheckProtocolIsContained(self, superset, subset): + """Check if the given list of protocols is wholly contained. + + Args: + superset: list of protocols + subset: list of protocols + + Returns: + bool: True if subset is contained in superset. false otherwise. + """ + if not superset: + return True + if not subset: + return False + + # Convert these lists to sets to use set comparison. + sup = set(superset) + sub = set(subset) + return sub.issubset(sup) + + def CheckPortIsContained(self, superset, subset): + """Check if the given list of ports is wholly contained. + + Args: + superset: list of port tuples + subset: list of port tuples + + Returns: + bool: True if subset is contained in superset, false otherwise + """ + if not superset: + return True + if not subset: + return False + + for sub_port in subset: + not_contains = True + for sup_port in superset: + if (int(sub_port[0]) >= int(sup_port[0]) + and int(sub_port[1]) <= int(sup_port[1])): + not_contains = False + break + if not_contains: + return False + return True + + def CheckAddressIsContained(self, superset, subset): + """Check if subset is wholey contained by superset. + + Args: + superset: list of the superset addresses + subset: list of the subset addresses + + Returns: + True or False. + """ + if not superset: + return True + if not subset: + return False + + for sub_addr in subset: + sub_contained = False + for sup_addr in superset: + # ipaddr ensures that version numbers match for inclusion. + if sub_addr in sup_addr: + sub_contained = True + break + if not sub_contained: + return False + return True + + +class VarType(object): + """Generic object meant to store lots of basic policy types.""" + + COMMENT = 0 + COUNTER = 1 + ACTION = 2 + SADDRESS = 3 + DADDRESS = 4 + ADDRESS = 5 + SPORT = 6 + DPORT = 7 + PROTOCOL_EXCEPT = 8 + OPTION = 9 + PROTOCOL = 10 + SADDREXCLUDE = 11 + DADDREXCLUDE = 12 + LOGGING = 13 + QOS = 14 + POLICER = 15 + PACKET_LEN = 16 + FRAGMENT_OFFSET = 17 + ICMP_TYPE = 18 + SPFX = 19 + DPFX = 20 + ETHER_TYPE = 21 + TRAFFIC_TYPE = 22 + VERBATIM = 23 + LOSS_PRIORITY = 24 + ROUTING_INSTANCE = 25 + PRECEDENCE = 26 + SINTERFACE = 27 + EXPIRATION = 28 + DINTERFACE = 29 + PLATFORM = 30 + PLATFORMEXCLUDE = 31 + PORT = 32 + TIMEOUT = 33 + OWNER = 34 + PRINCIPALS = 35 + ADDREXCLUDE = 36 + + def __init__(self, var_type, value): + self.var_type = var_type + if self.var_type == self.COMMENT: + # remove the double quotes + comment = value.strip('"') + # make all of the lines start w/o leading whitespace. + self.value = '\n'.join([x.lstrip() for x in comment.splitlines()]) + else: + self.value = value + + def __str__(self): + return self.value + + def __eq__(self, other): + return self.var_type == other.var_type and self.value == other.value + + +class Header(object): + """The header of the policy file contains the targets and a global comment.""" + + def __init__(self): + self.target = [] + self.comment = [] + + def AddObject(self, obj): + """Add and object to the Header. + + Args: + obj: of type VarType.COMMENT or Target + """ + if type(obj) == Target: + self.target.append(obj) + elif obj.var_type == VarType.COMMENT: + self.comment.append(str(obj)) + + @property + def platforms(self): + """The platform targets of this particular header.""" + return map(lambda x: x.platform, self.target) + + def FilterOptions(self, platform): + """Given a platform return the options. + + Args: + platform: string + + Returns: + list or None + """ + for target in self.target: + if target.platform == platform: + return target.options + return [] + + def FilterName(self, platform): + """Given a filter_type, return the filter name. + + Args: + platform: string + + Returns: + filter_name: string or None + + Notes: + !! Deprecated in favor of Header.FilterOptions(platform) !! + """ + for target in self.target: + if target.platform == platform: + if target.options: + return target.options[0] + return None + + +# This could be a VarType object, but I'm keeping it as it's class +# b/c we're almost certainly going to have to do something more exotic with +# it shortly to account for various rendering options like default iptables +# policies or output file names, etc. etc. +class Target(object): + """The type of acl to be rendered from this policy file.""" + + def __init__(self, target): + self.platform = target[0] + if len(target) > 1: + self.options = target[1:] + else: + self.options = None + + def __str__(self): + return self.platform + + def __eq__(self, other): + return self.platform == other.platform and self.options == other.options + + def __ne__(self, other): + return not self.__eq__(other) + + +# Lexing/Parsing starts here +tokens = ( + 'ACTION', + 'ADDR', + 'ADDREXCLUDE', + 'COMMENT', + 'COUNTER', + 'DADDR', + 'DADDREXCLUDE', + 'DPFX', + 'DPORT', + 'DINTERFACE', + 'DQUOTEDSTRING', + 'ETHER_TYPE', + 'EXPIRATION', + 'FRAGMENT_OFFSET', + 'HEADER', + 'ICMP_TYPE', + 'INTEGER', + 'LOGGING', + 'LOSS_PRIORITY', + 'OPTION', + 'OWNER', + 'PACKET_LEN', + 'PLATFORM', + 'PLATFORMEXCLUDE', + 'POLICER', + 'PORT', + 'PRECEDENCE', + 'PRINCIPALS', + 'PROTOCOL', + 'PROTOCOL_EXCEPT', + 'QOS', + 'ROUTING_INSTANCE', + 'SADDR', + 'SADDREXCLUDE', + 'SINTERFACE', + 'SPFX', + 'SPORT', + 'STRING', + 'TARGET', + 'TERM', + 'TIMEOUT', + 'TRAFFIC_TYPE', + 'VERBATIM', +) + +literals = r':{},-' +t_ignore = ' \t' + +reserved = { + 'action': 'ACTION', + 'address': 'ADDR', + 'address-exclude': 'ADDREXCLUDE', + 'comment': 'COMMENT', + 'counter': 'COUNTER', + 'destination-address': 'DADDR', + 'destination-exclude': 'DADDREXCLUDE', + 'destination-interface': 'DINTERFACE', + 'destination-prefix': 'DPFX', + 'destination-port': 'DPORT', + 'ether-type': 'ETHER_TYPE', + 'expiration': 'EXPIRATION', + 'fragment-offset': 'FRAGMENT_OFFSET', + 'header': 'HEADER', + 'icmp-type': 'ICMP_TYPE', + 'logging': 'LOGGING', + 'loss-priority': 'LOSS_PRIORITY', + 'option': 'OPTION', + 'owner': 'OWNER', + 'packet-length': 'PACKET_LEN', + 'platform': 'PLATFORM', + 'platform-exclude': 'PLATFORMEXCLUDE', + 'policer': 'POLICER', + 'port': 'PORT', + 'precedence': 'PRECEDENCE', + 'principals': 'PRINCIPALS', + 'protocol': 'PROTOCOL', + 'protocol-except': 'PROTOCOL_EXCEPT', + 'qos': 'QOS', + 'routing-instance': 'ROUTING_INSTANCE', + 'source-address': 'SADDR', + 'source-exclude': 'SADDREXCLUDE', + 'source-interface': 'SINTERFACE', + 'source-prefix': 'SPFX', + 'source-port': 'SPORT', + 'target': 'TARGET', + 'term': 'TERM', + 'timeout': 'TIMEOUT', + 'traffic-type': 'TRAFFIC_TYPE', + 'verbatim': 'VERBATIM', +} + + +# disable linting warnings for lexx/yacc code +# pylint: disable-msg=W0613,C6102,C6104,C6105,C6108,C6409 + + +def t_IGNORE_COMMENT(t): + r'\#.*' + pass + + +def t_DQUOTEDSTRING(t): + r'"[^"]*?"' + t.lexer.lineno += str(t.value).count('\n') + return t + + +def t_newline(t): + r'\n+' + t.lexer.lineno += len(t.value) + + +def t_error(t): + print "Illegal character '%s' on line %s" % (t.value[0], t.lineno) + t.lexer.skip(1) + + +def t_INTEGER(t): + r'\d+' + return t + + +def t_STRING(t): + r'\w+([-_+.@]\w*)*' + # we have an identifier; let's check if it's a keyword or just a string. + t.type = reserved.get(t.value, 'STRING') + return t + + +### +## parser starts here +### +def p_target(p): + """ target : target header terms + | """ + if len(p) > 1: + if type(p[1]) is Policy: + p[1].AddFilter(p[2], p[3]) + p[0] = p[1] + else: + p[0] = Policy(p[2], p[3]) + + +def p_header(p): + """ header : HEADER '{' header_spec '}' """ + p[0] = p[3] + + +def p_header_spec(p): + """ header_spec : header_spec target_spec + | header_spec comment_spec + | """ + if len(p) > 1: + if type(p[1]) == Header: + p[1].AddObject(p[2]) + p[0] = p[1] + else: + p[0] = Header() + p[0].AddObject(p[2]) + + +# we may want to change this at some point if we want to be clever with things +# like being able to set a default input/output policy for iptables policies. +def p_target_spec(p): + """ target_spec : TARGET ':' ':' strings_or_ints """ + p[0] = Target(p[4]) + + +def p_terms(p): + """ terms : terms TERM STRING '{' term_spec '}' + | """ + if len(p) > 1: + p[5].name = p[3] + if type(p[1]) == list: + p[1].append(p[5]) + p[0] = p[1] + else: + p[0] = [p[5]] + + +def p_term_spec(p): + """ term_spec : term_spec action_spec + | term_spec addr_spec + | term_spec comment_spec + | term_spec counter_spec + | term_spec ether_type_spec + | term_spec exclude_spec + | term_spec expiration_spec + | term_spec fragment_offset_spec + | term_spec icmp_type_spec + | term_spec interface_spec + | term_spec logging_spec + | term_spec losspriority_spec + | term_spec option_spec + | term_spec owner_spec + | term_spec packet_length_spec + | term_spec platform_spec + | term_spec policer_spec + | term_spec port_spec + | term_spec precedence_spec + | term_spec principals_spec + | term_spec prefix_list_spec + | term_spec protocol_spec + | term_spec qos_spec + | term_spec routinginstance_spec + | term_spec timeout_spec + | term_spec traffic_type_spec + | term_spec verbatim_spec + | """ + if len(p) > 1: + if type(p[1]) == Term: + p[1].AddObject(p[2]) + p[0] = p[1] + else: + p[0] = Term(p[2]) + + +def p_routinginstance_spec(p): + """ routinginstance_spec : ROUTING_INSTANCE ':' ':' STRING """ + p[0] = VarType(VarType.ROUTING_INSTANCE, p[4]) + + +def p_losspriority_spec(p): + """ losspriority_spec : LOSS_PRIORITY ':' ':' STRING """ + p[0] = VarType(VarType.LOSS_PRIORITY, p[4]) + + +def p_precedence_spec(p): + """ precedence_spec : PRECEDENCE ':' ':' one_or_more_ints """ + p[0] = VarType(VarType.PRECEDENCE, p[4]) + + +def p_icmp_type_spec(p): + """ icmp_type_spec : ICMP_TYPE ':' ':' one_or_more_strings """ + p[0] = VarType(VarType.ICMP_TYPE, p[4]) + + +def p_packet_length_spec(p): + """ packet_length_spec : PACKET_LEN ':' ':' INTEGER + | PACKET_LEN ':' ':' INTEGER '-' INTEGER """ + if len(p) == 4: + p[0] = VarType(VarType.PACKET_LEN, str(p[4])) + else: + p[0] = VarType(VarType.PACKET_LEN, str(p[4]) + '-' + str(p[6])) + + +def p_fragment_offset_spec(p): + """ fragment_offset_spec : FRAGMENT_OFFSET ':' ':' INTEGER + | FRAGMENT_OFFSET ':' ':' INTEGER '-' INTEGER """ + if len(p) == 4: + p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4])) + else: + p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4]) + '-' + str(p[6])) + + +def p_exclude_spec(p): + """ exclude_spec : SADDREXCLUDE ':' ':' one_or_more_strings + | DADDREXCLUDE ':' ':' one_or_more_strings + | ADDREXCLUDE ':' ':' one_or_more_strings + | PROTOCOL_EXCEPT ':' ':' one_or_more_strings """ + + p[0] = [] + for ex in p[4]: + if p[1].find('source-exclude') >= 0: + p[0].append(VarType(VarType.SADDREXCLUDE, ex)) + elif p[1].find('destination-exclude') >= 0: + p[0].append(VarType(VarType.DADDREXCLUDE, ex)) + elif p[1].find('address-exclude') >= 0: + p[0].append(VarType(VarType.ADDREXCLUDE, ex)) + elif p[1].find('protocol-except') >= 0: + p[0].append(VarType(VarType.PROTOCOL_EXCEPT, ex)) + + +def p_prefix_list_spec(p): + """ prefix_list_spec : DPFX ':' ':' one_or_more_strings + | SPFX ':' ':' one_or_more_strings """ + p[0] = [] + for pfx in p[4]: + if p[1].find('source-prefix') >= 0: + p[0].append(VarType(VarType.SPFX, pfx)) + elif p[1].find('destination-prefix') >= 0: + p[0].append(VarType(VarType.DPFX, pfx)) + + +def p_addr_spec(p): + """ addr_spec : SADDR ':' ':' one_or_more_strings + | DADDR ':' ':' one_or_more_strings + | ADDR ':' ':' one_or_more_strings """ + p[0] = [] + for addr in p[4]: + if p[1].find('source-address') >= 0: + p[0].append(VarType(VarType.SADDRESS, addr)) + elif p[1].find('destination-address') >= 0: + p[0].append(VarType(VarType.DADDRESS, addr)) + else: + p[0].append(VarType(VarType.ADDRESS, addr)) + + +def p_port_spec(p): + """ port_spec : SPORT ':' ':' one_or_more_strings + | DPORT ':' ':' one_or_more_strings + | PORT ':' ':' one_or_more_strings """ + p[0] = [] + for port in p[4]: + if p[1].find('source-port') >= 0: + p[0].append(VarType(VarType.SPORT, port)) + elif p[1].find('destination-port') >= 0: + p[0].append(VarType(VarType.DPORT, port)) + else: + p[0].append(VarType(VarType.PORT, port)) + + +def p_protocol_spec(p): + """ protocol_spec : PROTOCOL ':' ':' strings_or_ints """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.PROTOCOL, proto)) + + +def p_ether_type_spec(p): + """ ether_type_spec : ETHER_TYPE ':' ':' one_or_more_strings """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.ETHER_TYPE, proto)) + + +def p_traffic_type_spec(p): + """ traffic_type_spec : TRAFFIC_TYPE ':' ':' one_or_more_strings """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.TRAFFIC_TYPE, proto)) + + +def p_policer_spec(p): + """ policer_spec : POLICER ':' ':' STRING """ + p[0] = VarType(VarType.POLICER, p[4]) + + +def p_logging_spec(p): + """ logging_spec : LOGGING ':' ':' STRING """ + p[0] = VarType(VarType.LOGGING, p[4]) + + +def p_option_spec(p): + """ option_spec : OPTION ':' ':' one_or_more_strings """ + p[0] = [] + for opt in p[4]: + p[0].append(VarType(VarType.OPTION, opt)) + +def p_principals_spec(p): + """ principals_spec : PRINCIPALS ':' ':' one_or_more_strings """ + p[0] = [] + for opt in p[4]: + p[0].append(VarType(VarType.PRINCIPALS, opt)) + +def p_action_spec(p): + """ action_spec : ACTION ':' ':' STRING """ + p[0] = VarType(VarType.ACTION, p[4]) + + +def p_counter_spec(p): + """ counter_spec : COUNTER ':' ':' STRING """ + p[0] = VarType(VarType.COUNTER, p[4]) + + +def p_expiration_spec(p): + """ expiration_spec : EXPIRATION ':' ':' INTEGER '-' INTEGER '-' INTEGER """ + p[0] = VarType(VarType.EXPIRATION, datetime.date(int(p[4]), + int(p[6]), + int(p[8]))) + + +def p_comment_spec(p): + """ comment_spec : COMMENT ':' ':' DQUOTEDSTRING """ + p[0] = VarType(VarType.COMMENT, p[4]) + + +def p_owner_spec(p): + """ owner_spec : OWNER ':' ':' STRING """ + p[0] = VarType(VarType.OWNER, p[4]) + + +def p_verbatim_spec(p): + """ verbatim_spec : VERBATIM ':' ':' STRING DQUOTEDSTRING """ + p[0] = VarType(VarType.VERBATIM, [p[4], p[5].strip('"')]) + + +def p_qos_spec(p): + """ qos_spec : QOS ':' ':' STRING """ + p[0] = VarType(VarType.QOS, p[4]) + + +def p_interface_spec(p): + """ interface_spec : SINTERFACE ':' ':' STRING + | DINTERFACE ':' ':' STRING """ + if p[1].find('source-interface') >= 0: + p[0] = VarType(VarType.SINTERFACE, p[4]) + elif p[1].find('destination-interface') >= 0: + p[0] = VarType(VarType.DINTERFACE, p[4]) + + +def p_platform_spec(p): + """ platform_spec : PLATFORM ':' ':' one_or_more_strings + | PLATFORMEXCLUDE ':' ':' one_or_more_strings """ + p[0] = [] + for platform in p[4]: + if p[1].find('platform-exclude') >= 0: + p[0].append(VarType(VarType.PLATFORMEXCLUDE, platform)) + elif p[1].find('platform') >= 0: + p[0].append(VarType(VarType.PLATFORM, platform)) + + +def p_timeout_spec(p): + """ timeout_spec : TIMEOUT ':' ':' INTEGER """ + p[0] = VarType(VarType.TIMEOUT, p[4]) + + +def p_one_or_more_strings(p): + """ one_or_more_strings : one_or_more_strings STRING + | STRING + | """ + if len(p) > 1: + if type(p[1]) == type([]): + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_one_or_more_ints(p): + """ one_or_more_ints : one_or_more_ints INTEGER + | INTEGER + | """ + if len(p) > 1: + if type(p[1]) == type([]): + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_strings_or_ints(p): + """ strings_or_ints : strings_or_ints STRING + | strings_or_ints INTEGER + | STRING + | INTEGER + | """ + if len(p) > 1: + if type(p[1]) is list: + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_error(p): + """.""" + next_token = yacc.token() + if next_token is None: + use_token = 'EOF' + else: + use_token = repr(next_token.value) + + if p: + raise ParseError(' ERROR on "%s" (type %s, line %d, Next %s)' + % (p.value, p.type, p.lineno, use_token)) + else: + raise ParseError(' ERROR you likely have unablanaced "{"\'s') + +# pylint: enable-msg=W0613,C6102,C6104,C6105,C6108,C6409 + + +def _ReadFile(filename): + """Read data from a file if it exists. + + Args: + filename: str - Filename + + Returns: + data: str contents of file. + + Raises: + FileNotFoundError: if requested file does not exist. + FileReadError: Any error resulting from trying to open/read file. + """ + if os.path.exists(filename): + try: + data = open(filename, 'r').read() + return data + except IOError: + raise FileReadError('Unable to open or read file %s' % filename) + else: + raise FileNotFoundError('Unable to open policy file %s' % filename) + + +def _Preprocess(data, max_depth=5, base_dir=''): + """Search input for include statements and import specified include file. + + Search input for include statements and if found, import specified file + and recursively search included data for includes as well up to max_depth. + + Args: + data: A string of Policy file data. + max_depth: Maximum depth of included files + base_dir: Base path string where to look for policy or include files + + Returns: + A string containing result of the processed input data + + Raises: + RecursionTooDeepError: nested include files exceed maximum + """ + if not max_depth: + raise RecursionTooDeepError('%s' % ( + 'Included files exceed maximum recursion depth of %s.' % max_depth)) + rval = [] + lines = [x.rstrip() for x in data.splitlines()] + for index, line in enumerate(lines): + words = line.split() + if len(words) > 1 and words[0] == '#include': + # remove any quotes around included filename + include_file = words[1].strip('\'"') + data = _ReadFile(os.path.join(base_dir, include_file)) + # recursively handle includes in included data + inc_data = _Preprocess(data, max_depth - 1, base_dir=base_dir) + rval.extend(inc_data) + else: + rval.append(line) + return rval + + +def ParseFile(filename, definitions=None, optimize=True, base_dir='', + shade_check=False): + """Parse the policy contained in file, optionally provide a naming object. + + Read specified policy file and parse into a policy object. + + Args: + filename: Name of policy file to parse. + definitions: optional naming library definitions object. + optimize: bool - whether to summarize networks and services. + base_dir: base path string to look for acls or include files. + shade_check: bool - whether to raise an exception when a term is shaded. + + Returns: + policy object. + """ + data = _ReadFile(filename) + p = ParsePolicy(data, definitions, optimize, base_dir=base_dir, + shade_check=shade_check) + return p + + +def ParsePolicy(data, definitions=None, optimize=True, base_dir='', + shade_check=False): + """Parse the policy in 'data', optionally provide a naming object. + + Parse a blob of policy text into a policy object. + + Args: + data: a string blob of policy data to parse. + definitions: optional naming library definitions object. + optimize: bool - whether to summarize networks and services. + base_dir: base path string to look for acls or include files. + shade_check: bool - whether to raise an exception when a term is shaded. + + Returns: + policy object. + """ + try: + if definitions: + globals()['DEFINITIONS'] = definitions + else: + globals()['DEFINITIONS'] = naming.Naming(DEFAULT_DEFINITIONS) + if not optimize: + globals()['_OPTIMIZE'] = False + if shade_check: + globals()['_SHADE_CHECK'] = True + + lexer = lex.lex() + + preprocessed_data = '\n'.join(_Preprocess(data, base_dir=base_dir)) + p = yacc.yacc(write_tables=False, debug=0, errorlog=yacc.NullLogger()) + + return p.parse(preprocessed_data, lexer=lexer) + + except IndexError: + return False + + +# if you call this from the command line, you can specify a jcl file for it to +# read. +if __name__ == '__main__': + ret = 0 + if len(sys.argv) > 1: + try: + ret = ParsePolicy(open(sys.argv[1], 'r').read()) + except IOError: + print('ERROR: \'%s\' either does not exist or is not readable' % + (sys.argv[1])) + ret = 1 + else: + # default to reading stdin + ret = ParsePolicy(sys.stdin.read()) + sys.exit(ret) diff --git a/lib/policyreader.py b/lib/policyreader.py new file mode 100644 index 0000000..8124221 --- /dev/null +++ b/lib/policyreader.py @@ -0,0 +1,245 @@ +#!/usr/bin/python2.4 +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utility to provide exploration of policy definition files. + +Allows read only access of policy definition files. The library +creates a Policy object, which has filters containing terms. + +This library does no expansion on the tokens directly, such as in policy.py. + +TODO: This library is currently incomplete, and does not allow access to + every argument of a policy term. +""" + +__author__ = 'watson@google.com (Tony Watson)' + +from capirca import naming + + +class FileOpenError(Exception): + """Trouble opening a file.""" + + +class Filter(object): + """Simple filter with a name a list of terms.""" + + def __init__(self, filtername=''): + self.name = filtername + self.term = [] + + def __str__(self): + rval = [] + title = 'Filter: %s' % str(self.name) + rval.append('\n%s' % title) + rval.append('-' * len(title)) + for term in self.term: + rval.append(str(term)) + return '\n\n'.join(rval) + + +class Term(object): + """Simple term with a name a list of attributes.""" + + def __init__(self, termname=''): + self.name = termname + self.source = [] + self.destination = [] + self.sport = [] + self.dport = [] + self.action = [] + self.option = [] + self.protocol = [] + + def __str__(self): + rval = [] + rval.append(' Term: %s' % self.name) + rval.append(' Source-address:: %s' % ' '.join(self.source)) + rval.append(' Destination-address:: %s' % ' '.join(self.destination)) + rval.append(' Source-port:: %s' % ' '.join(self.sport)) + rval.append(' Destination-port:: %s' % ' '.join(self.dport)) + rval.append(' Protocol:: %s' % ' '.join(self.protocol)) + rval.append(' Option:: %s' % ' '.join(self.option)) + rval.append(' Action:: %s' % ' '.join(self.action)) + return '\n'.join(rval) + + +class Policy(object): + """Holds basic attributes of an unexpanded policy definition file.""" + + def __init__(self, filename, defs_data=None): + """Build policy object and naming definitions from provided filenames. + + Args: + filename: location of a .pol file + defs_data: location of naming definitions directory, if any + """ + self.defs = naming.Naming(defs_data) + self.filter = [] + try: + self.data = open(filename, 'r').readlines() + except IOError, error_info: + info = str(filename) + ' cannot be opened' + raise FileOpenError('%s\n%s' % (info, error_info)) + + indent = 0 + in_header = False + in_term = False + filt = Filter() + term = Term() + in_string = False + + for line in self.data: + words = line.strip().split() + quotes = len(line.split('"')) + 1 + if quotes % 2: # are we in or out of double quotes + in_string = not in_string # flip status of quote status + if not in_string: + if '{' in words: + indent += 1 + if words: + if words[0] == 'header': + in_header = True + if words[0] == 'term': + in_term = True + term = Term(words[1]) + if in_header and words[0] == 'target::': + if filt.name != words[2]: # avoid empty dupe filters due to + filt = Filter(words[2]) # multiple target header lines + if in_term: + if words[0] == 'source-address::': + term.source.extend(words[1:]) + if words[0] == 'destination-address::': + term.destination.extend(words[1:]) + if words[0] == 'source-port::': + term.sport.extend(words[1:]) + if words[0] == 'destination-port::': + term.dport.extend(words[1:]) + if words[0] == 'action::': + term.action.extend(words[1:]) + if words[0] == 'protocol::': + term.protocol.extend(words[1:]) + if words[0] == 'option::': + term.option.extend(words[1:]) + + if '}' in words: + indent -= 1 + if in_header: + self.filter.append(filt) + in_header = False + if in_term: + filt.term.append(term) + in_term = False + + def __str__(self): + return '\n'.join(str(next) for next in self.filter) + + def Matches(self, src=None, dst=None, dport=None, sport=None, + filtername=None): + """Return list of term names that match specific attributes. + + Args: + src: source ip address '12.1.1.1' + dst: destination ip address '10.1.1.1' + dport: any port/protocol combo, such as '80/tcp' or '53/udp' + sport: any port/protocol combo, such as '80/tcp' or '53/udp' + filtername: a filter name or None to search all filters + + Returns: + results: list of lists, each list is index to filter & term in the policy + + Example: + p=policyreader.Policy('policy_path', 'definitions_path') + + p.Matches(dst='209.85.216.5', dport='25/tcp') + [[0, 26]] + print p.filter[0].term[26].name + + for match in p.Matches(dst='209.85.216.5'): + print p.filter[match[0]].term[match[1]].name + + """ + rval = [] + results = [] + filter_list = [] + dport_parents = None + sport_parents = None + destination_parents = None + source_parents = None + if dport: + dport_parents = self.defs.GetServiceParents(dport) + if sport: + sport_parents = self.defs.GetServiceParents(sport) + if dst: + destination_parents = self.defs.GetIpParents(dst) + try: + destination_parents.remove('ANY') + destination_parents.remove('RESERVED') + except ValueError: + pass # ignore and continue + if src: + source_parents = self.defs.GetIpParents(src) + try: + source_parents.remove('ANY') + source_parents.remove('RESERVED') + except ValueError: + pass # ignore and continue + if not filtername: + filter_list = self.filter + else: + for idx, next in enumerate(self.filter): + if filtername == next.name: + filter_list = [self.filter[idx]] + if not filter_list: + raise 'invalid filter name: %s' % filtername + + for findex, xfilter in enumerate(filter_list): + mterms = [] + mterms.append(set()) # dport + mterms.append(set()) # sport + mterms.append(set()) # dst + mterms.append(set()) # src + for tindex, term in enumerate(xfilter.term): + if dport_parents: + for token in dport_parents: + if token in term.dport: + mterms[0].add(tindex) + else: + mterms[0].add(tindex) + if sport_parents: + for token in sport_parents: + if token in term.sport: + mterms[1].add(tindex) + else: + mterms[1].add(tindex) + if destination_parents: + for token in destination_parents: + if token in term.destination: + mterms[2].add(tindex) + else: + mterms[2].add(tindex) + if source_parents: + for token in source_parents: + if token in term.source: + mterms[3].add(tindex) + else: + mterms[3].add(tindex) + rval.append(list(mterms[0] & mterms[1] & mterms[2] & mterms[3])) + for findex, fresult in enumerate(rval): + for next in list(fresult): + results.append([findex, next]) + return results diff --git a/lib/port.py b/lib/port.py new file mode 100755 index 0000000..f28ac52 --- /dev/null +++ b/lib/port.py @@ -0,0 +1,55 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Common library for network ports and protocol handling.""" + +__author__ = 'watson@google.com (Tony Watson)' + + +class Error(Exception): + """Base error class.""" + + +class BadPortValue(Error): + """Invalid port format.""" + + +class BadPortRange(Error): + """Invalid port range.""" + + +def Port(port): + """Sanitize a port value. + + Args: + port: a port value + + Returns: + port: a port value + + Raises: + BadPortValue: port is not valid integer or string + BadPortRange: port is outside valid range + """ + pval = -1 + try: + pval = int(port) + except ValueError: + raise BadPortValue('port %s is not valid.' % port) + if pval < 0 or pval > 65535: + raise BadPortRange('port %s is out of range 0-65535.' % port) + return pval diff --git a/lib/setup.py b/lib/setup.py new file mode 100644 index 0000000..72ab5d0 --- /dev/null +++ b/lib/setup.py @@ -0,0 +1,39 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from distutils.core import setup + +import capirca + +setup(name='capirca', + maintainer='Google', + maintainer_email='capirca-dev@googlegroups.com', + version=ipaddr.__version__, + url='http://code.google.com/p/capirca/', + license='Apache License, Version 2.0', + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Topic :: Internet', + 'Topic :: Software Development :: Libraries', + 'Topic :: System :: Networking', + 'Topic :: Security'], + py_modules=['naming', 'policy', 'nacaddr', 'cisco', 'ciscoasa', 'juniper', + 'junipersrx', 'iptables', 'policyreader', 'aclcheck', + 'aclgenerator', 'port', 'packetfilter', 'speedway', 'demo']) diff --git a/lib/speedway.py b/lib/speedway.py new file mode 100755 index 0000000..233bbe0 --- /dev/null +++ b/lib/speedway.py @@ -0,0 +1,50 @@ +#!/usr/bin/python2.4 +# +# Copyright 2011 Google Inc. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Speedway iptables generator. + + This is a subclass of Iptables library. The primary difference is + that this library produced 'iptable-restore' compatible output.""" + +__author__ = 'watson@google.com (Tony Watson)' + +from string import Template +import iptables + + +class Error(Exception): + pass + + +class Term(iptables.Term): + """Generate Iptables policy terms.""" + _PLATFORM = 'speedway' + _PREJUMP_FORMAT = None + _POSTJUMP_FORMAT = Template('-A $filter -j $term') + + +class Speedway(iptables.Iptables): + """Generates filters and terms from provided policy object.""" + + _PLATFORM = 'speedway' + _DEFAULT_PROTOCOL = 'all' + _SUFFIX = '.ipt' + + _RENDER_PREFIX = '*filter' + _RENDER_SUFFIX = 'COMMIT' + _DEFAULTACTION_FORMAT = ':%s %s' + + _TERM = Term diff --git a/make_dist.sh b/make_dist.sh new file mode 100755 index 0000000..c342eaa --- /dev/null +++ b/make_dist.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# +# Copyright 2011 Google Inc. All Rights Reserved. +# Author: watson@google.com (Tony Watson) + +rev=`svn up|awk '{print $3}'` +archive="capirca-r"$rev"tgz" +filedir='./capirca' + +echo "Building: $archive" +find . -name \*.pyc -exec rm {} \; +pushd . > /dev/null +cd .. +tar -czf $archive --exclude-vcs $filedir +mv $archive $filedir +popd > /dev/null +ls -al $archive +echo "Done." + diff --git a/policies/includes/untrusted-networks-blocking.inc b/policies/includes/untrusted-networks-blocking.inc new file mode 100644 index 0000000..c77d064 --- /dev/null +++ b/policies/includes/untrusted-networks-blocking.inc @@ -0,0 +1,18 @@ +term deny-from-bogons { + comment:: "this is a sample edge input filter with a very very very long and + multi-line comment that" + comment:: "also has multiple entries." + source-address:: BOGON + action:: deny +} + +term deny-from-reserved { + source-address:: RESERVED + action:: deny +} + +term deny-to-rfc1918 { + destination-address:: RFC1918 + action:: deny +} + diff --git a/policies/sample_srx.pol b/policies/sample_srx.pol new file mode 100644 index 0000000..3649c47 --- /dev/null +++ b/policies/sample_srx.pol @@ -0,0 +1,26 @@ +# +# This is an example policy for capirca +# +header { + comment:: "this is a sample policy to generate Juniper SRX filter" + comment:: "from zone Untrust to zone DMZ." + target:: srx from-zone Untrust to-zone DMZ +} + +term test-tcp { + destination-address:: RFC1918 + protocol:: tcp udp + logging:: true + action:: accept +} + +term test-icmp { + destination-address:: RFC1918 + protocol:: icmp + icmp-type:: echo-request echo-reply + action:: accept +} + +term default-deny { + action:: deny +} diff --git a/policies/sample_tug_wlc_fw.pol b/policies/sample_tug_wlc_fw.pol new file mode 100644 index 0000000..76da91c --- /dev/null +++ b/policies/sample_tug_wlc_fw.pol @@ -0,0 +1,36 @@ +# +# This is an example policy for capirca +# +header { + comment:: "this is a sample output filter that generates" + comment:: "multiplatform for tug wlc protection" + target:: juniper fw_tug_wlc_protect inet + target:: srx from-zone NORDUnet_nets to-zone WLC_net + target:: cisco fw_tug_wlc_protect mixed + target:: speedway INPUT + target:: ciscoasa asa_in + target:: html MUPP +} + +term permit-icmp { + destination-address:: NDN_TUG_WLC_NET + protocol:: icmp + action:: accept +} + +term permit-traceroute { + destination-address:: NDN_TUG_WLC_NET + protocol:: udp + destination-port:: TRACEROUTE + action:: accept +} + +term permit-NORDUnet { + source-address:: NORDUNET_AGGREGATE SUNET_AP_STATICS + destination-address:: NDN_TUG_WLC_NET + action:: accept +} + +term default-deny { + action:: deny +} diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..d406f5d --- /dev/null +++ b/setup.py @@ -0,0 +1,43 @@ +#!/usr/bin/python +# +# Copyright 2009 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.core import setup + +setup(name='capirca', + maintainer='Google', + maintainer_email='capirca-dev@googlegroups.com', + version='1.109', + url='http://code.google.com/p/capirca', + license='Apache License, Version 2.0', + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Topic :: Internet', + 'Topic :: Software Development :: Libraries', + 'Topic :: Security'], + py_modules=['aclgen', 'definate', 'definate.generator', + 'definate.generator_factory', 'definate.dns_generator', + 'definate.filter_factory', 'definate.global_filter', + 'definate.file_filter', 'definate.definition_filter', + 'definate.yaml_validator', 'lib.cisco', 'lib.ciscoasa', + 'lib.iptables', 'lib.juniper', 'lib.junipersrx', + 'lib.nacaddr', 'lib.policy', 'lib.naming', 'lib.aclcheck', + 'lib.aclgenerator', 'lib.port', 'lib.demo', 'lib.speedway', + 'lib.ipset', 'lib.packetfilter', + 'third_party.ipaddr', 'third_party.ply.lex', + 'third_party.ply.yacc']) diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/third_party/ipaddr.py b/third_party/ipaddr.py new file mode 100644 index 0000000..f4060f6 --- /dev/null +++ b/third_party/ipaddr.py @@ -0,0 +1,1951 @@ +#!/usr/bin/python +# +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +__version__ = '2.1.7' + +import struct + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def IPAddress(address, version=None): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + version: An Integer, 4 or 6. If set, don't try to automatically + determine what the IP address type is. important for things + like IPAddress(1), which could be IPv4, '0.0.0.1', or IPv6, + '::1'. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + """ + if version: + if version == 4: + return IPv4Address(address) + elif version == 6: + return IPv6Address(address) + + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def IPNetwork(address, version=None, strict=False): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + version: An Integer, if set, don't try to automatically + determine what the IP address type is. important for things + like IPNetwork(1), which could be IPv4, '0.0.0.1/32', or IPv6, + '::1/128'. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if a strict network was requested and a strict + network wasn't given. + + """ + if version: + if version == 4: + return IPv4Network(address, strict) + elif version == 6: + return IPv6Network(address, strict) + + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def v4_int_to_packed(address): + """The binary representation of this address. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The binary representation of this address. + + Raises: + ValueError: If the integer is too large to be an IPv4 IP + address. + """ + if address > _BaseV4._ALL_ONES: + raise ValueError('Address too large for IPv4') + return struct.pack('!I', address) + + +def v6_int_to_packed(address): + """The binary representation of this address. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The binary representation of this address. + """ + return struct.pack('!QQ', address >> 64, address & (2**64 - 1)) + + +def _find_address_range(addresses): + """Find a sequence of addresses. + + Args: + addresses: a list of IPv4 or IPv6 addresses. + + Returns: + A tuple containing the first and last IP addresses in the sequence. + + """ + first = last = addresses[0] + for ip in addresses[1:]: + if ip._ip == last._ip + 1: + last = ip + else: + break + return (first, last) + +def _get_prefix_length(number1, number2, bits): + """Get the number of leading bits that are same for two numbers. + + Args: + number1: an integer. + number2: another integer. + bits: the maximum number of bits to compare. + + Returns: + The number of leading bits that are the same for two numbers. + + """ + for i in range(bits): + if number1 >> i == number2 >> i: + return bits - i + return 0 + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + for i in range(bits): + if (number >> i) % 2: + return i + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> summarize_address_range(IPv4Address('1.1.1.0'), + IPv4Address('1.1.1.130')) + [IPv4Network('1.1.1.0/25'), IPv4Network('1.1.1.128/31'), + IPv4Network('1.1.1.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + The address range collapsed to a list of IPv4Network's or + IPv6Network's. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version is not 4 or 6. + + """ + if not (isinstance(first, _BaseIP) and isinstance(last, _BaseIP)): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + str(self), str(other))) + if first > last: + raise ValueError('last IP address must be greater than first') + + networks = [] + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = _count_righthand_zero_bits(first_int, ip_bits) + current = None + while nbits >= 0: + addend = 2**nbits - 1 + current = first_int + addend + nbits -= 1 + if current <= last_int: + break + prefix = _get_prefix_length(first_int, current, ip_bits) + net = ip('%s/%d' % (str(first), prefix)) + networks.append(net) + if current == ip._ALL_ONES: + break + first_int = current + 1 + first = IPAddress(first_int, version=first._version) + return networks + +def _collapse_address_list_recursive(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network'1.1.0.0/24') + ip2 = IPv4Network'1.1.1.0/24') + ip3 = IPv4Network'1.1.2.0/24') + ip4 = IPv4Network'1.1.3.0/24') + ip5 = IPv4Network'1.1.4.0/24') + ip6 = IPv4Network'1.1.0.1/22') + + _collapse_address_list_recursive([ip1, ip2, ip3, ip4, ip5, ip6]) -> + [IPv4Network('1.1.0.0/22'), IPv4Network('1.1.4.0/24')] + + This shouldn't be called directly; it is called via + collapse_address_list([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + ret_array = [] + optimized = False + + for cur_addr in addresses: + if not ret_array: + ret_array.append(cur_addr) + continue + if cur_addr in ret_array[-1]: + optimized = True + elif cur_addr == ret_array[-1].supernet().subnet()[1]: + ret_array.append(ret_array.pop().supernet()) + optimized = True + else: + ret_array.append(cur_addr) + + if optimized: + return _collapse_address_list_recursive(ret_array) + + return ret_array + + +def collapse_address_list(addresses): + """Collapse a list of IP objects. + + Example: + collapse_address_list([IPv4('1.1.0.0/24'), IPv4('1.1.1.0/24')]) -> + [IPv4('1.1.0.0/23')] + + Args: + addresses: A list of IPv4Network or IPv6Network objects. + + Returns: + A list of IPv4Network or IPv6Network objects depending on what we + were passed. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + i = 0 + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseIP): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + str(ip), str(ips[-1]))) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + str(ip), str(ips[-1]))) + ips.append(ip.ip) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + str(ip), str(ips[-1]))) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + nets = sorted(set(nets)) + + while i < len(ips): + (first, last) = _find_address_range(ips[i:]) + i = ips.index(last) + 1 + addrs.extend(summarize_address_range(first, last)) + + return _collapse_address_list_recursive(sorted( + addrs + nets, key=_BaseNet._get_networks_key)) + +# backwards compatibility +CollapseAddrList = collapse_address_list + +# Test whether this Python implementation supports byte objects that +# are not identical to str ones. +# We need to exclude platforms where bytes == str so that we can +# distinguish between packed representations and strings, for example +# b'12::' (the IPv4 address 49.50.58.58) and '12::' (an IPv6 address). +try: + _compat_has_real_bytes = bytes is not str +except NameError: # other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, int): + return NotImplemented + return IPAddress(int(self) + other, version=self._version) + + def __sub__(self, other): + if not isinstance(other, int): + return NotImplemented + return IPAddress(int(self) - other, version=self._version) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def __str__(self): + return '%s' % self._string_from_ip_int(self._ip) + + def __hash__(self): + return hash(hex(long(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + @property + def version(self): + raise NotImplementedError('BaseIP has no version') + + +class _BaseNet(_IPAddrBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by networks. + + """ + + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def iterhosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + cur = int(self.network) + 1 + bcast = int(self.broadcast) - 1 + while cur <= bcast: + cur += 1 + yield IPAddress(cur - 1, version=self._version) + + def __iter__(self): + cur = int(self.network) + bcast = int(self.broadcast) + while cur <= bcast: + cur += 1 + yield IPAddress(cur - 1, version=self._version) + + def __getitem__(self, n): + network = int(self.network) + broadcast = int(self.broadcast) + if n >= 0: + if network + n > broadcast: + raise IndexError + return IPAddress(network + n, version=self._version) + else: + n += 1 + if broadcast + n < network: + raise IndexError + return IPAddress(broadcast + n, version=self._version) + + def __lt__(self, other): + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + str(self), str(other))) + if not isinstance(other, _BaseNet): + raise TypeError('%s and %s are not of the same type' % ( + str(self), str(other))) + if self.network != other.network: + return self.network < other.network + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __gt__(self, other): + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + str(self), str(other))) + if not isinstance(other, _BaseNet): + raise TypeError('%s and %s are not of the same type' % ( + str(self), str(other))) + if self.network != other.network: + return self.network > other.network + if self.netmask != other.netmask: + return self.netmask > other.netmask + return False + + def __le__(self, other): + gt = self.__gt__(other) + if gt is NotImplemented: + return NotImplemented + return not gt + + def __ge__(self, other): + lt = self.__lt__(other) + if lt is NotImplemented: + return NotImplemented + return not lt + + def __eq__(self, other): + try: + return (self._version == other._version + and self.network == other.network + and int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented + + def __ne__(self, other): + eq = self.__eq__(other) + if eq is NotImplemented: + return NotImplemented + return not eq + + def __str__(self): + return '%s/%s' % (str(self.ip), + str(self._prefixlen)) + + def __hash__(self): + return hash(int(self.network) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNet): + return (self.network <= other.network and + self.broadcast >= other.broadcast) + # dealing with another address + else: + return (int(self.network) <= int(other._ip) <= + int(self.broadcast)) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network in other or self.broadcast in other or ( + other.network in self or other.broadcast in self) + + @property + def network(self): + x = self._cache.get('network') + if x is None: + x = IPAddress(self._ip & int(self.netmask), version=self._version) + self._cache['network'] = x + return x + + @property + def broadcast(self): + x = self._cache.get('broadcast') + if x is None: + x = IPAddress(self._ip | int(self.hostmask), version=self._version) + self._cache['broadcast'] = x + return x + + @property + def hostmask(self): + x = self._cache.get('hostmask') + if x is None: + x = IPAddress(int(self.netmask) ^ self._ALL_ONES, + version=self._version) + self._cache['hostmask'] = x + return x + + @property + def with_prefixlen(self): + return '%s/%d' % (str(self.ip), self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (str(self.ip), str(self.netmask)) + + @property + def with_hostmask(self): + return '%s/%s' % (str(self.ip), str(self.hostmask)) + + @property + def numhosts(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast) - int(self.network) + 1 + + @property + def version(self): + raise NotImplementedError('BaseNet has no version') + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = IP('10.1.1.0/24') + addr2 = IP('10.1.1.0/26') + addr1.address_exclude(addr2) = + [IP('10.1.1.64/26'), IP('10.1.1.128/25')] + + or IPv6: + + addr1 = IP('::1/32') + addr2 = IP('::1/128') + addr1.address_exclude(addr2) = [IP('::0/128'), + IP('::2/127'), + IP('::4/126'), + IP('::8/125'), + ... + IP('0:0:8000::/33')] + + Args: + other: An IP object of the same type. + + Returns: + A sorted list of IP objects addresses which is self minus + other. + + Raises: + TypeError: If self and other are of difffering address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + str(self), str(other))) + + if not isinstance(other, _BaseNet): + raise TypeError("%s is not a network object" % str(other)) + + if other not in self: + raise ValueError('%s not contained in %s' % (str(other), + str(self))) + if other == self: + return [] + + ret_addrs = [] + + # Make sure we're comparing the network of other. + other = IPNetwork('%s/%s' % (str(other.network), str(other.prefixlen)), + version=other._version) + + s1, s2 = self.subnet() + while s1 != other and s2 != other: + if other in s1: + ret_addrs.append(s2) + s1, s2 = s1.subnet() + elif other in s2: + ret_addrs.append(s1) + s1, s2 = s2.subnet() + else: + # If we got here, there's a bug somewhere. + assert True == False, ('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (str(s1), str(s2), str(other))) + if s1 == other: + ret_addrs.append(s2) + elif s2 == other: + ret_addrs.append(s1) + else: + # If we got here, there's a bug somewhere. + assert True == False, ('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (str(s1), str(s2), str(other))) + + return sorted(ret_addrs, key=_BaseNet._get_networks_key) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4('1.1.1.0/24') < IPv4('1.1.2.0/24') + IPv6('1080::200C:417A') < IPv6('1080::200B:417B') + 0 if self == other + eg: IPv4('1.1.1.1/24') == IPv4('1.1.1.2/24') + IPv6('1080::200C:417A/96') == IPv6('1080::200C:417B/96') + 1 if self > other + eg: IPv4('1.1.1.0/24') > IPv4('1.1.0.0/24') + IPv6('1080::1:200C:417A/112') > + IPv6('1080::0:200C:417A/112') + + If the IP versions of self and other are different, returns: + + -1 if self._version < other._version + eg: IPv4('10.0.0.1/24') < IPv6('::1/128') + 1 if self._version > other._version + eg: IPv6('::1/128') > IPv4('255.255.255.0/24') + + """ + if self._version < other._version: + return -1 + if self._version > other._version: + return 1 + # self._version == other._version below here: + if self.network < other.network: + return -1 + if self.network > other.network: + return 1 + # self.network == other.network below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + # self.network == other.network and self.netmask == other.netmask + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network, self.netmask) + + def _ip_int_from_prefix(self, prefixlen=None): + """Turn the prefix length netmask into a int for comparison. + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + if not prefixlen and prefixlen != 0: + prefixlen = self._prefixlen + return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen) + + def _prefix_from_ip_int(self, ip_int, mask=32): + """Return prefix length from the decimal netmask. + + Args: + ip_int: An integer, the IP address. + mask: The netmask. Defaults to 32. + + Returns: + An integer, the prefix length. + + """ + while mask: + if ip_int & 1 == 1: + break + ip_int >>= 1 + mask -= 1 + + return mask + + def _ip_string_from_prefix(self, prefixlen=None): + """Turn a prefix length into a dotted decimal string. + + Args: + prefixlen: An integer, the netmask prefix length. + + Returns: + A string, the dotted decimal netmask string. + + """ + if not prefixlen: + prefixlen = self._prefixlen + return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen)) + + def iter_subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), return a list with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if not self._is_valid_netmask(str(new_prefixlen)): + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, str(self))) + + first = IPNetwork('%s/%s' % (str(self.network), + str(self._prefixlen + prefixlen_diff)), + version=self._version) + + yield first + current = first + while True: + broadcast = current.broadcast + if broadcast == self.broadcast: + return + new_addr = IPAddress(int(broadcast) + 1, version=self._version) + current = IPNetwork('%s/%s' % (str(new_addr), str(new_prefixlen)), + version=self._version) + + yield current + + def masked(self): + """Return the network object with the host bits masked out.""" + return IPNetwork('%s/%d' % (self.network, self._prefixlen), + version=self._version) + + def subnet(self, prefixlen_diff=1, new_prefix=None): + """Return a list of subnets, rather than an iterator.""" + return list(self.iter_subnets(prefixlen_diff, new_prefix)) + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have a + negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + + if self.prefixlen - prefixlen_diff < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return IPNetwork('%s/%s' % (str(self.network), + str(self.prefixlen - prefixlen_diff)), + version=self._version) + + # backwards compatibility + Subnet = subnet + Supernet = supernet + AddressExclude = address_exclude + CompareNetworks = compare_networks + Contains = __contains__ + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2**IPV4LENGTH) - 1 + + def __init__(self, address): + self._version = 4 + self._max_prefixlen = IPV4LENGTH + + def _explode_shorthand_ip_string(self, ip_str=None): + if not ip_str: + ip_str = str(self) + return ip_str + + def _ip_int_from_string(self, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if the string isn't a valid IP string. + + """ + packed_ip = 0 + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError(ip_str) + for oc in octets: + try: + packed_ip = (packed_ip << 8) | int(oc) + except ValueError: + raise AddressValueError(ip_str) + return packed_ip + + def _string_from_ip_int(self, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + octets = [] + for _ in xrange(4): + octets.insert(0, str(ip_int & 0xFF)) + ip_int >>= 8 + return '.'.join(octets) + + def _is_valid_ip(self, address): + """Validate the dotted decimal notation IP/netmask string. + + Args: + address: A string, either representing a quad-dotted ip + or an integer which is a valid IPv4 IP address. + + Returns: + A boolean, True if the string is a valid dotted decimal IP + string. + + """ + octets = address.split('.') + if len(octets) == 1: + # We have an integer rather than a dotted decimal IP. + try: + return int(address) >= 0 and int(address) <= self._ALL_ONES + except ValueError: + return False + + if len(octets) != 4: + return False + + for octet in octets: + try: + if not 0 <= int(octet) <= 255: + return False + except ValueError: + return False + return True + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def version(self): + return self._version + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in IPv4Network('240.0.0.0/4') + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per RFC 1918. + + """ + return (self in IPv4Network('10.0.0.0/8') or + self in IPv4Network('172.16.0.0/12') or + self in IPv4Network('192.168.0.0/16')) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in IPv4Network('224.0.0.0/4') + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self in IPv4Network('0.0.0.0') + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in IPv4Network('127.0.0.0/8') + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in IPv4Network('169.254.0.0/16') + + +class IPv4Address(_BaseV4, _BaseIP): + + """Represent and manipulate single IPv4 Addresses.""" + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + '192.168.1.1' + + Additionally, an integer can be passed, so + IPv4Address('192.168.1.1') == IPv4Address(3232235777). + or, more generally + IPv4Address(int(IPv4Address('192.168.1.1'))) == + IPv4Address('192.168.1.1') + + Raises: + AddressValueError: If ipaddr isn't a valid IPv4 address. + + """ + _BaseIP.__init__(self, address) + _BaseV4.__init__(self, address) + + # Efficient constructor from integer. + if isinstance(address, (int, long)): + self._ip = address + if address < 0 or address > self._ALL_ONES: + raise AddressValueError(address) + return + + # Constructing from a packed address + if _compat_has_real_bytes: + if isinstance(address, bytes) and len(address) == 4: + self._ip = struct.unpack('!I', address)[0] + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if not self._is_valid_ip(addr_str): + raise AddressValueError(addr_str) + + self._ip = self._ip_int_from_string(addr_str) + + +class IPv4Network(_BaseV4, _BaseNet): + + """This class represents and manipulates 32-bit IPv4 networks. + + Attributes: [examples for IPv4Network('1.2.3.4/27')] + ._ip: 16909060 + .ip: IPv4Address('1.2.3.4') + .network: IPv4Address('1.2.3.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast: IPv4Address('1.2.3.31') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0)) + + def __init__(self, address, strict=False): + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.168.1.1/24' + '192.168.1.1/255.255.255.0' + '192.168.1.1/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.168.1.1' + '192.168.1.1/255.255.255.255' + '192.168.1.1/32' + are also functionaly equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.168.1.1') == IPv4Network(3232235777). + or, more generally + IPv4Network(int(IPv4Network('192.168.1.1'))) == + IPv4Network('192.168.1.1') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 192.168.1.0/24 and not an + IP address on a network, eg, 192.168.1.1/24. + + Raises: + AddressValueError: If ipaddr isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNet.__init__(self, address) + _BaseV4.__init__(self, address) + + # Efficient constructor from integer. + if isinstance(address, (int, long)): + self._ip = address + self.ip = IPv4Address(self._ip) + self._prefixlen = self._max_prefixlen + self.netmask = IPv4Address(self._ALL_ONES) + if address < 0 or address > self._ALL_ONES: + raise AddressValueError(address) + return + + # Constructing from a packed address + if _compat_has_real_bytes: + if isinstance(address, bytes) and len(address) == 4: + self._ip = struct.unpack('!I', address)[0] + self.ip = IPv4Address(self._ip) + self._prefixlen = self._max_prefixlen + self.netmask = IPv4Address(self._ALL_ONES) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = str(address).split('/') + + if len(addr) > 2: + raise AddressValueError(address) + + if not self._is_valid_ip(addr[0]): + raise AddressValueError(addr[0]) + + self._ip = self._ip_int_from_string(addr[0]) + self.ip = IPv4Address(self._ip) + + if len(addr) == 2: + mask = addr[1].split('.') + if len(mask) == 4: + # We have dotted decimal netmask. + if self._is_valid_netmask(addr[1]): + self.netmask = IPv4Address(self._ip_int_from_string( + addr[1])) + elif self._is_hostmask(addr[1]): + self.netmask = IPv4Address( + self._ip_int_from_string(addr[1]) ^ self._ALL_ONES) + else: + raise NetmaskValueError('%s is not a valid netmask' + % addr[1]) + + self._prefixlen = self._prefix_from_ip_int(int(self.netmask)) + else: + # We have a netmask in prefix length form. + if not self._is_valid_netmask(addr[1]): + raise NetmaskValueError(addr[1]) + self._prefixlen = int(addr[1]) + self.netmask = IPv4Address(self._ip_int_from_prefix( + self._prefixlen)) + else: + self._prefixlen = self._max_prefixlen + self.netmask = IPv4Address(self._ip_int_from_prefix( + self._prefixlen)) + if strict: + if self.ip != self.network: + raise ValueError('%s has host bits set' % + self.ip) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split('.') + try: + parts = [int(x) for x in bits if int(x) in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _is_valid_netmask(self, netmask): + """Verify that the netmask is valid. + + Args: + netmask: A string, either a prefix or dotted decimal + netmask. + + Returns: + A boolean, True if the prefix represents a valid IPv4 + netmask. + + """ + mask = netmask.split('.') + if len(mask) == 4: + if [x for x in mask if int(x) not in self._valid_mask_octets]: + return False + if [y for idx, y in enumerate(mask) if idx > 0 and + y > mask[idx - 1]]: + return False + return True + try: + netmask = int(netmask) + except ValueError: + return False + return 0 <= netmask <= self._max_prefixlen + + # backwards compatibility + IsRFC1918 = lambda self: self.is_private + IsMulticast = lambda self: self.is_multicast + IsLoopback = lambda self: self.is_loopback + IsLinkLocal = lambda self: self.is_link_local + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + _ALL_ONES = (2**IPV6LENGTH) - 1 + + def __init__(self, address): + self._version = 6 + self._max_prefixlen = IPV6LENGTH + + def _ip_int_from_string(self, ip_str=None): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + A long, the IPv6 ip_str. + + Raises: + AddressValueError: if ip_str isn't a valid IP Address. + + """ + if not ip_str: + ip_str = str(self.ip) + + ip_int = 0 + + # Do we have an IPv4 mapped (::ffff:a.b.c.d) or compact (::a.b.c.d) + # ip_str? + fields = ip_str.split(':') + if fields[-1].count('.') == 3: + ipv4_string = fields.pop() + ipv4_int = IPv4Network(ipv4_string)._ip + octets = [] + for _ in xrange(2): + octets.append(hex(ipv4_int & 0xFFFF).lstrip('0x').rstrip('L')) + ipv4_int >>= 16 + fields.extend(reversed(octets)) + ip_str = ':'.join(fields) + + fields = self._explode_shorthand_ip_string(ip_str).split(':') + for field in fields: + try: + ip_int = (ip_int << 16) + int(field or '0', 16) + except ValueError: + raise AddressValueError(ip_str) + + return ip_int + + def _compress_hextets(self, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index in range(len(hextets)): + if hextets[index] == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + def _string_from_ip_int(self, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if not ip_int and ip_int != 0: + ip_int = int(self._ip) + + if ip_int > self._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = [] + for x in range(0, 32, 4): + hextets.append('%x' % int(hex_str[x:x+4], 16)) + + hextets = self._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self, ip_str=None): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if not ip_str: + ip_str = str(self) + if isinstance(self, _BaseNet): + ip_str = str(self.ip) + + if self._is_shorthand_ip(ip_str): + new_ip = [] + hextet = ip_str.split('::') + sep = len(hextet[0].split(':')) + len(hextet[1].split(':')) + new_ip = hextet[0].split(':') + + for _ in xrange(8 - sep): + new_ip.append('0000') + new_ip += hextet[1].split(':') + + # Now need to make sure every hextet is 4 lower case characters. + # If a hextet is < 4 characters, we've got missing leading 0's. + ret_ip = [] + for hextet in new_ip: + ret_ip.append(('0' * (4 - len(hextet)) + hextet).lower()) + return ':'.join(ret_ip) + # We've already got a longhand ip_str. + return ip_str + + def _is_valid_ip(self, ip_str): + """Ensure we have a valid IPv6 address. + + Probably not as exhaustive as it should be. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A boolean, True if this is a valid IPv6 address. + + """ + # We need to have at least one ':'. + if ':' not in ip_str: + return False + + # We can only have one '::' shortener. + if ip_str.count('::') > 1: + return False + + # '::' should be encompassed by start, digits or end. + if ':::' in ip_str: + return False + + # A single colon can neither start nor end an address. + if ((ip_str.startswith(':') and not ip_str.startswith('::')) or + (ip_str.endswith(':') and not ip_str.endswith('::'))): + return False + + # If we have no concatenation, we need to have 8 fields with 7 ':'. + if '::' not in ip_str and ip_str.count(':') != 7: + # We might have an IPv4 mapped address. + if ip_str.count('.') != 3: + return False + + ip_str = self._explode_shorthand_ip_string(ip_str) + + # Now that we have that all squared away, let's check that each of the + # hextets are between 0x0 and 0xFFFF. + for hextet in ip_str.split(':'): + if hextet.count('.') == 3: + # If we have an IPv4 mapped address, the IPv4 portion has to + # be at the end of the IPv6 portion. + if not ip_str.split(':')[-1] == hextet: + return False + try: + IPv4Network(hextet) + except AddressValueError: + return False + else: + try: + # a value error here means that we got a bad hextet, + # something like 0xzzzz + if int(hextet, 16) < 0x0 or int(hextet, 16) > 0xFFFF: + return False + except ValueError: + return False + return True + + def _is_shorthand_ip(self, ip_str=None): + """Determine if the address is shortened. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A boolean, True if the address is shortened. + + """ + if ip_str.count('::') == 1: + return True + return False + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def version(self): + return self._version + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in IPv6Network('ff00::/8') + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self in IPv6Network('::/8') or + self in IPv6Network('100::/8') or + self in IPv6Network('200::/7') or + self in IPv6Network('400::/6') or + self in IPv6Network('800::/5') or + self in IPv6Network('1000::/4') or + self in IPv6Network('4000::/3') or + self in IPv6Network('6000::/3') or + self in IPv6Network('8000::/3') or + self in IPv6Network('A000::/3') or + self in IPv6Network('C000::/3') or + self in IPv6Network('E000::/4') or + self in IPv6Network('F000::/5') or + self in IPv6Network('F800::/6') or + self in IPv6Network('FE00::/9')) + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self == IPv6Network('::') or self == IPv6Address('::')) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self == IPv6Network('::1') or self == IPv6Address('::1')) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in IPv6Network('fe80::/10') + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in IPv6Network('fec0::/10') + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per RFC 4193. + + """ + return self in IPv6Network('fc00::/7') + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + hextets = self._explode_shorthand_ip_string().split(':') + if hextets[-3] != 'ffff': + return None + try: + return IPv4Address(int('%s%s' % (hextets[-2], hextets[-1]), 16)) + except AddressValueError: + return None + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001) + + """ + bits = self._explode_shorthand_ip_string().split(':') + if not bits[0] == '2001': + return None + return (IPv4Address(int(''.join(bits[2:4]), 16)), + IPv4Address(int(''.join(bits[6:]), 16) ^ 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + bits = self._explode_shorthand_ip_string().split(':') + if not bits[0] == '2002': + return None + return IPv4Address(int(''.join(bits[1:3]), 16)) + + +class IPv6Address(_BaseV6, _BaseIP): + + """Represent and manipulate single IPv6 Addresses. + """ + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:4860::') == + IPv6Address(42541956101370907050197289607612071936L). + or, more generally + IPv6Address(IPv6Address('2001:4860::')._ip) == + IPv6Address('2001:4860::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + _BaseIP.__init__(self, address) + _BaseV6.__init__(self, address) + + # Efficient constructor from integer. + if isinstance(address, (int, long)): + self._ip = address + if address < 0 or address > self._ALL_ONES: + raise AddressValueError(address) + return + + # Constructing from a packed address + if _compat_has_real_bytes: + if isinstance(address, bytes) and len(address) == 16: + tmp = struct.unpack('!QQ', address) + self._ip = (tmp[0] << 64) | tmp[1] + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if not addr_str: + raise AddressValueError('') + + if not self._is_valid_ip(addr_str): + raise AddressValueError(addr_str) + + self._ip = self._ip_int_from_string(addr_str) + + +class IPv6Network(_BaseV6, _BaseNet): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:658:22A:CAFE:200::1/64')] + .ip: IPv6Address('2001:658:22a:cafe:200::1') + .network: IPv6Address('2001:658:22a:cafe::') + .hostmask: IPv6Address('::ffff:ffff:ffff:ffff') + .broadcast: IPv6Address('2001:658:22a:cafe:ffff:ffff:ffff:ffff') + .netmask: IPv6Address('ffff:ffff:ffff:ffff::') + .prefixlen: 64 + + """ + + + def __init__(self, address, strict=False): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the IP + and prefix/netmask. + '2001:4860::/128' + '2001:4860:0000:0000:0000:0000:0000:0000/128' + '2001:4860::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:4860::') == + IPv6Network(42541956101370907050197289607612071936L). + or, more generally + IPv6Network(IPv6Network('2001:4860::')._ip) == + IPv6Network('2001:4860::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 192.168.1.0/24 and not an + IP address on a network, eg, 192.168.1.1/24. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNet.__init__(self, address) + _BaseV6.__init__(self, address) + + # Efficient constructor from integer. + if isinstance(address, (int, long)): + self._ip = address + self.ip = IPv6Address(self._ip) + self._prefixlen = self._max_prefixlen + self.netmask = IPv6Address(self._ALL_ONES) + if address < 0 or address > self._ALL_ONES: + raise AddressValueError(address) + return + + # Constructing from a packed address + if _compat_has_real_bytes: + if isinstance(address, bytes) and len(address) == 16: + tmp = struct.unpack('!QQ', address) + self._ip = (tmp[0] << 64) | tmp[1] + self.ip = IPv6Address(self._ip) + self._prefixlen = self._max_prefixlen + self.netmask = IPv6Address(self._ALL_ONES) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = str(address).split('/') + + if len(addr) > 2: + raise AddressValueError(address) + + if not self._is_valid_ip(addr[0]): + raise AddressValueError(addr[0]) + + if len(addr) == 2: + if self._is_valid_netmask(addr[1]): + self._prefixlen = int(addr[1]) + else: + raise NetmaskValueError(addr[1]) + else: + self._prefixlen = self._max_prefixlen + + self.netmask = IPv6Address(self._ip_int_from_prefix(self._prefixlen)) + + self._ip = self._ip_int_from_string(addr[0]) + self.ip = IPv6Address(self._ip) + + if strict: + if self.ip != self.network: + raise ValueError('%s has host bits set' % + self.ip) + + def _is_valid_netmask(self, prefixlen): + """Verify that the netmask/prefixlen is valid. + + Args: + prefixlen: A string, the netmask in prefix length format. + + Returns: + A boolean, True if the prefix represents a valid IPv6 + netmask. + + """ + try: + prefixlen = int(prefixlen) + except ValueError: + return False + return 0 <= prefixlen <= self._max_prefixlen + + @property + def with_netmask(self): + return self.with_prefixlen diff --git a/third_party/ply/__init__.py b/third_party/ply/__init__.py new file mode 100644 index 0000000..853a985 --- /dev/null +++ b/third_party/ply/__init__.py @@ -0,0 +1,4 @@ +# PLY package +# Author: David Beazley (dave@dabeaz.com) + +__all__ = ['lex','yacc'] diff --git a/third_party/ply/lex.py b/third_party/ply/lex.py new file mode 100644 index 0000000..267ec10 --- /dev/null +++ b/third_party/ply/lex.py @@ -0,0 +1,1058 @@ +# ----------------------------------------------------------------------------- +# ply: lex.py +# +# Copyright (C) 2001-2009, +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- + +__version__ = "3.3" +__tabversion__ = "3.2" # Version of table file used + +import re, sys, types, copy, os + +# This tuple contains known string types +try: + # Python 2.6 + StringTypes = (types.StringType, types.UnicodeType) +except AttributeError: + # Python 3.0 + StringTypes = (str, bytes) + +# Extract the code attribute of a function. Different implementations +# are for Python 2/3 compatibility. + +if sys.version_info[0] < 3: + def func_code(f): + return f.func_code +else: + def func_code(f): + return f.__code__ + +# This regular expression is used to match valid token names +_is_identifier = re.compile(r'^[a-zA-Z0-9_]+$') + +# Exception thrown when invalid token encountered and no default error +# handler is defined. + +class LexError(Exception): + def __init__(self,message,s): + self.args = (message,) + self.text = s + +# Token class. This class is used to represent the tokens produced. +class LexToken(object): + def __str__(self): + return "LexToken(%s,%r,%d,%d)" % (self.type,self.value,self.lineno,self.lexpos) + def __repr__(self): + return str(self) + +# This object is a stand-in for a logging object created by the +# logging module. + +class PlyLogger(object): + def __init__(self,f): + self.f = f + def critical(self,msg,*args,**kwargs): + self.f.write((msg % args) + "\n") + + def warning(self,msg,*args,**kwargs): + self.f.write("WARNING: "+ (msg % args) + "\n") + + def error(self,msg,*args,**kwargs): + self.f.write("ERROR: " + (msg % args) + "\n") + + info = critical + debug = critical + +# Null logger is used when no output is generated. Does nothing. +class NullLogger(object): + def __getattribute__(self,name): + return self + def __call__(self,*args,**kwargs): + return self + +# ----------------------------------------------------------------------------- +# === Lexing Engine === +# +# The following Lexer class implements the lexer runtime. There are only +# a few public methods and attributes: +# +# input() - Store a new string in the lexer +# token() - Get the next token +# clone() - Clone the lexer +# +# lineno - Current line number +# lexpos - Current position in the input string +# ----------------------------------------------------------------------------- + +class Lexer: + def __init__(self): + self.lexre = None # Master regular expression. This is a list of + # tuples (re,findex) where re is a compiled + # regular expression and findex is a list + # mapping regex group numbers to rules + self.lexretext = None # Current regular expression strings + self.lexstatere = {} # Dictionary mapping lexer states to master regexs + self.lexstateretext = {} # Dictionary mapping lexer states to regex strings + self.lexstaterenames = {} # Dictionary mapping lexer states to symbol names + self.lexstate = "INITIAL" # Current lexer state + self.lexstatestack = [] # Stack of lexer states + self.lexstateinfo = None # State information + self.lexstateignore = {} # Dictionary of ignored characters for each state + self.lexstateerrorf = {} # Dictionary of error functions for each state + self.lexreflags = 0 # Optional re compile flags + self.lexdata = None # Actual input data (as a string) + self.lexpos = 0 # Current position in input text + self.lexlen = 0 # Length of the input text + self.lexerrorf = None # Error rule (if any) + self.lextokens = None # List of valid tokens + self.lexignore = "" # Ignored characters + self.lexliterals = "" # Literal characters that can be passed through + self.lexmodule = None # Module + self.lineno = 1 # Current line number + self.lexoptimize = 0 # Optimized mode + + def clone(self,object=None): + c = copy.copy(self) + + # If the object parameter has been supplied, it means we are attaching the + # lexer to a new object. In this case, we have to rebind all methods in + # the lexstatere and lexstateerrorf tables. + + if object: + newtab = { } + for key, ritem in self.lexstatere.items(): + newre = [] + for cre, findex in ritem: + newfindex = [] + for f in findex: + if not f or not f[0]: + newfindex.append(f) + continue + newfindex.append((getattr(object,f[0].__name__),f[1])) + newre.append((cre,newfindex)) + newtab[key] = newre + c.lexstatere = newtab + c.lexstateerrorf = { } + for key, ef in self.lexstateerrorf.items(): + c.lexstateerrorf[key] = getattr(object,ef.__name__) + c.lexmodule = object + return c + + # ------------------------------------------------------------ + # writetab() - Write lexer information to a table file + # ------------------------------------------------------------ + def writetab(self,tabfile,outputdir=""): + if isinstance(tabfile,types.ModuleType): + return + basetabfilename = tabfile.split(".")[-1] + filename = os.path.join(outputdir,basetabfilename)+".py" + tf = open(filename,"w") + tf.write("# %s.py. This file automatically created by PLY (version %s). Don't edit!\n" % (tabfile,__version__)) + tf.write("_tabversion = %s\n" % repr(__version__)) + tf.write("_lextokens = %s\n" % repr(self.lextokens)) + tf.write("_lexreflags = %s\n" % repr(self.lexreflags)) + tf.write("_lexliterals = %s\n" % repr(self.lexliterals)) + tf.write("_lexstateinfo = %s\n" % repr(self.lexstateinfo)) + + tabre = { } + # Collect all functions in the initial state + initial = self.lexstatere["INITIAL"] + initialfuncs = [] + for part in initial: + for f in part[1]: + if f and f[0]: + initialfuncs.append(f) + + for key, lre in self.lexstatere.items(): + titem = [] + for i in range(len(lre)): + titem.append((self.lexstateretext[key][i],_funcs_to_names(lre[i][1],self.lexstaterenames[key][i]))) + tabre[key] = titem + + tf.write("_lexstatere = %s\n" % repr(tabre)) + tf.write("_lexstateignore = %s\n" % repr(self.lexstateignore)) + + taberr = { } + for key, ef in self.lexstateerrorf.items(): + if ef: + taberr[key] = ef.__name__ + else: + taberr[key] = None + tf.write("_lexstateerrorf = %s\n" % repr(taberr)) + tf.close() + + # ------------------------------------------------------------ + # readtab() - Read lexer information from a tab file + # ------------------------------------------------------------ + def readtab(self,tabfile,fdict): + if isinstance(tabfile,types.ModuleType): + lextab = tabfile + else: + if sys.version_info[0] < 3: + exec("import %s as lextab" % tabfile) + else: + env = { } + exec("import %s as lextab" % tabfile, env,env) + lextab = env['lextab'] + + if getattr(lextab,"_tabversion","0.0") != __version__: + raise ImportError("Inconsistent PLY version") + + self.lextokens = lextab._lextokens + self.lexreflags = lextab._lexreflags + self.lexliterals = lextab._lexliterals + self.lexstateinfo = lextab._lexstateinfo + self.lexstateignore = lextab._lexstateignore + self.lexstatere = { } + self.lexstateretext = { } + for key,lre in lextab._lexstatere.items(): + titem = [] + txtitem = [] + for i in range(len(lre)): + titem.append((re.compile(lre[i][0],lextab._lexreflags | re.VERBOSE),_names_to_funcs(lre[i][1],fdict))) + txtitem.append(lre[i][0]) + self.lexstatere[key] = titem + self.lexstateretext[key] = txtitem + self.lexstateerrorf = { } + for key,ef in lextab._lexstateerrorf.items(): + self.lexstateerrorf[key] = fdict[ef] + self.begin('INITIAL') + + # ------------------------------------------------------------ + # input() - Push a new string into the lexer + # ------------------------------------------------------------ + def input(self,s): + # Pull off the first character to see if s looks like a string + c = s[:1] + if not isinstance(c,StringTypes): + raise ValueError("Expected a string") + self.lexdata = s + self.lexpos = 0 + self.lexlen = len(s) + + # ------------------------------------------------------------ + # begin() - Changes the lexing state + # ------------------------------------------------------------ + def begin(self,state): + if not state in self.lexstatere: + raise ValueError("Undefined state") + self.lexre = self.lexstatere[state] + self.lexretext = self.lexstateretext[state] + self.lexignore = self.lexstateignore.get(state,"") + self.lexerrorf = self.lexstateerrorf.get(state,None) + self.lexstate = state + + # ------------------------------------------------------------ + # push_state() - Changes the lexing state and saves old on stack + # ------------------------------------------------------------ + def push_state(self,state): + self.lexstatestack.append(self.lexstate) + self.begin(state) + + # ------------------------------------------------------------ + # pop_state() - Restores the previous state + # ------------------------------------------------------------ + def pop_state(self): + self.begin(self.lexstatestack.pop()) + + # ------------------------------------------------------------ + # current_state() - Returns the current lexing state + # ------------------------------------------------------------ + def current_state(self): + return self.lexstate + + # ------------------------------------------------------------ + # skip() - Skip ahead n characters + # ------------------------------------------------------------ + def skip(self,n): + self.lexpos += n + + # ------------------------------------------------------------ + # opttoken() - Return the next token from the Lexer + # + # Note: This function has been carefully implemented to be as fast + # as possible. Don't make changes unless you really know what + # you are doing + # ------------------------------------------------------------ + def token(self): + # Make local copies of frequently referenced attributes + lexpos = self.lexpos + lexlen = self.lexlen + lexignore = self.lexignore + lexdata = self.lexdata + + while lexpos < lexlen: + # This code provides some short-circuit code for whitespace, tabs, and other ignored characters + if lexdata[lexpos] in lexignore: + lexpos += 1 + continue + + # Look for a regular expression match + for lexre,lexindexfunc in self.lexre: + m = lexre.match(lexdata,lexpos) + if not m: continue + + # Create a token for return + tok = LexToken() + tok.value = m.group() + tok.lineno = self.lineno + tok.lexpos = lexpos + + i = m.lastindex + func,tok.type = lexindexfunc[i] + + if not func: + # If no token type was set, it's an ignored token + if tok.type: + self.lexpos = m.end() + return tok + else: + lexpos = m.end() + break + + lexpos = m.end() + + # If token is processed by a function, call it + + tok.lexer = self # Set additional attributes useful in token rules + self.lexmatch = m + self.lexpos = lexpos + + newtok = func(tok) + + # Every function must return a token, if nothing, we just move to next token + if not newtok: + lexpos = self.lexpos # This is here in case user has updated lexpos. + lexignore = self.lexignore # This is here in case there was a state change + break + + # Verify type of the token. If not in the token map, raise an error + if not self.lexoptimize: + if not newtok.type in self.lextokens: + raise LexError("%s:%d: Rule '%s' returned an unknown token type '%s'" % ( + func_code(func).co_filename, func_code(func).co_firstlineno, + func.__name__, newtok.type),lexdata[lexpos:]) + + return newtok + else: + # No match, see if in literals + if lexdata[lexpos] in self.lexliterals: + tok = LexToken() + tok.value = lexdata[lexpos] + tok.lineno = self.lineno + tok.type = tok.value + tok.lexpos = lexpos + self.lexpos = lexpos + 1 + return tok + + # No match. Call t_error() if defined. + if self.lexerrorf: + tok = LexToken() + tok.value = self.lexdata[lexpos:] + tok.lineno = self.lineno + tok.type = "error" + tok.lexer = self + tok.lexpos = lexpos + self.lexpos = lexpos + newtok = self.lexerrorf(tok) + if lexpos == self.lexpos: + # Error method didn't change text position at all. This is an error. + raise LexError("Scanning error. Illegal character '%s'" % (lexdata[lexpos]), lexdata[lexpos:]) + lexpos = self.lexpos + if not newtok: continue + return newtok + + self.lexpos = lexpos + raise LexError("Illegal character '%s' at index %d" % (lexdata[lexpos],lexpos), lexdata[lexpos:]) + + self.lexpos = lexpos + 1 + if self.lexdata is None: + raise RuntimeError("No input string given with input()") + return None + + # Iterator interface + def __iter__(self): + return self + + def next(self): + t = self.token() + if t is None: + raise StopIteration + return t + + __next__ = next + +# ----------------------------------------------------------------------------- +# ==== Lex Builder === +# +# The functions and classes below are used to collect lexing information +# and build a Lexer object from it. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# get_caller_module_dict() +# +# This function returns a dictionary containing all of the symbols defined within +# a caller further down the call stack. This is used to get the environment +# associated with the yacc() call if none was provided. +# ----------------------------------------------------------------------------- + +def get_caller_module_dict(levels): + try: + raise RuntimeError + except RuntimeError: + e,b,t = sys.exc_info() + f = t.tb_frame + while levels > 0: + f = f.f_back + levels -= 1 + ldict = f.f_globals.copy() + if f.f_globals != f.f_locals: + ldict.update(f.f_locals) + + return ldict + +# ----------------------------------------------------------------------------- +# _funcs_to_names() +# +# Given a list of regular expression functions, this converts it to a list +# suitable for output to a table file +# ----------------------------------------------------------------------------- + +def _funcs_to_names(funclist,namelist): + result = [] + for f,name in zip(funclist,namelist): + if f and f[0]: + result.append((name, f[1])) + else: + result.append(f) + return result + +# ----------------------------------------------------------------------------- +# _names_to_funcs() +# +# Given a list of regular expression function names, this converts it back to +# functions. +# ----------------------------------------------------------------------------- + +def _names_to_funcs(namelist,fdict): + result = [] + for n in namelist: + if n and n[0]: + result.append((fdict[n[0]],n[1])) + else: + result.append(n) + return result + +# ----------------------------------------------------------------------------- +# _form_master_re() +# +# This function takes a list of all of the regex components and attempts to +# form the master regular expression. Given limitations in the Python re +# module, it may be necessary to break the master regex into separate expressions. +# ----------------------------------------------------------------------------- + +def _form_master_re(relist,reflags,ldict,toknames): + if not relist: return [] + regex = "|".join(relist) + try: + lexre = re.compile(regex,re.VERBOSE | reflags) + + # Build the index to function map for the matching engine + lexindexfunc = [ None ] * (max(lexre.groupindex.values())+1) + lexindexnames = lexindexfunc[:] + + for f,i in lexre.groupindex.items(): + handle = ldict.get(f,None) + if type(handle) in (types.FunctionType, types.MethodType): + lexindexfunc[i] = (handle,toknames[f]) + lexindexnames[i] = f + elif handle is not None: + lexindexnames[i] = f + if f.find("ignore_") > 0: + lexindexfunc[i] = (None,None) + else: + lexindexfunc[i] = (None, toknames[f]) + + return [(lexre,lexindexfunc)],[regex],[lexindexnames] + except Exception: + m = int(len(relist)/2) + if m == 0: m = 1 + llist, lre, lnames = _form_master_re(relist[:m],reflags,ldict,toknames) + rlist, rre, rnames = _form_master_re(relist[m:],reflags,ldict,toknames) + return llist+rlist, lre+rre, lnames+rnames + +# ----------------------------------------------------------------------------- +# def _statetoken(s,names) +# +# Given a declaration name s of the form "t_" and a dictionary whose keys are +# state names, this function returns a tuple (states,tokenname) where states +# is a tuple of state names and tokenname is the name of the token. For example, +# calling this with s = "t_foo_bar_SPAM" might return (('foo','bar'),'SPAM') +# ----------------------------------------------------------------------------- + +def _statetoken(s,names): + nonstate = 1 + parts = s.split("_") + for i in range(1,len(parts)): + if not parts[i] in names and parts[i] != 'ANY': break + if i > 1: + states = tuple(parts[1:i]) + else: + states = ('INITIAL',) + + if 'ANY' in states: + states = tuple(names) + + tokenname = "_".join(parts[i:]) + return (states,tokenname) + + +# ----------------------------------------------------------------------------- +# LexerReflect() +# +# This class represents information needed to build a lexer as extracted from a +# user's input file. +# ----------------------------------------------------------------------------- +class LexerReflect(object): + def __init__(self,ldict,log=None,reflags=0): + self.ldict = ldict + self.error_func = None + self.tokens = [] + self.reflags = reflags + self.stateinfo = { 'INITIAL' : 'inclusive'} + self.files = {} + self.error = 0 + + if log is None: + self.log = PlyLogger(sys.stderr) + else: + self.log = log + + # Get all of the basic information + def get_all(self): + self.get_tokens() + self.get_literals() + self.get_states() + self.get_rules() + + # Validate all of the information + def validate_all(self): + self.validate_tokens() + self.validate_literals() + self.validate_rules() + return self.error + + # Get the tokens map + def get_tokens(self): + tokens = self.ldict.get("tokens",None) + if not tokens: + self.log.error("No token list is defined") + self.error = 1 + return + + if not isinstance(tokens,(list, tuple)): + self.log.error("tokens must be a list or tuple") + self.error = 1 + return + + if not tokens: + self.log.error("tokens is empty") + self.error = 1 + return + + self.tokens = tokens + + # Validate the tokens + def validate_tokens(self): + terminals = {} + for n in self.tokens: + if not _is_identifier.match(n): + self.log.error("Bad token name '%s'",n) + self.error = 1 + if n in terminals: + self.log.warning("Token '%s' multiply defined", n) + terminals[n] = 1 + + # Get the literals specifier + def get_literals(self): + self.literals = self.ldict.get("literals","") + + # Validate literals + def validate_literals(self): + try: + for c in self.literals: + if not isinstance(c,StringTypes) or len(c) > 1: + self.log.error("Invalid literal %s. Must be a single character", repr(c)) + self.error = 1 + continue + + except TypeError: + self.log.error("Invalid literals specification. literals must be a sequence of characters") + self.error = 1 + + def get_states(self): + self.states = self.ldict.get("states",None) + # Build statemap + if self.states: + if not isinstance(self.states,(tuple,list)): + self.log.error("states must be defined as a tuple or list") + self.error = 1 + else: + for s in self.states: + if not isinstance(s,tuple) or len(s) != 2: + self.log.error("Invalid state specifier %s. Must be a tuple (statename,'exclusive|inclusive')",repr(s)) + self.error = 1 + continue + name, statetype = s + if not isinstance(name,StringTypes): + self.log.error("State name %s must be a string", repr(name)) + self.error = 1 + continue + if not (statetype == 'inclusive' or statetype == 'exclusive'): + self.log.error("State type for state %s must be 'inclusive' or 'exclusive'",name) + self.error = 1 + continue + if name in self.stateinfo: + self.log.error("State '%s' already defined",name) + self.error = 1 + continue + self.stateinfo[name] = statetype + + # Get all of the symbols with a t_ prefix and sort them into various + # categories (functions, strings, error functions, and ignore characters) + + def get_rules(self): + tsymbols = [f for f in self.ldict if f[:2] == 't_' ] + + # Now build up a list of functions and a list of strings + + self.toknames = { } # Mapping of symbols to token names + self.funcsym = { } # Symbols defined as functions + self.strsym = { } # Symbols defined as strings + self.ignore = { } # Ignore strings by state + self.errorf = { } # Error functions by state + + for s in self.stateinfo: + self.funcsym[s] = [] + self.strsym[s] = [] + + if len(tsymbols) == 0: + self.log.error("No rules of the form t_rulename are defined") + self.error = 1 + return + + for f in tsymbols: + t = self.ldict[f] + states, tokname = _statetoken(f,self.stateinfo) + self.toknames[f] = tokname + + if hasattr(t,"__call__"): + if tokname == 'error': + for s in states: + self.errorf[s] = t + elif tokname == 'ignore': + line = func_code(t).co_firstlineno + file = func_code(t).co_filename + self.log.error("%s:%d: Rule '%s' must be defined as a string",file,line,t.__name__) + self.error = 1 + else: + for s in states: + self.funcsym[s].append((f,t)) + elif isinstance(t, StringTypes): + if tokname == 'ignore': + for s in states: + self.ignore[s] = t + if "\\" in t: + self.log.warning("%s contains a literal backslash '\\'",f) + + elif tokname == 'error': + self.log.error("Rule '%s' must be defined as a function", f) + self.error = 1 + else: + for s in states: + self.strsym[s].append((f,t)) + else: + self.log.error("%s not defined as a function or string", f) + self.error = 1 + + # Sort the functions by line number + for f in self.funcsym.values(): + if sys.version_info[0] < 3: + f.sort(lambda x,y: cmp(func_code(x[1]).co_firstlineno,func_code(y[1]).co_firstlineno)) + else: + # Python 3.0 + f.sort(key=lambda x: func_code(x[1]).co_firstlineno) + + # Sort the strings by regular expression length + for s in self.strsym.values(): + if sys.version_info[0] < 3: + s.sort(lambda x,y: (len(x[1]) < len(y[1])) - (len(x[1]) > len(y[1]))) + else: + # Python 3.0 + s.sort(key=lambda x: len(x[1]),reverse=True) + + # Validate all of the t_rules collected + def validate_rules(self): + for state in self.stateinfo: + # Validate all rules defined by functions + + + + for fname, f in self.funcsym[state]: + line = func_code(f).co_firstlineno + file = func_code(f).co_filename + self.files[file] = 1 + + tokname = self.toknames[fname] + if isinstance(f, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + nargs = func_code(f).co_argcount + if nargs > reqargs: + self.log.error("%s:%d: Rule '%s' has too many arguments",file,line,f.__name__) + self.error = 1 + continue + + if nargs < reqargs: + self.log.error("%s:%d: Rule '%s' requires an argument", file,line,f.__name__) + self.error = 1 + continue + + if not f.__doc__: + self.log.error("%s:%d: No regular expression defined for rule '%s'",file,line,f.__name__) + self.error = 1 + continue + + try: + c = re.compile("(?P<%s>%s)" % (fname,f.__doc__), re.VERBOSE | self.reflags) + if c.match(""): + self.log.error("%s:%d: Regular expression for rule '%s' matches empty string", file,line,f.__name__) + self.error = 1 + except re.error: + _etype, e, _etrace = sys.exc_info() + self.log.error("%s:%d: Invalid regular expression for rule '%s'. %s", file,line,f.__name__,e) + if '#' in f.__doc__: + self.log.error("%s:%d. Make sure '#' in rule '%s' is escaped with '\\#'",file,line, f.__name__) + self.error = 1 + + # Validate all rules defined by strings + for name,r in self.strsym[state]: + tokname = self.toknames[name] + if tokname == 'error': + self.log.error("Rule '%s' must be defined as a function", name) + self.error = 1 + continue + + if not tokname in self.tokens and tokname.find("ignore_") < 0: + self.log.error("Rule '%s' defined for an unspecified token %s",name,tokname) + self.error = 1 + continue + + try: + c = re.compile("(?P<%s>%s)" % (name,r),re.VERBOSE | self.reflags) + if (c.match("")): + self.log.error("Regular expression for rule '%s' matches empty string",name) + self.error = 1 + except re.error: + _etype, e, _etrace = sys.exc_info() + self.log.error("Invalid regular expression for rule '%s'. %s",name,e) + if '#' in r: + self.log.error("Make sure '#' in rule '%s' is escaped with '\\#'",name) + self.error = 1 + + if not self.funcsym[state] and not self.strsym[state]: + self.log.error("No rules defined for state '%s'",state) + self.error = 1 + + # Validate the error function + efunc = self.errorf.get(state,None) + if efunc: + f = efunc + line = func_code(f).co_firstlineno + file = func_code(f).co_filename + self.files[file] = 1 + + if isinstance(f, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + nargs = func_code(f).co_argcount + if nargs > reqargs: + self.log.error("%s:%d: Rule '%s' has too many arguments",file,line,f.__name__) + self.error = 1 + + if nargs < reqargs: + self.log.error("%s:%d: Rule '%s' requires an argument", file,line,f.__name__) + self.error = 1 + + for f in self.files: + self.validate_file(f) + + + # ----------------------------------------------------------------------------- + # validate_file() + # + # This checks to see if there are duplicated t_rulename() functions or strings + # in the parser input file. This is done using a simple regular expression + # match on each line in the given file. + # ----------------------------------------------------------------------------- + + def validate_file(self,filename): + import os.path + base,ext = os.path.splitext(filename) + if ext != '.py': return # No idea what the file is. Return OK + + try: + f = open(filename) + lines = f.readlines() + f.close() + except IOError: + return # Couldn't find the file. Don't worry about it + + fre = re.compile(r'\s*def\s+(t_[a-zA-Z_0-9]*)\(') + sre = re.compile(r'\s*(t_[a-zA-Z_0-9]*)\s*=') + + counthash = { } + linen = 1 + for l in lines: + m = fre.match(l) + if not m: + m = sre.match(l) + if m: + name = m.group(1) + prev = counthash.get(name) + if not prev: + counthash[name] = linen + else: + self.log.error("%s:%d: Rule %s redefined. Previously defined on line %d",filename,linen,name,prev) + self.error = 1 + linen += 1 + +# ----------------------------------------------------------------------------- +# lex(module) +# +# Build all of the regular expression rules from definitions in the supplied module +# ----------------------------------------------------------------------------- +def lex(module=None,object=None,debug=0,optimize=0,lextab="lextab",reflags=0,nowarn=0,outputdir="", debuglog=None, errorlog=None): + global lexer + ldict = None + stateinfo = { 'INITIAL' : 'inclusive'} + lexobj = Lexer() + lexobj.lexoptimize = optimize + global token,input + + if errorlog is None: + errorlog = PlyLogger(sys.stderr) + + if debug: + if debuglog is None: + debuglog = PlyLogger(sys.stderr) + + # Get the module dictionary used for the lexer + if object: module = object + + if module: + _items = [(k,getattr(module,k)) for k in dir(module)] + ldict = dict(_items) + else: + ldict = get_caller_module_dict(2) + + # Collect parser information from the dictionary + linfo = LexerReflect(ldict,log=errorlog,reflags=reflags) + linfo.get_all() + if not optimize: + if linfo.validate_all(): + raise SyntaxError("Can't build lexer") + + if optimize and lextab: + try: + lexobj.readtab(lextab,ldict) + token = lexobj.token + input = lexobj.input + lexer = lexobj + return lexobj + + except ImportError: + pass + + # Dump some basic debugging information + if debug: + debuglog.info("lex: tokens = %r", linfo.tokens) + debuglog.info("lex: literals = %r", linfo.literals) + debuglog.info("lex: states = %r", linfo.stateinfo) + + # Build a dictionary of valid token names + lexobj.lextokens = { } + for n in linfo.tokens: + lexobj.lextokens[n] = 1 + + # Get literals specification + if isinstance(linfo.literals,(list,tuple)): + lexobj.lexliterals = type(linfo.literals[0])().join(linfo.literals) + else: + lexobj.lexliterals = linfo.literals + + # Get the stateinfo dictionary + stateinfo = linfo.stateinfo + + regexs = { } + # Build the master regular expressions + for state in stateinfo: + regex_list = [] + + # Add rules defined by functions first + for fname, f in linfo.funcsym[state]: + line = func_code(f).co_firstlineno + file = func_code(f).co_filename + regex_list.append("(?P<%s>%s)" % (fname,f.__doc__)) + if debug: + debuglog.info("lex: Adding rule %s -> '%s' (state '%s')",fname,f.__doc__, state) + + # Now add all of the simple rules + for name,r in linfo.strsym[state]: + regex_list.append("(?P<%s>%s)" % (name,r)) + if debug: + debuglog.info("lex: Adding rule %s -> '%s' (state '%s')",name,r, state) + + regexs[state] = regex_list + + # Build the master regular expressions + + if debug: + debuglog.info("lex: ==== MASTER REGEXS FOLLOW ====") + + for state in regexs: + lexre, re_text, re_names = _form_master_re(regexs[state],reflags,ldict,linfo.toknames) + lexobj.lexstatere[state] = lexre + lexobj.lexstateretext[state] = re_text + lexobj.lexstaterenames[state] = re_names + if debug: + for i in range(len(re_text)): + debuglog.info("lex: state '%s' : regex[%d] = '%s'",state, i, re_text[i]) + + # For inclusive states, we need to add the regular expressions from the INITIAL state + for state,stype in stateinfo.items(): + if state != "INITIAL" and stype == 'inclusive': + lexobj.lexstatere[state].extend(lexobj.lexstatere['INITIAL']) + lexobj.lexstateretext[state].extend(lexobj.lexstateretext['INITIAL']) + lexobj.lexstaterenames[state].extend(lexobj.lexstaterenames['INITIAL']) + + lexobj.lexstateinfo = stateinfo + lexobj.lexre = lexobj.lexstatere["INITIAL"] + lexobj.lexretext = lexobj.lexstateretext["INITIAL"] + lexobj.lexreflags = reflags + + # Set up ignore variables + lexobj.lexstateignore = linfo.ignore + lexobj.lexignore = lexobj.lexstateignore.get("INITIAL","") + + # Set up error functions + lexobj.lexstateerrorf = linfo.errorf + lexobj.lexerrorf = linfo.errorf.get("INITIAL",None) + if not lexobj.lexerrorf: + errorlog.warning("No t_error rule is defined") + + # Check state information for ignore and error rules + for s,stype in stateinfo.items(): + if stype == 'exclusive': + if not s in linfo.errorf: + errorlog.warning("No error rule is defined for exclusive state '%s'", s) + if not s in linfo.ignore and lexobj.lexignore: + errorlog.warning("No ignore rule is defined for exclusive state '%s'", s) + elif stype == 'inclusive': + if not s in linfo.errorf: + linfo.errorf[s] = linfo.errorf.get("INITIAL",None) + if not s in linfo.ignore: + linfo.ignore[s] = linfo.ignore.get("INITIAL","") + + # Create global versions of the token() and input() functions + token = lexobj.token + input = lexobj.input + lexer = lexobj + + # If in optimize mode, we write the lextab + if lextab and optimize: + lexobj.writetab(lextab,outputdir) + + return lexobj + +# ----------------------------------------------------------------------------- +# runmain() +# +# This runs the lexer as a main program +# ----------------------------------------------------------------------------- + +def runmain(lexer=None,data=None): + if not data: + try: + filename = sys.argv[1] + f = open(filename) + data = f.read() + f.close() + except IndexError: + sys.stdout.write("Reading from standard input (type EOF to end):\n") + data = sys.stdin.read() + + if lexer: + _input = lexer.input + else: + _input = input + _input(data) + if lexer: + _token = lexer.token + else: + _token = token + + while 1: + tok = _token() + if not tok: break + sys.stdout.write("(%s,%r,%d,%d)\n" % (tok.type, tok.value, tok.lineno,tok.lexpos)) + +# ----------------------------------------------------------------------------- +# @TOKEN(regex) +# +# This decorator function can be used to set the regex expression on a function +# when its docstring might need to be set in an alternative way +# ----------------------------------------------------------------------------- + +def TOKEN(r): + def set_doc(f): + if hasattr(r,"__call__"): + f.__doc__ = r.__doc__ + else: + f.__doc__ = r + return f + return set_doc + +# Alternative spelling of the TOKEN decorator +Token = TOKEN + diff --git a/third_party/ply/yacc.py b/third_party/ply/yacc.py new file mode 100644 index 0000000..e9f5c65 --- /dev/null +++ b/third_party/ply/yacc.py @@ -0,0 +1,3276 @@ +# ----------------------------------------------------------------------------- +# ply: yacc.py +# +# Copyright (C) 2001-2009, +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# +# This implements an LR parser that is constructed from grammar rules defined +# as Python functions. The grammer is specified by supplying the BNF inside +# Python documentation strings. The inspiration for this technique was borrowed +# from John Aycock's Spark parsing system. PLY might be viewed as cross between +# Spark and the GNU bison utility. +# +# The current implementation is only somewhat object-oriented. The +# LR parser itself is defined in terms of an object (which allows multiple +# parsers to co-exist). However, most of the variables used during table +# construction are defined in terms of global variables. Users shouldn't +# notice unless they are trying to define multiple parsers at the same +# time using threads (in which case they should have their head examined). +# +# This implementation supports both SLR and LALR(1) parsing. LALR(1) +# support was originally implemented by Elias Ioup (ezioup@alumni.uchicago.edu), +# using the algorithm found in Aho, Sethi, and Ullman "Compilers: Principles, +# Techniques, and Tools" (The Dragon Book). LALR(1) has since been replaced +# by the more efficient DeRemer and Pennello algorithm. +# +# :::::::: WARNING ::::::: +# +# Construction of LR parsing tables is fairly complicated and expensive. +# To make this module run fast, a *LOT* of work has been put into +# optimization---often at the expensive of readability and what might +# consider to be good Python "coding style." Modify the code at your +# own risk! +# ---------------------------------------------------------------------------- + +__version__ = "3.3" +__tabversion__ = "3.2" # Table version + +#----------------------------------------------------------------------------- +# === User configurable parameters === +# +# Change these to modify the default behavior of yacc (if you wish) +#----------------------------------------------------------------------------- + +yaccdebug = 1 # Debugging mode. If set, yacc generates a + # a 'parser.out' file in the current directory + +debug_file = 'parser.out' # Default name of the debugging file +tab_module = 'parsetab' # Default name of the table module +default_lr = 'LALR' # Default LR table generation method + +error_count = 3 # Number of symbols that must be shifted to leave recovery mode + +yaccdevel = 0 # Set to True if developing yacc. This turns off optimized + # implementations of certain functions. + +resultlimit = 40 # Size limit of results when running in debug mode. + +pickle_protocol = 0 # Protocol to use when writing pickle files + +import re, types, sys, os.path + +# Compatibility function for python 2.6/3.0 +if sys.version_info[0] < 3: + def func_code(f): + return f.func_code +else: + def func_code(f): + return f.__code__ + +# Compatibility +try: + MAXINT = sys.maxint +except AttributeError: + MAXINT = sys.maxsize + +# Python 2.x/3.0 compatibility. +def load_ply_lex(): + if sys.version_info[0] < 3: + import lex + else: + import ply.lex as lex + return lex + +# This object is a stand-in for a logging object created by the +# logging module. PLY will use this by default to create things +# such as the parser.out file. If a user wants more detailed +# information, they can create their own logging object and pass +# it into PLY. + +class PlyLogger(object): + def __init__(self,f): + self.f = f + def debug(self,msg,*args,**kwargs): + self.f.write((msg % args) + "\n") + info = debug + + def warning(self,msg,*args,**kwargs): + self.f.write("WARNING: "+ (msg % args) + "\n") + + def error(self,msg,*args,**kwargs): + self.f.write("ERROR: " + (msg % args) + "\n") + + critical = debug + +# Null logger is used when no output is generated. Does nothing. +class NullLogger(object): + def __getattribute__(self,name): + return self + def __call__(self,*args,**kwargs): + return self + +# Exception raised for yacc-related errors +class YaccError(Exception): pass + +# Format the result message that the parser produces when running in debug mode. +def format_result(r): + repr_str = repr(r) + if '\n' in repr_str: repr_str = repr(repr_str) + if len(repr_str) > resultlimit: + repr_str = repr_str[:resultlimit]+" ..." + result = "<%s @ 0x%x> (%s)" % (type(r).__name__,id(r),repr_str) + return result + + +# Format stack entries when the parser is running in debug mode +def format_stack_entry(r): + repr_str = repr(r) + if '\n' in repr_str: repr_str = repr(repr_str) + if len(repr_str) < 16: + return repr_str + else: + return "<%s @ 0x%x>" % (type(r).__name__,id(r)) + +#----------------------------------------------------------------------------- +# === LR Parsing Engine === +# +# The following classes are used for the LR parser itself. These are not +# used during table construction and are independent of the actual LR +# table generation algorithm +#----------------------------------------------------------------------------- + +# This class is used to hold non-terminal grammar symbols during parsing. +# It normally has the following attributes set: +# .type = Grammar symbol type +# .value = Symbol value +# .lineno = Starting line number +# .endlineno = Ending line number (optional, set automatically) +# .lexpos = Starting lex position +# .endlexpos = Ending lex position (optional, set automatically) + +class YaccSymbol: + def __str__(self): return self.type + def __repr__(self): return str(self) + +# This class is a wrapper around the objects actually passed to each +# grammar rule. Index lookup and assignment actually assign the +# .value attribute of the underlying YaccSymbol object. +# The lineno() method returns the line number of a given +# item (or 0 if not defined). The linespan() method returns +# a tuple of (startline,endline) representing the range of lines +# for a symbol. The lexspan() method returns a tuple (lexpos,endlexpos) +# representing the range of positional information for a symbol. + +class YaccProduction: + def __init__(self,s,stack=None): + self.slice = s + self.stack = stack + self.lexer = None + self.parser= None + def __getitem__(self,n): + if n >= 0: return self.slice[n].value + else: return self.stack[n].value + + def __setitem__(self,n,v): + self.slice[n].value = v + + def __getslice__(self,i,j): + return [s.value for s in self.slice[i:j]] + + def __len__(self): + return len(self.slice) + + def lineno(self,n): + return getattr(self.slice[n],"lineno",0) + + def set_lineno(self,n,lineno): + self.slice[n].lineno = lineno + + def linespan(self,n): + startline = getattr(self.slice[n],"lineno",0) + endline = getattr(self.slice[n],"endlineno",startline) + return startline,endline + + def lexpos(self,n): + return getattr(self.slice[n],"lexpos",0) + + def lexspan(self,n): + startpos = getattr(self.slice[n],"lexpos",0) + endpos = getattr(self.slice[n],"endlexpos",startpos) + return startpos,endpos + + def error(self): + raise SyntaxError + + +# ----------------------------------------------------------------------------- +# == LRParser == +# +# The LR Parsing engine. +# ----------------------------------------------------------------------------- + +class LRParser: + def __init__(self,lrtab,errorf): + self.productions = lrtab.lr_productions + self.action = lrtab.lr_action + self.goto = lrtab.lr_goto + self.errorfunc = errorf + + def errok(self): + self.errorok = 1 + + def restart(self): + del self.statestack[:] + del self.symstack[:] + sym = YaccSymbol() + sym.type = '$end' + self.symstack.append(sym) + self.statestack.append(0) + + def parse(self,input=None,lexer=None,debug=0,tracking=0,tokenfunc=None): + if debug or yaccdevel: + if isinstance(debug,int): + debug = PlyLogger(sys.stderr) + return self.parsedebug(input,lexer,debug,tracking,tokenfunc) + elif tracking: + return self.parseopt(input,lexer,debug,tracking,tokenfunc) + else: + return self.parseopt_notrack(input,lexer,debug,tracking,tokenfunc) + + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parsedebug(). + # + # This is the debugging enabled version of parse(). All changes made to the + # parsing engine should be made here. For the non-debugging version, + # copy this code to a method parseopt() and delete all of the sections + # enclosed in: + # + # #--! DEBUG + # statements + # #--! DEBUG + # + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + def parsedebug(self,input=None,lexer=None,debug=None,tracking=0,tokenfunc=None): + lookahead = None # Current lookahead symbol + lookaheadstack = [ ] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + # --! DEBUG + debug.info("PLY: PARSE DEBUG START") + # --! DEBUG + + # If no lexer was given, we will try to use the lex module + if not lexer: + lex = load_ply_lex() + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set up the state and symbol stacks + + statestack = [ ] # Stack of parsing states + self.statestack = statestack + symstack = [ ] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = "$end" + symstack.append(sym) + state = 0 + while 1: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + # --! DEBUG + debug.debug('') + debug.debug('State : %s', state) + # --! DEBUG + + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = "$end" + + # --! DEBUG + debug.debug('Stack : %s', + ("%s . %s" % (" ".join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip()) + # --! DEBUG + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + # --! DEBUG + debug.debug("Action : Shift and goto state %s", t) + # --! DEBUG + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: errorcount -=1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + # --! DEBUG + if plen: + debug.info("Action : Reduce rule [%s] with %s and goto state %d", p.str, "["+",".join([format_stack_entry(_v.value) for _v in symstack[-plen:]])+"]",-t) + else: + debug.info("Action : Reduce rule [%s] with %s and goto state %d", p.str, [],-t) + + # --! DEBUG + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + # --! TRACKING + if tracking: + t1 = targ[1] + sym.lineno = t1.lineno + sym.lexpos = t1.lexpos + t1 = targ[-1] + sym.endlineno = getattr(t1,"endlineno",t1.lineno) + sym.endlexpos = getattr(t1,"endlexpos",t1.lexpos) + + # --! TRACKING + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + del statestack[-plen:] + p.callable(pslice) + # --! DEBUG + debug.info("Result : %s", format_result(pslice[0])) + # --! DEBUG + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + # --! TRACKING + if tracking: + sym.lineno = lexer.lineno + sym.lexpos = lexer.lexpos + # --! TRACKING + + targ = [ sym ] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + p.callable(pslice) + # --! DEBUG + debug.info("Result : %s", format_result(pslice[0])) + # --! DEBUG + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + result = getattr(n,"value",None) + # --! DEBUG + debug.info("Done : Returning %s", format_result(result)) + debug.info("PLY: PARSE DEBUG END") + # --! DEBUG + return result + + if t == None: + + # --! DEBUG + debug.error('Error : %s', + ("%s . %s" % (" ".join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip()) + # --! DEBUG + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = 0 + errtoken = lookahead + if errtoken.type == "$end": + errtoken = None # End of file! + if self.errorfunc: + global errok,token,restart + errok = self.errok # Set some special functions available in error recovery + token = get_token + restart = self.restart + if errtoken and not hasattr(errtoken,'lexer'): + errtoken.lexer = lexer + tok = self.errorfunc(errtoken) + del errok, token, restart # Delete special functions + + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken,"lineno"): lineno = lookahead.lineno + else: lineno = 0 + if lineno: + sys.stderr.write("yacc: Syntax error at line %d, token=%s\n" % (lineno, errtoken.type)) + else: + sys.stderr.write("yacc: Syntax error, token=%s" % errtoken.type) + else: + sys.stderr.write("yacc: Parse error in input. EOF\n") + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != "$end": + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == "$end": + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + lookahead = None + continue + t = YaccSymbol() + t.type = 'error' + if hasattr(lookahead,"lineno"): + t.lineno = lookahead.lineno + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + symstack.pop() + statestack.pop() + state = statestack[-1] # Potential bug fix + + continue + + # Call an error function here + raise RuntimeError("yacc: internal parser error!!!\n") + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parseopt(). + # + # Optimized version of parse() method. DO NOT EDIT THIS CODE DIRECTLY. + # Edit the debug version above, then copy any modifications to the method + # below while removing #--! DEBUG sections. + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + + def parseopt(self,input=None,lexer=None,debug=0,tracking=0,tokenfunc=None): + lookahead = None # Current lookahead symbol + lookaheadstack = [ ] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + # If no lexer was given, we will try to use the lex module + if not lexer: + lex = load_ply_lex() + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set up the state and symbol stacks + + statestack = [ ] # Stack of parsing states + self.statestack = statestack + symstack = [ ] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = '$end' + symstack.append(sym) + state = 0 + while 1: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: errorcount -=1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + # --! TRACKING + if tracking: + t1 = targ[1] + sym.lineno = t1.lineno + sym.lexpos = t1.lexpos + t1 = targ[-1] + sym.endlineno = getattr(t1,"endlineno",t1.lineno) + sym.endlexpos = getattr(t1,"endlexpos",t1.lexpos) + + # --! TRACKING + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + del statestack[-plen:] + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + # --! TRACKING + if tracking: + sym.lineno = lexer.lineno + sym.lexpos = lexer.lexpos + # --! TRACKING + + targ = [ sym ] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + return getattr(n,"value",None) + + if t == None: + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = 0 + errtoken = lookahead + if errtoken.type == '$end': + errtoken = None # End of file! + if self.errorfunc: + global errok,token,restart + errok = self.errok # Set some special functions available in error recovery + token = get_token + restart = self.restart + if errtoken and not hasattr(errtoken,'lexer'): + errtoken.lexer = lexer + tok = self.errorfunc(errtoken) + del errok, token, restart # Delete special functions + + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken,"lineno"): lineno = lookahead.lineno + else: lineno = 0 + if lineno: + sys.stderr.write("yacc: Syntax error at line %d, token=%s\n" % (lineno, errtoken.type)) + else: + sys.stderr.write("yacc: Syntax error, token=%s" % errtoken.type) + else: + sys.stderr.write("yacc: Parse error in input. EOF\n") + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + lookahead = None + continue + t = YaccSymbol() + t.type = 'error' + if hasattr(lookahead,"lineno"): + t.lineno = lookahead.lineno + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + symstack.pop() + statestack.pop() + state = statestack[-1] # Potential bug fix + + continue + + # Call an error function here + raise RuntimeError("yacc: internal parser error!!!\n") + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parseopt_notrack(). + # + # Optimized version of parseopt() with line number tracking removed. + # DO NOT EDIT THIS CODE DIRECTLY. Copy the optimized version and remove + # code in the #--! TRACKING sections + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + def parseopt_notrack(self,input=None,lexer=None,debug=0,tracking=0,tokenfunc=None): + lookahead = None # Current lookahead symbol + lookaheadstack = [ ] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + # If no lexer was given, we will try to use the lex module + if not lexer: + lex = load_ply_lex() + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set up the state and symbol stacks + + statestack = [ ] # Stack of parsing states + self.statestack = statestack + symstack = [ ] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = '$end' + symstack.append(sym) + state = 0 + while 1: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: errorcount -=1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + del statestack[-plen:] + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + targ = [ sym ] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) + symstack.pop() + statestack.pop() + state = statestack[-1] + sym.type = 'error' + lookahead = sym + errorcount = error_count + self.errorok = 0 + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + return getattr(n,"value",None) + + if t == None: + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = 0 + errtoken = lookahead + if errtoken.type == '$end': + errtoken = None # End of file! + if self.errorfunc: + global errok,token,restart + errok = self.errok # Set some special functions available in error recovery + token = get_token + restart = self.restart + if errtoken and not hasattr(errtoken,'lexer'): + errtoken.lexer = lexer + tok = self.errorfunc(errtoken) + del errok, token, restart # Delete special functions + + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken,"lineno"): lineno = lookahead.lineno + else: lineno = 0 + if lineno: + sys.stderr.write("yacc: Syntax error at line %d, token=%s\n" % (lineno, errtoken.type)) + else: + sys.stderr.write("yacc: Syntax error, token=%s" % errtoken.type) + else: + sys.stderr.write("yacc: Parse error in input. EOF\n") + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + lookahead = None + continue + t = YaccSymbol() + t.type = 'error' + if hasattr(lookahead,"lineno"): + t.lineno = lookahead.lineno + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + symstack.pop() + statestack.pop() + state = statestack[-1] # Potential bug fix + + continue + + # Call an error function here + raise RuntimeError("yacc: internal parser error!!!\n") + +# ----------------------------------------------------------------------------- +# === Grammar Representation === +# +# The following functions, classes, and variables are used to represent and +# manipulate the rules that make up a grammar. +# ----------------------------------------------------------------------------- + +import re + +# regex matching identifiers +_is_identifier = re.compile(r'^[a-zA-Z0-9_-]+$') + +# ----------------------------------------------------------------------------- +# class Production: +# +# This class stores the raw information about a single production or grammar rule. +# A grammar rule refers to a specification such as this: +# +# expr : expr PLUS term +# +# Here are the basic attributes defined on all productions +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','PLUS','term'] +# prec - Production precedence level +# number - Production number. +# func - Function that executes on reduce +# file - File where production function is defined +# lineno - Line number where production function is defined +# +# The following attributes are defined or optional. +# +# len - Length of the production (number of symbols on right hand side) +# usyms - Set of unique symbols found in the production +# ----------------------------------------------------------------------------- + +class Production(object): + reduced = 0 + def __init__(self,number,name,prod,precedence=('right',0),func=None,file='',line=0): + self.name = name + self.prod = tuple(prod) + self.number = number + self.func = func + self.callable = None + self.file = file + self.line = line + self.prec = precedence + + # Internal settings used during table construction + + self.len = len(self.prod) # Length of the production + + # Create a list of unique production symbols used in the production + self.usyms = [ ] + for s in self.prod: + if s not in self.usyms: + self.usyms.append(s) + + # List of all LR items for the production + self.lr_items = [] + self.lr_next = None + + # Create a string representation + if self.prod: + self.str = "%s -> %s" % (self.name," ".join(self.prod)) + else: + self.str = "%s -> " % self.name + + def __str__(self): + return self.str + + def __repr__(self): + return "Production("+str(self)+")" + + def __len__(self): + return len(self.prod) + + def __nonzero__(self): + return 1 + + def __getitem__(self,index): + return self.prod[index] + + # Return the nth lr_item from the production (or None if at the end) + def lr_item(self,n): + if n > len(self.prod): return None + p = LRItem(self,n) + + # Precompute the list of productions immediately following. Hack. Remove later + try: + p.lr_after = Prodnames[p.prod[n+1]] + except (IndexError,KeyError): + p.lr_after = [] + try: + p.lr_before = p.prod[n-1] + except IndexError: + p.lr_before = None + + return p + + # Bind the production function name to a callable + def bind(self,pdict): + if self.func: + self.callable = pdict[self.func] + +# This class serves as a minimal standin for Production objects when +# reading table data from files. It only contains information +# actually used by the LR parsing engine, plus some additional +# debugging information. +class MiniProduction(object): + def __init__(self,str,name,len,func,file,line): + self.name = name + self.len = len + self.func = func + self.callable = None + self.file = file + self.line = line + self.str = str + def __str__(self): + return self.str + def __repr__(self): + return "MiniProduction(%s)" % self.str + + # Bind the production function name to a callable + def bind(self,pdict): + if self.func: + self.callable = pdict[self.func] + + +# ----------------------------------------------------------------------------- +# class LRItem +# +# This class represents a specific stage of parsing a production rule. For +# example: +# +# expr : expr . PLUS term +# +# In the above, the "." represents the current location of the parse. Here +# basic attributes: +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','.', 'PLUS','term'] +# number - Production number. +# +# lr_next Next LR item. Example, if we are ' expr -> expr . PLUS term' +# then lr_next refers to 'expr -> expr PLUS . term' +# lr_index - LR item index (location of the ".") in the prod list. +# lookaheads - LALR lookahead symbols for this item +# len - Length of the production (number of symbols on right hand side) +# lr_after - List of all productions that immediately follow +# lr_before - Grammar symbol immediately before +# ----------------------------------------------------------------------------- + +class LRItem(object): + def __init__(self,p,n): + self.name = p.name + self.prod = list(p.prod) + self.number = p.number + self.lr_index = n + self.lookaheads = { } + self.prod.insert(n,".") + self.prod = tuple(self.prod) + self.len = len(self.prod) + self.usyms = p.usyms + + def __str__(self): + if self.prod: + s = "%s -> %s" % (self.name," ".join(self.prod)) + else: + s = "%s -> " % self.name + return s + + def __repr__(self): + return "LRItem("+str(self)+")" + +# ----------------------------------------------------------------------------- +# rightmost_terminal() +# +# Return the rightmost terminal from a list of symbols. Used in add_production() +# ----------------------------------------------------------------------------- +def rightmost_terminal(symbols, terminals): + i = len(symbols) - 1 + while i >= 0: + if symbols[i] in terminals: + return symbols[i] + i -= 1 + return None + +# ----------------------------------------------------------------------------- +# === GRAMMAR CLASS === +# +# The following class represents the contents of the specified grammar along +# with various computed properties such as first sets, follow sets, LR items, etc. +# This data is used for critical parts of the table generation process later. +# ----------------------------------------------------------------------------- + +class GrammarError(YaccError): pass + +class Grammar(object): + def __init__(self,terminals): + self.Productions = [None] # A list of all of the productions. The first + # entry is always reserved for the purpose of + # building an augmented grammar + + self.Prodnames = { } # A dictionary mapping the names of nonterminals to a list of all + # productions of that nonterminal. + + self.Prodmap = { } # A dictionary that is only used to detect duplicate + # productions. + + self.Terminals = { } # A dictionary mapping the names of terminal symbols to a + # list of the rules where they are used. + + for term in terminals: + self.Terminals[term] = [] + + self.Terminals['error'] = [] + + self.Nonterminals = { } # A dictionary mapping names of nonterminals to a list + # of rule numbers where they are used. + + self.First = { } # A dictionary of precomputed FIRST(x) symbols + + self.Follow = { } # A dictionary of precomputed FOLLOW(x) symbols + + self.Precedence = { } # Precedence rules for each terminal. Contains tuples of the + # form ('right',level) or ('nonassoc', level) or ('left',level) + + self.UsedPrecedence = { } # Precedence rules that were actually used by the grammer. + # This is only used to provide error checking and to generate + # a warning about unused precedence rules. + + self.Start = None # Starting symbol for the grammar + + + def __len__(self): + return len(self.Productions) + + def __getitem__(self,index): + return self.Productions[index] + + # ----------------------------------------------------------------------------- + # set_precedence() + # + # Sets the precedence for a given terminal. assoc is the associativity such as + # 'left','right', or 'nonassoc'. level is a numeric level. + # + # ----------------------------------------------------------------------------- + + def set_precedence(self,term,assoc,level): + assert self.Productions == [None],"Must call set_precedence() before add_production()" + if term in self.Precedence: + raise GrammarError("Precedence already specified for terminal '%s'" % term) + if assoc not in ['left','right','nonassoc']: + raise GrammarError("Associativity must be one of 'left','right', or 'nonassoc'") + self.Precedence[term] = (assoc,level) + + # ----------------------------------------------------------------------------- + # add_production() + # + # Given an action function, this function assembles a production rule and + # computes its precedence level. + # + # The production rule is supplied as a list of symbols. For example, + # a rule such as 'expr : expr PLUS term' has a production name of 'expr' and + # symbols ['expr','PLUS','term']. + # + # Precedence is determined by the precedence of the right-most non-terminal + # or the precedence of a terminal specified by %prec. + # + # A variety of error checks are performed to make sure production symbols + # are valid and that %prec is used correctly. + # ----------------------------------------------------------------------------- + + def add_production(self,prodname,syms,func=None,file='',line=0): + + if prodname in self.Terminals: + raise GrammarError("%s:%d: Illegal rule name '%s'. Already defined as a token" % (file,line,prodname)) + if prodname == 'error': + raise GrammarError("%s:%d: Illegal rule name '%s'. error is a reserved word" % (file,line,prodname)) + if not _is_identifier.match(prodname): + raise GrammarError("%s:%d: Illegal rule name '%s'" % (file,line,prodname)) + + # Look for literal tokens + for n,s in enumerate(syms): + if s[0] in "'\"": + try: + c = eval(s) + if (len(c) > 1): + raise GrammarError("%s:%d: Literal token %s in rule '%s' may only be a single character" % (file,line,s, prodname)) + if not c in self.Terminals: + self.Terminals[c] = [] + syms[n] = c + continue + except SyntaxError: + pass + if not _is_identifier.match(s) and s != '%prec': + raise GrammarError("%s:%d: Illegal name '%s' in rule '%s'" % (file,line,s, prodname)) + + # Determine the precedence level + if '%prec' in syms: + if syms[-1] == '%prec': + raise GrammarError("%s:%d: Syntax error. Nothing follows %%prec" % (file,line)) + if syms[-2] != '%prec': + raise GrammarError("%s:%d: Syntax error. %%prec can only appear at the end of a grammar rule" % (file,line)) + precname = syms[-1] + prodprec = self.Precedence.get(precname,None) + if not prodprec: + raise GrammarError("%s:%d: Nothing known about the precedence of '%s'" % (file,line,precname)) + else: + self.UsedPrecedence[precname] = 1 + del syms[-2:] # Drop %prec from the rule + else: + # If no %prec, precedence is determined by the rightmost terminal symbol + precname = rightmost_terminal(syms,self.Terminals) + prodprec = self.Precedence.get(precname,('right',0)) + + # See if the rule is already in the rulemap + map = "%s -> %s" % (prodname,syms) + if map in self.Prodmap: + m = self.Prodmap[map] + raise GrammarError("%s:%d: Duplicate rule %s. " % (file,line, m) + + "Previous definition at %s:%d" % (m.file, m.line)) + + # From this point on, everything is valid. Create a new Production instance + pnumber = len(self.Productions) + if not prodname in self.Nonterminals: + self.Nonterminals[prodname] = [ ] + + # Add the production number to Terminals and Nonterminals + for t in syms: + if t in self.Terminals: + self.Terminals[t].append(pnumber) + else: + if not t in self.Nonterminals: + self.Nonterminals[t] = [ ] + self.Nonterminals[t].append(pnumber) + + # Create a production and add it to the list of productions + p = Production(pnumber,prodname,syms,prodprec,func,file,line) + self.Productions.append(p) + self.Prodmap[map] = p + + # Add to the global productions list + try: + self.Prodnames[prodname].append(p) + except KeyError: + self.Prodnames[prodname] = [ p ] + return 0 + + # ----------------------------------------------------------------------------- + # set_start() + # + # Sets the starting symbol and creates the augmented grammar. Production + # rule 0 is S' -> start where start is the start symbol. + # ----------------------------------------------------------------------------- + + def set_start(self,start=None): + if not start: + start = self.Productions[1].name + if start not in self.Nonterminals: + raise GrammarError("start symbol %s undefined" % start) + self.Productions[0] = Production(0,"S'",[start]) + self.Nonterminals[start].append(0) + self.Start = start + + # ----------------------------------------------------------------------------- + # find_unreachable() + # + # Find all of the nonterminal symbols that can't be reached from the starting + # symbol. Returns a list of nonterminals that can't be reached. + # ----------------------------------------------------------------------------- + + def find_unreachable(self): + + # Mark all symbols that are reachable from a symbol s + def mark_reachable_from(s): + if reachable[s]: + # We've already reached symbol s. + return + reachable[s] = 1 + for p in self.Prodnames.get(s,[]): + for r in p.prod: + mark_reachable_from(r) + + reachable = { } + for s in list(self.Terminals) + list(self.Nonterminals): + reachable[s] = 0 + + mark_reachable_from( self.Productions[0].prod[0] ) + + return [s for s in list(self.Nonterminals) + if not reachable[s]] + + # ----------------------------------------------------------------------------- + # infinite_cycles() + # + # This function looks at the various parsing rules and tries to detect + # infinite recursion cycles (grammar rules where there is no possible way + # to derive a string of only terminals). + # ----------------------------------------------------------------------------- + + def infinite_cycles(self): + terminates = {} + + # Terminals: + for t in self.Terminals: + terminates[t] = 1 + + terminates['$end'] = 1 + + # Nonterminals: + + # Initialize to false: + for n in self.Nonterminals: + terminates[n] = 0 + + # Then propagate termination until no change: + while 1: + some_change = 0 + for (n,pl) in self.Prodnames.items(): + # Nonterminal n terminates iff any of its productions terminates. + for p in pl: + # Production p terminates iff all of its rhs symbols terminate. + for s in p.prod: + if not terminates[s]: + # The symbol s does not terminate, + # so production p does not terminate. + p_terminates = 0 + break + else: + # didn't break from the loop, + # so every symbol s terminates + # so production p terminates. + p_terminates = 1 + + if p_terminates: + # symbol n terminates! + if not terminates[n]: + terminates[n] = 1 + some_change = 1 + # Don't need to consider any more productions for this n. + break + + if not some_change: + break + + infinite = [] + for (s,term) in terminates.items(): + if not term: + if not s in self.Prodnames and not s in self.Terminals and s != 'error': + # s is used-but-not-defined, and we've already warned of that, + # so it would be overkill to say that it's also non-terminating. + pass + else: + infinite.append(s) + + return infinite + + + # ----------------------------------------------------------------------------- + # undefined_symbols() + # + # Find all symbols that were used the grammar, but not defined as tokens or + # grammar rules. Returns a list of tuples (sym, prod) where sym in the symbol + # and prod is the production where the symbol was used. + # ----------------------------------------------------------------------------- + def undefined_symbols(self): + result = [] + for p in self.Productions: + if not p: continue + + for s in p.prod: + if not s in self.Prodnames and not s in self.Terminals and s != 'error': + result.append((s,p)) + return result + + # ----------------------------------------------------------------------------- + # unused_terminals() + # + # Find all terminals that were defined, but not used by the grammar. Returns + # a list of all symbols. + # ----------------------------------------------------------------------------- + def unused_terminals(self): + unused_tok = [] + for s,v in self.Terminals.items(): + if s != 'error' and not v: + unused_tok.append(s) + + return unused_tok + + # ------------------------------------------------------------------------------ + # unused_rules() + # + # Find all grammar rules that were defined, but not used (maybe not reachable) + # Returns a list of productions. + # ------------------------------------------------------------------------------ + + def unused_rules(self): + unused_prod = [] + for s,v in self.Nonterminals.items(): + if not v: + p = self.Prodnames[s][0] + unused_prod.append(p) + return unused_prod + + # ----------------------------------------------------------------------------- + # unused_precedence() + # + # Returns a list of tuples (term,precedence) corresponding to precedence + # rules that were never used by the grammar. term is the name of the terminal + # on which precedence was applied and precedence is a string such as 'left' or + # 'right' corresponding to the type of precedence. + # ----------------------------------------------------------------------------- + + def unused_precedence(self): + unused = [] + for termname in self.Precedence: + if not (termname in self.Terminals or termname in self.UsedPrecedence): + unused.append((termname,self.Precedence[termname][0])) + + return unused + + # ------------------------------------------------------------------------- + # _first() + # + # Compute the value of FIRST1(beta) where beta is a tuple of symbols. + # + # During execution of compute_first1, the result may be incomplete. + # Afterward (e.g., when called from compute_follow()), it will be complete. + # ------------------------------------------------------------------------- + def _first(self,beta): + + # We are computing First(x1,x2,x3,...,xn) + result = [ ] + for x in beta: + x_produces_empty = 0 + + # Add all the non- symbols of First[x] to the result. + for f in self.First[x]: + if f == '': + x_produces_empty = 1 + else: + if f not in result: result.append(f) + + if x_produces_empty: + # We have to consider the next x in beta, + # i.e. stay in the loop. + pass + else: + # We don't have to consider any further symbols in beta. + break + else: + # There was no 'break' from the loop, + # so x_produces_empty was true for all x in beta, + # so beta produces empty as well. + result.append('') + + return result + + # ------------------------------------------------------------------------- + # compute_first() + # + # Compute the value of FIRST1(X) for all symbols + # ------------------------------------------------------------------------- + def compute_first(self): + if self.First: + return self.First + + # Terminals: + for t in self.Terminals: + self.First[t] = [t] + + self.First['$end'] = ['$end'] + + # Nonterminals: + + # Initialize to the empty set: + for n in self.Nonterminals: + self.First[n] = [] + + # Then propagate symbols until no change: + while 1: + some_change = 0 + for n in self.Nonterminals: + for p in self.Prodnames[n]: + for f in self._first(p.prod): + if f not in self.First[n]: + self.First[n].append( f ) + some_change = 1 + if not some_change: + break + + return self.First + + # --------------------------------------------------------------------- + # compute_follow() + # + # Computes all of the follow sets for every non-terminal symbol. The + # follow set is the set of all symbols that might follow a given + # non-terminal. See the Dragon book, 2nd Ed. p. 189. + # --------------------------------------------------------------------- + def compute_follow(self,start=None): + # If already computed, return the result + if self.Follow: + return self.Follow + + # If first sets not computed yet, do that first. + if not self.First: + self.compute_first() + + # Add '$end' to the follow list of the start symbol + for k in self.Nonterminals: + self.Follow[k] = [ ] + + if not start: + start = self.Productions[1].name + + self.Follow[start] = [ '$end' ] + + while 1: + didadd = 0 + for p in self.Productions[1:]: + # Here is the production set + for i in range(len(p.prod)): + B = p.prod[i] + if B in self.Nonterminals: + # Okay. We got a non-terminal in a production + fst = self._first(p.prod[i+1:]) + hasempty = 0 + for f in fst: + if f != '' and f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = 1 + if f == '': + hasempty = 1 + if hasempty or i == (len(p.prod)-1): + # Add elements of follow(a) to follow(b) + for f in self.Follow[p.name]: + if f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = 1 + if not didadd: break + return self.Follow + + + # ----------------------------------------------------------------------------- + # build_lritems() + # + # This function walks the list of productions and builds a complete set of the + # LR items. The LR items are stored in two ways: First, they are uniquely + # numbered and placed in the list _lritems. Second, a linked list of LR items + # is built for each production. For example: + # + # E -> E PLUS E + # + # Creates the list + # + # [E -> . E PLUS E, E -> E . PLUS E, E -> E PLUS . E, E -> E PLUS E . ] + # ----------------------------------------------------------------------------- + + def build_lritems(self): + for p in self.Productions: + lastlri = p + i = 0 + lr_items = [] + while 1: + if i > len(p): + lri = None + else: + lri = LRItem(p,i) + # Precompute the list of productions immediately following + try: + lri.lr_after = self.Prodnames[lri.prod[i+1]] + except (IndexError,KeyError): + lri.lr_after = [] + try: + lri.lr_before = lri.prod[i-1] + except IndexError: + lri.lr_before = None + + lastlri.lr_next = lri + if not lri: break + lr_items.append(lri) + lastlri = lri + i += 1 + p.lr_items = lr_items + +# ----------------------------------------------------------------------------- +# == Class LRTable == +# +# This basic class represents a basic table of LR parsing information. +# Methods for generating the tables are not defined here. They are defined +# in the derived class LRGeneratedTable. +# ----------------------------------------------------------------------------- + +class VersionError(YaccError): pass + +class LRTable(object): + def __init__(self): + self.lr_action = None + self.lr_goto = None + self.lr_productions = None + self.lr_method = None + + def read_table(self,module): + if isinstance(module,types.ModuleType): + parsetab = module + else: + if sys.version_info[0] < 3: + exec("import %s as parsetab" % module) + else: + env = { } + exec("import %s as parsetab" % module, env, env) + parsetab = env['parsetab'] + + if parsetab._tabversion != __tabversion__: + raise VersionError("yacc table file version is out of date") + + self.lr_action = parsetab._lr_action + self.lr_goto = parsetab._lr_goto + + self.lr_productions = [] + for p in parsetab._lr_productions: + self.lr_productions.append(MiniProduction(*p)) + + self.lr_method = parsetab._lr_method + return parsetab._lr_signature + + def read_pickle(self,filename): + try: + import cPickle as pickle + except ImportError: + import pickle + + in_f = open(filename,"rb") + + tabversion = pickle.load(in_f) + if tabversion != __tabversion__: + raise VersionError("yacc table file version is out of date") + self.lr_method = pickle.load(in_f) + signature = pickle.load(in_f) + self.lr_action = pickle.load(in_f) + self.lr_goto = pickle.load(in_f) + productions = pickle.load(in_f) + + self.lr_productions = [] + for p in productions: + self.lr_productions.append(MiniProduction(*p)) + + in_f.close() + return signature + + # Bind all production function names to callable objects in pdict + def bind_callables(self,pdict): + for p in self.lr_productions: + p.bind(pdict) + +# ----------------------------------------------------------------------------- +# === LR Generator === +# +# The following classes and functions are used to generate LR parsing tables on +# a grammar. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# digraph() +# traverse() +# +# The following two functions are used to compute set valued functions +# of the form: +# +# F(x) = F'(x) U U{F(y) | x R y} +# +# This is used to compute the values of Read() sets as well as FOLLOW sets +# in LALR(1) generation. +# +# Inputs: X - An input set +# R - A relation +# FP - Set-valued function +# ------------------------------------------------------------------------------ + +def digraph(X,R,FP): + N = { } + for x in X: + N[x] = 0 + stack = [] + F = { } + for x in X: + if N[x] == 0: traverse(x,N,stack,F,X,R,FP) + return F + +def traverse(x,N,stack,F,X,R,FP): + stack.append(x) + d = len(stack) + N[x] = d + F[x] = FP(x) # F(X) <- F'(x) + + rel = R(x) # Get y's related to x + for y in rel: + if N[y] == 0: + traverse(y,N,stack,F,X,R,FP) + N[x] = min(N[x],N[y]) + for a in F.get(y,[]): + if a not in F[x]: F[x].append(a) + if N[x] == d: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + while element != x: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + +class LALRError(YaccError): pass + +# ----------------------------------------------------------------------------- +# == LRGeneratedTable == +# +# This class implements the LR table generation algorithm. There are no +# public methods except for write() +# ----------------------------------------------------------------------------- + +class LRGeneratedTable(LRTable): + def __init__(self,grammar,method='LALR',log=None): + if method not in ['SLR','LALR']: + raise LALRError("Unsupported method %s" % method) + + self.grammar = grammar + self.lr_method = method + + # Set up the logger + if not log: + log = NullLogger() + self.log = log + + # Internal attributes + self.lr_action = {} # Action table + self.lr_goto = {} # Goto table + self.lr_productions = grammar.Productions # Copy of grammar Production array + self.lr_goto_cache = {} # Cache of computed gotos + self.lr0_cidhash = {} # Cache of closures + + self._add_count = 0 # Internal counter used to detect cycles + + # Diagonistic information filled in by the table generator + self.sr_conflict = 0 + self.rr_conflict = 0 + self.conflicts = [] # List of conflicts + + self.sr_conflicts = [] + self.rr_conflicts = [] + + # Build the tables + self.grammar.build_lritems() + self.grammar.compute_first() + self.grammar.compute_follow() + self.lr_parse_table() + + # Compute the LR(0) closure operation on I, where I is a set of LR(0) items. + + def lr0_closure(self,I): + self._add_count += 1 + + # Add everything in I to J + J = I[:] + didadd = 1 + while didadd: + didadd = 0 + for j in J: + for x in j.lr_after: + if getattr(x,"lr0_added",0) == self._add_count: continue + # Add B --> .G to J + J.append(x.lr_next) + x.lr0_added = self._add_count + didadd = 1 + + return J + + # Compute the LR(0) goto function goto(I,X) where I is a set + # of LR(0) items and X is a grammar symbol. This function is written + # in a way that guarantees uniqueness of the generated goto sets + # (i.e. the same goto set will never be returned as two different Python + # objects). With uniqueness, we can later do fast set comparisons using + # id(obj) instead of element-wise comparison. + + def lr0_goto(self,I,x): + # First we look for a previously cached entry + g = self.lr_goto_cache.get((id(I),x),None) + if g: return g + + # Now we generate the goto set in a way that guarantees uniqueness + # of the result + + s = self.lr_goto_cache.get(x,None) + if not s: + s = { } + self.lr_goto_cache[x] = s + + gs = [ ] + for p in I: + n = p.lr_next + if n and n.lr_before == x: + s1 = s.get(id(n),None) + if not s1: + s1 = { } + s[id(n)] = s1 + gs.append(n) + s = s1 + g = s.get('$end',None) + if not g: + if gs: + g = self.lr0_closure(gs) + s['$end'] = g + else: + s['$end'] = gs + self.lr_goto_cache[(id(I),x)] = g + return g + + # Compute the LR(0) sets of item function + def lr0_items(self): + + C = [ self.lr0_closure([self.grammar.Productions[0].lr_next]) ] + i = 0 + for I in C: + self.lr0_cidhash[id(I)] = i + i += 1 + + # Loop over the items in C and each grammar symbols + i = 0 + while i < len(C): + I = C[i] + i += 1 + + # Collect all of the symbols that could possibly be in the goto(I,X) sets + asyms = { } + for ii in I: + for s in ii.usyms: + asyms[s] = None + + for x in asyms: + g = self.lr0_goto(I,x) + if not g: continue + if id(g) in self.lr0_cidhash: continue + self.lr0_cidhash[id(g)] = len(C) + C.append(g) + + return C + + # ----------------------------------------------------------------------------- + # ==== LALR(1) Parsing ==== + # + # LALR(1) parsing is almost exactly the same as SLR except that instead of + # relying upon Follow() sets when performing reductions, a more selective + # lookahead set that incorporates the state of the LR(0) machine is utilized. + # Thus, we mainly just have to focus on calculating the lookahead sets. + # + # The method used here is due to DeRemer and Pennelo (1982). + # + # DeRemer, F. L., and T. J. Pennelo: "Efficient Computation of LALR(1) + # Lookahead Sets", ACM Transactions on Programming Languages and Systems, + # Vol. 4, No. 4, Oct. 1982, pp. 615-649 + # + # Further details can also be found in: + # + # J. Tremblay and P. Sorenson, "The Theory and Practice of Compiler Writing", + # McGraw-Hill Book Company, (1985). + # + # ----------------------------------------------------------------------------- + + # ----------------------------------------------------------------------------- + # compute_nullable_nonterminals() + # + # Creates a dictionary containing all of the non-terminals that might produce + # an empty production. + # ----------------------------------------------------------------------------- + + def compute_nullable_nonterminals(self): + nullable = {} + num_nullable = 0 + while 1: + for p in self.grammar.Productions[1:]: + if p.len == 0: + nullable[p.name] = 1 + continue + for t in p.prod: + if not t in nullable: break + else: + nullable[p.name] = 1 + if len(nullable) == num_nullable: break + num_nullable = len(nullable) + return nullable + + # ----------------------------------------------------------------------------- + # find_nonterminal_trans(C) + # + # Given a set of LR(0) items, this functions finds all of the non-terminal + # transitions. These are transitions in which a dot appears immediately before + # a non-terminal. Returns a list of tuples of the form (state,N) where state + # is the state number and N is the nonterminal symbol. + # + # The input C is the set of LR(0) items. + # ----------------------------------------------------------------------------- + + def find_nonterminal_transitions(self,C): + trans = [] + for state in range(len(C)): + for p in C[state]: + if p.lr_index < p.len - 1: + t = (state,p.prod[p.lr_index+1]) + if t[1] in self.grammar.Nonterminals: + if t not in trans: trans.append(t) + state = state + 1 + return trans + + # ----------------------------------------------------------------------------- + # dr_relation() + # + # Computes the DR(p,A) relationships for non-terminal transitions. The input + # is a tuple (state,N) where state is a number and N is a nonterminal symbol. + # + # Returns a list of terminals. + # ----------------------------------------------------------------------------- + + def dr_relation(self,C,trans,nullable): + dr_set = { } + state,N = trans + terms = [] + + g = self.lr0_goto(C[state],N) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index+1] + if a in self.grammar.Terminals: + if a not in terms: terms.append(a) + + # This extra bit is to handle the start state + if state == 0 and N == self.grammar.Productions[0].prod[0]: + terms.append('$end') + + return terms + + # ----------------------------------------------------------------------------- + # reads_relation() + # + # Computes the READS() relation (p,A) READS (t,C). + # ----------------------------------------------------------------------------- + + def reads_relation(self,C, trans, empty): + # Look for empty transitions + rel = [] + state, N = trans + + g = self.lr0_goto(C[state],N) + j = self.lr0_cidhash.get(id(g),-1) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index + 1] + if a in empty: + rel.append((j,a)) + + return rel + + # ----------------------------------------------------------------------------- + # compute_lookback_includes() + # + # Determines the lookback and includes relations + # + # LOOKBACK: + # + # This relation is determined by running the LR(0) state machine forward. + # For example, starting with a production "N : . A B C", we run it forward + # to obtain "N : A B C ." We then build a relationship between this final + # state and the starting state. These relationships are stored in a dictionary + # lookdict. + # + # INCLUDES: + # + # Computes the INCLUDE() relation (p,A) INCLUDES (p',B). + # + # This relation is used to determine non-terminal transitions that occur + # inside of other non-terminal transition states. (p,A) INCLUDES (p', B) + # if the following holds: + # + # B -> LAT, where T -> epsilon and p' -L-> p + # + # L is essentially a prefix (which may be empty), T is a suffix that must be + # able to derive an empty string. State p' must lead to state p with the string L. + # + # ----------------------------------------------------------------------------- + + def compute_lookback_includes(self,C,trans,nullable): + + lookdict = {} # Dictionary of lookback relations + includedict = {} # Dictionary of include relations + + # Make a dictionary of non-terminal transitions + dtrans = {} + for t in trans: + dtrans[t] = 1 + + # Loop over all transitions and compute lookbacks and includes + for state,N in trans: + lookb = [] + includes = [] + for p in C[state]: + if p.name != N: continue + + # Okay, we have a name match. We now follow the production all the way + # through the state machine until we get the . on the right hand side + + lr_index = p.lr_index + j = state + while lr_index < p.len - 1: + lr_index = lr_index + 1 + t = p.prod[lr_index] + + # Check to see if this symbol and state are a non-terminal transition + if (j,t) in dtrans: + # Yes. Okay, there is some chance that this is an includes relation + # the only way to know for certain is whether the rest of the + # production derives empty + + li = lr_index + 1 + while li < p.len: + if p.prod[li] in self.grammar.Terminals: break # No forget it + if not p.prod[li] in nullable: break + li = li + 1 + else: + # Appears to be a relation between (j,t) and (state,N) + includes.append((j,t)) + + g = self.lr0_goto(C[j],t) # Go to next set + j = self.lr0_cidhash.get(id(g),-1) # Go to next state + + # When we get here, j is the final state, now we have to locate the production + for r in C[j]: + if r.name != p.name: continue + if r.len != p.len: continue + i = 0 + # This look is comparing a production ". A B C" with "A B C ." + while i < r.lr_index: + if r.prod[i] != p.prod[i+1]: break + i = i + 1 + else: + lookb.append((j,r)) + for i in includes: + if not i in includedict: includedict[i] = [] + includedict[i].append((state,N)) + lookdict[(state,N)] = lookb + + return lookdict,includedict + + # ----------------------------------------------------------------------------- + # compute_read_sets() + # + # Given a set of LR(0) items, this function computes the read sets. + # + # Inputs: C = Set of LR(0) items + # ntrans = Set of nonterminal transitions + # nullable = Set of empty transitions + # + # Returns a set containing the read sets + # ----------------------------------------------------------------------------- + + def compute_read_sets(self,C, ntrans, nullable): + FP = lambda x: self.dr_relation(C,x,nullable) + R = lambda x: self.reads_relation(C,x,nullable) + F = digraph(ntrans,R,FP) + return F + + # ----------------------------------------------------------------------------- + # compute_follow_sets() + # + # Given a set of LR(0) items, a set of non-terminal transitions, a readset, + # and an include set, this function computes the follow sets + # + # Follow(p,A) = Read(p,A) U U {Follow(p',B) | (p,A) INCLUDES (p',B)} + # + # Inputs: + # ntrans = Set of nonterminal transitions + # readsets = Readset (previously computed) + # inclsets = Include sets (previously computed) + # + # Returns a set containing the follow sets + # ----------------------------------------------------------------------------- + + def compute_follow_sets(self,ntrans,readsets,inclsets): + FP = lambda x: readsets[x] + R = lambda x: inclsets.get(x,[]) + F = digraph(ntrans,R,FP) + return F + + # ----------------------------------------------------------------------------- + # add_lookaheads() + # + # Attaches the lookahead symbols to grammar rules. + # + # Inputs: lookbacks - Set of lookback relations + # followset - Computed follow set + # + # This function directly attaches the lookaheads to productions contained + # in the lookbacks set + # ----------------------------------------------------------------------------- + + def add_lookaheads(self,lookbacks,followset): + for trans,lb in lookbacks.items(): + # Loop over productions in lookback + for state,p in lb: + if not state in p.lookaheads: + p.lookaheads[state] = [] + f = followset.get(trans,[]) + for a in f: + if a not in p.lookaheads[state]: p.lookaheads[state].append(a) + + # ----------------------------------------------------------------------------- + # add_lalr_lookaheads() + # + # This function does all of the work of adding lookahead information for use + # with LALR parsing + # ----------------------------------------------------------------------------- + + def add_lalr_lookaheads(self,C): + # Determine all of the nullable nonterminals + nullable = self.compute_nullable_nonterminals() + + # Find all non-terminal transitions + trans = self.find_nonterminal_transitions(C) + + # Compute read sets + readsets = self.compute_read_sets(C,trans,nullable) + + # Compute lookback/includes relations + lookd, included = self.compute_lookback_includes(C,trans,nullable) + + # Compute LALR FOLLOW sets + followsets = self.compute_follow_sets(trans,readsets,included) + + # Add all of the lookaheads + self.add_lookaheads(lookd,followsets) + + # ----------------------------------------------------------------------------- + # lr_parse_table() + # + # This function constructs the parse tables for SLR or LALR + # ----------------------------------------------------------------------------- + def lr_parse_table(self): + Productions = self.grammar.Productions + Precedence = self.grammar.Precedence + goto = self.lr_goto # Goto array + action = self.lr_action # Action array + log = self.log # Logger for output + + actionp = { } # Action production array (temporary) + + log.info("Parsing method: %s", self.lr_method) + + # Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items + # This determines the number of states + + C = self.lr0_items() + + if self.lr_method == 'LALR': + self.add_lalr_lookaheads(C) + + # Build the parser table, state by state + st = 0 + for I in C: + # Loop over each production in I + actlist = [ ] # List of actions + st_action = { } + st_actionp = { } + st_goto = { } + log.info("") + log.info("state %d", st) + log.info("") + for p in I: + log.info(" (%d) %s", p.number, str(p)) + log.info("") + + for p in I: + if p.len == p.lr_index + 1: + if p.name == "S'": + # Start symbol. Accept! + st_action["$end"] = 0 + st_actionp["$end"] = p + else: + # We are at the end of a production. Reduce! + if self.lr_method == 'LALR': + laheads = p.lookaheads[st] + else: + laheads = self.grammar.Follow[p.name] + for a in laheads: + actlist.append((a,p,"reduce using rule %d (%s)" % (p.number,p))) + r = st_action.get(a,None) + if r is not None: + # Whoa. Have a shift/reduce or reduce/reduce conflict + if r > 0: + # Need to decide on shift or reduce here + # By default we favor shifting. Need to add + # some precedence rules here. + sprec,slevel = Productions[st_actionp[a].number].prec + rprec,rlevel = Precedence.get(a,('right',0)) + if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')): + # We really need to reduce here. + st_action[a] = -p.number + st_actionp[a] = p + if not slevel and not rlevel: + log.info(" ! shift/reduce conflict for %s resolved as reduce",a) + self.sr_conflicts.append((st,a,'reduce')) + Productions[p.number].reduced += 1 + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the shift + if not rlevel: + log.info(" ! shift/reduce conflict for %s resolved as shift",a) + self.sr_conflicts.append((st,a,'shift')) + elif r < 0: + # Reduce/reduce conflict. In this case, we favor the rule + # that was defined first in the grammar file + oldp = Productions[-r] + pp = Productions[p.number] + if oldp.line > pp.line: + st_action[a] = -p.number + st_actionp[a] = p + chosenp,rejectp = pp,oldp + Productions[p.number].reduced += 1 + Productions[oldp.number].reduced -= 1 + else: + chosenp,rejectp = oldp,pp + self.rr_conflicts.append((st,chosenp,rejectp)) + log.info(" ! reduce/reduce conflict for %s resolved using rule %d (%s)", a,st_actionp[a].number, st_actionp[a]) + else: + raise LALRError("Unknown conflict in state %d" % st) + else: + st_action[a] = -p.number + st_actionp[a] = p + Productions[p.number].reduced += 1 + else: + i = p.lr_index + a = p.prod[i+1] # Get symbol right after the "." + if a in self.grammar.Terminals: + g = self.lr0_goto(I,a) + j = self.lr0_cidhash.get(id(g),-1) + if j >= 0: + # We are in a shift state + actlist.append((a,p,"shift and go to state %d" % j)) + r = st_action.get(a,None) + if r is not None: + # Whoa have a shift/reduce or shift/shift conflict + if r > 0: + if r != j: + raise LALRError("Shift/shift conflict in state %d" % st) + elif r < 0: + # Do a precedence check. + # - if precedence of reduce rule is higher, we reduce. + # - if precedence of reduce is same and left assoc, we reduce. + # - otherwise we shift + rprec,rlevel = Productions[st_actionp[a].number].prec + sprec,slevel = Precedence.get(a,('right',0)) + if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')): + # We decide to shift here... highest precedence to shift + Productions[st_actionp[a].number].reduced -= 1 + st_action[a] = j + st_actionp[a] = p + if not rlevel: + log.info(" ! shift/reduce conflict for %s resolved as shift",a) + self.sr_conflicts.append((st,a,'shift')) + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the reduce + if not slevel and not rlevel: + log.info(" ! shift/reduce conflict for %s resolved as reduce",a) + self.sr_conflicts.append((st,a,'reduce')) + + else: + raise LALRError("Unknown conflict in state %d" % st) + else: + st_action[a] = j + st_actionp[a] = p + + # Print the actions associated with each terminal + _actprint = { } + for a,p,m in actlist: + if a in st_action: + if p is st_actionp[a]: + log.info(" %-15s %s",a,m) + _actprint[(a,m)] = 1 + log.info("") + # Print the actions that were not used. (debugging) + not_used = 0 + for a,p,m in actlist: + if a in st_action: + if p is not st_actionp[a]: + if not (a,m) in _actprint: + log.debug(" ! %-15s [ %s ]",a,m) + not_used = 1 + _actprint[(a,m)] = 1 + if not_used: + log.debug("") + + # Construct the goto table for this state + + nkeys = { } + for ii in I: + for s in ii.usyms: + if s in self.grammar.Nonterminals: + nkeys[s] = None + for n in nkeys: + g = self.lr0_goto(I,n) + j = self.lr0_cidhash.get(id(g),-1) + if j >= 0: + st_goto[n] = j + log.info(" %-30s shift and go to state %d",n,j) + + action[st] = st_action + actionp[st] = st_actionp + goto[st] = st_goto + st += 1 + + + # ----------------------------------------------------------------------------- + # write() + # + # This function writes the LR parsing tables to a file + # ----------------------------------------------------------------------------- + + def write_table(self,modulename,outputdir='',signature=""): + basemodulename = modulename.split(".")[-1] + filename = os.path.join(outputdir,basemodulename) + ".py" + try: + f = open(filename,"w") + + f.write(""" +# %s +# This file is automatically generated. Do not edit. +_tabversion = %r + +_lr_method = %r + +_lr_signature = %r + """ % (filename, __tabversion__, self.lr_method, signature)) + + # Change smaller to 0 to go back to original tables + smaller = 1 + + # Factor out names to try and make smaller + if smaller: + items = { } + + for s,nd in self.lr_action.items(): + for name,v in nd.items(): + i = items.get(name) + if not i: + i = ([],[]) + items[name] = i + i[0].append(s) + i[1].append(v) + + f.write("\n_lr_action_items = {") + for k,v in items.items(): + f.write("%r:([" % k) + for i in v[0]: + f.write("%r," % i) + f.write("],[") + for i in v[1]: + f.write("%r," % i) + + f.write("]),") + f.write("}\n") + + f.write(""" +_lr_action = { } +for _k, _v in _lr_action_items.items(): + for _x,_y in zip(_v[0],_v[1]): + if not _x in _lr_action: _lr_action[_x] = { } + _lr_action[_x][_k] = _y +del _lr_action_items +""") + + else: + f.write("\n_lr_action = { "); + for k,v in self.lr_action.items(): + f.write("(%r,%r):%r," % (k[0],k[1],v)) + f.write("}\n"); + + if smaller: + # Factor out names to try and make smaller + items = { } + + for s,nd in self.lr_goto.items(): + for name,v in nd.items(): + i = items.get(name) + if not i: + i = ([],[]) + items[name] = i + i[0].append(s) + i[1].append(v) + + f.write("\n_lr_goto_items = {") + for k,v in items.items(): + f.write("%r:([" % k) + for i in v[0]: + f.write("%r," % i) + f.write("],[") + for i in v[1]: + f.write("%r," % i) + + f.write("]),") + f.write("}\n") + + f.write(""" +_lr_goto = { } +for _k, _v in _lr_goto_items.items(): + for _x,_y in zip(_v[0],_v[1]): + if not _x in _lr_goto: _lr_goto[_x] = { } + _lr_goto[_x][_k] = _y +del _lr_goto_items +""") + else: + f.write("\n_lr_goto = { "); + for k,v in self.lr_goto.items(): + f.write("(%r,%r):%r," % (k[0],k[1],v)) + f.write("}\n"); + + # Write production table + f.write("_lr_productions = [\n") + for p in self.lr_productions: + if p.func: + f.write(" (%r,%r,%d,%r,%r,%d),\n" % (p.str,p.name, p.len, p.func,p.file,p.line)) + else: + f.write(" (%r,%r,%d,None,None,None),\n" % (str(p),p.name, p.len)) + f.write("]\n") + f.close() + + except IOError: + e = sys.exc_info()[1] + sys.stderr.write("Unable to create '%s'\n" % filename) + sys.stderr.write(str(e)+"\n") + return + + + # ----------------------------------------------------------------------------- + # pickle_table() + # + # This function pickles the LR parsing tables to a supplied file object + # ----------------------------------------------------------------------------- + + def pickle_table(self,filename,signature=""): + try: + import cPickle as pickle + except ImportError: + import pickle + outf = open(filename,"wb") + pickle.dump(__tabversion__,outf,pickle_protocol) + pickle.dump(self.lr_method,outf,pickle_protocol) + pickle.dump(signature,outf,pickle_protocol) + pickle.dump(self.lr_action,outf,pickle_protocol) + pickle.dump(self.lr_goto,outf,pickle_protocol) + + outp = [] + for p in self.lr_productions: + if p.func: + outp.append((p.str,p.name, p.len, p.func,p.file,p.line)) + else: + outp.append((str(p),p.name,p.len,None,None,None)) + pickle.dump(outp,outf,pickle_protocol) + outf.close() + +# ----------------------------------------------------------------------------- +# === INTROSPECTION === +# +# The following functions and classes are used to implement the PLY +# introspection features followed by the yacc() function itself. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# get_caller_module_dict() +# +# This function returns a dictionary containing all of the symbols defined within +# a caller further down the call stack. This is used to get the environment +# associated with the yacc() call if none was provided. +# ----------------------------------------------------------------------------- + +def get_caller_module_dict(levels): + try: + raise RuntimeError + except RuntimeError: + e,b,t = sys.exc_info() + f = t.tb_frame + while levels > 0: + f = f.f_back + levels -= 1 + ldict = f.f_globals.copy() + if f.f_globals != f.f_locals: + ldict.update(f.f_locals) + + return ldict + +# ----------------------------------------------------------------------------- +# parse_grammar() +# +# This takes a raw grammar rule string and parses it into production data +# ----------------------------------------------------------------------------- +def parse_grammar(doc,file,line): + grammar = [] + # Split the doc string into lines + pstrings = doc.splitlines() + lastp = None + dline = line + for ps in pstrings: + dline += 1 + p = ps.split() + if not p: continue + try: + if p[0] == '|': + # This is a continuation of a previous rule + if not lastp: + raise SyntaxError("%s:%d: Misplaced '|'" % (file,dline)) + prodname = lastp + syms = p[1:] + else: + prodname = p[0] + lastp = prodname + syms = p[2:] + assign = p[1] + if assign != ':' and assign != '::=': + raise SyntaxError("%s:%d: Syntax error. Expected ':'" % (file,dline)) + + grammar.append((file,dline,prodname,syms)) + except SyntaxError: + raise + except Exception: + raise SyntaxError("%s:%d: Syntax error in rule '%s'" % (file,dline,ps.strip())) + + return grammar + +# ----------------------------------------------------------------------------- +# ParserReflect() +# +# This class represents information extracted for building a parser including +# start symbol, error function, tokens, precedence list, action functions, +# etc. +# ----------------------------------------------------------------------------- +class ParserReflect(object): + def __init__(self,pdict,log=None): + self.pdict = pdict + self.start = None + self.error_func = None + self.tokens = None + self.files = {} + self.grammar = [] + self.error = 0 + + if log is None: + self.log = PlyLogger(sys.stderr) + else: + self.log = log + + # Get all of the basic information + def get_all(self): + self.get_start() + self.get_error_func() + self.get_tokens() + self.get_precedence() + self.get_pfunctions() + + # Validate all of the information + def validate_all(self): + self.validate_start() + self.validate_error_func() + self.validate_tokens() + self.validate_precedence() + self.validate_pfunctions() + self.validate_files() + return self.error + + # Compute a signature over the grammar + def signature(self): + try: + from hashlib import md5 + except ImportError: + from md5 import md5 + try: + sig = md5() + if self.start: + sig.update(self.start.encode('latin-1')) + if self.prec: + sig.update("".join(["".join(p) for p in self.prec]).encode('latin-1')) + if self.tokens: + sig.update(" ".join(self.tokens).encode('latin-1')) + for f in self.pfuncs: + if f[3]: + sig.update(f[3].encode('latin-1')) + except (TypeError,ValueError): + pass + return sig.digest() + + # ----------------------------------------------------------------------------- + # validate_file() + # + # This method checks to see if there are duplicated p_rulename() functions + # in the parser module file. Without this function, it is really easy for + # users to make mistakes by cutting and pasting code fragments (and it's a real + # bugger to try and figure out why the resulting parser doesn't work). Therefore, + # we just do a little regular expression pattern matching of def statements + # to try and detect duplicates. + # ----------------------------------------------------------------------------- + + def validate_files(self): + # Match def p_funcname( + fre = re.compile(r'\s*def\s+(p_[a-zA-Z_0-9]*)\(') + + for filename in self.files.keys(): + base,ext = os.path.splitext(filename) + if ext != '.py': return 1 # No idea. Assume it's okay. + + try: + f = open(filename) + lines = f.readlines() + f.close() + except IOError: + continue + + counthash = { } + for linen,l in enumerate(lines): + linen += 1 + m = fre.match(l) + if m: + name = m.group(1) + prev = counthash.get(name) + if not prev: + counthash[name] = linen + else: + self.log.warning("%s:%d: Function %s redefined. Previously defined on line %d", filename,linen,name,prev) + + # Get the start symbol + def get_start(self): + self.start = self.pdict.get('start') + + # Validate the start symbol + def validate_start(self): + if self.start is not None: + if not isinstance(self.start,str): + self.log.error("'start' must be a string") + + # Look for error handler + def get_error_func(self): + self.error_func = self.pdict.get('p_error') + + # Validate the error function + def validate_error_func(self): + if self.error_func: + if isinstance(self.error_func,types.FunctionType): + ismethod = 0 + elif isinstance(self.error_func, types.MethodType): + ismethod = 1 + else: + self.log.error("'p_error' defined, but is not a function or method") + self.error = 1 + return + + eline = func_code(self.error_func).co_firstlineno + efile = func_code(self.error_func).co_filename + self.files[efile] = 1 + + if (func_code(self.error_func).co_argcount != 1+ismethod): + self.log.error("%s:%d: p_error() requires 1 argument",efile,eline) + self.error = 1 + + # Get the tokens map + def get_tokens(self): + tokens = self.pdict.get("tokens",None) + if not tokens: + self.log.error("No token list is defined") + self.error = 1 + return + + if not isinstance(tokens,(list, tuple)): + self.log.error("tokens must be a list or tuple") + self.error = 1 + return + + if not tokens: + self.log.error("tokens is empty") + self.error = 1 + return + + self.tokens = tokens + + # Validate the tokens + def validate_tokens(self): + # Validate the tokens. + if 'error' in self.tokens: + self.log.error("Illegal token name 'error'. Is a reserved word") + self.error = 1 + return + + terminals = {} + for n in self.tokens: + if n in terminals: + self.log.warning("Token '%s' multiply defined", n) + terminals[n] = 1 + + # Get the precedence map (if any) + def get_precedence(self): + self.prec = self.pdict.get("precedence",None) + + # Validate and parse the precedence map + def validate_precedence(self): + preclist = [] + if self.prec: + if not isinstance(self.prec,(list,tuple)): + self.log.error("precedence must be a list or tuple") + self.error = 1 + return + for level,p in enumerate(self.prec): + if not isinstance(p,(list,tuple)): + self.log.error("Bad precedence table") + self.error = 1 + return + + if len(p) < 2: + self.log.error("Malformed precedence entry %s. Must be (assoc, term, ..., term)",p) + self.error = 1 + return + assoc = p[0] + if not isinstance(assoc,str): + self.log.error("precedence associativity must be a string") + self.error = 1 + return + for term in p[1:]: + if not isinstance(term,str): + self.log.error("precedence items must be strings") + self.error = 1 + return + preclist.append((term,assoc,level+1)) + self.preclist = preclist + + # Get all p_functions from the grammar + def get_pfunctions(self): + p_functions = [] + for name, item in self.pdict.items(): + if name[:2] != 'p_': continue + if name == 'p_error': continue + if isinstance(item,(types.FunctionType,types.MethodType)): + line = func_code(item).co_firstlineno + file = func_code(item).co_filename + p_functions.append((line,file,name,item.__doc__)) + + # Sort all of the actions by line number + p_functions.sort() + self.pfuncs = p_functions + + + # Validate all of the p_functions + def validate_pfunctions(self): + grammar = [] + # Check for non-empty symbols + if len(self.pfuncs) == 0: + self.log.error("no rules of the form p_rulename are defined") + self.error = 1 + return + + for line, file, name, doc in self.pfuncs: + func = self.pdict[name] + if isinstance(func, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + if func_code(func).co_argcount > reqargs: + self.log.error("%s:%d: Rule '%s' has too many arguments",file,line,func.__name__) + self.error = 1 + elif func_code(func).co_argcount < reqargs: + self.log.error("%s:%d: Rule '%s' requires an argument",file,line,func.__name__) + self.error = 1 + elif not func.__doc__: + self.log.warning("%s:%d: No documentation string specified in function '%s' (ignored)",file,line,func.__name__) + else: + try: + parsed_g = parse_grammar(doc,file,line) + for g in parsed_g: + grammar.append((name, g)) + except SyntaxError: + e = sys.exc_info()[1] + self.log.error(str(e)) + self.error = 1 + + # Looks like a valid grammar rule + # Mark the file in which defined. + self.files[file] = 1 + + # Secondary validation step that looks for p_ definitions that are not functions + # or functions that look like they might be grammar rules. + + for n,v in self.pdict.items(): + if n[0:2] == 'p_' and isinstance(v, (types.FunctionType, types.MethodType)): continue + if n[0:2] == 't_': continue + if n[0:2] == 'p_' and n != 'p_error': + self.log.warning("'%s' not defined as a function", n) + if ((isinstance(v,types.FunctionType) and func_code(v).co_argcount == 1) or + (isinstance(v,types.MethodType) and func_code(v).co_argcount == 2)): + try: + doc = v.__doc__.split(" ") + if doc[1] == ':': + self.log.warning("%s:%d: Possible grammar rule '%s' defined without p_ prefix", + func_code(v).co_filename, func_code(v).co_firstlineno,n) + except Exception: + pass + + self.grammar = grammar + +# ----------------------------------------------------------------------------- +# yacc(module) +# +# Build a parser +# ----------------------------------------------------------------------------- + +def yacc(method='LALR', debug=yaccdebug, module=None, tabmodule=tab_module, start=None, + check_recursion=1, optimize=0, write_tables=1, debugfile=debug_file,outputdir='', + debuglog=None, errorlog = None, picklefile=None): + + global parse # Reference to the parsing method of the last built parser + + # If pickling is enabled, table files are not created + + if picklefile: + write_tables = 0 + + if errorlog is None: + errorlog = PlyLogger(sys.stderr) + + # Get the module dictionary used for the parser + if module: + _items = [(k,getattr(module,k)) for k in dir(module)] + pdict = dict(_items) + else: + pdict = get_caller_module_dict(2) + + # Collect parser information from the dictionary + pinfo = ParserReflect(pdict,log=errorlog) + pinfo.get_all() + + if pinfo.error: + raise YaccError("Unable to build parser") + + # Check signature against table files (if any) + signature = pinfo.signature() + + # Read the tables + try: + lr = LRTable() + if picklefile: + read_signature = lr.read_pickle(picklefile) + else: + read_signature = lr.read_table(tabmodule) + if optimize or (read_signature == signature): + try: + lr.bind_callables(pinfo.pdict) + parser = LRParser(lr,pinfo.error_func) + parse = parser.parse + return parser + except Exception: + e = sys.exc_info()[1] + errorlog.warning("There was a problem loading the table file: %s", repr(e)) + except VersionError: + e = sys.exc_info() + errorlog.warning(str(e)) + except Exception: + pass + + if debuglog is None: + if debug: + debuglog = PlyLogger(open(debugfile,"w")) + else: + debuglog = NullLogger() + + debuglog.info("Created by PLY version %s (http://www.dabeaz.com/ply)", __version__) + + + errors = 0 + + # Validate the parser information + if pinfo.validate_all(): + raise YaccError("Unable to build parser") + + if not pinfo.error_func: + errorlog.warning("no p_error() function is defined") + + # Create a grammar object + grammar = Grammar(pinfo.tokens) + + # Set precedence level for terminals + for term, assoc, level in pinfo.preclist: + try: + grammar.set_precedence(term,assoc,level) + except GrammarError: + e = sys.exc_info()[1] + errorlog.warning("%s",str(e)) + + # Add productions to the grammar + for funcname, gram in pinfo.grammar: + file, line, prodname, syms = gram + try: + grammar.add_production(prodname,syms,funcname,file,line) + except GrammarError: + e = sys.exc_info()[1] + errorlog.error("%s",str(e)) + errors = 1 + + # Set the grammar start symbols + try: + if start is None: + grammar.set_start(pinfo.start) + else: + grammar.set_start(start) + except GrammarError: + e = sys.exc_info()[1] + errorlog.error(str(e)) + errors = 1 + + if errors: + raise YaccError("Unable to build parser") + + # Verify the grammar structure + undefined_symbols = grammar.undefined_symbols() + for sym, prod in undefined_symbols: + errorlog.error("%s:%d: Symbol '%s' used, but not defined as a token or a rule",prod.file,prod.line,sym) + errors = 1 + + unused_terminals = grammar.unused_terminals() + if unused_terminals: + debuglog.info("") + debuglog.info("Unused terminals:") + debuglog.info("") + for term in unused_terminals: + errorlog.warning("Token '%s' defined, but not used", term) + debuglog.info(" %s", term) + + # Print out all productions to the debug log + if debug: + debuglog.info("") + debuglog.info("Grammar") + debuglog.info("") + for n,p in enumerate(grammar.Productions): + debuglog.info("Rule %-5d %s", n, p) + + # Find unused non-terminals + unused_rules = grammar.unused_rules() + for prod in unused_rules: + errorlog.warning("%s:%d: Rule '%s' defined, but not used", prod.file, prod.line, prod.name) + + if len(unused_terminals) == 1: + errorlog.warning("There is 1 unused token") + if len(unused_terminals) > 1: + errorlog.warning("There are %d unused tokens", len(unused_terminals)) + + if len(unused_rules) == 1: + errorlog.warning("There is 1 unused rule") + if len(unused_rules) > 1: + errorlog.warning("There are %d unused rules", len(unused_rules)) + + if debug: + debuglog.info("") + debuglog.info("Terminals, with rules where they appear") + debuglog.info("") + terms = list(grammar.Terminals) + terms.sort() + for term in terms: + debuglog.info("%-20s : %s", term, " ".join([str(s) for s in grammar.Terminals[term]])) + + debuglog.info("") + debuglog.info("Nonterminals, with rules where they appear") + debuglog.info("") + nonterms = list(grammar.Nonterminals) + nonterms.sort() + for nonterm in nonterms: + debuglog.info("%-20s : %s", nonterm, " ".join([str(s) for s in grammar.Nonterminals[nonterm]])) + debuglog.info("") + + if check_recursion: + unreachable = grammar.find_unreachable() + for u in unreachable: + errorlog.warning("Symbol '%s' is unreachable",u) + + infinite = grammar.infinite_cycles() + for inf in infinite: + errorlog.error("Infinite recursion detected for symbol '%s'", inf) + errors = 1 + + unused_prec = grammar.unused_precedence() + for term, assoc in unused_prec: + errorlog.error("Precedence rule '%s' defined for unknown symbol '%s'", assoc, term) + errors = 1 + + if errors: + raise YaccError("Unable to build parser") + + # Run the LRGeneratedTable on the grammar + if debug: + errorlog.debug("Generating %s tables", method) + + lr = LRGeneratedTable(grammar,method,debuglog) + + if debug: + num_sr = len(lr.sr_conflicts) + + # Report shift/reduce and reduce/reduce conflicts + if num_sr == 1: + errorlog.warning("1 shift/reduce conflict") + elif num_sr > 1: + errorlog.warning("%d shift/reduce conflicts", num_sr) + + num_rr = len(lr.rr_conflicts) + if num_rr == 1: + errorlog.warning("1 reduce/reduce conflict") + elif num_rr > 1: + errorlog.warning("%d reduce/reduce conflicts", num_rr) + + # Write out conflicts to the output file + if debug and (lr.sr_conflicts or lr.rr_conflicts): + debuglog.warning("") + debuglog.warning("Conflicts:") + debuglog.warning("") + + for state, tok, resolution in lr.sr_conflicts: + debuglog.warning("shift/reduce conflict for %s in state %d resolved as %s", tok, state, resolution) + + already_reported = {} + for state, rule, rejected in lr.rr_conflicts: + if (state,id(rule),id(rejected)) in already_reported: + continue + debuglog.warning("reduce/reduce conflict in state %d resolved using rule (%s)", state, rule) + debuglog.warning("rejected rule (%s) in state %d", rejected,state) + errorlog.warning("reduce/reduce conflict in state %d resolved using rule (%s)", state, rule) + errorlog.warning("rejected rule (%s) in state %d", rejected, state) + already_reported[state,id(rule),id(rejected)] = 1 + + warned_never = [] + for state, rule, rejected in lr.rr_conflicts: + if not rejected.reduced and (rejected not in warned_never): + debuglog.warning("Rule (%s) is never reduced", rejected) + errorlog.warning("Rule (%s) is never reduced", rejected) + warned_never.append(rejected) + + # Write the table file if requested + if write_tables: + lr.write_table(tabmodule,outputdir,signature) + + # Write a pickled version of the tables + if picklefile: + lr.pickle_table(picklefile,signature) + + # Build the parser + lr.bind_callables(pinfo.pdict) + parser = LRParser(lr,pinfo.error_func) + + parse = parser.parse + return parser diff --git a/tools/cgrep.py b/tools/cgrep.py new file mode 100755 index 0000000..bc7a993 --- /dev/null +++ b/tools/cgrep.py @@ -0,0 +1,80 @@ +#!/usr/bin/python +# +# Simply util to grep through network definitions. +# Examples: +# To find out which tokens contain "10.4.3.1" use +# $ cgrep.py -i 10.4.3.1 +# +# To find out if token 'FOO' includes ip "1.2.3.4" use +# $ cgrep.py -t FOO -i 1.2.3.4 +# +# To find the difference and union of tokens 'FOO' and 'BAR' use +# $ cgrep.py -c FOO BAR +# +__author__ = "watson@google.com (Tony Watson)" + +import sys +sys.path.append('../') +from lib import naming +from lib import nacaddr +from optparse import OptionParser + +def main(argv): + parser = OptionParser() + + parser.add_option("-d", "--def", dest="defs", action="store", + help="Network Definitions directory location", + default="../def") + parser.add_option("-i", "--ip", dest="ip", action="store", + help="Return list of defintions containing this IP. " + "Multiple IPs permitted.") + + parser.add_option("-t", "--token", dest="token", action="store", + help="See if an IP is contained within this token." + "Must be used in conjunction with --ip [addr].") + + parser.add_option("-c", "--cmp", dest="cmp", action="store_true", + help="Compare two network definition tokens") + + (options, args) = parser.parse_args() + + db = naming.Naming(options.defs) + + if options.ip is not None and options.token is None: + for arg in sys.argv[2:]: + print "%s: " % arg + rval = db.GetIpParents(arg) + print rval + + if options.token is not None and options.ip is None: + print "You must specify and IP Address with --ip [addr] to check." + sys.exit(0) + + if options.token is not None and options.ip is not None: + token = options.token + ip = options.ip + rval = db.GetIpParents(ip) + if token in rval: + print '%s is in %s' % (ip, token) + else: + print '%s is not in %s' % (ip, token) + + if options.cmp is not None: + t1 = argv[2] + t2 = argv[3] + d1 = db.GetNet(t1) + d2 = db.GetNet(t2) + union = list(set(d1 + d2)) + print 'Union of %s and %s:\n %s\n' % (t1, t2, union) + print 'Diff of %s and %s:' % (t1, t2) + for el in set(d1 + d2): + el = nacaddr.IP(el) + if el in d1 and el in d2: + print ' %s' % el + elif el in d1: + print '+ %s' % el + elif el in d2: + print '- %s' % el + +if __name__ == '__main__': + main(sys.argv) diff --git a/tools/get-country-zones.pl b/tools/get-country-zones.pl new file mode 100755 index 0000000..93a0c48 --- /dev/null +++ b/tools/get-country-zones.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl +# +# Author: Paul Armstrong +# +# Downloads maps of countries to CIDR netblocks for the world and then turns +# them into definition files usable by Capirca + +use strict; +use warnings; +use File::Find; + +my @files; +my $destination = '../def/'; +my $extension = '.net'; + +system("wget http://www.ipdeny.com/ipblocks/data/countries/all-zones.tar.gz") + == 0 or die "Unable to get all-zones.tar.gz: $?\n"; + +system("tar -zxf all-zones.tar.gz") == 0 + or die "Unable to untar all-zones.tar.gz: $?\n"; + +# We don't need these lying around +unlink("Copyrights.txt"); +unlink("MD5SUM"); +unlink("all-zones.tar.gz"); + +sub zone_files +{ + push @files, $File::Find::name if(/\.zone$/i); +} + +find(\&zone_files, $ENV{PWD}); + +for my $file (@files) +{ + if($file =~ /^.*\/([a-z]{2})\.zone/) + { + my $country = $1; + my $new_name = "$destination$country$extension"; + my $country_uc = uc($country); + die "$file is zero bytes\n" if(!-s $file); + open(OLDFILE, $file) or die "Unable to open $file: $!\n"; + open(NEWFILE, ">$new_name") + or die "Unable to open $new_name: $!\n"; + while() + { + chomp; + if ($. == 1) + { + print NEWFILE "${country_uc}_NETBLOCKS = $_\n" + or die "Unable to print to $new_name: $!\n"; + } + else + { + print NEWFILE " $_\n" + or die "Unable to print to $new_name: $!\n"; + } + } + close(NEWFILE) or die "$new_name didn't close properly: $!\n"; + close(OLDFILE); + die "$new_name is zero bytes\n" if(!-s $new_name); + unlink($file); # clean up the originals. + } +} -- cgit v1.1