commit cf31fe9b4fb650b27e19f5d7ee7297e383660caf Author: The Android Open Source Project Date: Tue Oct 21 07:00:00 2008 -0700 Initial Contribution diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0d20b648 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/COPYING b/COPYING new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/COPYING @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..0184e08a --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +# +# Copyright 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +GERRIT_SRC=../gerrit +GERRIT_MODULES=codereview froofle + +all: + +clean: + find . -name \*.pyc -type f | xargs rm -f + +update-pyclient: + $(MAKE) -C $(GERRIT_SRC) release-pyclient + rm -rf $(GERRIT_MODULES) + (cd $(GERRIT_SRC)/release/pyclient && \ + find . -type f \ + | cpio -pd $(abspath .)) diff --git a/codereview/__init__.py b/codereview/__init__.py new file mode 100644 index 00000000..e47bc94e --- /dev/null +++ b/codereview/__init__.py @@ -0,0 +1 @@ +__version__ = 'v1.0' diff --git a/codereview/need_retry_pb2.py b/codereview/need_retry_pb2.py new file mode 100644 index 00000000..3fab2d43 --- /dev/null +++ b/codereview/need_retry_pb2.py @@ -0,0 +1,32 @@ +#!/usr/bin/python2.4 +# Generated by the protocol buffer compiler. DO NOT EDIT! + +from froofle.protobuf import descriptor +from froofle.protobuf import message +from froofle.protobuf import reflection +from froofle.protobuf import service +from froofle.protobuf import service_reflection +from froofle.protobuf import descriptor_pb2 + + + +_RETRYREQUESTLATERRESPONSE = descriptor.Descriptor( + name='RetryRequestLaterResponse', + full_name='codereview.RetryRequestLaterResponse', + filename='need_retry.proto', + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + + +class RetryRequestLaterResponse(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _RETRYREQUESTLATERRESPONSE + diff --git a/codereview/proto_client.py b/codereview/proto_client.py new file mode 100755 index 00000000..e11beff0 --- /dev/null +++ b/codereview/proto_client.py @@ -0,0 +1,349 @@ +# Copyright 2007, 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import cookielib +import getpass +import logging +import md5 +import os +import random +import socket +import time +import urllib +import urllib2 +import urlparse + +from froofle.protobuf.service import RpcChannel +from froofle.protobuf.service import RpcController +from need_retry_pb2 import RetryRequestLaterResponse; + +class ClientLoginError(urllib2.HTTPError): + """Raised to indicate an error authenticating with ClientLogin.""" + + def __init__(self, url, code, msg, headers, args): + urllib2.HTTPError.__init__(self, url, code, msg, headers, None) + self.args = args + self.reason = args["Error"] + + +class Proxy(object): + class _ResultHolder(object): + def __call__(self, result): + self._result = result + + class _RemoteController(RpcController): + def Reset(self): + pass + + def Failed(self): + pass + + def ErrorText(self): + pass + + def StartCancel(self): + pass + + def SetFailed(self, reason): + raise RuntimeError, reason + + def IsCancelled(self): + pass + + def NotifyOnCancel(self, callback): + pass + + def __init__(self, stub): + self._stub = stub + + def __getattr__(self, key): + method = getattr(self._stub, key) + + def call(request): + done = self._ResultHolder() + method(self._RemoteController(), request, done) + return done._result + + return call + + +class HttpRpc(RpcChannel): + """Simple protobuf over HTTP POST implementation.""" + + def __init__(self, host, auth_function, + host_override=None, + extra_headers={}, + cookie_file=None): + """Creates a new HttpRpc. + + Args: + host: The host to send requests to. + auth_function: A function that takes no arguments and returns an + (email, password) tuple when called. Will be called if authentication + is required. + host_override: The host header to send to the server (defaults to host). + extra_headers: A dict of extra headers to append to every request. + cookie_file: If not None, name of the file in ~/ to save the + cookie jar into. Applications are encouraged to set this to + '.$appname_cookies' or some otherwise unique name. + """ + self.host = host.lower() + self.host_override = host_override + self.auth_function = auth_function + self.authenticated = False + self.extra_headers = extra_headers + self.xsrf_token = None + if cookie_file is None: + self.cookie_file = None + else: + self.cookie_file = os.path.expanduser("~/%s" % cookie_file) + self.opener = self._GetOpener() + if self.host_override: + logging.info("Server: %s; Host: %s", self.host, self.host_override) + else: + logging.info("Server: %s", self.host) + + def CallMethod(self, method, controller, request, response_type, done): + pat = "application/x-google-protobuf; name=%s" + + url = "/proto/%s/%s" % (method.containing_service.name, method.name) + reqbin = request.SerializeToString() + reqtyp = pat % request.DESCRIPTOR.full_name + reqmd5 = base64.b64encode(md5.new(reqbin).digest()) + + start = time.time() + while True: + t, b = self._Send(url, reqbin, reqtyp, reqmd5) + if t == (pat % RetryRequestLaterResponse.DESCRIPTOR.full_name): + if time.time() >= (start + 1800): + controller.SetFailed("timeout") + return + s = random.uniform(0.250, 2.000) + print "Busy, retrying in %.3f seconds ..." % s + time.sleep(s) + continue + + if t == (pat % response_type.DESCRIPTOR.full_name): + response = response_type() + response.ParseFromString(b) + done(response) + else: + controller.SetFailed("Unexpected %s response" % t) + break + + def _CreateRequest(self, url, data=None): + """Creates a new urllib request.""" + logging.debug("Creating request for: '%s' with payload:\n%s", url, data) + req = urllib2.Request(url, data=data) + if self.host_override: + req.add_header("Host", self.host_override) + for key, value in self.extra_headers.iteritems(): + req.add_header(key, value) + return req + + def _GetAuthToken(self, email, password): + """Uses ClientLogin to authenticate the user, returning an auth token. + + Args: + email: The user's email address + password: The user's password + + Raises: + ClientLoginError: If there was an error authenticating with ClientLogin. + HTTPError: If there was some other form of HTTP error. + + Returns: + The authentication token returned by ClientLogin. + """ + req = self._CreateRequest( + url="https://www.google.com/accounts/ClientLogin", + data=urllib.urlencode({ + "Email": email, + "Passwd": password, + "service": "ah", + "source": "gerrit-codereview-client", + "accountType": "HOSTED_OR_GOOGLE", + }) + ) + try: + response = self.opener.open(req) + response_body = response.read() + response_dict = dict(x.split("=") + for x in response_body.split("\n") if x) + return response_dict["Auth"] + except urllib2.HTTPError, e: + if e.code == 403: + body = e.read() + response_dict = dict(x.split("=", 1) for x in body.split("\n") if x) + raise ClientLoginError(req.get_full_url(), e.code, e.msg, + e.headers, response_dict) + else: + raise + + def _GetAuthCookie(self, auth_token): + """Fetches authentication cookies for an authentication token. + + Args: + auth_token: The authentication token returned by ClientLogin. + + Raises: + HTTPError: If there was an error fetching the authentication cookies. + """ + # This is a dummy value to allow us to identify when we're successful. + continue_location = "http://localhost/" + args = {"continue": continue_location, "auth": auth_token} + req = self._CreateRequest("http://%s/_ah/login?%s" % + (self.host, urllib.urlencode(args))) + try: + response = self.opener.open(req) + except urllib2.HTTPError, e: + response = e + if (response.code != 302 or + response.info()["location"] != continue_location): + raise urllib2.HTTPError(req.get_full_url(), response.code, response.msg, + response.headers, response.fp) + self.authenticated = True + + def _GetXsrfToken(self): + """Fetches /proto/_token for use in X-XSRF-Token HTTP header. + + Raises: + HTTPError: If there was an error fetching a new token. + """ + tries = 0 + while True: + url = "http://%s/proto/_token" % self.host + req = self._CreateRequest(url) + try: + response = self.opener.open(req) + self.xsrf_token = response.read() + return + except urllib2.HTTPError, e: + if tries > 3: + raise + elif e.code == 401: + self._Authenticate() + else: + raise + + def _Authenticate(self): + """Authenticates the user. + + The authentication process works as follows: + 1) We get a username and password from the user + 2) We use ClientLogin to obtain an AUTH token for the user + (see http://code.google.com/apis/accounts/AuthForInstalledApps.html). + 3) We pass the auth token to /_ah/login on the server to obtain an + authentication cookie. If login was successful, it tries to redirect + us to the URL we provided. + + If we attempt to access the upload API without first obtaining an + authentication cookie, it returns a 401 response and directs us to + authenticate ourselves with ClientLogin. + """ + for i in range(3): + credentials = self.auth_function() + auth_token = self._GetAuthToken(credentials[0], credentials[1]) + self._GetAuthCookie(auth_token) + if self.cookie_file is not None: + self.cookie_jar.save() + return + + def _Send(self, request_path, payload, content_type, content_md5): + """Sends an RPC and returns the response. + + Args: + request_path: The path to send the request to, eg /api/appversion/create. + payload: The body of the request, or None to send an empty request. + content_type: The Content-Type header to use. + content_md5: The Content-MD5 header to use. + + Returns: + The content type, as a string. + The response body, as a string. + """ + if not self.authenticated: + self._Authenticate() + if not self.xsrf_token: + self._GetXsrfToken() + + old_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(None) + try: + tries = 0 + while True: + tries += 1 + url = "http://%s%s" % (self.host, request_path) + req = self._CreateRequest(url=url, data=payload) + req.add_header("Content-Type", content_type) + req.add_header("Content-MD5", content_md5) + req.add_header("X-XSRF-Token", self.xsrf_token) + try: + f = self.opener.open(req) + hdr = f.info() + type = hdr.getheader('Content-Type', + 'application/octet-stream') + response = f.read() + f.close() + return type, response + except urllib2.HTTPError, e: + if tries > 3: + raise + elif e.code == 401: + self._Authenticate() + elif e.code == 403: + if not hasattr(e, 'read'): + e.read = lambda self: '' + raise RuntimeError, '403\nxsrf: %s\n%s' \ + % (self.xsrf_token, e.read()) + else: + raise + finally: + socket.setdefaulttimeout(old_timeout) + + def _GetOpener(self): + """Returns an OpenerDirector that supports cookies and ignores redirects. + + Returns: + A urllib2.OpenerDirector object. + """ + opener = urllib2.OpenerDirector() + opener.add_handler(urllib2.ProxyHandler()) + opener.add_handler(urllib2.UnknownHandler()) + opener.add_handler(urllib2.HTTPHandler()) + opener.add_handler(urllib2.HTTPDefaultErrorHandler()) + opener.add_handler(urllib2.HTTPSHandler()) + opener.add_handler(urllib2.HTTPErrorProcessor()) + if self.cookie_file is not None: + self.cookie_jar = cookielib.MozillaCookieJar(self.cookie_file) + if os.path.exists(self.cookie_file): + try: + self.cookie_jar.load() + self.authenticated = True + except (cookielib.LoadError, IOError): + # Failed to load cookies - just ignore them. + pass + else: + # Create an empty cookie file with mode 600 + fd = os.open(self.cookie_file, os.O_CREAT, 0600) + os.close(fd) + # Always chmod the cookie file + os.chmod(self.cookie_file, 0600) + else: + # Don't save cookies across runs of update.py. + self.cookie_jar = cookielib.CookieJar() + opener.add_handler(urllib2.HTTPCookieProcessor(self.cookie_jar)) + return opener + diff --git a/codereview/review_pb2.py b/codereview/review_pb2.py new file mode 100644 index 00000000..0896feba --- /dev/null +++ b/codereview/review_pb2.py @@ -0,0 +1,48 @@ +#!/usr/bin/python2.4 +# Generated by the protocol buffer compiler. DO NOT EDIT! + +from froofle.protobuf import descriptor +from froofle.protobuf import message +from froofle.protobuf import reflection +from froofle.protobuf import service +from froofle.protobuf import service_reflection +from froofle.protobuf import descriptor_pb2 + + +import upload_bundle_pb2 + + + +_REVIEWSERVICE = descriptor.ServiceDescriptor( + name='ReviewService', + full_name='codereview.ReviewService', + index=0, + options=None, + methods=[ + descriptor.MethodDescriptor( + name='UploadBundle', + full_name='codereview.ReviewService.UploadBundle', + index=0, + containing_service=None, + input_type=upload_bundle_pb2._UPLOADBUNDLEREQUEST, + output_type=upload_bundle_pb2._UPLOADBUNDLERESPONSE, + options=None, + ), + descriptor.MethodDescriptor( + name='ContinueBundle', + full_name='codereview.ReviewService.ContinueBundle', + index=1, + containing_service=None, + input_type=upload_bundle_pb2._UPLOADBUNDLECONTINUE, + output_type=upload_bundle_pb2._UPLOADBUNDLERESPONSE, + options=None, + ), +]) + +class ReviewService(service.Service): + __metaclass__ = service_reflection.GeneratedServiceType + DESCRIPTOR = _REVIEWSERVICE +class ReviewService_Stub(ReviewService): + __metaclass__ = service_reflection.GeneratedServiceStubType + DESCRIPTOR = _REVIEWSERVICE + diff --git a/codereview/upload_bundle_pb2.py b/codereview/upload_bundle_pb2.py new file mode 100644 index 00000000..48c36512 --- /dev/null +++ b/codereview/upload_bundle_pb2.py @@ -0,0 +1,190 @@ +#!/usr/bin/python2.4 +# Generated by the protocol buffer compiler. DO NOT EDIT! + +from froofle.protobuf import descriptor +from froofle.protobuf import message +from froofle.protobuf import reflection +from froofle.protobuf import service +from froofle.protobuf import service_reflection +from froofle.protobuf import descriptor_pb2 + + +_UPLOADBUNDLERESPONSE_CODETYPE = descriptor.EnumDescriptor( + name='CodeType', + full_name='codereview.UploadBundleResponse.CodeType', + filename='CodeType', + values=[ + descriptor.EnumValueDescriptor( + name='RECEIVED', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='CONTINUE', index=1, number=4, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='UNAUTHORIZED_USER', index=2, number=7, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='UNKNOWN_PROJECT', index=3, number=2, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='UNKNOWN_BRANCH', index=4, number=3, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='UNKNOWN_BUNDLE', index=5, number=5, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='NOT_BUNDLE_OWNER', index=6, number=6, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='BUNDLE_CLOSED', index=7, number=8, + options=None, + type=None), + ], + options=None, +) + + +_UPLOADBUNDLEREQUEST = descriptor.Descriptor( + name='UploadBundleRequest', + full_name='codereview.UploadBundleRequest', + filename='upload_bundle.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='dest_project', full_name='codereview.UploadBundleRequest.dest_project', index=0, + number=10, type=9, cpp_type=9, label=2, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='dest_branch', full_name='codereview.UploadBundleRequest.dest_branch', index=1, + number=11, type=9, cpp_type=9, label=2, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='partial_upload', full_name='codereview.UploadBundleRequest.partial_upload', index=2, + number=12, type=8, cpp_type=7, label=2, + default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bundle_data', full_name='codereview.UploadBundleRequest.bundle_data', index=3, + number=13, type=12, cpp_type=9, label=2, + default_value="", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='contained_object', full_name='codereview.UploadBundleRequest.contained_object', index=4, + number=1, type=9, cpp_type=9, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_UPLOADBUNDLERESPONSE = descriptor.Descriptor( + name='UploadBundleResponse', + full_name='codereview.UploadBundleResponse', + filename='upload_bundle.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='status_code', full_name='codereview.UploadBundleResponse.status_code', index=0, + number=10, type=14, cpp_type=8, label=2, + default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bundle_id', full_name='codereview.UploadBundleResponse.bundle_id', index=1, + number=11, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + _UPLOADBUNDLERESPONSE_CODETYPE, + ], + options=None) + + +_UPLOADBUNDLECONTINUE = descriptor.Descriptor( + name='UploadBundleContinue', + full_name='codereview.UploadBundleContinue', + filename='upload_bundle.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='bundle_id', full_name='codereview.UploadBundleContinue.bundle_id', index=0, + number=10, type=9, cpp_type=9, label=2, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='segment_id', full_name='codereview.UploadBundleContinue.segment_id', index=1, + number=11, type=5, cpp_type=1, label=2, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='partial_upload', full_name='codereview.UploadBundleContinue.partial_upload', index=2, + number=12, type=8, cpp_type=7, label=2, + default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bundle_data', full_name='codereview.UploadBundleContinue.bundle_data', index=3, + number=13, type=12, cpp_type=9, label=1, + default_value="", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_UPLOADBUNDLERESPONSE.fields_by_name['status_code'].enum_type = _UPLOADBUNDLERESPONSE_CODETYPE + +class UploadBundleRequest(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _UPLOADBUNDLEREQUEST + +class UploadBundleResponse(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _UPLOADBUNDLERESPONSE + +class UploadBundleContinue(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _UPLOADBUNDLECONTINUE + diff --git a/color.py b/color.py new file mode 100644 index 00000000..b3a558cd --- /dev/null +++ b/color.py @@ -0,0 +1,154 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import pager +from git_config import GitConfig + +COLORS = {None :-1, + 'normal' :-1, + 'black' : 0, + 'red' : 1, + 'green' : 2, + 'yellow' : 3, + 'blue' : 4, + 'magenta': 5, + 'cyan' : 6, + 'white' : 7} + +ATTRS = {None :-1, + 'bold' : 1, + 'dim' : 2, + 'ul' : 4, + 'blink' : 5, + 'reverse': 7} + +RESET = "\033[m" + +def is_color(s): return s in COLORS +def is_attr(s): return s in ATTRS + +def _Color(fg = None, bg = None, attr = None): + fg = COLORS[fg] + bg = COLORS[bg] + attr = ATTRS[attr] + + if attr >= 0 or fg >= 0 or bg >= 0: + need_sep = False + code = "\033[" + + if attr >= 0: + code += chr(ord('0') + attr) + need_sep = True + + if fg >= 0: + if need_sep: + code += ';' + need_sep = True + + if fg < 8: + code += '3%c' % (ord('0') + fg) + else: + code += '38;5;%d' % fg + + if bg >= 0: + if need_sep: + code += ';' + need_sep = True + + if bg < 8: + code += '4%c' % (ord('0') + bg) + else: + code += '48;5;%d' % bg + code += 'm' + else: + code = '' + return code + + +class Coloring(object): + def __init__(self, config, type): + self._section = 'color.%s' % type + self._config = config + self._out = sys.stdout + + on = self._config.GetString(self._section) + if on is None: + on = self._config.GetString('color.ui') + + if on == 'auto': + if pager.active or os.isatty(1): + self._on = True + else: + self._on = False + elif on in ('true', 'always'): + self._on = True + else: + self._on = False + + @property + def is_on(self): + return self._on + + def write(self, fmt, *args): + self._out.write(fmt % args) + + def nl(self): + self._out.write('\n') + + def printer(self, opt=None, fg=None, bg=None, attr=None): + s = self + c = self.colorer(opt, fg, bg, attr) + def f(fmt, *args): + s._out.write(c(fmt, *args)) + return f + + def colorer(self, opt=None, fg=None, bg=None, attr=None): + if self._on: + c = self._parse(opt, fg, bg, attr) + def f(fmt, *args): + str = fmt % args + return ''.join([c, str, RESET]) + return f + else: + def f(fmt, *args): + return fmt % args + return f + + def _parse(self, opt, fg, bg, attr): + if not opt: + return _Color(fg, bg, attr) + + v = self._config.GetString('%s.%s' % (self._section, opt)) + if v is None: + return _Color(fg, bg, attr) + + v = v.trim().lowercase() + if v == "reset": + return RESET + elif v == '': + return _Color(fg, bg, attr) + + have_fg = False + for a in v.split(' '): + if is_color(a): + if have_fg: bg = a + else: fg = a + elif is_attr(a): + attr = a + + return _Color(fg, bg, attr) diff --git a/command.py b/command.py new file mode 100644 index 00000000..516c2d9d --- /dev/null +++ b/command.py @@ -0,0 +1,116 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import optparse +import sys + +from error import NoSuchProjectError + +class Command(object): + """Base class for any command line action in repo. + """ + + common = False + manifest = None + _optparse = None + + @property + def OptionParser(self): + if self._optparse is None: + try: + me = 'repo %s' % self.NAME + usage = self.helpUsage.strip().replace('%prog', me) + except AttributeError: + usage = 'repo %s' % self.NAME + self._optparse = optparse.OptionParser(usage = usage) + self._Options(self._optparse) + return self._optparse + + def _Options(self, p): + """Initialize the option parser. + """ + + def Usage(self): + """Display usage and terminate. + """ + self.OptionParser.print_usage() + sys.exit(1) + + def Execute(self, opt, args): + """Perform the action, after option parsing is complete. + """ + raise NotImplementedError + + def GetProjects(self, args, missing_ok=False): + """A list of projects that match the arguments. + """ + all = self.manifest.projects + result = [] + + if not args: + for project in all.values(): + if missing_ok or project.Exists: + result.append(project) + else: + by_path = None + + for arg in args: + project = all.get(arg) + + if not project: + path = os.path.abspath(arg) + + if not by_path: + by_path = dict() + for p in all.values(): + by_path[p.worktree] = p + + if os.path.exists(path): + while path \ + and path != '/' \ + and path != self.manifest.topdir: + try: + project = by_path[path] + break + except KeyError: + path = os.path.dirname(path) + else: + try: + project = by_path[path] + except KeyError: + pass + + if not project: + raise NoSuchProjectError(arg) + if not missing_ok and not project.Exists: + raise NoSuchProjectError(arg) + + result.append(project) + + def _getpath(x): + return x.relpath + result.sort(key=_getpath) + return result + +class InteractiveCommand(Command): + """Command which requires user interaction on the tty and + must not run within a pager, even if the user asks to. + """ + +class PagedCommand(Command): + """Command which defaults to output in a pager, as its + display tends to be larger than one screen full. + """ diff --git a/editor.py b/editor.py new file mode 100644 index 00000000..4f22257f --- /dev/null +++ b/editor.py @@ -0,0 +1,85 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess +import tempfile + +from error import EditorError + +class Editor(object): + """Manages the user's preferred text editor.""" + + _editor = None + globalConfig = None + + @classmethod + def _GetEditor(cls): + if cls._editor is None: + cls._editor = cls._SelectEditor() + return cls._editor + + @classmethod + def _SelectEditor(cls): + e = os.getenv('GIT_EDITOR') + if e: + return e + + e = cls.globalConfig.GetString('core.editor') + if e: + return e + + e = os.getenv('VISUAL') + if e: + return e + + e = os.getenv('EDITOR') + if e: + return e + + if os.getenv('TERM') == 'dumb': + print >>sys.stderr,\ +"""No editor specified in GIT_EDITOR, core.editor, VISUAL or EDITOR. +Tried to fall back to vi but terminal is dumb. Please configure at +least one of these before using this command.""" + sys.exit(1) + + return 'vi' + + @classmethod + def EditString(cls, data): + """Opens an editor to edit the given content. + + Args: + data : the text to edit + + Returns: + new value of edited text; None if editing did not succeed + """ + editor = cls._GetEditor() + fd, path = tempfile.mkstemp() + try: + os.write(fd, data) + os.close(fd) + fd = None + + if subprocess.Popen([editor, path]).wait() != 0: + raise EditorError() + return open(path).read() + finally: + if fd: + os.close(fd) + os.remove(path) diff --git a/error.py b/error.py new file mode 100644 index 00000000..e3cf41c1 --- /dev/null +++ b/error.py @@ -0,0 +1,66 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ManifestParseError(Exception): + """Failed to parse the manifest file. + """ + +class EditorError(Exception): + """Unspecified error from the user's text editor. + """ + +class GitError(Exception): + """Unspecified internal error from git. + """ + def __init__(self, command): + self.command = command + + def __str__(self): + return self.command + +class ImportError(Exception): + """An import from a non-Git format cannot be performed. + """ + def __init__(self, reason): + self.reason = reason + + def __str__(self): + return self.reason + +class UploadError(Exception): + """A bundle upload to Gerrit did not succeed. + """ + def __init__(self, reason): + self.reason = reason + + def __str__(self): + return self.reason + +class NoSuchProjectError(Exception): + """A specified project does not exist in the work tree. + """ + def __init__(self, name=None): + self.name = name + + def __str__(self): + if self.Name is None: + return 'in current directory' + return self.name + +class RepoChangedException(Exception): + """Thrown if 'repo sync' results in repo updating its internal + repo or manifest repositories. In this special case we must + use exec to re-execute repo with the new code and manifest. + """ diff --git a/froofle/__init__.py b/froofle/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/froofle/protobuf/__init__.py b/froofle/protobuf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/froofle/protobuf/descriptor.py b/froofle/protobuf/descriptor.py new file mode 100644 index 00000000..e74cf25e --- /dev/null +++ b/froofle/protobuf/descriptor.py @@ -0,0 +1,433 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We probably need to provide deep-copy methods for +# descriptor types. When a FieldDescriptor is passed into +# Descriptor.__init__(), we should make a deep copy and then set +# containing_type on it. Alternatively, we could just get +# rid of containing_type (iit's not needed for reflection.py, at least). +# +# TODO(robinson): Print method? +# +# TODO(robinson): Useful __repr__? + +"""Descriptors essentially contain exactly the information found in a .proto +file, in types that make this information accessible in Python. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +class DescriptorBase(object): + + """Descriptors base class. + + This class is the base of all descriptor classes. It provides common options + related functionaility. + """ + + def __init__(self, options, options_class_name): + """Initialize the descriptor given its options message and the name of the + class of the options message. The name of the class is required in case + the options message is None and has to be created. + """ + self._options = options + self._options_class_name = options_class_name + + def GetOptions(self): + """Retrieves descriptor options. + + This method returns the options set or creates the default options for the + descriptor. + """ + if self._options: + return self._options + from froofle.protobuf import descriptor_pb2 + try: + options_class = getattr(descriptor_pb2, self._options_class_name) + except AttributeError: + raise RuntimeError('Unknown options class name %s!' % + (self._options_class_name)) + self._options = options_class() + return self._options + + +class Descriptor(DescriptorBase): + + """Descriptor for a protocol message type. + + A Descriptor instance has the following attributes: + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + + filename: (str) Name of the .proto file containing this message. + + containing_type: (Descriptor) Reference to the descriptor of the + type containing us, or None if we have no containing type. + + fields: (list of FieldDescriptors) Field descriptors for all + fields in this type. + fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "number" attribute in each + FieldDescriptor. + fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "name" attribute in each + FieldDescriptor. + + nested_types: (list of Descriptors) Descriptor references + for all protocol message types nested within this one. + nested_types_by_name: (dict str -> Descriptor) Same Descriptor + objects as in |nested_types|, but indexed by "name" attribute + in each Descriptor. + + enum_types: (list of EnumDescriptors) EnumDescriptor references + for all enums contained within this type. + enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor + objects as in |enum_types|, but indexed by "name" attribute + in each EnumDescriptor. + enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping + from enum value name to EnumValueDescriptor for that value. + + extensions: (list of FieldDescriptor) All extensions defined directly + within this message type (NOT within a nested type). + extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor + objects as |extensions|, but indexed by "name" attribute of each + FieldDescriptor. + + options: (descriptor_pb2.MessageOptions) Protocol message options or None + to use default message options. + """ + + def __init__(self, name, full_name, filename, containing_type, + fields, nested_types, enum_types, extensions, options=None): + """Arguments to __init__() are as described in the description + of Descriptor fields above. + """ + super(Descriptor, self).__init__(options, 'MessageOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.containing_type = containing_type + + # We have fields in addition to fields_by_name and fields_by_number, + # so that: + # 1. Clients can index fields by "order in which they're listed." + # 2. Clients can easily iterate over all fields with the terse + # syntax: for f in descriptor.fields: ... + self.fields = fields + for field in self.fields: + field.containing_type = self + self.fields_by_number = dict((f.number, f) for f in fields) + self.fields_by_name = dict((f.name, f) for f in fields) + + self.nested_types = nested_types + self.nested_types_by_name = dict((t.name, t) for t in nested_types) + + self.enum_types = enum_types + for enum_type in self.enum_types: + enum_type.containing_type = self + self.enum_types_by_name = dict((t.name, t) for t in enum_types) + self.enum_values_by_name = dict( + (v.name, v) for t in enum_types for v in t.values) + + self.extensions = extensions + for extension in self.extensions: + extension.extension_scope = self + self.extensions_by_name = dict((f.name, f) for f in extensions) + + +# TODO(robinson): We should have aggressive checking here, +# for example: +# * If you specify a repeated field, you should not be allowed +# to specify a default value. +# * [Other examples here as needed]. +# +# TODO(robinson): for this and other *Descriptor classes, we +# might also want to lock things down aggressively (e.g., +# prevent clients from setting the attributes). Having +# stronger invariants here in general will reduce the number +# of runtime checks we must do in reflection.py... +class FieldDescriptor(DescriptorBase): + + """Descriptor for a single field in a .proto file. + + A FieldDescriptor instance has the following attriubtes: + + name: (str) Name of this field, exactly as it appears in .proto. + full_name: (str) Name of this field, including containing scope. This is + particularly relevant for extensions. + index: (int) Dense, 0-indexed index giving the order that this + field textually appears within its message in the .proto file. + number: (int) Tag number declared for this field in the .proto file. + + type: (One of the TYPE_* constants below) Declared type. + cpp_type: (One of the CPPTYPE_* constants below) C++ type used to + represent this field. + + label: (One of the LABEL_* constants below) Tells whether this + field is optional, required, or repeated. + default_value: (Varies) Default value of this field. Only + meaningful for non-repeated scalar fields. Repeated fields + should always set this to [], and non-repeated composite + fields should always set this to None. + + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + Somewhat confusingly, for extension fields, this is the + descriptor of the EXTENDED message, not the descriptor + of the message containing this field. (See is_extension and + extension_scope below). + message_type: (Descriptor) If a composite field, a descriptor + of the message type contained in this field. Otherwise, this is None. + enum_type: (EnumDescriptor) If this field contains an enum, a + descriptor of that enum. Otherwise, this is None. + + is_extension: True iff this describes an extension field. + extension_scope: (Descriptor) Only meaningful if is_extension is True. + Gives the message that immediately contains this extension field. + Will be None iff we're a top-level (file-level) extension field. + + options: (descriptor_pb2.FieldOptions) Protocol message field options or + None to use default field options. + """ + + # Must be consistent with C++ FieldDescriptor::Type enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + TYPE_DOUBLE = 1 + TYPE_FLOAT = 2 + TYPE_INT64 = 3 + TYPE_UINT64 = 4 + TYPE_INT32 = 5 + TYPE_FIXED64 = 6 + TYPE_FIXED32 = 7 + TYPE_BOOL = 8 + TYPE_STRING = 9 + TYPE_GROUP = 10 + TYPE_MESSAGE = 11 + TYPE_BYTES = 12 + TYPE_UINT32 = 13 + TYPE_ENUM = 14 + TYPE_SFIXED32 = 15 + TYPE_SFIXED64 = 16 + TYPE_SINT32 = 17 + TYPE_SINT64 = 18 + MAX_TYPE = 18 + + # Must be consistent with C++ FieldDescriptor::CppType enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + CPPTYPE_INT32 = 1 + CPPTYPE_INT64 = 2 + CPPTYPE_UINT32 = 3 + CPPTYPE_UINT64 = 4 + CPPTYPE_DOUBLE = 5 + CPPTYPE_FLOAT = 6 + CPPTYPE_BOOL = 7 + CPPTYPE_ENUM = 8 + CPPTYPE_STRING = 9 + CPPTYPE_MESSAGE = 10 + MAX_CPPTYPE = 10 + + # Must be consistent with C++ FieldDescriptor::Label enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + LABEL_OPTIONAL = 1 + LABEL_REQUIRED = 2 + LABEL_REPEATED = 3 + MAX_LABEL = 3 + + def __init__(self, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None): + """The arguments are as described in the description of FieldDescriptor + attributes above. + + Note that containing_type may be None, and may be set later if necessary + (to deal with circular references between message types, for example). + Likewise for extension_scope. + """ + super(FieldDescriptor, self).__init__(options, 'FieldOptions') + self.name = name + self.full_name = full_name + self.index = index + self.number = number + self.type = type + self.cpp_type = cpp_type + self.label = label + self.default_value = default_value + self.containing_type = containing_type + self.message_type = message_type + self.enum_type = enum_type + self.is_extension = is_extension + self.extension_scope = extension_scope + + +class EnumDescriptor(DescriptorBase): + + """Descriptor for an enum defined in a .proto file. + + An EnumDescriptor instance has the following attributes: + + name: (str) Name of the enum type. + full_name: (str) Full name of the type, including package name + and any enclosing type(s). + filename: (str) Name of the .proto file in which this appears. + + values: (list of EnumValueDescriptors) List of the values + in this enum. + values_by_name: (dict str -> EnumValueDescriptor) Same as |values|, + but indexed by the "name" field of each EnumValueDescriptor. + values_by_number: (dict int -> EnumValueDescriptor) Same as |values|, + but indexed by the "number" field of each EnumValueDescriptor. + containing_type: (Descriptor) Descriptor of the immediate containing + type of this enum, or None if this is an enum defined at the + top level in a .proto file. Set by Descriptor's constructor + if we're passed into one. + options: (descriptor_pb2.EnumOptions) Enum options message or + None to use default enum options. + """ + + def __init__(self, name, full_name, filename, values, + containing_type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumDescriptor, self).__init__(options, 'EnumOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.values = values + for value in self.values: + value.type = self + self.values_by_name = dict((v.name, v) for v in values) + self.values_by_number = dict((v.number, v) for v in values) + self.containing_type = containing_type + + +class EnumValueDescriptor(DescriptorBase): + + """Descriptor for a single value within an enum. + + name: (str) Name of this value. + index: (int) Dense, 0-indexed index giving the order that this + value appears textually within its enum in the .proto file. + number: (int) Actual number assigned to this enum value. + type: (EnumDescriptor) EnumDescriptor to which this value + belongs. Set by EnumDescriptor's constructor if we're + passed into one. + options: (descriptor_pb2.EnumValueOptions) Enum value options message or + None to use default enum value options options. + """ + + def __init__(self, name, index, number, type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') + self.name = name + self.index = index + self.number = number + self.type = type + + +class ServiceDescriptor(DescriptorBase): + + """Descriptor for a service. + + name: (str) Name of the service. + full_name: (str) Full name of the service, including package name. + index: (int) 0-indexed index giving the order that this services + definition appears withing the .proto file. + methods: (list of MethodDescriptor) List of methods provided by this + service. + options: (descriptor_pb2.ServiceOptions) Service options message or + None to use default service options. + """ + + def __init__(self, name, full_name, index, methods, options=None): + super(ServiceDescriptor, self).__init__(options, 'ServiceOptions') + self.name = name + self.full_name = full_name + self.index = index + self.methods = methods + # Set the containing service for each method in this service. + for method in self.methods: + method.containing_service = self + + def FindMethodByName(self, name): + """Searches for the specified method, and returns its descriptor.""" + for method in self.methods: + if name == method.name: + return method + return None + + +class MethodDescriptor(DescriptorBase): + + """Descriptor for a method in a service. + + name: (str) Name of the method within the service. + full_name: (str) Full name of method. + index: (int) 0-indexed index of the method inside the service. + containing_service: (ServiceDescriptor) The service that contains this + method. + input_type: The descriptor of the message that this method accepts. + output_type: The descriptor of the message that this method returns. + options: (descriptor_pb2.MethodOptions) Method options message or + None to use default method options. + """ + + def __init__(self, name, full_name, index, containing_service, + input_type, output_type, options=None): + """The arguments are as described in the description of MethodDescriptor + attributes above. + + Note that containing_service may be None, and may be set later if necessary. + """ + super(MethodDescriptor, self).__init__(options, 'MethodOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_service = containing_service + self.input_type = input_type + self.output_type = output_type + + +def _ParseOptions(message, string): + """Parses serialized options. + + This helper function is used to parse serialized options in generated + proto2 files. It must not be used outside proto2. + """ + message.ParseFromString(string) + return message; diff --git a/froofle/protobuf/descriptor_pb2.py b/froofle/protobuf/descriptor_pb2.py new file mode 100644 index 00000000..16873834 --- /dev/null +++ b/froofle/protobuf/descriptor_pb2.py @@ -0,0 +1,950 @@ +#!/usr/bin/python2.4 +# Generated by the protocol buffer compiler. DO NOT EDIT! + +from froofle.protobuf import descriptor +from froofle.protobuf import message +from froofle.protobuf import reflection +from froofle.protobuf import service +from froofle.protobuf import service_reflection + + +_FIELDDESCRIPTORPROTO_TYPE = descriptor.EnumDescriptor( + name='Type', + full_name='froofle.protobuf.FieldDescriptorProto.Type', + filename='Type', + values=[ + descriptor.EnumValueDescriptor( + name='TYPE_DOUBLE', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_FLOAT', index=1, number=2, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_INT64', index=2, number=3, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_UINT64', index=3, number=4, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_INT32', index=4, number=5, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_FIXED64', index=5, number=6, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_FIXED32', index=6, number=7, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_BOOL', index=7, number=8, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_STRING', index=8, number=9, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_GROUP', index=9, number=10, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_MESSAGE', index=10, number=11, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_BYTES', index=11, number=12, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_UINT32', index=12, number=13, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_ENUM', index=13, number=14, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_SFIXED32', index=14, number=15, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_SFIXED64', index=15, number=16, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_SINT32', index=16, number=17, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='TYPE_SINT64', index=17, number=18, + options=None, + type=None), + ], + options=None, +) + +_FIELDDESCRIPTORPROTO_LABEL = descriptor.EnumDescriptor( + name='Label', + full_name='froofle.protobuf.FieldDescriptorProto.Label', + filename='Label', + values=[ + descriptor.EnumValueDescriptor( + name='LABEL_OPTIONAL', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='LABEL_REQUIRED', index=1, number=2, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='LABEL_REPEATED', index=2, number=3, + options=None, + type=None), + ], + options=None, +) + +_FILEOPTIONS_OPTIMIZEMODE = descriptor.EnumDescriptor( + name='OptimizeMode', + full_name='froofle.protobuf.FileOptions.OptimizeMode', + filename='OptimizeMode', + values=[ + descriptor.EnumValueDescriptor( + name='SPEED', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='CODE_SIZE', index=1, number=2, + options=None, + type=None), + ], + options=None, +) + +_FIELDOPTIONS_CTYPE = descriptor.EnumDescriptor( + name='CType', + full_name='froofle.protobuf.FieldOptions.CType', + filename='CType', + values=[ + descriptor.EnumValueDescriptor( + name='CORD', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='STRING_PIECE', index=1, number=2, + options=None, + type=None), + ], + options=None, +) + + +_FILEDESCRIPTORSET = descriptor.Descriptor( + name='FileDescriptorSet', + full_name='froofle.protobuf.FileDescriptorSet', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='file', full_name='froofle.protobuf.FileDescriptorSet.file', index=0, + number=1, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_FILEDESCRIPTORPROTO = descriptor.Descriptor( + name='FileDescriptorProto', + full_name='froofle.protobuf.FileDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.FileDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='package', full_name='froofle.protobuf.FileDescriptorProto.package', index=1, + number=2, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='dependency', full_name='froofle.protobuf.FileDescriptorProto.dependency', index=2, + number=3, type=9, cpp_type=9, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='message_type', full_name='froofle.protobuf.FileDescriptorProto.message_type', index=3, + number=4, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='enum_type', full_name='froofle.protobuf.FileDescriptorProto.enum_type', index=4, + number=5, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='service', full_name='froofle.protobuf.FileDescriptorProto.service', index=5, + number=6, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='extension', full_name='froofle.protobuf.FileDescriptorProto.extension', index=6, + number=7, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.FileDescriptorProto.options', index=7, + number=8, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_DESCRIPTORPROTO_EXTENSIONRANGE = descriptor.Descriptor( + name='ExtensionRange', + full_name='froofle.protobuf.DescriptorProto.ExtensionRange', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='start', full_name='froofle.protobuf.DescriptorProto.ExtensionRange.start', index=0, + number=1, type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='end', full_name='froofle.protobuf.DescriptorProto.ExtensionRange.end', index=1, + number=2, type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + +_DESCRIPTORPROTO = descriptor.Descriptor( + name='DescriptorProto', + full_name='froofle.protobuf.DescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.DescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='field', full_name='froofle.protobuf.DescriptorProto.field', index=1, + number=2, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='extension', full_name='froofle.protobuf.DescriptorProto.extension', index=2, + number=6, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='nested_type', full_name='froofle.protobuf.DescriptorProto.nested_type', index=3, + number=3, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='enum_type', full_name='froofle.protobuf.DescriptorProto.enum_type', index=4, + number=4, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='extension_range', full_name='froofle.protobuf.DescriptorProto.extension_range', index=5, + number=5, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.DescriptorProto.options', index=6, + number=7, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_FIELDDESCRIPTORPROTO = descriptor.Descriptor( + name='FieldDescriptorProto', + full_name='froofle.protobuf.FieldDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.FieldDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='number', full_name='froofle.protobuf.FieldDescriptorProto.number', index=1, + number=3, type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='label', full_name='froofle.protobuf.FieldDescriptorProto.label', index=2, + number=4, type=14, cpp_type=8, label=1, + default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='type', full_name='froofle.protobuf.FieldDescriptorProto.type', index=3, + number=5, type=14, cpp_type=8, label=1, + default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='type_name', full_name='froofle.protobuf.FieldDescriptorProto.type_name', index=4, + number=6, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='extendee', full_name='froofle.protobuf.FieldDescriptorProto.extendee', index=5, + number=2, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='default_value', full_name='froofle.protobuf.FieldDescriptorProto.default_value', index=6, + number=7, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.FieldDescriptorProto.options', index=7, + number=8, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + _FIELDDESCRIPTORPROTO_TYPE, + _FIELDDESCRIPTORPROTO_LABEL, + ], + options=None) + + +_ENUMDESCRIPTORPROTO = descriptor.Descriptor( + name='EnumDescriptorProto', + full_name='froofle.protobuf.EnumDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.EnumDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='value', full_name='froofle.protobuf.EnumDescriptorProto.value', index=1, + number=2, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.EnumDescriptorProto.options', index=2, + number=3, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_ENUMVALUEDESCRIPTORPROTO = descriptor.Descriptor( + name='EnumValueDescriptorProto', + full_name='froofle.protobuf.EnumValueDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.EnumValueDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='number', full_name='froofle.protobuf.EnumValueDescriptorProto.number', index=1, + number=2, type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.EnumValueDescriptorProto.options', index=2, + number=3, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_SERVICEDESCRIPTORPROTO = descriptor.Descriptor( + name='ServiceDescriptorProto', + full_name='froofle.protobuf.ServiceDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.ServiceDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='method', full_name='froofle.protobuf.ServiceDescriptorProto.method', index=1, + number=2, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.ServiceDescriptorProto.options', index=2, + number=3, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_METHODDESCRIPTORPROTO = descriptor.Descriptor( + name='MethodDescriptorProto', + full_name='froofle.protobuf.MethodDescriptorProto', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.MethodDescriptorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='input_type', full_name='froofle.protobuf.MethodDescriptorProto.input_type', index=1, + number=2, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='output_type', full_name='froofle.protobuf.MethodDescriptorProto.output_type', index=2, + number=3, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='options', full_name='froofle.protobuf.MethodDescriptorProto.options', index=3, + number=4, type=11, cpp_type=10, label=1, + default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_FILEOPTIONS = descriptor.Descriptor( + name='FileOptions', + full_name='froofle.protobuf.FileOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='java_package', full_name='froofle.protobuf.FileOptions.java_package', index=0, + number=1, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='java_outer_classname', full_name='froofle.protobuf.FileOptions.java_outer_classname', index=1, + number=8, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='java_multiple_files', full_name='froofle.protobuf.FileOptions.java_multiple_files', index=2, + number=10, type=8, cpp_type=7, label=1, + default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='optimize_for', full_name='froofle.protobuf.FileOptions.optimize_for', index=3, + number=9, type=14, cpp_type=8, label=1, + default_value=2, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.FileOptions.uninterpreted_option', index=4, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + _FILEOPTIONS_OPTIMIZEMODE, + ], + options=None) + + +_MESSAGEOPTIONS = descriptor.Descriptor( + name='MessageOptions', + full_name='froofle.protobuf.MessageOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='message_set_wire_format', full_name='froofle.protobuf.MessageOptions.message_set_wire_format', index=0, + number=1, type=8, cpp_type=7, label=1, + default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.MessageOptions.uninterpreted_option', index=1, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_FIELDOPTIONS = descriptor.Descriptor( + name='FieldOptions', + full_name='froofle.protobuf.FieldOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='ctype', full_name='froofle.protobuf.FieldOptions.ctype', index=0, + number=1, type=14, cpp_type=8, label=1, + default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='experimental_map_key', full_name='froofle.protobuf.FieldOptions.experimental_map_key', index=1, + number=9, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.FieldOptions.uninterpreted_option', index=2, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + _FIELDOPTIONS_CTYPE, + ], + options=None) + + +_ENUMOPTIONS = descriptor.Descriptor( + name='EnumOptions', + full_name='froofle.protobuf.EnumOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.EnumOptions.uninterpreted_option', index=0, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_ENUMVALUEOPTIONS = descriptor.Descriptor( + name='EnumValueOptions', + full_name='froofle.protobuf.EnumValueOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.EnumValueOptions.uninterpreted_option', index=0, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_SERVICEOPTIONS = descriptor.Descriptor( + name='ServiceOptions', + full_name='froofle.protobuf.ServiceOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.ServiceOptions.uninterpreted_option', index=0, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_METHODOPTIONS = descriptor.Descriptor( + name='MethodOptions', + full_name='froofle.protobuf.MethodOptions', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='uninterpreted_option', full_name='froofle.protobuf.MethodOptions.uninterpreted_option', index=0, + number=999, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_UNINTERPRETEDOPTION_NAMEPART = descriptor.Descriptor( + name='NamePart', + full_name='froofle.protobuf.UninterpretedOption.NamePart', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name_part', full_name='froofle.protobuf.UninterpretedOption.NamePart.name_part', index=0, + number=1, type=9, cpp_type=9, label=2, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='is_extension', full_name='froofle.protobuf.UninterpretedOption.NamePart.is_extension', index=1, + number=2, type=8, cpp_type=7, label=2, + default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + +_UNINTERPRETEDOPTION = descriptor.Descriptor( + name='UninterpretedOption', + full_name='froofle.protobuf.UninterpretedOption', + filename='froofle/protobuf/descriptor.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='name', full_name='froofle.protobuf.UninterpretedOption.name', index=0, + number=2, type=11, cpp_type=10, label=3, + default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='identifier_value', full_name='froofle.protobuf.UninterpretedOption.identifier_value', index=1, + number=3, type=9, cpp_type=9, label=1, + default_value=unicode("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='positive_int_value', full_name='froofle.protobuf.UninterpretedOption.positive_int_value', index=2, + number=4, type=4, cpp_type=4, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='negative_int_value', full_name='froofle.protobuf.UninterpretedOption.negative_int_value', index=3, + number=5, type=3, cpp_type=2, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='double_value', full_name='froofle.protobuf.UninterpretedOption.double_value', index=4, + number=6, type=1, cpp_type=5, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='string_value', full_name='froofle.protobuf.UninterpretedOption.string_value', index=5, + number=7, type=12, cpp_type=9, label=1, + default_value="", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], # TODO(robinson): Implement. + enum_types=[ + ], + options=None) + + +_FILEDESCRIPTORSET.fields_by_name['file'].message_type = _FILEDESCRIPTORPROTO +_FILEDESCRIPTORPROTO.fields_by_name['message_type'].message_type = _DESCRIPTORPROTO +_FILEDESCRIPTORPROTO.fields_by_name['enum_type'].message_type = _ENUMDESCRIPTORPROTO +_FILEDESCRIPTORPROTO.fields_by_name['service'].message_type = _SERVICEDESCRIPTORPROTO +_FILEDESCRIPTORPROTO.fields_by_name['extension'].message_type = _FIELDDESCRIPTORPROTO +_FILEDESCRIPTORPROTO.fields_by_name['options'].message_type = _FILEOPTIONS +_DESCRIPTORPROTO.fields_by_name['field'].message_type = _FIELDDESCRIPTORPROTO +_DESCRIPTORPROTO.fields_by_name['extension'].message_type = _FIELDDESCRIPTORPROTO +_DESCRIPTORPROTO.fields_by_name['nested_type'].message_type = _DESCRIPTORPROTO +_DESCRIPTORPROTO.fields_by_name['enum_type'].message_type = _ENUMDESCRIPTORPROTO +_DESCRIPTORPROTO.fields_by_name['extension_range'].message_type = _DESCRIPTORPROTO_EXTENSIONRANGE +_DESCRIPTORPROTO.fields_by_name['options'].message_type = _MESSAGEOPTIONS +_FIELDDESCRIPTORPROTO.fields_by_name['label'].enum_type = _FIELDDESCRIPTORPROTO_LABEL +_FIELDDESCRIPTORPROTO.fields_by_name['type'].enum_type = _FIELDDESCRIPTORPROTO_TYPE +_FIELDDESCRIPTORPROTO.fields_by_name['options'].message_type = _FIELDOPTIONS +_ENUMDESCRIPTORPROTO.fields_by_name['value'].message_type = _ENUMVALUEDESCRIPTORPROTO +_ENUMDESCRIPTORPROTO.fields_by_name['options'].message_type = _ENUMOPTIONS +_ENUMVALUEDESCRIPTORPROTO.fields_by_name['options'].message_type = _ENUMVALUEOPTIONS +_SERVICEDESCRIPTORPROTO.fields_by_name['method'].message_type = _METHODDESCRIPTORPROTO +_SERVICEDESCRIPTORPROTO.fields_by_name['options'].message_type = _SERVICEOPTIONS +_METHODDESCRIPTORPROTO.fields_by_name['options'].message_type = _METHODOPTIONS +_FILEOPTIONS.fields_by_name['optimize_for'].enum_type = _FILEOPTIONS_OPTIMIZEMODE +_FILEOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_MESSAGEOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_FIELDOPTIONS.fields_by_name['ctype'].enum_type = _FIELDOPTIONS_CTYPE +_FIELDOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_ENUMOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_ENUMVALUEOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_SERVICEOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_METHODOPTIONS.fields_by_name['uninterpreted_option'].message_type = _UNINTERPRETEDOPTION +_UNINTERPRETEDOPTION.fields_by_name['name'].message_type = _UNINTERPRETEDOPTION_NAMEPART + +class FileDescriptorSet(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _FILEDESCRIPTORSET + +class FileDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _FILEDESCRIPTORPROTO + +class DescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + + class ExtensionRange(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _DESCRIPTORPROTO_EXTENSIONRANGE + DESCRIPTOR = _DESCRIPTORPROTO + +class FieldDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _FIELDDESCRIPTORPROTO + +class EnumDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _ENUMDESCRIPTORPROTO + +class EnumValueDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _ENUMVALUEDESCRIPTORPROTO + +class ServiceDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _SERVICEDESCRIPTORPROTO + +class MethodDescriptorProto(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _METHODDESCRIPTORPROTO + +class FileOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _FILEOPTIONS + +class MessageOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _MESSAGEOPTIONS + +class FieldOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _FIELDOPTIONS + +class EnumOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _ENUMOPTIONS + +class EnumValueOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _ENUMVALUEOPTIONS + +class ServiceOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _SERVICEOPTIONS + +class MethodOptions(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _METHODOPTIONS + +class UninterpretedOption(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + + class NamePart(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = _UNINTERPRETEDOPTION_NAMEPART + DESCRIPTOR = _UNINTERPRETEDOPTION + diff --git a/froofle/protobuf/internal/__init__.py b/froofle/protobuf/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/froofle/protobuf/internal/decoder.py b/froofle/protobuf/internal/decoder.py new file mode 100644 index 00000000..2dd4c96e --- /dev/null +++ b/froofle/protobuf/internal/decoder.py @@ -0,0 +1,209 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Class for decoding protocol buffer primitives. + +Contains the logic for decoding every logical protocol field type +from one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from froofle.protobuf import message +from froofle.protobuf.internal import input_stream +from froofle.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Decoder(object): + + """Decodes logical protocol buffer fields from the wire.""" + + def __init__(self, s): + """Initializes the decoder to read from s. + + Args: + s: An immutable sequence of bytes, which must be accessible + via the Python buffer() primitive (i.e., buffer(s)). + """ + self._stream = input_stream.InputStream(s) + + def EndOfStream(self): + """Returns true iff we've reached the end of the bytes we're reading.""" + return self._stream.EndOfStream() + + def Position(self): + """Returns the 0-indexed position in |s|.""" + return self._stream.Position() + + def ReadFieldNumberAndWireType(self): + """Reads a tag from the wire. Returns a (field_number, wire_type) pair.""" + tag_and_type = self.ReadUInt32() + return wire_format.UnpackTag(tag_and_type) + + def SkipBytes(self, bytes): + """Skips the specified number of bytes on the wire.""" + self._stream.SkipBytes(bytes) + + # Note that the Read*() methods below are not exactly symmetrical with the + # corresponding Encoder.Append*() methods. Those Encoder methods first + # encode a tag, but the Read*() methods below assume that the tag has already + # been read, and that the client wishes to read a field of the specified type + # starting at the current position. + + def ReadInt32(self): + """Reads and returns a signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarint32() + + def ReadInt64(self): + """Reads and returns a signed, varint-encoded, 64-bit integer.""" + return self._stream.ReadVarint64() + + def ReadUInt32(self): + """Reads and returns an signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarUInt32() + + def ReadUInt64(self): + """Reads and returns an signed, varint-encoded,64-bit integer.""" + return self._stream.ReadVarUInt64() + + def ReadSInt32(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 32-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt32()) + + def ReadSInt64(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 64-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt64()) + + def ReadFixed32(self): + """Reads and returns an unsigned, fixed-width, 32-bit integer.""" + return self._stream.ReadLittleEndian32() + + def ReadFixed64(self): + """Reads and returns an unsigned, fixed-width, 64-bit integer.""" + return self._stream.ReadLittleEndian64() + + def ReadSFixed32(self): + """Reads and returns a signed, fixed-width, 32-bit integer.""" + value = self._stream.ReadLittleEndian32() + if value >= (1 << 31): + value -= (1 << 32) + return value + + def ReadSFixed64(self): + """Reads and returns a signed, fixed-width, 64-bit integer.""" + value = self._stream.ReadLittleEndian64() + if value >= (1 << 63): + value -= (1 << 64) + return value + + def ReadFloat(self): + """Reads and returns a 4-byte floating-point number.""" + serialized = self._stream.ReadBytes(4) + return struct.unpack('f', serialized)[0] + + def ReadDouble(self): + """Reads and returns an 8-byte floating-point number.""" + serialized = self._stream.ReadBytes(8) + return struct.unpack('d', serialized)[0] + + def ReadBool(self): + """Reads and returns a bool.""" + i = self._stream.ReadVarUInt32() + return bool(i) + + def ReadEnum(self): + """Reads and returns an enum value.""" + return self._stream.ReadVarUInt32() + + def ReadString(self): + """Reads and returns a length-delimited string.""" + bytes = self.ReadBytes() + return unicode(bytes, 'utf-8') + + def ReadBytes(self): + """Reads and returns a length-delimited byte sequence.""" + length = self._stream.ReadVarUInt32() + return self._stream.ReadBytes(length) + + def ReadMessageInto(self, msg): + """Calls msg.MergeFromString() to merge + length-delimited serialized message data into |msg|. + + REQUIRES: The decoder must be positioned at the serialized "length" + prefix to a length-delmiited serialized message. + + POSTCONDITION: The decoder is positioned just after the + serialized message, and we have merged those serialized + contents into |msg|. + """ + length = self._stream.ReadVarUInt32() + sub_buffer = self._stream.GetSubBuffer(length) + num_bytes_used = msg.MergeFromString(sub_buffer) + if num_bytes_used != length: + raise message.DecodeError( + 'Submessage told to deserialize from %d-byte encoding, ' + 'but used only %d bytes' % (length, num_bytes_used)) + self._stream.SkipBytes(num_bytes_used) + + def ReadGroupInto(self, expected_field_number, group): + """Calls group.MergeFromString() to merge + END_GROUP-delimited serialized message data into |group|. + We'll raise an exception if we don't find an END_GROUP + tag immediately after the serialized message contents. + + REQUIRES: The decoder is positioned just after the START_GROUP + tag for this group. + + POSTCONDITION: The decoder is positioned just after the + END_GROUP tag for this group, and we have merged + the contents of the group into |group|. + """ + sub_buffer = self._stream.GetSubBuffer() # No a priori length limit. + num_bytes_used = group.MergeFromString(sub_buffer) + if num_bytes_used < 0: + raise message.DecodeError('Group message reported negative bytes read.') + self._stream.SkipBytes(num_bytes_used) + field_number, field_type = self.ReadFieldNumberAndWireType() + if field_type != wire_format.WIRETYPE_END_GROUP: + raise message.DecodeError('Group message did not end with an END_GROUP.') + if field_number != expected_field_number: + raise message.DecodeError('END_GROUP tag had field ' + 'number %d, was expecting field number %d' % ( + field_number, expected_field_number)) + # We're now positioned just after the END_GROUP tag. Perfect. diff --git a/froofle/protobuf/internal/encoder.py b/froofle/protobuf/internal/encoder.py new file mode 100644 index 00000000..8b924b3e --- /dev/null +++ b/froofle/protobuf/internal/encoder.py @@ -0,0 +1,206 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Class for encoding protocol message primitives. + +Contains the logic for encoding every logical protocol field type +into one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from froofle.protobuf import message +from froofle.protobuf.internal import wire_format +from froofle.protobuf.internal import output_stream + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Encoder(object): + + """Encodes logical protocol buffer fields to the wire format.""" + + def __init__(self): + self._stream = output_stream.OutputStream() + + def ToString(self): + """Returns all values encoded in this object as a string.""" + return self._stream.ToString() + + # All the Append*() methods below first append a tag+type pair to the buffer + # before appending the specified value. + + def AppendInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarint32(value) + + def AppendInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarint64(value) + + def AppendUInt32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarUInt32(unsigned_value) + + def AppendUInt64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarUInt64(unsigned_value) + + def AppendSInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt32(zigzag_value) + + def AppendSInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt64(zigzag_value) + + def AppendFixed32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, in little-endian + byte-order. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendLittleEndian32(unsigned_value) + + def AppendFixed64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, in little-endian + byte-order. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendLittleEndian64(unsigned_value) + + def AppendSFixed32(self, field_number, value): + """Appends a signed 32-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x80000000) and -1 or 0 + if value >> 32 != sign: + raise message.EncodeError('SFixed32 out of range: %d' % value) + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendLittleEndian32(value & 0xffffffff) + + def AppendSFixed64(self, field_number, value): + """Appends a signed 64-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x8000000000000000) and -1 or 0 + if value >> 64 != sign: + raise message.EncodeError('SFixed64 out of range: %d' % value) + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendLittleEndian64(value & 0xffffffffffffffff) + + def AppendFloat(self, field_number, value): + """Appends a floating-point number to our buffer.""" + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendRawBytes(struct.pack('f', value)) + + def AppendDouble(self, field_number, value): + """Appends a double-precision floating-point number to our buffer.""" + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendRawBytes(struct.pack('d', value)) + + def AppendBool(self, field_number, value): + """Appends a boolean to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendEnum(self, field_number, value): + """Appends an enum value to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendString(self, field_number, value): + """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer, + with the length varint-encoded. + """ + self.AppendBytes(field_number, value.encode('utf-8')) + + def AppendBytes(self, field_number, value): + """Appends a length-prefixed sequence of bytes to our buffer, with the + length varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(len(value)) + self._stream.AppendRawBytes(value) + + # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to + # avoid the extra string copy here. We can do so if we widen the Message + # interface to be able to serialize to a stream in addition to a string. The + # challenge when thinking ahead to the Python/C API implementation of Message + # is finding a stream-like Python thing to which we can write raw bytes + # from C. I'm not sure such a thing exists(?). (array.array is pretty much + # what we want, but it's not directly exposed in the Python/C API). + + def AppendGroup(self, field_number, group): + """Appends a group to our buffer. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) + self._stream.AppendRawBytes(group.SerializeToString()) + self._AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) + + def AppendMessage(self, field_number, msg): + """Appends a nested message to our buffer. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(msg.ByteSize()) + self._stream.AppendRawBytes(msg.SerializeToString()) + + def AppendMessageSetItem(self, field_number, msg): + """Appends an item using the message set wire format. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + self._AppendTag(1, wire_format.WIRETYPE_START_GROUP) + self.AppendInt32(2, field_number) + self.AppendMessage(3, msg) + self._AppendTag(1, wire_format.WIRETYPE_END_GROUP) + + def _AppendTag(self, field_number, wire_type): + """Appends a tag containing field number and wire type information.""" + self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) diff --git a/froofle/protobuf/internal/input_stream.py b/froofle/protobuf/internal/input_stream.py new file mode 100644 index 00000000..26a26dcf --- /dev/null +++ b/froofle/protobuf/internal/input_stream.py @@ -0,0 +1,326 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""InputStream is the primitive interface for reading bits from the wire. + +All protocol buffer deserialization can be expressed in terms of +the InputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from array import array +from froofle.protobuf import message +from froofle.protobuf.internal import wire_format + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedInputStream from the C++ +# proto2 implementation. + + +class InputStreamBuffer(object): + + """Contains all logic for reading bits, and dealing with stream position. + + If an InputStream method ever raises an exception, the stream is left + in an indeterminate state and is not safe for further use. + """ + + def __init__(self, s): + # What we really want is something like array('B', s), where elements we + # read from the array are already given to us as one-byte integers. BUT + # using array() instead of buffer() would force full string copies to result + # from each GetSubBuffer() call. + # + # So, if the N serialized bytes of a single protocol buffer object are + # split evenly between 2 child messages, and so on recursively, using + # array('B', s) instead of buffer() would incur an additional N*logN bytes + # copied during deserialization. + # + # The higher constant overhead of having to ord() for every byte we read + # from the buffer in _ReadVarintHelper() could definitely lead to worse + # performance in many real-world scenarios, even if the asymptotic + # complexity is better. However, our real answer is that the mythical + # Python/C extension module output mode for the protocol compiler will + # be blazing-fast and will eliminate most use of this class anyway. + self._buffer = buffer(s) + self._pos = 0 + + def EndOfStream(self): + """Returns true iff we're at the end of the stream. + If this returns true, then a call to any other InputStream method + will raise an exception. + """ + return self._pos >= len(self._buffer) + + def Position(self): + """Returns the current position in the stream, or equivalently, the + number of bytes read so far. + """ + return self._pos + + def GetSubBuffer(self, size=None): + """Returns a sequence-like object that represents a portion of our + underlying sequence. + + Position 0 in the returned object corresponds to self.Position() + in this stream. + + If size is specified, then the returned object ends after the + next "size" bytes in this stream. If size is not specified, + then the returned object ends at the end of this stream. + + We guarantee that the returned object R supports the Python buffer + interface (and thus that the call buffer(R) will work). + + Note that the returned buffer is read-only. + + The intended use for this method is for nested-message and nested-group + deserialization, where we want to make a recursive MergeFromString() + call on the portion of the original sequence that contains the serialized + nested message. (And we'd like to do so without making unnecessary string + copies). + + REQUIRES: size is nonnegative. + """ + # Note that buffer() doesn't perform any actual string copy. + if size is None: + return buffer(self._buffer, self._pos) + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return buffer(self._buffer, self._pos, size) + + def SkipBytes(self, num_bytes): + """Skip num_bytes bytes ahead, or go to the end of the stream, whichever + comes first. + + REQUIRES: num_bytes is nonnegative. + """ + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadBytes(self, size): + """Reads up to 'size' bytes from the stream, stopping early + only if we reach the end of the stream. Returns the bytes read + as a string. + """ + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = (self._buffer[self._pos : self._pos + size]) + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + """Interprets the next 4 bytes of the stream as a little-endian + encoded, unsiged 32-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + """Interprets the next 8 bytes of the stream as a little-endian + encoded, unsiged 64-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + """Reads a varint from the stream, interprets this varint + as a signed, 32-bit integer, and returns the integer. + """ + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 32-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + """Reads a varint from the stream, interprets this varint + as a signed, 64-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 64-bit integer, and returns the integer. + """ + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + """Helper for the various varint-reading methods above. + Reads an unsigned, varint-encoded integer from the stream and + returns this integer. + + Does no bounds checking except to ensure that we read at most as many bytes + as could possibly be present in a varint-encoded 64-bit number. + """ + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = ord(self._buffer[self._pos]) + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result + +class InputStreamArray(object): + def __init__(self, s): + self._buffer = array('B', s) + self._pos = 0 + + def EndOfStream(self): + return self._pos >= len(self._buffer) + + def Position(self): + return self._pos + + def GetSubBuffer(self, size=None): + if size is None: + return self._buffer[self._pos : ].tostring() + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return self._buffer[self._pos : self._pos + size].tostring() + + def SkipBytes(self, num_bytes): + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadBytes(self, size): + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = self._buffer[self._pos : self._pos + size].tostring() + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = self._buffer[self._pos] + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result + +try: + buffer("") + InputStream = InputStreamBuffer +except NotImplementedError: + # Google App Engine: dev_appserver.py + InputStream = InputStreamArray +except RuntimeError: + # Google App Engine: production + InputStream = InputStreamArray diff --git a/froofle/protobuf/internal/message_listener.py b/froofle/protobuf/internal/message_listener.py new file mode 100644 index 00000000..43978952 --- /dev/null +++ b/froofle/protobuf/internal/message_listener.py @@ -0,0 +1,69 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a listener interface for observing certain +state transitions on Message objects. + +Also defines a null implementation of this interface. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +class MessageListener(object): + + """Listens for transitions to nonempty and for invalidations of cached + byte sizes. Meant to be registered via Message._SetListener(). + """ + + def TransitionToNonempty(self): + """Called the *first* time that this message becomes nonempty. + Implementations are free (but not required) to call this method multiple + times after the message has become nonempty. + """ + raise NotImplementedError + + def ByteSizeDirty(self): + """Called *every* time the cached byte size value + for this object is invalidated (transitions from being + "clean" to "dirty"). + """ + raise NotImplementedError + + +class NullMessageListener(object): + + """No-op MessageListener implementation.""" + + def TransitionToNonempty(self): + pass + + def ByteSizeDirty(self): + pass diff --git a/froofle/protobuf/internal/output_stream.py b/froofle/protobuf/internal/output_stream.py new file mode 100644 index 00000000..f62cd1c4 --- /dev/null +++ b/froofle/protobuf/internal/output_stream.py @@ -0,0 +1,125 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""OutputStream is the primitive interface for sticking bits on the wire. + +All protocol buffer serialization can be expressed in terms of +the OutputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import array +import struct +from froofle.protobuf import message +from froofle.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedOutputStream from the C++ +# proto2 implementation. + + +class OutputStream(object): + + """Contains all logic for writing bits, and ToString() to get the result.""" + + def __init__(self): + self._buffer = array.array('B') + + def AppendRawBytes(self, raw_bytes): + """Appends raw_bytes to our internal buffer.""" + self._buffer.fromstring(raw_bytes) + + def AppendLittleEndian32(self, unsigned_value): + """Appends an unsigned 32-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT32_MAX: + raise message.EncodeError( + 'Unsigned 32-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value)) + + def AppendLittleEndian64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError( + 'Unsigned 64-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value)) + + def AppendVarint32(self, value): + """Appends a signed 32-bit integer to the internal buffer, + encoded as a varint. (Note that a negative varint32 will + always require 10 bytes of space.) + """ + if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarint64(value) + + def AppendVarUInt32(self, value): + """Appends an unsigned 32-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= value <= wire_format.UINT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarUInt64(value) + + def AppendVarint64(self, value): + """Appends a signed 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX: + raise message.EncodeError('Value out of range: %d' % value) + if value < 0: + value += (1 << 64) + self.AppendVarUInt64(value) + + def AppendVarUInt64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % unsigned_value) + while True: + bits = unsigned_value & 0x7f + unsigned_value >>= 7 + if not unsigned_value: + self._buffer.append(bits) + break + self._buffer.append(0x80|bits) + + def ToString(self): + """Returns a string containing the bytes in our internal buffer.""" + return self._buffer.tostring() diff --git a/froofle/protobuf/internal/type_checkers.py b/froofle/protobuf/internal/type_checkers.py new file mode 100644 index 00000000..aaf7a844 --- /dev/null +++ b/froofle/protobuf/internal/type_checkers.py @@ -0,0 +1,268 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides type checking routines. + +This module defines type checking utilities in the forms of dictionaries: + +VALUE_CHECKERS: A dictionary of field types and a value validation object. +TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing + function. +TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization + function. +FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their + coresponding wire types. +TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization + function. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +from froofle.protobuf.internal import decoder +from froofle.protobuf.internal import encoder +from froofle.protobuf.internal import wire_format +from froofle.protobuf import descriptor + +_FieldDescriptor = descriptor.FieldDescriptor + + +def GetTypeChecker(cpp_type, field_type): + """Returns a type checker for a message field of the specified types. + + Args: + cpp_type: C++ type of the field (see descriptor.py). + field_type: Protocol message field type (see descriptor.py). + + Returns: + An instance of TypeChecker which can be used to verify the types + of values assigned to a field of the specified type. + """ + if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field_type == _FieldDescriptor.TYPE_STRING): + return UnicodeValueChecker() + return _VALUE_CHECKERS[cpp_type] + + +# None of the typecheckers below make any attempt to guard against people +# subclassing builtin types and doing weird things. We're not trying to +# protect against malicious clients here, just people accidentally shooting +# themselves in the foot in obvious ways. + +class TypeChecker(object): + + """Type checker used to catch type errors as early as possible + when the client is setting scalar fields in protocol messages. + """ + + def __init__(self, *acceptable_types): + self._acceptable_types = acceptable_types + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, self._acceptable_types): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), self._acceptable_types)) + raise TypeError(message) + + +# IntValueChecker and its subclasses perform integer type-checks +# and bounds-checks. +class IntValueChecker(object): + + """Checker used for integer fields. Performs type-check and range check.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if not self._MIN <= proposed_value <= self._MAX: + raise ValueError('Value out of range: %d' % proposed_value) + + +class UnicodeValueChecker(object): + + """Checker used for string fields.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (str, unicode)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (str, unicode))) + raise TypeError(message) + + # If the value is of type 'str' make sure that it is in 7-bit ASCII + # encoding. + if isinstance(proposed_value, str): + try: + unicode(proposed_value, 'ascii') + except UnicodeDecodeError: + raise ValueError('%.1024r isn\'t in 7-bit ASCII encoding.' + % (proposed_value)) + + +class Int32ValueChecker(IntValueChecker): + # We're sure to use ints instead of longs here since comparison may be more + # efficient. + _MIN = -2147483648 + _MAX = 2147483647 + + +class Uint32ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 32) - 1 + + +class Int64ValueChecker(IntValueChecker): + _MIN = -(1 << 63) + _MAX = (1 << 63) - 1 + + +class Uint64ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 64) - 1 + + +# Type-checkers for all scalar CPPTYPEs. +_VALUE_CHECKERS = { + _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), + _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + } + + +# Map from field type to a function F, such that F(field_num, value) +# gives the total byte size for a value of the given type. This +# byte size includes tag information and any other additional space +# associated with serializing "value". +TYPE_TO_BYTE_SIZE_FN = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, + _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, + _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, + _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, + _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, + _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, + _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, + _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, + _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, + _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, + _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, + _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, + _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, + _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, + _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, + _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, + _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, + _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize + } + + +# Maps from field type to an unbound Encoder method F, such that +# F(encoder, field_number, value) will append the serialization +# of a value of this type to the encoder. +_Encoder = encoder.Encoder +TYPE_TO_SERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, + _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, + _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, + _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, + _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, + _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, + _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, + _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, + _FieldDescriptor.TYPE_STRING: _Encoder.AppendString, + _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, + _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, + _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, + _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, + _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, + _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64, + _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32, + _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64, + } + + +# Maps from field type to expected wiretype. +FIELD_TYPE_TO_WIRE_TYPE = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_STRING: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, + _FieldDescriptor.TYPE_MESSAGE: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_BYTES: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, + } + + +# Maps from field type to an unbound Decoder method F, +# such that F(decoder) will read a field of the requested type. +# +# Note that Message and Group are intentionally missing here. +# They're handled by _RecursivelyMerge(). +_Decoder = decoder.Decoder +TYPE_TO_DESERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble, + _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat, + _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64, + _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64, + _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32, + _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64, + _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32, + _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool, + _FieldDescriptor.TYPE_STRING: _Decoder.ReadString, + _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes, + _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32, + _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum, + _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64, + _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32, + _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64, + } diff --git a/froofle/protobuf/internal/wire_format.py b/froofle/protobuf/internal/wire_format.py new file mode 100644 index 00000000..4d823c8d --- /dev/null +++ b/froofle/protobuf/internal/wire_format.py @@ -0,0 +1,236 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Constants and static functions to support protocol buffer wire format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from froofle.protobuf import message + + +TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. +_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 + +# These numbers identify the wire type of a protocol buffer value. +# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded +# tag-and-type to store one of these WIRETYPE_* constants. +# These values must match WireType enum in //net/proto2/public/wire_format.h. +WIRETYPE_VARINT = 0 +WIRETYPE_FIXED64 = 1 +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 +WIRETYPE_END_GROUP = 4 +WIRETYPE_FIXED32 = 5 +_WIRETYPE_MAX = 5 + + +# Bounds for various integer types. +INT32_MAX = int((1 << 31) - 1) +INT32_MIN = int(-(1 << 31)) +UINT32_MAX = (1 << 32) - 1 + +INT64_MAX = (1 << 63) - 1 +INT64_MIN = -(1 << 63) +UINT64_MAX = (1 << 64) - 1 + +# "struct" format strings that will encode/decode the specified formats. +FORMAT_UINT32_LITTLE_ENDIAN = '> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK) + + +def ZigZagEncode(value): + """ZigZag Transform: Encodes signed integers so that they can be + effectively used with varint encoding. See wire_format.h for + more details. + """ + if value >= 0: + return value << 1 + return (value << 1) ^ (~0) + + +def ZigZagDecode(value): + """Inverse of ZigZagEncode().""" + if not value & 0x1: + return value >> 1 + return (value >> 1) ^ (~0) + + + +# The *ByteSize() functions below return the number of bytes required to +# serialize "field number + type" information and then serialize the value. + + +def Int32ByteSize(field_number, int32): + return Int64ByteSize(field_number, int32) + + +def Int64ByteSize(field_number, int64): + # Have to convert to uint before calling UInt64ByteSize(). + return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) + + +def UInt32ByteSize(field_number, uint32): + return UInt64ByteSize(field_number, uint32) + + +def UInt64ByteSize(field_number, uint64): + return _TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + + +def SInt32ByteSize(field_number, int32): + return UInt32ByteSize(field_number, ZigZagEncode(int32)) + + +def SInt64ByteSize(field_number, int64): + return UInt64ByteSize(field_number, ZigZagEncode(int64)) + + +def Fixed32ByteSize(field_number, fixed32): + return _TagByteSize(field_number) + 4 + + +def Fixed64ByteSize(field_number, fixed64): + return _TagByteSize(field_number) + 8 + + +def SFixed32ByteSize(field_number, sfixed32): + return _TagByteSize(field_number) + 4 + + +def SFixed64ByteSize(field_number, sfixed64): + return _TagByteSize(field_number) + 8 + + +def FloatByteSize(field_number, flt): + return _TagByteSize(field_number) + 4 + + +def DoubleByteSize(field_number, double): + return _TagByteSize(field_number) + 8 + + +def BoolByteSize(field_number, b): + return _TagByteSize(field_number) + 1 + + +def EnumByteSize(field_number, enum): + return UInt32ByteSize(field_number, enum) + + +def StringByteSize(field_number, string): + return BytesByteSize(field_number, string.encode('utf-8')) + + +def BytesByteSize(field_number, b): + return (_TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(len(b)) + + len(b)) + + +def GroupByteSize(field_number, message): + return (2 * _TagByteSize(field_number) # START and END group. + + message.ByteSize()) + + +def MessageByteSize(field_number, message): + return (_TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(message.ByteSize()) + + message.ByteSize()) + + +def MessageSetItemByteSize(field_number, msg): + # First compute the sizes of the tags. + # There are 2 tags for the beginning and ending of the repeated group, that + # is field number 1, one with field number 2 (type_id) and one with field + # number 3 (message). + total_size = (2 * _TagByteSize(1) + _TagByteSize(2) + _TagByteSize(3)) + + # Add the number of bytes for type_id. + total_size += _VarUInt64ByteSizeNoTag(field_number) + + message_size = msg.ByteSize() + + # The number of bytes for encoding the length of the message. + total_size += _VarUInt64ByteSizeNoTag(message_size) + + # The size of the message. + total_size += message_size + return total_size + + +# Private helper functions for the *ByteSize() functions above. + + +def _TagByteSize(field_number): + """Returns the bytes required to serialize a tag with this field number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) + + +def _VarUInt64ByteSizeNoTag(uint64): + """Returns the bytes required to serialize a single varint. + uint64 must be unsigned. + """ + if uint64 > UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % uint64) + bytes = 1 + while uint64 > 0x7f: + bytes += 1 + uint64 >>= 7 + return bytes diff --git a/froofle/protobuf/message.py b/froofle/protobuf/message.py new file mode 100644 index 00000000..ed714853 --- /dev/null +++ b/froofle/protobuf/message.py @@ -0,0 +1,246 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We should just make these methods all "pure-virtual" and move +# all implementation out, into reflection.py for now. + + +"""Contains an abstract base class for protocol messages.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +from froofle.protobuf import text_format + +class Error(Exception): pass +class DecodeError(Error): pass +class EncodeError(Error): pass + + +class Message(object): + + """Abstract base class for protocol messages. + + Protocol message classes are almost always generated by the protocol + compiler. These generated types subclass Message and implement the methods + shown below. + + TODO(robinson): Link to an HTML document here. + + TODO(robinson): Document that instances of this class will also + have an Extensions attribute with __getitem__ and __setitem__. + Again, not sure how to best convey this. + + TODO(robinson): Document that the class must also have a static + RegisterExtension(extension_field) method. + Not sure how to best express at this point. + """ + + # TODO(robinson): Document these fields and methods. + + __slots__ = [] + + DESCRIPTOR = None + + def __eq__(self, other_msg): + raise NotImplementedError + + def __ne__(self, other_msg): + # Can't just say self != other_msg, since that would infinitely recurse. :) + return not self == other_msg + + def __str__(self): + return text_format.MessageToString(self) + + def MergeFrom(self, other_msg): + """Merges the contents of the specified message into current message. + + This method merges the contents of the specified message into the current + message. Singular fields that are set in the specified message overwrite + the corresponding fields in the current message. Repeated fields are + appended. Singular sub-messages and groups are recursively merged. + + Args: + other_msg: Message to merge into the current message. + """ + raise NotImplementedError + + def CopyFrom(self, other_msg): + """Copies the content of the specified message into the current message. + + The method clears the current message and then merges the specified + message using MergeFrom. + + Args: + other_msg: Message to copy into the current one. + """ + if self == other_msg: + return + self.Clear() + self.MergeFrom(other_msg) + + def Clear(self): + """Clears all data that was set in the message.""" + raise NotImplementedError + + def IsInitialized(self): + """Checks if the message is initialized. + + Returns: + The method returns True if the message is initialized (i.e. all of its + required fields are set). + """ + raise NotImplementedError + + # TODO(robinson): MergeFromString() should probably return None and be + # implemented in terms of a helper that returns the # of bytes read. Our + # deserialization routines would use the helper when recursively + # deserializing, but the end user would almost always just want the no-return + # MergeFromString(). + + def MergeFromString(self, serialized): + """Merges serialized protocol buffer data into this message. + + When we find a field in |serialized| that is already present + in this message: + - If it's a "repeated" field, we append to the end of our list. + - Else, if it's a scalar, we overwrite our field. + - Else, (it's a nonrepeated composite), we recursively merge + into the existing composite. + + TODO(robinson): Document handling of unknown fields. + + Args: + serialized: Any object that allows us to call buffer(serialized) + to access a string of bytes using the buffer interface. + + TODO(robinson): When we switch to a helper, this will return None. + + Returns: + The number of bytes read from |serialized|. + For non-group messages, this will always be len(serialized), + but for messages which are actually groups, this will + generally be less than len(serialized), since we must + stop when we reach an END_GROUP tag. Note that if + we *do* stop because of an END_GROUP tag, the number + of bytes returned does not include the bytes + for the END_GROUP tag information. + """ + raise NotImplementedError + + def ParseFromString(self, serialized): + """Like MergeFromString(), except we clear the object first.""" + self.Clear() + self.MergeFromString(serialized) + + def SerializeToString(self): + """Serializes the protocol message to a binary string. + + Returns: + A binary string representation of the message if all of the required + fields in the message are set (i.e. the message is initialized). + + Raises: + message.EncodeError if the message isn't initialized. + """ + raise NotImplementedError + + def SerializePartialToString(self): + """Serializes the protocol message to a binary string. + + This method is similar to SerializeToString but doesn't check if the + message is initialized. + + Returns: + A string representation of the partial message. + """ + raise NotImplementedError + + # TODO(robinson): Decide whether we like these better + # than auto-generated has_foo() and clear_foo() methods + # on the instances themselves. This way is less consistent + # with C++, but it makes reflection-type access easier and + # reduces the number of magically autogenerated things. + # + # TODO(robinson): Be sure to document (and test) exactly + # which field names are accepted here. Are we case-sensitive? + # What do we do with fields that share names with Python keywords + # like 'lambda' and 'yield'? + # + # nnorwitz says: + # """ + # Typically (in python), an underscore is appended to names that are + # keywords. So they would become lambda_ or yield_. + # """ + def ListFields(self, field_name): + """Returns a list of (FieldDescriptor, value) tuples for all + fields in the message which are not empty. A singular field is non-empty + if HasField() would return true, and a repeated field is non-empty if + it contains at least one element. The fields are ordered by field + number""" + raise NotImplementedError + + def HasField(self, field_name): + raise NotImplementedError + + def ClearField(self, field_name): + raise NotImplementedError + + def HasExtension(self, extension_handle): + raise NotImplementedError + + def ClearExtension(self, extension_handle): + raise NotImplementedError + + def ByteSize(self): + """Returns the serialized size of this message. + Recursively calls ByteSize() on all contained messages. + """ + raise NotImplementedError + + def _SetListener(self, message_listener): + """Internal method used by the protocol message implementation. + Clients should not call this directly. + + Sets a listener that this message will call on certain state transitions. + + The purpose of this method is to register back-edges from children to + parents at runtime, for the purpose of setting "has" bits and + byte-size-dirty bits in the parent and ancestor objects whenever a child or + descendant object is modified. + + If the client wants to disconnect this Message from the object tree, she + explicitly sets callback to None. + + If message_listener is None, unregisters any existing listener. Otherwise, + message_listener must implement the MessageListener interface in + internal/message_listener.py, and we discard any listener registered + via a previous _SetListener() call. + """ + raise NotImplementedError diff --git a/froofle/protobuf/reflection.py b/froofle/protobuf/reflection.py new file mode 100644 index 00000000..e2abff04 --- /dev/null +++ b/froofle/protobuf/reflection.py @@ -0,0 +1,1653 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import heapq +import threading +import weakref +# We use "as" to avoid name collisions with variables. +from froofle.protobuf.internal import decoder +from froofle.protobuf.internal import encoder +from froofle.protobuf.internal import message_listener as message_listener_mod +from froofle.protobuf.internal import type_checkers +from froofle.protobuf.internal import wire_format +from froofle.protobuf import descriptor as descriptor_mod +from froofle.protobuf import message as message_mod + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddSlots(descriptor, dictionary) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + superclass = super(GeneratedProtocolMessageType, cls) + return superclass.__new__(cls, name, bases, dictionary) + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + # We act as a "friend" class of the descriptor, setting + # its _concrete_class attribute the first time we use a + # given descriptor to initialize a concrete protocol message + # class. + concrete_class_attr_name = '_concrete_class' + if not hasattr(descriptor, concrete_class_attr_name): + setattr(descriptor, concrete_class_attr_name, cls) + cls._known_extensions = [] + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(cls) + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(cls, name, bases, dictionary) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + return proto_field_name + + +def _ValueFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which objects + should use to store the current value for a given protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_value_' + proto_field_name + + +def _HasFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which + objects should use to store a boolean telling whether this field + is explicitly set or not. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_has_' + proto_field_name + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] + field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields + if f.label != _FieldDescriptor.LABEL_REPEATED) + field_names.extend(('Extensions', + '_cached_byte_size', + '_cached_byte_size_dirty', + '_called_transition_to_nonempty', + '_listener', + '_lock', '__weakref__')) + dictionary['__slots__'] = field_names + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueForField(message, field): + """Returns a default value for a field. + + Args: + message: Message instance containing this field, or a weakref proxy + of same. + field: FieldDescriptor object for this field. + + Returns: A default value for this field. May refer back to |message| + via a weak reference. + """ + # TODO(robinson): Only the repeated fields need a reference to 'message' (so + # that they can set the 'has' bit on the containing Message when someone + # append()s a value). We could special-case this, and avoid an extra + # function call on __init__() and Clear() for non-repeated fields. + + # TODO(robinson): Find a better place for the default value assertion in this + # function. No need to repeat them every time the client calls Clear('foo'). + # (We should probably just assert these things once and as early as possible, + # by tightening checking in the descriptor classes.) + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + listener = _Listener(message, None) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + return _RepeatedCompositeFieldContainer(listener, field.message_type) + else: + return _RepeatedScalarFieldContainer( + listener, type_checkers.GetTypeChecker(field.cpp_type, field.type)) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + assert field.default_value is None + + return field.default_value + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = False + self._listener = message_listener_mod.NullMessageListener() + self._called_transition_to_nonempty = False + # TODO(robinson): We should only create a lock if we really need one + # in this class. + self._lock = threading.Lock() + for field in fields: + default_value = _DefaultValueForField(self, field) + python_field_name = _ValueFieldName(field.name) + setattr(self, python_field_name, default_value) + if field.label != _FieldDescriptor.LABEL_REPEATED: + setattr(self, _HasFieldName(field.name), False) + self.Extensions = _ExtensionDict(self, cls._known_extensions) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def setter(self, new_value): + type_checker.CheckValue(new_value) + setattr(self, has_field_name, True) + self._MarkByteSizeDirty() + self._MaybeCallTransitionToNonemptyCallback() + setattr(self, python_field_name, new_value) + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + message_type = field.message_type + + def getter(self): + # TODO(robinson): Appropriately scary note about double-checked locking. + field_value = getattr(self, python_field_name) + if field_value is None: + self._lock.acquire() + try: + field_value = getattr(self, python_field_name) + if field_value is None: + field_class = message_type._concrete_class + field_value = field_class() + field_value._SetListener(_Listener(self, has_field_name)) + setattr(self, python_field_name, field_value) + finally: + self._lock.release() + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + cls._known_extensions.append(extension_handle) + cls.RegisterExtension = staticmethod(RegisterExtension) + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + # Ensure that we always list in ascending field-number order. + # For non-extension fields, we can do the sort once, here, at import-time. + # For extensions, we sort on each ListFields() call, though + # we could do better if we have to. + fields = sorted(message_descriptor.fields, key=lambda f: f.number) + has_field_names = (_HasFieldName(f.name) for f in fields) + value_field_names = (_ValueFieldName(f.name) for f in fields) + triplets = zip(has_field_names, value_field_names, fields) + + def ListFields(self): + # We need to list all extension and non-extension fields + # together, in sorted order by field number. + + # Step 0: Get an iterator over all "set" non-extension fields, + # sorted by field number. + # This iterator yields (field_number, field_descriptor, value) tuples. + def SortedSetFieldsIter(): + # Note that triplets is already sorted by field number. + for has_field_name, value_field_name, field_descriptor in triplets: + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + value = getattr(self, _ValueFieldName(field_descriptor.name)) + if len(value) > 0: + yield (field_descriptor.number, field_descriptor, value) + elif getattr(self, _HasFieldName(field_descriptor.name)): + value = getattr(self, _ValueFieldName(field_descriptor.name)) + yield (field_descriptor.number, field_descriptor, value) + sorted_fields = SortedSetFieldsIter() + + # Step 1: Get an iterator over all "set" extension fields, + # sorted by field number. + # This iterator ALSO yields (field_number, field_descriptor, value) tuples. + # TODO(robinson): It's not necessary to repeat this with each + # serialization call. We can do better. + sorted_extension_fields = sorted( + [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) + + # Step 2: Create a composite iterator that merges the extension- + # and non-extension fields, and that still yields fields in + # sorted order. + all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) + + # Step 3: Strip off the field numbers and return. + return [field[1:] for field in all_set_fields] + + cls.ListFields = ListFields + +def _AddHasFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def HasField(self, field_name): + try: + return getattr(self, _HasFieldName(field_name)) + except AttributeError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + cls.HasField = HasField + + +def _AddClearFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = self.DESCRIPTOR.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + default_value = _DefaultValueForField(self, field) + if field.label == _FieldDescriptor.LABEL_REPEATED: + self._MarkByteSizeDirty() + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + old_field_value = getattr(self, python_field_name) + if old_field_value is not None: + # Snip the old object out of the object tree. + old_field_value._SetListener(None) + if getattr(self, has_field_name): + setattr(self, has_field_name, False) + # Set dirty bit on ourself and parents only if + # we're actually changing state. + self._MarkByteSizeDirty() + setattr(self, python_field_name, default_value) + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + self.Extensions._ClearExtension(extension_handle) + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + fields = self.DESCRIPTOR.fields + for field in fields: + self.ClearField(field.name) + # Clear extensions. + extensions = self.Extensions._ListSetExtensions() + for extension in extensions: + self.ClearExtension(extension[0]) + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + return self.Extensions._HasExtension(extension_handle) + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if self is other: + return True + + # Compare all fields contained directly in this message. + for field_descriptor in message_descriptor.fields: + label = field_descriptor.label + property_name = _PropertyName(field_descriptor.name) + # Non-repeated field equality requires matching "has" bits as well + # as having an equal value. + if label != _FieldDescriptor.LABEL_REPEATED: + self_has = self.HasField(property_name) + other_has = other.HasField(property_name) + if self_has != other_has: + return False + if not self_has: + # If the "has" bit for this field is False, we must stop here. + # Otherwise we will recurse forever on recursively-defined protos. + continue + if getattr(self, property_name) != getattr(other, property_name): + return False + + # Compare the extensions present in both messages. + return self.Extensions == other.Extensions + cls.__eq__ = __eq__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def BytesForField(message, field, value): + """Returns the number of bytes required to serialize a single field + in message. The field may be repeated or not, composite or not. + + Args: + message: The Message instance containing a field of the given type. + field: A FieldDescriptor describing the field of interest. + value: The value whose byte size we're interested in. + + Returns: The number of bytes required to serialize the current value + of "field" in "message", including space for tags and any other + necessary information. + """ + + if _MessageSetField(field): + return wire_format.MessageSetItemByteSize(field.number, value) + + field_number, field_type = field.number, field.type + + # Repeated fields. + if field.label == _FieldDescriptor.LABEL_REPEATED: + elements = value + else: + elements = [value] + + size = sum(_BytesForNonRepeatedElement(element, field_number, field_type) + for element in elements) + return size + + fields = message_descriptor.fields + has_field_names = (_HasFieldName(f.name) for f in fields) + zipped = zip(has_field_names, fields) + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + # Hardcoded fields first. + for has_field_name, field in zipped: + if (field.label == _FieldDescriptor.LABEL_REPEATED + or getattr(self, has_field_name)): + value = getattr(self, _ValueFieldName(field.name)) + size += BytesForField(self, field, value) + # Extensions next. + for field, value in self.Extensions._ListSetExtensions(): + size += BytesForField(self, field, value) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + return size + cls.ByteSize = ByteSize + + +def _MessageSetField(field_descriptor): + """Checks if a field should be serialized using the message set wire format. + + Args: + field_descriptor: Descriptor of the field. + + Returns: + True if the field should be serialized using the message set wire format, + false otherwise. + """ + return (field_descriptor.is_extension and + field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and + field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + field_descriptor.containing_type.GetOptions().message_set_wire_format) + + +def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): + """Appends the serialization of a single value to encoder. + + Args: + value: Value to serialize. + field_number: Field number of this value. + field_descriptor: Descriptor of the field to serialize. + encoder: encoder.Encoder object to which we should serialize this value. + """ + if _MessageSetField(field_descriptor): + encoder.AppendMessageSetItem(field_number, value) + return + + try: + method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] + method(encoder, field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % + field_descriptor.type) + + +def _ImergeSorted(*streams): + """Merges N sorted iterators into a single sorted iterator. + Each element in streams must be an iterable that yields + its elements in sorted order, and the elements contained + in each stream must all be comparable. + + There may be repeated elements in the component streams or + across the streams; the repeated elements will all be repeated + in the merged iterator as well. + + I believe that the heapq module at HEAD in the Python + sources has a method like this, but for now we roll our own. + """ + iters = [iter(stream) for stream in streams] + heap = [] + for index, it in enumerate(iters): + try: + heap.append((it.next(), index)) + except StopIteration: + pass + heapq.heapify(heap) + + while heap: + smallest_value, idx = heap[0] + yield smallest_value + try: + next_element = iters[idx].next() + heapq.heapreplace(heap, (next_element, idx)) + except StopIteration: + heapq.heappop(heap) + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not _InternalIsInitialized(self, errors): + raise message_mod.EncodeError('\n'.join(errors)) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Encoder = encoder.Encoder + + def SerializePartialToString(self): + encoder = Encoder() + # We need to serialize all extension and non-extension fields + # together, in sorted order by field number. + for field_descriptor, field_value in self.ListFields(): + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + repeated_value = field_value + else: + repeated_value = [field_value] + for element in repeated_value: + _SerializeValueToEncoder(element, field_descriptor.number, + field_descriptor, encoder) + return encoder.ToString() + cls.SerializePartialToString = SerializePartialToString + + +def _WireTypeForFieldType(field_type): + """Given a field type, returns the expected wire type.""" + try: + return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] + except KeyError: + raise message_mod.DecodeError('Unknown field type: %d' % field_type) + + +def _RecursivelyMerge(field_number, field_type, decoder, message): + """Decodes a message from decoder into message. + message is either a group or a nested message within some containing + protocol message. If it's a group, we use the group protocol to + deserialize, and if it's a nested message, we use the nested-message + protocol. + + Args: + field_number: The field number of message in its enclosing protocol buffer. + field_type: The field type of message. Must be either TYPE_MESSAGE + or TYPE_GROUP. + decoder: Decoder to read from. + message: Message to deserialize into. + """ + if field_type == _FieldDescriptor.TYPE_MESSAGE: + decoder.ReadMessageInto(message) + elif field_type == _FieldDescriptor.TYPE_GROUP: + decoder.ReadGroupInto(field_number, message) + else: + raise message_mod.DecodeError('Unexpected field type: %d' % field_type) + + +def _DeserializeScalarFromDecoder(field_type, decoder): + """Deserializes a scalar of the requested type from decoder. field_type must + be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. + """ + try: + method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] + return method(decoder) + except KeyError: + raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) + + +def _SkipField(field_number, wire_type, decoder): + """Skips a field with the specified wire type. + + Args: + field_number: Tag number of the field to skip. + wire_type: Wire type of the field to skip. + decoder: Decoder used to deserialize the messsage. It must be positioned + just after reading the the tag and wire type of the field. + """ + if wire_type == wire_format.WIRETYPE_VARINT: + decoder.ReadUInt64() + elif wire_type == wire_format.WIRETYPE_FIXED64: + decoder.ReadFixed64() + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + decoder.SkipBytes(decoder.ReadInt32()) + elif wire_type == wire_format.WIRETYPE_START_GROUP: + _SkipGroup(field_number, decoder) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + pass + elif wire_type == wire_format.WIRETYPE_FIXED32: + decoder.ReadFixed32() + else: + raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) + + +def _SkipGroup(group_number, decoder): + """Skips a nested group from the decoder. + + Args: + group_number: Tag number of the group to skip. + decoder: Decoder used to deserialize the message. It must be positioned + exactly at the beginning of the message that should be skipped. + """ + while True: + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if (wire_type == wire_format.WIRETYPE_END_GROUP and + field_number == group_number): + return + _SkipField(field_number, wire_type, decoder) + + +def _DeserializeMessageSetItem(message, decoder): + """Deserializes a message using the message set wire format. + + Args: + message: Message to be parsed to. + decoder: The decoder to be used to deserialize encoded data. Note that the + decoder should be positioned just after reading the START_GROUP tag that + began the messageset item. + """ + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + type_id = decoder.ReadInt32() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if type_id not in extensions_by_number: + _SkipField(field_number, wire_type, decoder) + return + + field_descriptor = extensions_by_number[type_id] + value = extension_dict[field_descriptor] + decoder.ReadMessageInto(value) + # Read the END_GROUP tag. + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + +def _DeserializeOneEntity(message_descriptor, message, decoder): + """Deserializes the next wire entity from decoder into message. + The next wire entity is either a scalar or a nested message, + and may also be an element in a repeated field (the wire encoding + is the same). + + Args: + message_descriptor: A Descriptor instance describing all fields + in message. + message: The Message instance into which we're decoding our fields. + decoder: The Decoder we're using to deserialize encoded data. + + Returns: The number of bytes read from decoder during this method. + """ + initial_position = decoder.Position() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if field_number in message_descriptor.fields_by_number: + # Non-extension field. + field_descriptor = message_descriptor.fields_by_number[field_number] + value = getattr(message, _PropertyName(field_descriptor.name)) + def nonextension_setter_fn(scalar): + setattr(message, _PropertyName(field_descriptor.name), scalar) + scalar_setter_fn = nonextension_setter_fn + elif field_number in extensions_by_number: + # Extension field. + field_descriptor = extensions_by_number[field_number] + value = extension_dict[field_descriptor] + def extension_setter_fn(scalar): + extension_dict[field_descriptor] = scalar + scalar_setter_fn = extension_setter_fn + elif wire_type == wire_format.WIRETYPE_END_GROUP: + # We assume we're being parsed as the group that's ended. + return 0 + elif (wire_type == wire_format.WIRETYPE_START_GROUP and + field_number == 1 and + message_descriptor.GetOptions().message_set_wire_format): + # A Message Set item. + _DeserializeMessageSetItem(message, decoder) + return decoder.Position() - initial_position + else: + _SkipField(field_number, wire_type, decoder) + return decoder.Position() - initial_position + + # If we reach this point, we've identified the field as either + # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, + # and |value| appropriately. Now actually deserialize the thing. + # + # field_descriptor: Describes the field we're deserializing. + # value: The value currently stored in the field to deserialize. + # Used only if the field is composite and/or repeated. + # scalar_setter_fn: A function F such that F(scalar) will + # set a nonrepeated scalar value for this field. Used only + # if this field is a nonrepeated scalar. + + field_number = field_descriptor.number + field_type = field_descriptor.type + expected_wire_type = _WireTypeForFieldType(field_type) + if wire_type != expected_wire_type: + # Need to fill in uninterpreted_bytes. Work for the next CL. + raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') + + property_name = _PropertyName(field_descriptor.name) + label = field_descriptor.label + cpp_type = field_descriptor.cpp_type + + # Nonrepeated scalar. Just set the field directly. + if (label != _FieldDescriptor.LABEL_REPEATED + and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + + # Nonrepeated composite. Recursively deserialize. + if label != _FieldDescriptor.LABEL_REPEATED: + composite = value + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + # Now we know we're dealing with a repeated field of some kind. + element_list = value + + if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated scalar. + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + else: + # Repeated composite. + composite = element_list.add() + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + +def _FieldOrExtensionValues(message, field_or_extension): + """Retrieves the list of values for the specified field or extension. + + The target field or extension can be optional, required or repeated, but it + must have value(s) set. The assumption is that the target field or extension + is set (e.g. _HasFieldOrExtension holds true). + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension for which the list of values is + required. Must be an instance of FieldDescriptor. + + Returns: + A list of values for the specified field or extension. This list will only + contain a single element if the field is non-repeated. + """ + if field_or_extension.is_extension: + value = message.Extensions[field_or_extension] + else: + value = getattr(message, _ValueFieldName(field_or_extension.name)) + if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: + return [value] + else: + # In this case value is a list or repeated values. + return value + + +def _HasFieldOrExtension(message, field_or_extension): + """Checks if a message has the specified field or extension set. + + The field or extension specified can be optional, required or repeated. If + it is repeated, this function returns True. Otherwise it checks the has bit + of the field or extension. + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension to check. This must be a + FieldDescriptor instance. + + Returns: + True if the message has a value set for the specified field or extension, + or if the field or extension is repeated. + """ + if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: + return True + if field_or_extension.is_extension: + return message.HasExtension(field_or_extension) + else: + return message.HasField(field_or_extension.name) + + +def _IsFieldOrExtensionInitialized(message, field, errors=None): + """Checks if a message field or extension is initialized. + + Args: + message: The message which contains the field or extension. + field: Field or extension to check. This must be a FieldDescriptor instance. + errors: Errors will be appended to it, if set to a meaningful value. + + Returns: + True if the field/extension can be considered initialized. + """ + # If the field is required and is not set, it isn't initialized. + if field.label == _FieldDescriptor.LABEL_REQUIRED: + if not _HasFieldOrExtension(message, field): + if errors is not None: + errors.append('Required field %s is not set.' % field.full_name) + return False + + # If the field is optional and is not set, or if it + # isn't a submessage then the field is initialized. + if field.label == _FieldDescriptor.LABEL_OPTIONAL: + if not _HasFieldOrExtension(message, field): + return True + if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + return True + + # The field is set and is either a single or a repeated submessage. + messages = _FieldOrExtensionValues(message, field) + # If all submessages in this field are initialized, the field is + # considered initialized. + for message in messages: + if not _InternalIsInitialized(message, errors): + return False + return True + + +def _InternalIsInitialized(message, errors=None): + """Checks if all required fields of a message are set. + + Args: + message: The message to check. + errors: If set, initialization errors will be appended to it. + + Returns: + True iff the specified message has all required fields set. + """ + fields_and_extensions = [] + fields_and_extensions.extend(message.DESCRIPTOR.fields) + fields_and_extensions.extend( + [extension[0] for extension in message.Extensions._ListSetExtensions()]) + for field_or_extension in fields_and_extensions: + if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): + return False + return True + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Decoder = decoder.Decoder + def MergeFromString(self, serialized): + decoder = Decoder(serialized) + byte_count = 0 + while not decoder.EndOfStream(): + bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) + if not bytes_read: + break + byte_count += bytes_read + return byte_count + cls.MergeFromString = MergeFromString + + +def _AddIsInitializedMethod(cls): + """Adds the IsInitialized method to the protocol message class.""" + cls.IsInitialized = _InternalIsInitialized + + +def _MergeFieldOrExtension(destination_msg, field, value): + """Merges a specified message field into another message.""" + property_name = _PropertyName(field.name) + is_extension = field.is_extension + + if not is_extension: + destination = getattr(destination_msg, property_name) + elif (field.label == _FieldDescriptor.LABEL_REPEATED or + field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + destination = destination_msg.Extensions[field] + + # Case 1 - a composite field. + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.add().MergeFrom(v) + else: + destination.MergeFrom(value) + return + + # Case 2 - a repeated field. + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.append(v) + return + + # Case 3 - a singular field. + if is_extension: + destination_msg.Extensions[field] = value + else: + setattr(destination_msg, property_name, value) + + +def _AddMergeFromMethod(cls): + def MergeFrom(self, msg): + assert msg is not self + for field in msg.ListFields(): + _MergeFieldOrExtension(self, field[0], field[1]) + cls.MergeFrom = MergeFrom + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(cls) + _AddClearFieldMethod(cls) + _AddClearExtensionMethod(cls) + _AddClearMethod(cls) + _AddHasExtensionMethod(cls) + _AddEqualsMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(cls) + _AddMergeFromMethod(cls) + + +def _AddPrivateHelperMethods(cls): + """Adds implementation of private helper methods to cls.""" + + def MaybeCallTransitionToNonemptyCallback(self): + """Calls self._listener.TransitionToNonempty() the first time this + method is called. On all subsequent calls, this is a no-op. + """ + if not self._called_transition_to_nonempty: + self._listener.TransitionToNonempty() + self._called_transition_to_nonempty = True + cls._MaybeCallTransitionToNonemptyCallback = ( + MaybeCallTransitionToNonemptyCallback) + + def MarkByteSizeDirty(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener.ByteSizeDirty() + cls._MarkByteSizeDirty = MarkByteSizeDirty + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message, has_field_name): + """Args: + parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() + and _MarkByteSizeDirty() methods we should call when we receive + TransitionToNonempty() and ByteSizeDirty() messages. + has_field_name: The name of the "has" field that we should set in + the parent message when we receive a TransitionToNonempty message, + or None if there's no "has" field to set. (This will be the case + for child objects in "repeated" fields). + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + self._has_field_name = has_field_name + + def TransitionToNonempty(self): + try: + if self._has_field_name is not None: + setattr(self._parent_message_weakref, self._has_field_name, True) + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + def ByteSizeDirty(self): + try: + self._parent_message_weakref._MarkByteSizeDirty() + except ReferenceError: + # Same as above. + pass + + +# TODO(robinson): Move elsewhere? +# TODO(robinson): Provide a clear() method here in addition to ClearField()? +class _RepeatedScalarFieldContainer(object): + + """Simple, type-checked, list-like container for holding repeated scalars.""" + + # Minimizes memory usage and disallows assignment to other attributes. + __slots__ = ['_message_listener', '_type_checker', '_values'] + + def __init__(self, message_listener, type_checker): + """ + Args: + message_listener: A MessageListener implementation. + The _RepeatedScalarFieldContaininer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + type_checker: A _ValueChecker instance to run on elements inserted + into this container. + """ + self._message_listener = message_listener + self._type_checker = type_checker + self._values = [] + + def append(self, elem): + self._type_checker.CheckValue(elem) + self._values.append(elem) + self._message_listener.ByteSizeDirty() + if len(self._values) == 1: + self._message_listener.TransitionToNonempty() + + def remove(self, elem): + self._values.remove(elem) + self._message_listener.ByteSizeDirty() + + # List-like __getitem__() support also makes us iterable (via "iter(foo)" + # or implicitly via "for i in mylist:") for free. + def __getitem__(self, key): + return self._values[key] + + def __setitem__(self, key, value): + # No need to call TransitionToNonempty(), since if we're able to + # set the element at this index, we were already nonempty before + # this method was called. + self._message_listener.ByteSizeDirty() + self._type_checker.CheckValue(value) + self._values[key] = value + + def __len__(self): + return len(self._values) + + def __eq__(self, other): + if self is other: + return True + # Special case for the same type which should be common and fast. + if isinstance(other, self.__class__): + return other._values == self._values + # We are presumably comparing against some other sequence type. + return other == self._values + + def __ne__(self, other): + # Can't use != here since it would infinitely recurse. + return not self == other + + +# TODO(robinson): Move elsewhere? +# TODO(robinson): Provide a clear() method here in addition to ClearField()? +# TODO(robinson): Unify common functionality with +# _RepeatedScalarFieldContaininer? +class _RepeatedCompositeFieldContainer(object): + + """Simple, list-like container for holding repeated composite fields.""" + + # Minimizes memory usage and disallows assignment to other attributes. + __slots__ = ['_values', '_message_descriptor', '_message_listener'] + + def __init__(self, message_listener, message_descriptor): + """Note that we pass in a descriptor instead of the generated directly, + since at the time we construct a _RepeatedCompositeFieldContainer we + haven't yet necessarily initialized the type that will be contained in the + container. + + Args: + message_listener: A MessageListener implementation. + The _RepeatedCompositeFieldContainer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + message_descriptor: A Descriptor instance describing the protocol type + that should be present in this container. We'll use the + _concrete_class field of this descriptor when the client calls add(). + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._values = [] + + def add(self): + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values.append(new_element) + self._message_listener.ByteSizeDirty() + self._message_listener.TransitionToNonempty() + return new_element + + def __delitem__(self, key): + self._message_listener.ByteSizeDirty() + del self._values[key] + + # List-like __getitem__() support also makes us iterable (via "iter(foo)" + # or implicitly via "for i in mylist:") for free. + def __getitem__(self, key): + return self._values[key] + + def __len__(self): + return len(self._values) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + return self._values == other._values + + def __ne__(self, other): + # Can't use != here since it would infinitely recurse. + return not self == other + + # TODO(robinson): Implement, document, and test slicing support. + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): There's so much similarity between the way that +# extensions behave and the way that normal fields behave that it would +# be really nice to unify more code. It's not immediately obvious +# how to do this, though, and I'd rather get the full functionality +# implemented (and, crucially, get all the tests and specs fleshed out +# and passing), and then come back to this thorny unification problem. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + class _ExtensionListener(object): + + """Adapts an _ExtensionDict to behave as a MessageListener.""" + + def __init__(self, extension_dict, handle_id): + self._extension_dict = extension_dict + self._handle_id = handle_id + + def TransitionToNonempty(self): + self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) + + def ByteSizeDirty(self): + self._extension_dict._SubmessageByteSizeBecameDirty() + + # TODO(robinson): Somewhere, we need to blow up if people + # try to register two extensions with the same field number. + # (And we need a test for this of course). + + def __init__(self, extended_message, known_extensions): + """extended_message: Message instance for which we are the Extensions dict. + known_extensions: Iterable of known extension handles. + These must be FieldDescriptors. + """ + # We keep a weak reference to extended_message, since + # it has a reference to this instance in turn. + self._extended_message = weakref.proxy(extended_message) + # We make a deep copy of known_extensions to avoid any + # thread-safety concerns, since the argument passed in + # is the global (class-level) dict of known extensions for + # this type of message, which could be modified at any time + # via a RegisterExtension() call. + # + # This dict maps from handle id to handle (a FieldDescriptor). + # + # XXX + # TODO(robinson): This isn't good enough. The client could + # instantiate an object in module A, then afterward import + # module B and pass the instance to B.Foo(). If B imports + # an extender of this proto and then tries to use it, B + # will get a KeyError, even though the extension *is* registered + # at the time of use. + # XXX + self._known_extensions = dict((id(e), e) for e in known_extensions) + # Read lock around self._values, which may be modified by multiple + # concurrent readers in the conceptually "const" __getitem__ method. + # So, we grab this lock in every "read-only" method to ensure + # that concurrent read access is safe without external locking. + self._lock = threading.Lock() + # Maps from extension handle ID to current value of that extension. + self._values = {} + # Maps from extension handle ID to a boolean "has" bit, but only + # for non-repeated extension fields. + keys = (id for id, extension in self._known_extensions.iteritems() + if extension.label != _FieldDescriptor.LABEL_REPEATED) + self._has_bits = dict.fromkeys(keys, False) + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + # We don't care as much about keeping critical sections short in the + # extension support, since it's presumably much less of a common case. + self._lock.acquire() + try: + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + if handle_id not in self._values: + self._AddMissingHandle(extension_handle, handle_id) + return self._values[handle_id] + finally: + self._lock.release() + + def __eq__(self, other): + # We have to grab read locks since we're accessing _values + # in a "const" method. See the comment in the constructor. + if self is other: + return True + self._lock.acquire() + try: + other._lock.acquire() + try: + if self._has_bits != other._has_bits: + return False + # If there's a "has" bit, then only compare values where it is true. + for k, v in self._values.iteritems(): + if self._has_bits.get(k, False) and v != other._values[k]: + return False + return True + finally: + other._lock.release() + finally: + self._lock.release() + + def __ne__(self, other): + return not self == other + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call + # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field + # this way, to set any necssary "has" bits in the ancestors of the extended + # message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + field = extension_handle # Just shorten the name. + if (field.label == _FieldDescriptor.LABEL_OPTIONAL + and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker.CheckValue(value) + self._values[handle_id] = value + self._has_bits[handle_id] = True + self._extended_message._MarkByteSizeDirty() + self._extended_message._MaybeCallTransitionToNonemptyCallback() + else: + raise TypeError('Extension is repeated and/or a composite type.') + + def _AddMissingHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # Special handling for non-repeated message extensions, which (like + # normal fields of this kind) are initialized lazily. + # REQUIRES: _lock already held. + cpp_type = extension_handle.cpp_type + label = extension_handle.label + if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + and label != _FieldDescriptor.LABEL_REPEATED): + self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) + else: + self._values[handle_id] = _DefaultValueForField( + self._extended_message, extension_handle) + + def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # REQUIRES: _lock already held. + value = extension_handle.message_type._concrete_class() + value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) + self._values[handle_id] = value + + def _SubmessageTransitionedToNonempty(self, handle_id): + """Called when a submessage with a given handle id first transitions to + being nonempty. Called by _ExtensionListener. + """ + assert handle_id in self._has_bits + self._has_bits[handle_id] = True + self._extended_message._MaybeCallTransitionToNonemptyCallback() + + def _SubmessageByteSizeBecameDirty(self): + """Called whenever a submessage's cached byte size becomes invalid + (goes from being "clean" to being "dirty"). Called by _ExtensionListener. + """ + self._extended_message._MarkByteSizeDirty() + + # We may wish to widen the public interface of Message.Extensions + # to expose some of this private functionality in the future. + # For now, we make all this functionality module-private and just + # implement what we need for serialization/deserialization, + # HasField()/ClearField(), etc. + + def _HasExtension(self, extension_handle): + """Method for internal use by this module. + Returns true iff we "have" this extension in the sense of the + "has" bit being set. + """ + handle_id = id(extension_handle) + # Note that this is different from the other checks. + if handle_id not in self._has_bits: + raise KeyError('Extension not known to this class, or is repeated field.') + return self._has_bits[handle_id] + + # Intentionally pretty similar to ClearField() above. + def _ClearExtension(self, extension_handle): + """Method for internal use by this module. + Clears the specified extension, unsetting its "has" bit. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + default_value = _DefaultValueForField(self._extended_message, + extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + self._extended_message._MarkByteSizeDirty() + else: + cpp_type = extension_handle.cpp_type + if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if handle_id in self._values: + # Future modifications to this object shouldn't set any + # "has" bits here. + self._values[handle_id]._SetListener(None) + if self._has_bits[handle_id]: + self._has_bits[handle_id] = False + self._extended_message._MarkByteSizeDirty() + if handle_id in self._values: + del self._values[handle_id] + + def _ListSetExtensions(self): + """Method for internal use by this module. + + Returns an sequence of all extensions that are currently "set" + in this extension dict. A "set" extension is a repeated extension, + or a non-repeated extension with its "has" bit set. + + The returned sequence contains (field_descriptor, value) pairs, + where value is the current value of the extension with the given + field descriptor. + + The sequence values are in arbitrary order. + """ + self._lock.acquire() # Read-only methods must lock around self._values. + try: + set_extensions = [] + for handle_id, value in self._values.iteritems(): + handle = self._known_extensions[handle_id] + if (handle.label == _FieldDescriptor.LABEL_REPEATED + or self._has_bits[handle_id]): + set_extensions.append((handle, value)) + return set_extensions + finally: + self._lock.release() + + def _AllExtensionsByNumber(self): + """Method for internal use by this module. + + Returns: A dict mapping field_number to (handle, field_descriptor), + for *all* registered extensions for this dict. + """ + # TODO(robinson): Precompute and store this away. Note that we'll have to + # be careful when we move away from having _known_extensions as a + # deep-copied member of this object. + return dict((f.number, f) for f in self._known_extensions.itervalues()) diff --git a/froofle/protobuf/service.py b/froofle/protobuf/service.py new file mode 100644 index 00000000..3989216a --- /dev/null +++ b/froofle/protobuf/service.py @@ -0,0 +1,208 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Declares the RPC service interfaces. + +This module declares the abstract interfaces underlying proto2 RPC +services. These are intented to be independent of any particular RPC +implementation, so that proto2 services can be used on top of a variety +of implementations. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class Service(object): + + """Abstract base interface for protocol-buffer-based RPC services. + + Services themselves are abstract classes (implemented either by servers or as + stubs), but they subclass this base interface. The methods of this + interface can be used to call the methods of the service without knowing + its exact type at compile time (analogous to the Message interface). + """ + + def GetDescriptor(self): + """Retrieves this service's descriptor.""" + raise NotImplementedError + + def CallMethod(self, method_descriptor, rpc_controller, + request, done): + """Calls a method of the service specified by method_descriptor. + + Preconditions: + * method_descriptor.service == GetDescriptor + * request is of the exact same classes as returned by + GetRequestClass(method). + * After the call has started, the request must not be modified. + * "rpc_controller" is of the correct type for the RPC implementation being + used by this Service. For stubs, the "correct type" depends on the + RpcChannel which the stub is using. + + Postconditions: + * "done" will be called when the method is complete. This may be + before CallMethod() returns or it may be at some point in the future. + """ + raise NotImplementedError + + def GetRequestClass(self, method_descriptor): + """Returns the class of the request message for the specified method. + + CallMethod() requires that the request is of a particular subclass of + Message. GetRequestClass() gets the default instance of this required + type. + + Example: + method = service.GetDescriptor().FindMethodByName("Foo") + request = stub.GetRequestClass(method)() + request.ParseFromString(input) + service.CallMethod(method, request, callback) + """ + raise NotImplementedError + + def GetResponseClass(self, method_descriptor): + """Returns the class of the response message for the specified method. + + This method isn't really needed, as the RpcChannel's CallMethod constructs + the response protocol message. It's provided anyway in case it is useful + for the caller to know the response type in advance. + """ + raise NotImplementedError + + +class RpcController(object): + + """An RpcController mediates a single method call. + + The primary purpose of the controller is to provide a way to manipulate + settings specific to the RPC implementation and to find out about RPC-level + errors. The methods provided by the RpcController interface are intended + to be a "least common denominator" set of features which we expect all + implementations to support. Specific implementations may provide more + advanced features (e.g. deadline propagation). + """ + + # Client-side methods below + + def Reset(self): + """Resets the RpcController to its initial state. + + After the RpcController has been reset, it may be reused in + a new call. Must not be called while an RPC is in progress. + """ + raise NotImplementedError + + def Failed(self): + """Returns true if the call failed. + + After a call has finished, returns true if the call failed. The possible + reasons for failure depend on the RPC implementation. Failed() must not + be called before a call has finished. If Failed() returns true, the + contents of the response message are undefined. + """ + raise NotImplementedError + + def ErrorText(self): + """If Failed is true, returns a human-readable description of the error.""" + raise NotImplementedError + + def StartCancel(self): + """Initiate cancellation. + + Advises the RPC system that the caller desires that the RPC call be + canceled. The RPC system may cancel it immediately, may wait awhile and + then cancel it, or may not even cancel the call at all. If the call is + canceled, the "done" callback will still be called and the RpcController + will indicate that the call failed at that time. + """ + raise NotImplementedError + + # Server-side methods below + + def SetFailed(self, reason): + """Sets a failure reason. + + Causes Failed() to return true on the client side. "reason" will be + incorporated into the message returned by ErrorText(). If you find + you need to return machine-readable information about failures, you + should incorporate it into your response protocol buffer and should + NOT call SetFailed(). + """ + raise NotImplementedError + + def IsCanceled(self): + """Checks if the client cancelled the RPC. + + If true, indicates that the client canceled the RPC, so the server may + as well give up on replying to it. The server should still call the + final "done" callback. + """ + raise NotImplementedError + + def NotifyOnCancel(self, callback): + """Sets a callback to invoke on cancel. + + Asks that the given callback be called when the RPC is canceled. The + callback will always be called exactly once. If the RPC completes without + being canceled, the callback will be called after completion. If the RPC + has already been canceled when NotifyOnCancel() is called, the callback + will be called immediately. + + NotifyOnCancel() must be called no more than once per request. + """ + raise NotImplementedError + + +class RpcChannel(object): + + """Abstract interface for an RPC channel. + + An RpcChannel represents a communication line to a service which can be used + to call that service's methods. The service may be running on another + machine. Normally, you should not use an RpcChannel directly, but instead + construct a stub {@link Service} wrapping it. Example: + + Example: + RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234") + RpcController controller = rpcImpl.Controller() + MyService service = MyService_Stub(channel) + service.MyMethod(controller, request, callback) + """ + + def CallMethod(self, method_descriptor, rpc_controller, + request, response_class, done): + """Calls the method identified by the descriptor. + + Call the given method of the remote service. The signature of this + procedure looks the same as Service.CallMethod(), but the requirements + are less strict in one important way: the request object doesn't have to + be of any specific class as long as its descriptor is method.input_type. + """ + raise NotImplementedError diff --git a/froofle/protobuf/service_reflection.py b/froofle/protobuf/service_reflection.py new file mode 100644 index 00000000..bdd6bad5 --- /dev/null +++ b/froofle/protobuf/service_reflection.py @@ -0,0 +1,289 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains metaclasses used to create protocol service and service stub +classes from ServiceDescriptor objects at runtime. + +The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to +inject all useful functionality into the classes output by the protocol +compiler at compile-time. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class GeneratedServiceType(type): + + """Metaclass for service classes created at runtime from ServiceDescriptors. + + Implementations for all methods described in the Service class are added here + by this class. We also create properties to allow getting/setting all fields + in the protocol message. + + The protocol compiler currently uses this metaclass to create protocol service + classes at runtime. Clients can also manually create their own classes at + runtime, as in this example: + + mydescriptor = ServiceDescriptor(.....) + class MyProtoService(service.Service): + __metaclass__ = GeneratedServiceType + DESCRIPTOR = mydescriptor + myservice_instance = MyProtoService() + ... + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service class. + + Args: + name: Name of the class (ignored, but required by the metaclass + protocol). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service class is subclassed. + if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY] + service_builder = _ServiceBuilder(descriptor) + service_builder.BuildService(cls) + + +class GeneratedServiceStubType(GeneratedServiceType): + + """Metaclass for service stubs created at runtime from ServiceDescriptors. + + This class has similar responsibilities as GeneratedServiceType, except that + it creates the service stub classes. + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service stub class. + + Args: + name: Name of the class (ignored, here). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary) + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service stub is subclassed. + if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY] + service_stub_builder = _ServiceStubBuilder(descriptor) + service_stub_builder.BuildServiceStub(cls) + + +class _ServiceBuilder(object): + + """This class constructs a protocol service class using a service descriptor. + + Given a service descriptor, this class constructs a class that represents + the specified service descriptor. One service builder instance constructs + exactly one service class. That means all instances of that class share the + same builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + service class. + """ + self.descriptor = service_descriptor + + def BuildService(self, cls): + """Constructs the service class. + + Args: + cls: The class that will be constructed. + """ + + # CallMethod needs to operate with an instance of the Service class. This + # internal wrapper function exists only to be able to pass the service + # instance to the method that does the real CallMethod work. + def _WrapCallMethod(srvc, method_descriptor, + rpc_controller, request, callback): + self._CallMethod(srvc, method_descriptor, + rpc_controller, request, callback) + self.cls = cls + cls.CallMethod = _WrapCallMethod + cls.GetDescriptor = self._GetDescriptor + cls.GetRequestClass = self._GetRequestClass + cls.GetResponseClass = self._GetResponseClass + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) + + def _GetDescriptor(self): + """Retrieves the service descriptor. + + Returns: + The descriptor of the service (of type ServiceDescriptor). + """ + return self.descriptor + + def _CallMethod(self, srvc, method_descriptor, + rpc_controller, request, callback): + """Calls the method described by a given method descriptor. + + Args: + srvc: Instance of the service for which this method is called. + method_descriptor: Descriptor that represent the method to call. + rpc_controller: RPC controller to use for this method's execution. + request: Request protocol message. + callback: A callback to invoke after the method has completed. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'CallMethod() given method descriptor for wrong service type.') + method = getattr(srvc, method_descriptor.name) + method(rpc_controller, request, callback) + + def _GetRequestClass(self, method_descriptor): + """Returns the class of the request protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + request protocol message class. + + Returns: + A class that represents the input protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetRequestClass() given method descriptor for wrong service type.') + return method_descriptor.input_type._concrete_class + + def _GetResponseClass(self, method_descriptor): + """Returns the class of the response protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + response protocol message class. + + Returns: + A class that represents the output protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetResponseClass() given method descriptor for wrong service type.') + return method_descriptor.output_type._concrete_class + + def _GenerateNonImplementedMethod(self, method): + """Generates and returns a method that can be set for a service methods. + + Args: + method: Descriptor of the service method for which a method is to be + generated. + + Returns: + A method that can be added to the service class. + """ + return lambda inst, rpc_controller, request, callback: ( + self._NonImplementedMethod(method.name, rpc_controller, callback)) + + def _NonImplementedMethod(self, method_name, rpc_controller, callback): + """The body of all methods in the generated service class. + + Args: + method_name: Name of the method being executed. + rpc_controller: RPC controller used to execute this method. + callback: A callback which will be invoked when the method finishes. + """ + rpc_controller.SetFailed('Method %s not implemented.' % method_name) + callback(None) + + +class _ServiceStubBuilder(object): + + """Constructs a protocol service stub class using a service descriptor. + + Given a service descriptor, this class constructs a suitable stub class. + A stub is just a type-safe wrapper around an RpcChannel which emulates a + local implementation of the service. + + One service stub builder instance constructs exactly one class. It means all + instances of that class share the same service stub builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service stub class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + stub class. + """ + self.descriptor = service_descriptor + + def BuildServiceStub(self, cls): + """Constructs the stub class. + + Args: + cls: The class that will be constructed. + """ + + def _ServiceStubInit(stub, rpc_channel): + stub.rpc_channel = rpc_channel + self.cls = cls + cls.__init__ = _ServiceStubInit + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateStubMethod(method)) + + def _GenerateStubMethod(self, method): + return lambda inst, rpc_controller, request, callback: self._StubMethod( + inst, method, rpc_controller, request, callback) + + def _StubMethod(self, stub, method_descriptor, + rpc_controller, request, callback): + """The body of all service methods in the generated stub class. + + Args: + stub: Stub instance. + method_descriptor: Descriptor of the invoked method. + rpc_controller: Rpc controller to execute the method. + request: Request protocol message. + callback: A callback to execute when the method finishes. + """ + stub.rpc_channel.CallMethod( + method_descriptor, rpc_controller, request, + method_descriptor.output_type._concrete_class, callback) diff --git a/froofle/protobuf/text_format.py b/froofle/protobuf/text_format.py new file mode 100644 index 00000000..1c4cadfc --- /dev/null +++ b/froofle/protobuf/text_format.py @@ -0,0 +1,125 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in text format.""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import cStringIO + +from froofle.protobuf import descriptor + +__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue' ] + +def MessageToString(message): + out = cStringIO.StringIO() + PrintMessage(message, out) + result = out.getvalue() + out.close() + return result + +def PrintMessage(message, out, indent = 0): + for field, value in message.ListFields(): + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + PrintField(field, element, out, indent) + else: + PrintField(field, value, out, indent) + +def PrintField(field, value, out, indent = 0): + """Print a single field name/value pair. For repeated fields, the value + should be a single element.""" + + out.write(' ' * indent); + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + PrintFieldValue(field, value, out, indent) + out.write('\n') + +def PrintFieldValue(field, value, out, indent = 0): + """Print a single field value (not including name). For repeated fields, + the value should be a single element.""" + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + out.write(' {\n') + PrintMessage(value, out, indent + 2) + out.write(' ' * indent + '}') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + out.write(field.enum_type.values_by_number[value].name) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + out.write(_CEscape(value)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write("true") + else: + out.write("false") + else: + out.write(str(value)) + +# text.encode('string_escape') does not seem to satisfy our needs as it +# encodes unprintable characters using two-digit hex escapes whereas our +# C++ unescaping function allows hex escapes to be any length. So, +# "\0011".encode('string_escape') ends up being "\\x011", which will be +# decoded in C++ as a single-character string with char code 0x11. +def _CEscape(text): + def escape(c): + o = ord(c) + if o == 10: return r"\n" # optional escape + if o == 13: return r"\r" # optional escape + if o == 9: return r"\t" # optional escape + if o == 39: return r"\'" # optional escape + + if o == 34: return r'\"' # necessary escape + if o == 92: return r"\\" # necessary escape + + if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes + return c + return "".join([escape(c) for c in text]) diff --git a/gerrit_upload.py b/gerrit_upload.py new file mode 100755 index 00000000..a49fb232 --- /dev/null +++ b/gerrit_upload.py @@ -0,0 +1,156 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import getpass +import os +import subprocess +import sys +from tempfile import mkstemp + +from codereview.proto_client import HttpRpc, Proxy +from codereview.review_pb2 import ReviewService_Stub +from codereview.upload_bundle_pb2 import * +from git_command import GitCommand +from error import UploadError + +try: + import readline +except ImportError: + pass + +MAX_SEGMENT_SIZE = 1020 * 1024 + +def _GetRpcServer(email, server, save_cookies): + """Returns an RpcServer. + + Returns: + A new RpcServer, on which RPC calls can be made. + """ + + def GetUserCredentials(): + """Prompts the user for a username and password.""" + e = email + if e is None: + e = raw_input("Email: ").strip() + password = getpass.getpass("Password for %s: " % e) + return (e, password) + + # If this is the dev_appserver, use fake authentication. + lc_server = server.lower() + if lc_server == "localhost" or lc_server.startswith("localhost:"): + if email is None: + email = "test@example.com" + server = HttpRpc( + server, + lambda: (email, "password"), + extra_headers={"Cookie": + 'dev_appserver_login="%s:False"' % email}) + # Don't try to talk to ClientLogin. + server.authenticated = True + return server + + if save_cookies: + cookie_file = ".gerrit_cookies" + else: + cookie_file = None + + return HttpRpc(server, GetUserCredentials, + cookie_file=cookie_file) + +def UploadBundle(project, + server, + email, + dest_project, + dest_branch, + src_branch, + bases, + save_cookies=True): + + srv = _GetRpcServer(email, server, save_cookies) + review = Proxy(ReviewService_Stub(srv)) + tmp_fd, tmp_bundle = mkstemp(".bundle", ".gpq") + os.close(tmp_fd) + + srcid = project.bare_git.rev_parse(src_branch) + revlist = project._revlist(src_branch, *bases) + + if srcid not in revlist: + # This can happen if src_branch is an annotated tag + # + revlist.append(srcid) + revlist_size = len(revlist) * 42 + + try: + cmd = ['bundle', 'create', tmp_bundle, src_branch] + cmd.extend(bases) + if GitCommand(project, cmd).Wait() != 0: + raise UploadError('cannot create bundle') + fd = open(tmp_bundle, "rb") + + bundle_id = None + segment_id = 0 + next_data = fd.read(MAX_SEGMENT_SIZE - revlist_size) + + while True: + this_data = next_data + next_data = fd.read(MAX_SEGMENT_SIZE) + segment_id += 1 + + if bundle_id is None: + req = UploadBundleRequest() + req.dest_project = str(dest_project) + req.dest_branch = str(dest_branch) + for c in revlist: + req.contained_object.append(c) + else: + req = UploadBundleContinue() + req.bundle_id = bundle_id + req.segment_id = segment_id + + req.bundle_data = this_data + if len(next_data) > 0: + req.partial_upload = True + else: + req.partial_upload = False + + if bundle_id is None: + rsp = review.UploadBundle(req) + else: + rsp = review.ContinueBundle(req) + + if rsp.status_code == UploadBundleResponse.CONTINUE: + bundle_id = rsp.bundle_id + elif rsp.status_code == UploadBundleResponse.RECEIVED: + bundle_id = rsp.bundle_id + return bundle_id + else: + if rsp.status_code == UploadBundleResponse.UNKNOWN_PROJECT: + reason = 'unknown project "%s"' % dest_project + elif rsp.status_code == UploadBundleResponse.UNKNOWN_BRANCH: + reason = 'unknown branch "%s"' % dest_branch + elif rsp.status_code == UploadBundleResponse.UNKNOWN_BUNDLE: + reason = 'unknown bundle' + elif rsp.status_code == UploadBundleResponse.NOT_BUNDLE_OWNER: + reason = 'not bundle owner' + elif rsp.status_code == UploadBundleResponse.BUNDLE_CLOSED: + reason = 'bundle closed' + elif rsp.status_code == UploadBundleResponse.UNAUTHORIZED_USER: + reason = ('Unauthorized user. Visit http://%s/hello to sign up.' + % server) + else: + reason = 'unknown error ' + str(rsp.status_code) + raise UploadError(reason) + finally: + os.unlink(tmp_bundle) diff --git a/git_command.py b/git_command.py new file mode 100644 index 00000000..a3bd9192 --- /dev/null +++ b/git_command.py @@ -0,0 +1,164 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess +from error import GitError + +GIT = 'git' +MIN_GIT_VERSION = (1, 5, 4) +GIT_DIR = 'GIT_DIR' +REPO_TRACE = 'REPO_TRACE' + +LAST_GITDIR = None +LAST_CWD = None +try: + TRACE = os.environ[REPO_TRACE] == '1' +except KeyError: + TRACE = False + + +class _GitCall(object): + def version(self): + p = GitCommand(None, ['--version'], capture_stdout=True) + if p.Wait() == 0: + return p.stdout + return None + + def __getattr__(self, name): + name = name.replace('_','-') + def fun(*cmdv): + command = [name] + command.extend(cmdv) + return GitCommand(None, command).Wait() == 0 + return fun +git = _GitCall() + +class GitCommand(object): + def __init__(self, + project, + cmdv, + bare = False, + provide_stdin = False, + capture_stdout = False, + capture_stderr = False, + disable_editor = False, + cwd = None, + gitdir = None): + env = dict(os.environ) + + for e in [REPO_TRACE, + GIT_DIR, + 'GIT_ALTERNATE_OBJECT_DIRECTORIES', + 'GIT_OBJECT_DIRECTORY', + 'GIT_WORK_TREE', + 'GIT_GRAFT_FILE', + 'GIT_INDEX_FILE']: + if e in env: + del env[e] + + if disable_editor: + env['GIT_EDITOR'] = ':' + + if project: + if not cwd: + cwd = project.worktree + if not gitdir: + gitdir = project.gitdir + + command = [GIT] + if bare: + if gitdir: + env[GIT_DIR] = gitdir + cwd = None + command.extend(cmdv) + + if provide_stdin: + stdin = subprocess.PIPE + else: + stdin = None + + if capture_stdout: + stdout = subprocess.PIPE + else: + stdout = None + + if capture_stderr: + stderr = subprocess.PIPE + else: + stderr = None + + if TRACE: + global LAST_CWD + global LAST_GITDIR + + dbg = '' + + if cwd and LAST_CWD != cwd: + if LAST_GITDIR or LAST_CWD: + dbg += '\n' + dbg += ': cd %s\n' % cwd + LAST_CWD = cwd + + if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]: + if LAST_GITDIR or LAST_CWD: + dbg += '\n' + dbg += ': export GIT_DIR=%s\n' % env[GIT_DIR] + LAST_GITDIR = env[GIT_DIR] + + dbg += ': ' + dbg += ' '.join(command) + if stdin == subprocess.PIPE: + dbg += ' 0<|' + if stdout == subprocess.PIPE: + dbg += ' 1>|' + if stderr == subprocess.PIPE: + dbg += ' 2>|' + print >>sys.stderr, dbg + + try: + p = subprocess.Popen(command, + cwd = cwd, + env = env, + stdin = stdin, + stdout = stdout, + stderr = stderr) + except Exception, e: + raise GitError('%s: %s' % (command[1], e)) + + self.process = p + self.stdin = p.stdin + + def Wait(self): + p = self.process + + if p.stdin: + p.stdin.close() + self.stdin = None + + if p.stdout: + self.stdout = p.stdout.read() + p.stdout.close() + else: + p.stdout = None + + if p.stderr: + self.stderr = p.stderr.read() + p.stderr.close() + else: + p.stderr = None + + return self.process.wait() diff --git a/git_config.py b/git_config.py new file mode 100644 index 00000000..f6c5bd1e --- /dev/null +++ b/git_config.py @@ -0,0 +1,344 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys +from error import GitError +from git_command import GitCommand + +R_HEADS = 'refs/heads/' +R_TAGS = 'refs/tags/' +ID_RE = re.compile('^[0-9a-f]{40}$') + +def IsId(rev): + return ID_RE.match(rev) + + +class GitConfig(object): + @classmethod + def ForUser(cls): + return cls(file = os.path.expanduser('~/.gitconfig')) + + @classmethod + def ForRepository(cls, gitdir, defaults=None): + return cls(file = os.path.join(gitdir, 'config'), + defaults = defaults) + + def __init__(self, file, defaults=None): + self.file = file + self.defaults = defaults + self._cache_dict = None + self._remotes = {} + self._branches = {} + + def Has(self, name, include_defaults = True): + """Return true if this configuration file has the key. + """ + name = name.lower() + if name in self._cache: + return True + if include_defaults and self.defaults: + return self.defaults.Has(name, include_defaults = True) + return False + + def GetBoolean(self, name): + """Returns a boolean from the configuration file. + None : The value was not defined, or is not a boolean. + True : The value was set to true or yes. + False: The value was set to false or no. + """ + v = self.GetString(name) + if v is None: + return None + v = v.lower() + if v in ('true', 'yes'): + return True + if v in ('false', 'no'): + return False + return None + + def GetString(self, name, all=False): + """Get the first value for a key, or None if it is not defined. + + This configuration file is used first, if the key is not + defined or all = True then the defaults are also searched. + """ + name = name.lower() + + try: + v = self._cache[name] + except KeyError: + if self.defaults: + return self.defaults.GetString(name, all = all) + v = [] + + if not all: + if v: + return v[0] + return None + + r = [] + r.extend(v) + if self.defaults: + r.extend(self.defaults.GetString(name, all = True)) + return r + + def SetString(self, name, value): + """Set the value(s) for a key. + Only this configuration file is modified. + + The supplied value should be either a string, + or a list of strings (to store multiple values). + """ + name = name.lower() + + try: + old = self._cache[name] + except KeyError: + old = [] + + if value is None: + if old: + del self._cache[name] + self._do('--unset-all', name) + + elif isinstance(value, list): + if len(value) == 0: + self.SetString(name, None) + + elif len(value) == 1: + self.SetString(name, value[0]) + + elif old != value: + self._cache[name] = list(value) + self._do('--replace-all', name, value[0]) + for i in xrange(1, len(value)): + self._do('--add', name, value[i]) + + elif len(old) != 1 or old[0] != value: + self._cache[name] = [value] + self._do('--replace-all', name, value) + + def GetRemote(self, name): + """Get the remote.$name.* configuration values as an object. + """ + try: + r = self._remotes[name] + except KeyError: + r = Remote(self, name) + self._remotes[r.name] = r + return r + + def GetBranch(self, name): + """Get the branch.$name.* configuration values as an object. + """ + try: + b = self._branches[name] + except KeyError: + b = Branch(self, name) + self._branches[b.name] = b + return b + + @property + def _cache(self): + if self._cache_dict is None: + self._cache_dict = self._Read() + return self._cache_dict + + def _Read(self): + d = self._do('--null', '--list') + c = {} + while d: + lf = d.index('\n') + nul = d.index('\0', lf + 1) + + key = d[0:lf] + val = d[lf + 1:nul] + + if key in c: + c[key].append(val) + else: + c[key] = [val] + + d = d[nul + 1:] + return c + + def _do(self, *args): + command = ['config', '--file', self.file] + command.extend(args) + + p = GitCommand(None, + command, + capture_stdout = True, + capture_stderr = True) + if p.Wait() == 0: + return p.stdout + else: + GitError('git config %s: %s' % (str(args), p.stderr)) + + +class RefSpec(object): + """A Git refspec line, split into its components: + + forced: True if the line starts with '+' + src: Left side of the line + dst: Right side of the line + """ + + @classmethod + def FromString(cls, rs): + lhs, rhs = rs.split(':', 2) + if lhs.startswith('+'): + lhs = lhs[1:] + forced = True + else: + forced = False + return cls(forced, lhs, rhs) + + def __init__(self, forced, lhs, rhs): + self.forced = forced + self.src = lhs + self.dst = rhs + + def SourceMatches(self, rev): + if self.src: + if rev == self.src: + return True + if self.src.endswith('/*') and rev.startswith(self.src[:-1]): + return True + return False + + def DestMatches(self, ref): + if self.dst: + if ref == self.dst: + return True + if self.dst.endswith('/*') and ref.startswith(self.dst[:-1]): + return True + return False + + def MapSource(self, rev): + if self.src.endswith('/*'): + return self.dst[:-1] + rev[len(self.src) - 1:] + return self.dst + + def __str__(self): + s = '' + if self.forced: + s += '+' + if self.src: + s += self.src + if self.dst: + s += ':' + s += self.dst + return s + + +class Remote(object): + """Configuration options related to a remote. + """ + def __init__(self, config, name): + self._config = config + self.name = name + self.url = self._Get('url') + self.review = self._Get('review') + self.fetch = map(lambda x: RefSpec.FromString(x), + self._Get('fetch', all=True)) + + def ToLocal(self, rev): + """Convert a remote revision string to something we have locally. + """ + if IsId(rev): + return rev + if rev.startswith(R_TAGS): + return rev + + if not rev.startswith('refs/'): + rev = R_HEADS + rev + + for spec in self.fetch: + if spec.SourceMatches(rev): + return spec.MapSource(rev) + raise GitError('remote %s does not have %s' % (self.name, rev)) + + def WritesTo(self, ref): + """True if the remote stores to the tracking ref. + """ + for spec in self.fetch: + if spec.DestMatches(ref): + return True + return False + + def ResetFetch(self): + """Set the fetch refspec to its default value. + """ + self.fetch = [RefSpec(True, + 'refs/heads/*', + 'refs/remotes/%s/*' % self.name)] + + def Save(self): + """Save this remote to the configuration. + """ + self._Set('url', self.url) + self._Set('review', self.review) + self._Set('fetch', map(lambda x: str(x), self.fetch)) + + def _Set(self, key, value): + key = 'remote.%s.%s' % (self.name, key) + return self._config.SetString(key, value) + + def _Get(self, key, all=False): + key = 'remote.%s.%s' % (self.name, key) + return self._config.GetString(key, all = all) + + +class Branch(object): + """Configuration options related to a single branch. + """ + def __init__(self, config, name): + self._config = config + self.name = name + self.merge = self._Get('merge') + + r = self._Get('remote') + if r: + self.remote = self._config.GetRemote(r) + else: + self.remote = None + + @property + def LocalMerge(self): + """Convert the merge spec to a local name. + """ + if self.remote and self.merge: + return self.remote.ToLocal(self.merge) + return None + + def Save(self): + """Save this branch back into the configuration. + """ + self._Set('merge', self.merge) + if self.remote: + self._Set('remote', self.remote.name) + else: + self._Set('remote', None) + + def _Set(self, key, value): + key = 'branch.%s.%s' % (self.name, key) + return self._config.SetString(key, value) + + def _Get(self, key, all=False): + key = 'branch.%s.%s' % (self.name, key) + return self._config.GetString(key, all = all) diff --git a/import_ext.py b/import_ext.py new file mode 100644 index 00000000..2a1ebf88 --- /dev/null +++ b/import_ext.py @@ -0,0 +1,422 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import stat +import sys +import urllib2 +import StringIO + +from error import GitError, ImportError +from git_command import GitCommand + +class ImportExternal(object): + """Imports a single revision from a non-git data source. + Suitable for use to import a tar or zip based snapshot. + """ + def __init__(self): + self._marks = 0 + self._files = {} + self._tempref = 'refs/repo-external/import' + + self._urls = [] + self._remap = [] + self.parent = None + self._user_name = 'Upstream' + self._user_email = 'upstream-import@none' + self._user_when = 1000000 + + self.commit = None + + def Clone(self): + r = self.__class__() + + r.project = self.project + for u in self._urls: + r._urls.append(u) + for p in self._remap: + r._remap.append(_PathMap(r, p._old, p._new)) + + return r + + def SetProject(self, project): + self.project = project + + def SetVersion(self, version): + self.version = version + + def AddUrl(self, url): + self._urls.append(url) + + def SetParent(self, commit_hash): + self.parent = commit_hash + + def SetCommit(self, commit_hash): + self.commit = commit_hash + + def RemapPath(self, old, new, replace_version=True): + self._remap.append(_PathMap(self, old, new)) + + @property + def TagName(self): + v = '' + for c in self.version: + if c >= '0' and c <= '9': + v += c + elif c >= 'A' and c <= 'Z': + v += c + elif c >= 'a' and c <= 'z': + v += c + elif c in ('-', '_', '.', '/', '+', '@'): + v += c + return 'upstream/%s' % v + + @property + def PackageName(self): + n = self.project.name + if n.startswith('platform/'): + # This was not my finest moment... + # + n = n[len('platform/'):] + return n + + def Import(self): + self._need_graft = False + if self.parent: + try: + self.project.bare_git.cat_file('-e', self.parent) + except GitError: + self._need_graft = True + + gfi = GitCommand(self.project, + ['fast-import', '--force', '--quiet'], + bare = True, + provide_stdin = True) + try: + self._out = gfi.stdin + + try: + self._UnpackFiles() + self._MakeCommit() + self._out.flush() + finally: + rc = gfi.Wait() + if rc != 0: + raise ImportError('fast-import failed') + + if self._need_graft: + id = self._GraftCommit() + else: + id = self.project.bare_git.rev_parse('%s^0' % self._tempref) + + if self.commit and self.commit != id: + raise ImportError('checksum mismatch: %s expected,' + ' %s imported' % (self.commit, id)) + + self._MakeTag(id) + return id + finally: + try: + self.project.bare_git.DeleteRef(self._tempref) + except GitError: + pass + + def _PickUrl(self, failed): + u = map(lambda x: x.replace('%version%', self.version), self._urls) + for f in failed: + if f in u: + u.remove(f) + if len(u) == 0: + return None + return random.choice(u) + + def _OpenUrl(self): + failed = {} + while True: + url = self._PickUrl(failed.keys()) + if url is None: + why = 'Cannot download %s' % self.project.name + + if failed: + why += ': one or more mirrors are down\n' + bad_urls = list(failed.keys()) + bad_urls.sort() + for url in bad_urls: + why += ' %s: %s\n' % (url, failed[url]) + else: + why += ': no mirror URLs' + raise ImportError(why) + + print >>sys.stderr, "Getting %s ..." % url + try: + return urllib2.urlopen(url), url + except urllib2.HTTPError, e: + failed[url] = e.code + except urllib2.URLError, e: + failed[url] = e.reason[1] + except OSError, e: + failed[url] = e.strerror + + def _UnpackFiles(self): + raise NotImplementedError + + def _NextMark(self): + self._marks += 1 + return self._marks + + def _UnpackOneFile(self, mode, size, name, fd): + if stat.S_ISDIR(mode): # directory + return + else: + mode = self._CleanMode(mode, name) + + old_name = name + name = self._CleanName(name) + + if stat.S_ISLNK(mode) and self._remap: + # The link is relative to the old_name, and may need to + # be rewritten according to our remap rules if it goes + # up high enough in the tree structure. + # + dest = self._RewriteLink(fd.read(size), old_name, name) + fd = StringIO.StringIO(dest) + size = len(dest) + + fi = _File(mode, name, self._NextMark()) + + self._out.write('blob\n') + self._out.write('mark :%d\n' % fi.mark) + self._out.write('data %d\n' % size) + while size > 0: + n = min(2048, size) + self._out.write(fd.read(n)) + size -= n + self._out.write('\n') + self._files[fi.name] = fi + + def _SetFileMode(self, name, mode): + if not stat.S_ISDIR(mode): + mode = self._CleanMode(mode, name) + name = self._CleanName(name) + try: + fi = self._files[name] + except KeyError: + raise ImportError('file %s was not unpacked' % name) + fi.mode = mode + + def _RewriteLink(self, dest, relto_old, relto_new): + # Drop the last components of the symlink itself + # as the dest is relative to the directory its in. + # + relto_old = _TrimPath(relto_old) + relto_new = _TrimPath(relto_new) + + # Resolve the link to be absolute from the top of + # the archive, so we can remap its destination. + # + while dest.find('/./') >= 0 or dest.find('//') >= 0: + dest = dest.replace('/./', '/') + dest = dest.replace('//', '/') + + if dest.startswith('../') or dest.find('/../') > 0: + dest = _FoldPath('%s/%s' % (relto_old, dest)) + + for pm in self._remap: + if pm.Matches(dest): + dest = pm.Apply(dest) + break + + dest, relto_new = _StripCommonPrefix(dest, relto_new) + while relto_new: + i = relto_new.find('/') + if i > 0: + relto_new = relto_new[i + 1:] + else: + relto_new = '' + dest = '../' + dest + return dest + + def _CleanMode(self, mode, name): + if stat.S_ISREG(mode): # regular file + if (mode & 0111) == 0: + return 0644 + else: + return 0755 + elif stat.S_ISLNK(mode): # symlink + return stat.S_IFLNK + else: + raise ImportError('invalid mode %o in %s' % (mode, name)) + + def _CleanName(self, name): + old_name = name + for pm in self._remap: + if pm.Matches(name): + name = pm.Apply(name) + break + while name.startswith('/'): + name = name[1:] + if not name: + raise ImportError('path %s is empty after remap' % old_name) + if name.find('/./') >= 0 or name.find('/../') >= 0: + raise ImportError('path %s contains relative parts' % name) + return name + + def _MakeCommit(self): + msg = '%s %s\n' % (self.PackageName, self.version) + + self._out.write('commit %s\n' % self._tempref) + self._out.write('committer %s <%s> %d +0000\n' % ( + self._user_name, + self._user_email, + self._user_when)) + self._out.write('data %d\n' % len(msg)) + self._out.write(msg) + self._out.write('\n') + if self.parent and not self._need_graft: + self._out.write('from %s^0\n' % self.parent) + self._out.write('deleteall\n') + + for f in self._files.values(): + self._out.write('M %o :%d %s\n' % (f.mode, f.mark, f.name)) + self._out.write('\n') + + def _GraftCommit(self): + raw = self.project.bare_git.cat_file('commit', self._tempref) + raw = raw.split("\n") + while raw[1].startswith('parent '): + del raw[1] + raw.insert(1, 'parent %s' % self.parent) + id = self._WriteObject('commit', "\n".join(raw)) + + graft_file = os.path.join(self.project.gitdir, 'info/grafts') + if os.path.exists(graft_file): + graft_list = open(graft_file, 'rb').read().split("\n") + if graft_list and graft_list[-1] == '': + del graft_list[-1] + else: + graft_list = [] + + exists = False + for line in graft_list: + if line == id: + exists = True + break + + if not exists: + graft_list.append(id) + graft_list.append('') + fd = open(graft_file, 'wb') + fd.write("\n".join(graft_list)) + fd.close() + + return id + + def _MakeTag(self, id): + name = self.TagName + + raw = [] + raw.append('object %s' % id) + raw.append('type commit') + raw.append('tag %s' % name) + raw.append('tagger %s <%s> %d +0000' % ( + self._user_name, + self._user_email, + self._user_when)) + raw.append('') + raw.append('%s %s\n' % (self.PackageName, self.version)) + + tagid = self._WriteObject('tag', "\n".join(raw)) + self.project.bare_git.UpdateRef('refs/tags/%s' % name, tagid) + + def _WriteObject(self, type, data): + wo = GitCommand(self.project, + ['hash-object', '-t', type, '-w', '--stdin'], + bare = True, + provide_stdin = True, + capture_stdout = True, + capture_stderr = True) + wo.stdin.write(data) + if wo.Wait() != 0: + raise GitError('cannot create %s from (%s)' % (type, data)) + return wo.stdout[:-1] + + +def _TrimPath(path): + i = path.rfind('/') + if i > 0: + path = path[0:i] + return '' + +def _StripCommonPrefix(a, b): + while True: + ai = a.find('/') + bi = b.find('/') + if ai > 0 and bi > 0 and a[0:ai] == b[0:bi]: + a = a[ai + 1:] + b = b[bi + 1:] + else: + break + return a, b + +def _FoldPath(path): + while True: + if path.startswith('../'): + return path + + i = path.find('/../') + if i <= 0: + if path.startswith('/'): + return path[1:] + return path + + lhs = path[0:i] + rhs = path[i + 4:] + + i = lhs.rfind('/') + if i > 0: + path = lhs[0:i + 1] + rhs + else: + path = rhs + +class _File(object): + def __init__(self, mode, name, mark): + self.mode = mode + self.name = name + self.mark = mark + + +class _PathMap(object): + def __init__(self, imp, old, new): + self._imp = imp + self._old = old + self._new = new + + def _r(self, p): + return p.replace('%version%', self._imp.version) + + @property + def old(self): + return self._r(self._old) + + @property + def new(self): + return self._r(self._new) + + def Matches(self, name): + return name.startswith(self.old) + + def Apply(self, name): + return self.new + name[len(self.old):] diff --git a/import_tar.py b/import_tar.py new file mode 100644 index 00000000..d7ce14de --- /dev/null +++ b/import_tar.py @@ -0,0 +1,206 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bz2 +import stat +import tarfile +import zlib +import StringIO + +from import_ext import ImportExternal +from error import ImportError + +class ImportTar(ImportExternal): + """Streams a (optionally compressed) tar file from the network + directly into a Project's Git repository. + """ + @classmethod + def CanAccept(cls, url): + """Can this importer read and unpack the data stored at url? + """ + if url.endswith('.tar.gz') or url.endswith('.tgz'): + return True + if url.endswith('.tar.bz2'): + return True + if url.endswith('.tar'): + return True + return False + + def _UnpackFiles(self): + url_fd, url = self._OpenUrl() + try: + if url.endswith('.tar.gz') or url.endswith('.tgz'): + tar_fd = _Gzip(url_fd) + elif url.endswith('.tar.bz2'): + tar_fd = _Bzip2(url_fd) + elif url.endswith('.tar'): + tar_fd = _Raw(url_fd) + else: + raise ImportError('non-tar file extension: %s' % url) + + try: + tar = tarfile.TarFile(name = url, + mode = 'r', + fileobj = tar_fd) + try: + for entry in tar: + mode = entry.mode + + if (mode & 0170000) == 0: + if entry.isdir(): + mode |= stat.S_IFDIR + elif entry.isfile() or entry.islnk(): # hard links as files + mode |= stat.S_IFREG + elif entry.issym(): + mode |= stat.S_IFLNK + + if stat.S_ISLNK(mode): # symlink + data_fd = StringIO.StringIO(entry.linkname) + data_sz = len(entry.linkname) + elif stat.S_ISDIR(mode): # directory + data_fd = StringIO.StringIO('') + data_sz = 0 + else: + data_fd = tar.extractfile(entry) + data_sz = entry.size + + self._UnpackOneFile(mode, data_sz, entry.name, data_fd) + finally: + tar.close() + finally: + tar_fd.close() + finally: + url_fd.close() + + + +class _DecompressStream(object): + """file like object to decompress a tar stream + """ + def __init__(self, fd): + self._fd = fd + self._pos = 0 + self._buf = None + + def tell(self): + return self._pos + + def seek(self, offset): + d = offset - self._pos + if d > 0: + self.read(d) + elif d == 0: + pass + else: + raise NotImplementedError, 'seek backwards' + + def close(self): + self._fd = None + + def read(self, size = -1): + if not self._fd: + raise EOFError, 'Reached EOF' + + r = [] + try: + if size >= 0: + self._ReadChunk(r, size) + else: + while True: + self._ReadChunk(r, 2048) + except EOFError: + pass + + if len(r) == 1: + r = r[0] + else: + r = ''.join(r) + self._pos += len(r) + return r + + def _ReadChunk(self, r, size): + b = self._buf + try: + while size > 0: + if b is None or len(b) == 0: + b = self._Decompress(self._fd.read(2048)) + continue + + use = min(size, len(b)) + r.append(b[:use]) + b = b[use:] + size -= use + finally: + self._buf = b + + def _Decompress(self, b): + raise NotImplementedError, '_Decompress' + + +class _Raw(_DecompressStream): + """file like object for an uncompressed stream + """ + def __init__(self, fd): + _DecompressStream.__init__(self, fd) + + def _Decompress(self, b): + return b + + +class _Bzip2(_DecompressStream): + """file like object to decompress a .bz2 stream + """ + def __init__(self, fd): + _DecompressStream.__init__(self, fd) + self._bz = bz2.BZ2Decompressor() + + def _Decompress(self, b): + return self._bz.decompress(b) + + +_FHCRC, _FEXTRA, _FNAME, _FCOMMENT = 2, 4, 8, 16 +class _Gzip(_DecompressStream): + """file like object to decompress a .gz stream + """ + def __init__(self, fd): + _DecompressStream.__init__(self, fd) + self._z = zlib.decompressobj(-zlib.MAX_WBITS) + + magic = fd.read(2) + if magic != '\037\213': + raise IOError, 'Not a gzipped file' + + method = ord(fd.read(1)) + if method != 8: + raise IOError, 'Unknown compression method' + + flag = ord(fd.read(1)) + fd.read(6) + + if flag & _FEXTRA: + xlen = ord(fd.read(1)) + xlen += 256 * ord(fd.read(1)) + fd.read(xlen) + if flag & _FNAME: + while fd.read(1) != '\0': + pass + if flag & _FCOMMENT: + while fd.read(1) != '\0': + pass + if flag & _FHCRC: + fd.read(2) + + def _Decompress(self, b): + return self._z.decompress(b) diff --git a/import_zip.py b/import_zip.py new file mode 100644 index 00000000..08aff326 --- /dev/null +++ b/import_zip.py @@ -0,0 +1,345 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import stat +import struct +import zlib +import cStringIO + +from import_ext import ImportExternal +from error import ImportError + +class ImportZip(ImportExternal): + """Streams a zip file from the network directly into a Project's + Git repository. + """ + @classmethod + def CanAccept(cls, url): + """Can this importer read and unpack the data stored at url? + """ + if url.endswith('.zip') or url.endswith('.jar'): + return True + return False + + def _UnpackFiles(self): + url_fd, url = self._OpenUrl() + try: + if not self.__class__.CanAccept(url): + raise ImportError('non-zip file extension: %s' % url) + + zip = _ZipFile(url_fd) + for entry in zip.FileRecords(): + data = zip.Open(entry).read() + sz = len(data) + + if data and _SafeCRLF(data): + data = data.replace('\r\n', '\n') + sz = len(data) + + fd = cStringIO.StringIO(data) + self._UnpackOneFile(entry.mode, sz, entry.name, fd) + zip.Close(entry) + + for entry in zip.CentralDirectory(): + self._SetFileMode(entry.name, entry.mode) + + zip.CheckTail() + finally: + url_fd.close() + + +def _SafeCRLF(data): + """Is it reasonably safe to perform a CRLF->LF conversion? + + If the stream contains a NUL byte it is likely binary, + and thus a CRLF->LF conversion may damage the stream. + + If the only NUL is in the last position of the stream, + but it otherwise can do a CRLF<->LF conversion we do + the CRLF conversion anyway. At least one source ZIP + file has this structure in its source code. + + If every occurrance of a CR and LF is paired up as a + CRLF pair then the conversion is safely bi-directional. + s/\r\n/\n/g == s/\n/\r\\n/g can convert between them. + """ + nul = data.find('\0') + if 0 <= nul and nul < (len(data) - 1): + return False + + n_lf = 0 + last = 0 + while True: + lf = data.find('\n', last) + if lf < 0: + break + if lf == 0 or data[lf - 1] != '\r': + return False + last = lf + 1 + n_lf += 1 + return n_lf > 0 + +class _ZipFile(object): + """Streaming iterator to parse a zip file on the fly. + """ + def __init__(self, fd): + self._fd = _UngetStream(fd) + + def FileRecords(self): + return _FileIter(self._fd) + + def CentralDirectory(self): + return _CentIter(self._fd) + + def CheckTail(self): + type_buf = self._fd.read(4) + type = struct.unpack('> 16 + else: + self.mode = stat.S_IFREG | 0644 + + +class _UngetStream(object): + """File like object to read and rewind a stream. + """ + def __init__(self, fd): + self._fd = fd + self._buf = None + + def read(self, size = -1): + r = [] + try: + if size >= 0: + self._ReadChunk(r, size) + else: + while True: + self._ReadChunk(r, 2048) + except EOFError: + pass + + if len(r) == 1: + return r[0] + return ''.join(r) + + def unread(self, buf): + b = self._buf + if b is None or len(b) == 0: + self._buf = buf + else: + self._buf = buf + b + + def _ReadChunk(self, r, size): + b = self._buf + try: + while size > 0: + if b is None or len(b) == 0: + b = self._Inflate(self._fd.read(2048)) + if not b: + raise EOFError() + continue + + use = min(size, len(b)) + r.append(b[:use]) + b = b[use:] + size -= use + finally: + self._buf = b + + def _Inflate(self, b): + return b + + +class _FixedLengthStream(_UngetStream): + """File like object to read a fixed length stream. + """ + def __init__(self, fd, have): + _UngetStream.__init__(self, fd) + self._have = have + + def _Inflate(self, b): + n = self._have + if n == 0: + self._fd.unread(b) + return None + + if len(b) > n: + self._fd.unread(b[n:]) + b = b[:n] + self._have -= len(b) + return b + + +class _InflateStream(_UngetStream): + """Inflates the stream as it reads input. + """ + def __init__(self, fd): + _UngetStream.__init__(self, fd) + self._z = zlib.decompressobj(-zlib.MAX_WBITS) + + def _Inflate(self, b): + z = self._z + if not z: + self._fd.unread(b) + return None + + b = z.decompress(b) + if z.unconsumed_tail != '': + self._fd.unread(z.unconsumed_tail) + elif z.unused_data != '': + self._fd.unread(z.unused_data) + self._z = None + return b diff --git a/main.py b/main.py new file mode 100755 index 00000000..56092990 --- /dev/null +++ b/main.py @@ -0,0 +1,198 @@ +#!/bin/sh +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +magic='--calling-python-from-/bin/sh--' +"""exec" python2.4 -E "$0" "$@" """#$magic" +if __name__ == '__main__': + import sys + if sys.argv[-1] == '#%s' % magic: + del sys.argv[-1] +del magic + +import optparse +import os +import re +import sys + +from command import InteractiveCommand, PagedCommand +from error import NoSuchProjectError +from error import RepoChangedException +from manifest import Manifest +from pager import RunPager + +from subcmds import all as all_commands + +global_options = optparse.OptionParser( + usage="repo [-p|--paginate|--no-pager] COMMAND [ARGS]" + ) +global_options.add_option('-p', '--paginate', + dest='pager', action='store_true', + help='display command output in the pager') +global_options.add_option('--no-pager', + dest='no_pager', action='store_true', + help='disable the pager') + +class _Repo(object): + def __init__(self, repodir): + self.repodir = repodir + self.commands = all_commands + + def _Run(self, argv): + name = None + glob = [] + + for i in xrange(0, len(argv)): + if not argv[i].startswith('-'): + name = argv[i] + if i > 0: + glob = argv[:i] + argv = argv[i + 1:] + break + if not name: + glob = argv + name = 'help' + argv = [] + gopts, gargs = global_options.parse_args(glob) + + try: + cmd = self.commands[name] + except KeyError: + print >>sys.stderr,\ + "repo: '%s' is not a repo command. See 'repo help'."\ + % name + sys.exit(1) + + cmd.repodir = self.repodir + cmd.manifest = Manifest(cmd.repodir) + + if not gopts.no_pager and not isinstance(cmd, InteractiveCommand): + config = cmd.manifest.globalConfig + if gopts.pager: + use_pager = True + else: + use_pager = config.GetBoolean('pager.%s' % name) + if use_pager is None: + use_pager = isinstance(cmd, PagedCommand) + if use_pager: + RunPager(config) + + copts, cargs = cmd.OptionParser.parse_args(argv) + try: + cmd.Execute(copts, cargs) + except NoSuchProjectError, e: + if e.name: + print >>sys.stderr, 'error: project %s not found' % e.name + else: + print >>sys.stderr, 'error: no project in current directory' + sys.exit(1) + +def _MyWrapperPath(): + return os.path.join(os.path.dirname(__file__), 'repo') + +def _CurrentWrapperVersion(): + VERSION = None + pat = re.compile(r'^VERSION *=') + fd = open(_MyWrapperPath()) + for line in fd: + if pat.match(line): + fd.close() + exec line + return VERSION + raise NameError, 'No VERSION in repo script' + +def _CheckWrapperVersion(ver, repo_path): + if not repo_path: + repo_path = '~/bin/repo' + + if not ver: + print >>sys.stderr, 'no --wrapper-version argument' + sys.exit(1) + + exp = _CurrentWrapperVersion() + ver = tuple(map(lambda x: int(x), ver.split('.'))) + if len(ver) == 1: + ver = (0, ver[0]) + + if exp[0] > ver[0] or ver < (0, 4): + exp_str = '.'.join(map(lambda x: str(x), exp)) + print >>sys.stderr, """ +!!! A new repo command (%5s) is available. !!! +!!! You must upgrade before you can continue: !!! + + cp %s %s +""" % (exp_str, _MyWrapperPath(), repo_path) + sys.exit(1) + + if exp > ver: + exp_str = '.'.join(map(lambda x: str(x), exp)) + print >>sys.stderr, """ +... A new repo command (%5s) is available. +... You should upgrade soon: + + cp %s %s +""" % (exp_str, _MyWrapperPath(), repo_path) + +def _CheckRepoDir(dir): + if not dir: + print >>sys.stderr, 'no --repo-dir argument' + sys.exit(1) + +def _PruneOptions(argv, opt): + i = 0 + while i < len(argv): + a = argv[i] + if a == '--': + break + if a.startswith('--'): + eq = a.find('=') + if eq > 0: + a = a[0:eq] + if not opt.has_option(a): + del argv[i] + continue + i += 1 + +def _Main(argv): + opt = optparse.OptionParser(usage="repo wrapperinfo -- ...") + opt.add_option("--repo-dir", dest="repodir", + help="path to .repo/") + opt.add_option("--wrapper-version", dest="wrapper_version", + help="version of the wrapper script") + opt.add_option("--wrapper-path", dest="wrapper_path", + help="location of the wrapper script") + _PruneOptions(argv, opt) + opt, argv = opt.parse_args(argv) + + _CheckWrapperVersion(opt.wrapper_version, opt.wrapper_path) + _CheckRepoDir(opt.repodir) + + repo = _Repo(opt.repodir) + try: + repo._Run(argv) + except KeyboardInterrupt: + sys.exit(1) + except RepoChangedException: + # If the repo or manifest changed, re-exec ourselves. + # + try: + os.execv(__file__, sys.argv) + except OSError, e: + print >>sys.stderr, 'fatal: cannot restart repo after upgrade' + print >>sys.stderr, 'fatal: %s' % e + sys.exit(128) + +if __name__ == '__main__': + _Main(sys.argv[1:]) diff --git a/manifest.py b/manifest.py new file mode 100644 index 00000000..45b0f9a5 --- /dev/null +++ b/manifest.py @@ -0,0 +1,338 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import xml.dom.minidom + +from editor import Editor +from git_config import GitConfig, IsId +from import_tar import ImportTar +from import_zip import ImportZip +from project import Project, MetaProject, R_TAGS +from remote import Remote +from error import ManifestParseError + +MANIFEST_FILE_NAME = 'manifest.xml' + +class _Default(object): + """Project defaults within the manifest.""" + + revision = None + remote = None + + +class Manifest(object): + """manages the repo configuration file""" + + def __init__(self, repodir): + self.repodir = os.path.abspath(repodir) + self.topdir = os.path.dirname(self.repodir) + self.manifestFile = os.path.join(self.repodir, MANIFEST_FILE_NAME) + + self.globalConfig = GitConfig.ForUser() + Editor.globalConfig = self.globalConfig + + self.repoProject = MetaProject(self, 'repo', + gitdir = os.path.join(repodir, 'repo/.git'), + worktree = os.path.join(repodir, 'repo')) + + wt = os.path.join(repodir, 'manifests') + gd_new = os.path.join(repodir, 'manifests.git') + gd_old = os.path.join(wt, '.git') + if os.path.exists(gd_new) or not os.path.exists(gd_old): + gd = gd_new + else: + gd = gd_old + self.manifestProject = MetaProject(self, 'manifests', + gitdir = gd, + worktree = wt) + + self._Unload() + + def Link(self, name): + """Update the repo metadata to use a different manifest. + """ + path = os.path.join(self.manifestProject.worktree, name) + if not os.path.isfile(path): + raise ManifestParseError('manifest %s not found' % name) + + old = self.manifestFile + try: + self.manifestFile = path + self._Unload() + self._Load() + finally: + self.manifestFile = old + + try: + if os.path.exists(self.manifestFile): + os.remove(self.manifestFile) + os.symlink('manifests/%s' % name, self.manifestFile) + except OSError, e: + raise ManifestParseError('cannot link manifest %s' % name) + + @property + def projects(self): + self._Load() + return self._projects + + @property + def remotes(self): + self._Load() + return self._remotes + + @property + def default(self): + self._Load() + return self._default + + def _Unload(self): + self._loaded = False + self._projects = {} + self._remotes = {} + self._default = None + self.branch = None + + def _Load(self): + if not self._loaded: + self._ParseManifest() + self._loaded = True + + def _ParseManifest(self): + root = xml.dom.minidom.parse(self.manifestFile) + if not root or not root.childNodes: + raise ManifestParseError, \ + "no root node in %s" % \ + self.manifestFile + + config = root.childNodes[0] + if config.nodeName != 'manifest': + raise ManifestParseError, \ + "no in %s" % \ + self.manifestFile + + self.branch = config.getAttribute('branch') + if not self.branch: + self.branch = 'default' + + for node in config.childNodes: + if node.nodeName == 'remote': + remote = self._ParseRemote(node) + if self._remotes.get(remote.name): + raise ManifestParseError, \ + 'duplicate remote %s in %s' % \ + (remote.name, self.manifestFile) + self._remotes[remote.name] = remote + + for node in config.childNodes: + if node.nodeName == 'default': + if self._default is not None: + raise ManifestParseError, \ + 'duplicate default in %s' % \ + (self.manifestFile) + self._default = self._ParseDefault(node) + if self._default is None: + self._default = _Default() + + for node in config.childNodes: + if node.nodeName == 'project': + project = self._ParseProject(node) + if self._projects.get(project.name): + raise ManifestParseError, \ + 'duplicate project %s in %s' % \ + (project.name, self.manifestFile) + self._projects[project.name] = project + + def _ParseRemote(self, node): + """ + reads a element from the manifest file + """ + name = self._reqatt(node, 'name') + fetch = self._reqatt(node, 'fetch') + review = node.getAttribute('review') + + r = Remote(name=name, + fetch=fetch, + review=review) + + for n in node.childNodes: + if n.nodeName == 'require': + r.requiredCommits.append(self._reqatt(n, 'commit')) + + return r + + def _ParseDefault(self, node): + """ + reads a element from the manifest file + """ + d = _Default() + d.remote = self._get_remote(node) + d.revision = node.getAttribute('revision') + return d + + def _ParseProject(self, node): + """ + reads a element from the manifest file + """ + name = self._reqatt(node, 'name') + + remote = self._get_remote(node) + if remote is None: + remote = self._default.remote + if remote is None: + raise ManifestParseError, \ + "no remote for project %s within %s" % \ + (name, self.manifestFile) + + revision = node.getAttribute('revision') + if not revision: + revision = self._default.revision + if not revision: + raise ManifestParseError, \ + "no revision for project %s within %s" % \ + (name, self.manifestFile) + + path = node.getAttribute('path') + if not path: + path = name + if path.startswith('/'): + raise ManifestParseError, \ + "project %s path cannot be absolute in %s" % \ + (name, self.manifestFile) + + worktree = os.path.join(self.topdir, path) + gitdir = os.path.join(self.repodir, 'projects/%s.git' % path) + + project = Project(manifest = self, + name = name, + remote = remote, + gitdir = gitdir, + worktree = worktree, + relpath = path, + revision = revision) + + for n in node.childNodes: + if n.nodeName == 'remote': + r = self._ParseRemote(n) + if project.extraRemotes.get(r.name) \ + or project.remote.name == r.name: + raise ManifestParseError, \ + 'duplicate remote %s in project %s in %s' % \ + (r.name, project.name, self.manifestFile) + project.extraRemotes[r.name] = r + elif n.nodeName == 'copyfile': + self._ParseCopyFile(project, n) + + to_resolve = [] + by_version = {} + + for n in node.childNodes: + if n.nodeName == 'import': + self._ParseImport(project, n, to_resolve, by_version) + + for pair in to_resolve: + sn, pr = pair + try: + sn.SetParent(by_version[pr].commit) + except KeyError: + raise ManifestParseError, \ + 'snapshot %s not in project %s in %s' % \ + (pr, project.name, self.manifestFile) + + return project + + def _ParseImport(self, project, import_node, to_resolve, by_version): + first_url = None + for node in import_node.childNodes: + if node.nodeName == 'mirror': + first_url = self._reqatt(node, 'url') + break + if not first_url: + raise ManifestParseError, \ + 'mirror url required for project %s in %s' % \ + (project.name, self.manifestFile) + + imp = None + for cls in [ImportTar, ImportZip]: + if cls.CanAccept(first_url): + imp = cls() + break + if not imp: + raise ManifestParseError, \ + 'snapshot %s unsupported for project %s in %s' % \ + (first_url, project.name, self.manifestFile) + + imp.SetProject(project) + + for node in import_node.childNodes: + if node.nodeName == 'remap': + old = node.getAttribute('strip') + new = node.getAttribute('insert') + imp.RemapPath(old, new) + + elif node.nodeName == 'mirror': + imp.AddUrl(self._reqatt(node, 'url')) + + for node in import_node.childNodes: + if node.nodeName == 'snapshot': + sn = imp.Clone() + sn.SetVersion(self._reqatt(node, 'version')) + sn.SetCommit(node.getAttribute('check')) + + pr = node.getAttribute('prior') + if pr: + if IsId(pr): + sn.SetParent(pr) + else: + to_resolve.append((sn, pr)) + + rev = R_TAGS + sn.TagName + + if rev in project.snapshots: + raise ManifestParseError, \ + 'duplicate snapshot %s for project %s in %s' % \ + (sn.version, project.name, self.manifestFile) + project.snapshots[rev] = sn + by_version[sn.version] = sn + + def _ParseCopyFile(self, project, node): + src = self._reqatt(node, 'src') + dest = self._reqatt(node, 'dest') + # src is project relative, and dest is relative to the top of the tree + project.AddCopyFile(src, os.path.join(self.topdir, dest)) + + def _get_remote(self, node): + name = node.getAttribute('remote') + if not name: + return None + + v = self._remotes.get(name) + if not v: + raise ManifestParseError, \ + "remote %s not defined in %s" % \ + (name, self.manifestFile) + return v + + def _reqatt(self, node, attname): + """ + reads a required attribute from the node. + """ + v = node.getAttribute(attname) + if not v: + raise ManifestParseError, \ + "no %s in <%s> within %s" % \ + (attname, node.nodeName, self.manifestFile) + return v diff --git a/pager.py b/pager.py new file mode 100755 index 00000000..320131cd --- /dev/null +++ b/pager.py @@ -0,0 +1,84 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import select +import sys + +active = False + +def RunPager(globalConfig): + global active + + if not os.isatty(0): + return + pager = _SelectPager(globalConfig) + if pager == '' or pager == 'cat': + return + + # This process turns into the pager; a child it forks will + # do the real processing and output back to the pager. This + # is necessary to keep the pager in control of the tty. + # + try: + r, w = os.pipe() + pid = os.fork() + if not pid: + os.dup2(w, 1) + os.dup2(w, 2) + os.close(r) + os.close(w) + active = True + return + + os.dup2(r, 0) + os.close(r) + os.close(w) + + _BecomePager(pager) + except Exception: + print >>sys.stderr, "fatal: cannot start pager '%s'" % pager + os.exit(255) + +def _SelectPager(globalConfig): + try: + return os.environ['GIT_PAGER'] + except KeyError: + pass + + pager = globalConfig.GetString('core.pager') + if pager: + return pager + + try: + return os.environ['PAGER'] + except KeyError: + pass + + return 'less' + +def _BecomePager(pager): + # Delaying execution of the pager until we have output + # ready works around a long-standing bug in popularly + # available versions of 'less', a better 'more'. + # + a, b, c = select.select([0], [], [0]) + + os.environ['LESS'] = 'FRSX' + + try: + os.execvp(pager, [pager]) + except OSError, e: + os.execv('/bin/sh', ['sh', '-c', pager]) diff --git a/project.py b/project.py new file mode 100644 index 00000000..7c0c58f9 --- /dev/null +++ b/project.py @@ -0,0 +1,1058 @@ +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import filecmp +import os +import re +import shutil +import stat +import sys +import urllib2 + +from color import Coloring +from git_command import GitCommand +from git_config import GitConfig, IsId +from gerrit_upload import UploadBundle +from error import GitError, ImportError, UploadError +from remote import Remote +from codereview import proto_client + +HEAD = 'HEAD' +R_HEADS = 'refs/heads/' +R_TAGS = 'refs/tags/' +R_PUB = 'refs/published/' +R_M = 'refs/remotes/m/' + +def _warn(fmt, *args): + msg = fmt % args + print >>sys.stderr, 'warn: %s' % msg + +def _info(fmt, *args): + msg = fmt % args + print >>sys.stderr, 'info: %s' % msg + +def not_rev(r): + return '^' + r + +class ReviewableBranch(object): + _commit_cache = None + + def __init__(self, project, branch, base): + self.project = project + self.branch = branch + self.base = base + + @property + def name(self): + return self.branch.name + + @property + def commits(self): + if self._commit_cache is None: + self._commit_cache = self.project.bare_git.rev_list( + '--abbrev=8', + '--abbrev-commit', + '--pretty=oneline', + '--reverse', + '--date-order', + not_rev(self.base), + R_HEADS + self.name, + '--') + return self._commit_cache + + @property + def date(self): + return self.project.bare_git.log( + '--pretty=format:%cd', + '-n', '1', + R_HEADS + self.name, + '--') + + def UploadForReview(self): + self.project.UploadForReview(self.name) + + @property + def tip_url(self): + me = self.project.GetBranch(self.name) + commit = self.project.bare_git.rev_parse(R_HEADS + self.name) + return 'http://%s/r/%s' % (me.remote.review, commit[0:12]) + + +class StatusColoring(Coloring): + def __init__(self, config): + Coloring.__init__(self, config, 'status') + self.project = self.printer('header', attr = 'bold') + self.branch = self.printer('header', attr = 'bold') + self.nobranch = self.printer('nobranch', fg = 'red') + + self.added = self.printer('added', fg = 'green') + self.changed = self.printer('changed', fg = 'red') + self.untracked = self.printer('untracked', fg = 'red') + + +class DiffColoring(Coloring): + def __init__(self, config): + Coloring.__init__(self, config, 'diff') + self.project = self.printer('header', attr = 'bold') + + +class _CopyFile: + def __init__(self, src, dest): + self.src = src + self.dest = dest + + def _Copy(self): + src = self.src + dest = self.dest + # copy file if it does not exist or is out of date + if not os.path.exists(dest) or not filecmp.cmp(src, dest): + try: + # remove existing file first, since it might be read-only + if os.path.exists(dest): + os.remove(dest) + shutil.copy(src, dest) + # make the file read-only + mode = os.stat(dest)[stat.ST_MODE] + mode = mode & ~(stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH) + os.chmod(dest, mode) + except IOError: + print >>sys.stderr, \ + 'error: Cannot copy file %s to %s' \ + % (src, dest) + + +class Project(object): + def __init__(self, + manifest, + name, + remote, + gitdir, + worktree, + relpath, + revision): + self.manifest = manifest + self.name = name + self.remote = remote + self.gitdir = gitdir + self.worktree = worktree + self.relpath = relpath + self.revision = revision + self.snapshots = {} + self.extraRemotes = {} + self.copyfiles = [] + self.config = GitConfig.ForRepository( + gitdir = self.gitdir, + defaults = self.manifest.globalConfig) + + self.work_git = self._GitGetByExec(self, bare=False) + self.bare_git = self._GitGetByExec(self, bare=True) + + @property + def Exists(self): + return os.path.isdir(self.gitdir) + + @property + def CurrentBranch(self): + """Obtain the name of the currently checked out branch. + The branch name omits the 'refs/heads/' prefix. + None is returned if the project is on a detached HEAD. + """ + try: + b = self.work_git.GetHead() + except GitError: + return None + if b.startswith(R_HEADS): + return b[len(R_HEADS):] + return None + + def IsDirty(self, consider_untracked=True): + """Is the working directory modified in some way? + """ + self.work_git.update_index('-q', + '--unmerged', + '--ignore-missing', + '--refresh') + if self.work_git.DiffZ('diff-index','-M','--cached',HEAD): + return True + if self.work_git.DiffZ('diff-files'): + return True + if consider_untracked and self.work_git.LsOthers(): + return True + return False + + _userident_name = None + _userident_email = None + + @property + def UserName(self): + """Obtain the user's personal name. + """ + if self._userident_name is None: + self._LoadUserIdentity() + return self._userident_name + + @property + def UserEmail(self): + """Obtain the user's email address. This is very likely + to be their Gerrit login. + """ + if self._userident_email is None: + self._LoadUserIdentity() + return self._userident_email + + def _LoadUserIdentity(self): + u = self.bare_git.var('GIT_COMMITTER_IDENT') + m = re.compile("^(.*) <([^>]*)> ").match(u) + if m: + self._userident_name = m.group(1) + self._userident_email = m.group(2) + else: + self._userident_name = '' + self._userident_email = '' + + def GetRemote(self, name): + """Get the configuration for a single remote. + """ + return self.config.GetRemote(name) + + def GetBranch(self, name): + """Get the configuration for a single branch. + """ + return self.config.GetBranch(name) + + +## Status Display ## + + def PrintWorkTreeStatus(self): + """Prints the status of the repository to stdout. + """ + if not os.path.isdir(self.worktree): + print '' + print 'project %s/' % self.relpath + print ' missing (run "repo sync")' + return + + self.work_git.update_index('-q', + '--unmerged', + '--ignore-missing', + '--refresh') + di = self.work_git.DiffZ('diff-index', '-M', '--cached', HEAD) + df = self.work_git.DiffZ('diff-files') + do = self.work_git.LsOthers() + if not di and not df and not do: + return + + out = StatusColoring(self.config) + out.project('project %-40s', self.relpath + '/') + + branch = self.CurrentBranch + if branch is None: + out.nobranch('(*** NO BRANCH ***)') + else: + out.branch('branch %s', branch) + out.nl() + + paths = list() + paths.extend(di.keys()) + paths.extend(df.keys()) + paths.extend(do) + + paths = list(set(paths)) + paths.sort() + + for p in paths: + try: i = di[p] + except KeyError: i = None + + try: f = df[p] + except KeyError: f = None + + if i: i_status = i.status.upper() + else: i_status = '-' + + if f: f_status = f.status.lower() + else: f_status = '-' + + if i and i.src_path: + line = ' %s%s\t%s => (%s%%)' % (i_status, f_status, + i.src_path, p, i.level) + else: + line = ' %s%s\t%s' % (i_status, f_status, p) + + if i and not f: + out.added('%s', line) + elif (i and f) or (not i and f): + out.changed('%s', line) + elif not i and not f: + out.untracked('%s', line) + else: + out.write('%s', line) + out.nl() + + def PrintWorkTreeDiff(self): + """Prints the status of the repository to stdout. + """ + out = DiffColoring(self.config) + cmd = ['diff'] + if out.is_on: + cmd.append('--color') + cmd.append(HEAD) + cmd.append('--') + p = GitCommand(self, + cmd, + capture_stdout = True, + capture_stderr = True) + has_diff = False + for line in p.process.stdout: + if not has_diff: + out.nl() + out.project('project %s/' % self.relpath) + out.nl() + has_diff = True + print line[:-1] + p.Wait() + + +## Publish / Upload ## + + def WasPublished(self, branch): + """Was the branch published (uploaded) for code review? + If so, returns the SHA-1 hash of the last published + state for the branch. + """ + try: + return self.bare_git.rev_parse(R_PUB + branch) + except GitError: + return None + + def CleanPublishedCache(self): + """Prunes any stale published refs. + """ + heads = set() + canrm = {} + for name, id in self._allrefs.iteritems(): + if name.startswith(R_HEADS): + heads.add(name) + elif name.startswith(R_PUB): + canrm[name] = id + + for name, id in canrm.iteritems(): + n = name[len(R_PUB):] + if R_HEADS + n not in heads: + self.bare_git.DeleteRef(name, id) + + def GetUploadableBranches(self): + """List any branches which can be uploaded for review. + """ + heads = {} + pubed = {} + + for name, id in self._allrefs.iteritems(): + if name.startswith(R_HEADS): + heads[name[len(R_HEADS):]] = id + elif name.startswith(R_PUB): + pubed[name[len(R_PUB):]] = id + + ready = [] + for branch, id in heads.iteritems(): + if branch in pubed and pubed[branch] == id: + continue + + branch = self.GetBranch(branch) + base = branch.LocalMerge + if branch.LocalMerge: + rb = ReviewableBranch(self, branch, base) + if rb.commits: + ready.append(rb) + return ready + + def UploadForReview(self, branch=None): + """Uploads the named branch for code review. + """ + if branch is None: + branch = self.CurrentBranch + if branch is None: + raise GitError('not currently on a branch') + + branch = self.GetBranch(branch) + if not branch.LocalMerge: + raise GitError('branch %s does not track a remote' % branch.name) + if not branch.remote.review: + raise GitError('remote %s has no review url' % branch.remote.name) + + dest_branch = branch.merge + if not dest_branch.startswith(R_HEADS): + dest_branch = R_HEADS + dest_branch + + base_list = [] + for name, id in self._allrefs.iteritems(): + if branch.remote.WritesTo(name): + base_list.append(not_rev(name)) + if not base_list: + raise GitError('no base refs, cannot upload %s' % branch.name) + + print >>sys.stderr, '' + _info("Uploading %s to %s:", branch.name, self.name) + try: + UploadBundle(project = self, + server = branch.remote.review, + email = self.UserEmail, + dest_project = self.name, + dest_branch = dest_branch, + src_branch = R_HEADS + branch.name, + bases = base_list) + except proto_client.ClientLoginError: + raise UploadError('Login failure') + except urllib2.HTTPError, e: + raise UploadError('HTTP error %d' % e.code) + + msg = "posted to %s for %s" % (branch.remote.review, dest_branch) + self.bare_git.UpdateRef(R_PUB + branch.name, + R_HEADS + branch.name, + message = msg) + + +## Sync ## + + def Sync_NetworkHalf(self): + """Perform only the network IO portion of the sync process. + Local working directory/branch state is not affected. + """ + if not self.Exists: + print >>sys.stderr + print >>sys.stderr, 'Initializing project %s ...' % self.name + self._InitGitDir() + self._InitRemote() + for r in self.extraRemotes.values(): + if not self._RemoteFetch(r.name): + return False + if not self._SnapshotDownload(): + return False + if not self._RemoteFetch(): + return False + self._InitMRef() + return True + + def _CopyFiles(self): + for file in self.copyfiles: + file._Copy() + + def Sync_LocalHalf(self): + """Perform only the local IO portion of the sync process. + Network access is not required. + + Return: + True: the sync was successful + False: the sync requires user input + """ + self._InitWorkTree() + self.CleanPublishedCache() + + rem = self.GetRemote(self.remote.name) + rev = rem.ToLocal(self.revision) + branch = self.CurrentBranch + + if branch is None: + # Currently on a detached HEAD. The user is assumed to + # not have any local modifications worth worrying about. + # + lost = self._revlist(not_rev(rev), HEAD) + if lost: + _info("[%s] Discarding %d commits", self.name, len(lost)) + try: + self._Checkout(rev, quiet=True) + except GitError: + return False + self._CopyFiles() + return True + + branch = self.GetBranch(branch) + merge = branch.LocalMerge + + if not merge: + # The current branch has no tracking configuration. + # Jump off it to a deatched HEAD. + # + _info("[%s] Leaving %s" + " (does not track any upstream)", + self.name, + branch.name) + try: + self._Checkout(rev, quiet=True) + except GitError: + return False + self._CopyFiles() + return True + + upstream_gain = self._revlist(not_rev(HEAD), rev) + pub = self.WasPublished(branch.name) + if pub: + not_merged = self._revlist(not_rev(rev), pub) + if not_merged: + if upstream_gain: + # The user has published this branch and some of those + # commits are not yet merged upstream. We do not want + # to rewrite the published commits so we punt. + # + _info("[%s] Branch %s is published," + " but is now %d commits behind.", + self.name, branch.name, len(upstream_gain)) + _info("[%s] Consider merging or rebasing the" + " unpublished commits.", self.name) + return True + + if merge == rev: + try: + old_merge = self.bare_git.rev_parse('%s@{1}' % merge) + except GitError: + old_merge = merge + else: + # The upstream switched on us. Time to cross our fingers + # and pray that the old upstream also wasn't in the habit + # of rebasing itself. + # + _info("[%s] Manifest switched from %s to %s", + self.name, merge, rev) + old_merge = merge + + if rev == old_merge: + upstream_lost = [] + else: + upstream_lost = self._revlist(not_rev(rev), old_merge) + + if not upstream_lost and not upstream_gain: + # Trivially no changes caused by the upstream. + # + return True + + if self.IsDirty(consider_untracked=False): + _warn('[%s] commit (or discard) uncommitted changes' + ' before sync', self.name) + return False + + if upstream_lost: + # Upstream rebased. Not everything in HEAD + # may have been caused by the user. + # + _info("[%s] Discarding %d commits removed from upstream", + self.name, len(upstream_lost)) + + branch.remote = rem + branch.merge = self.revision + branch.Save() + + my_changes = self._revlist(not_rev(old_merge), HEAD) + if my_changes: + try: + self._Rebase(upstream = old_merge, onto = rev) + except GitError: + return False + elif upstream_lost: + try: + self._ResetHard(rev) + except GitError: + return False + else: + try: + self._FastForward(rev) + except GitError: + return False + + self._CopyFiles() + return True + + def _SnapshotDownload(self): + if self.snapshots: + have = set(self._allrefs.keys()) + need = [] + + for tag, sn in self.snapshots.iteritems(): + if tag not in have: + need.append(sn) + + if need: + print >>sys.stderr, """ + *** Downloading source(s) from a mirror site. *** + *** If the network hangs, kill and restart repo. *** +""" + for sn in need: + try: + sn.Import() + except ImportError, e: + print >>sys.stderr, \ + 'error: Cannot import %s: %s' \ + % (self.name, e) + return False + cmd = ['repack', '-a', '-d', '-f', '-l'] + if GitCommand(self, cmd, bare = True).Wait() != 0: + return False + return True + + def AddCopyFile(self, src, dest): + # dest should already be an absolute path, but src is project relative + # make src an absolute path + src = os.path.join(self.worktree, src) + self.copyfiles.append(_CopyFile(src, dest)) + + +## Branch Management ## + + def StartBranch(self, name): + """Create a new branch off the manifest's revision. + """ + branch = self.GetBranch(name) + branch.remote = self.GetRemote(self.remote.name) + branch.merge = self.revision + + rev = branch.LocalMerge + cmd = ['checkout', '-b', branch.name, rev] + if GitCommand(self, cmd).Wait() == 0: + branch.Save() + else: + raise GitError('%s checkout %s ' % (self.name, rev)) + + def PruneHeads(self): + """Prune any topic branches already merged into upstream. + """ + cb = self.CurrentBranch + kill = [] + for name in self._allrefs.keys(): + if name.startswith(R_HEADS): + name = name[len(R_HEADS):] + if cb is None or name != cb: + kill.append(name) + + rev = self.GetRemote(self.remote.name).ToLocal(self.revision) + if cb is not None \ + and not self._revlist(HEAD + '...' + rev) \ + and not self.IsDirty(consider_untracked = False): + self.work_git.DetachHead(HEAD) + kill.append(cb) + + deleted = set() + if kill: + try: + old = self.bare_git.GetHead() + except GitError: + old = 'refs/heads/please_never_use_this_as_a_branch_name' + + rm_re = re.compile(r"^Deleted branch (.*)\.$") + try: + self.bare_git.DetachHead(rev) + + b = ['branch', '-d'] + b.extend(kill) + b = GitCommand(self, b, bare=True, + capture_stdout=True, + capture_stderr=True) + b.Wait() + finally: + self.bare_git.SetHead(old) + + for line in b.stdout.split("\n"): + m = rm_re.match(line) + if m: + deleted.add(m.group(1)) + + if deleted: + self.CleanPublishedCache() + + if cb and cb not in kill: + kill.append(cb) + kill.sort() + + kept = [] + for branch in kill: + if branch not in deleted: + branch = self.GetBranch(branch) + base = branch.LocalMerge + if not base: + base = rev + kept.append(ReviewableBranch(self, branch, base)) + return kept + + +## Direct Git Commands ## + + def _RemoteFetch(self, name=None): + if not name: + name = self.remote.name + + hide_errors = False + if self.extraRemotes or self.snapshots: + hide_errors = True + + proc = GitCommand(self, + ['fetch', name], + bare = True, + capture_stderr = hide_errors) + if hide_errors: + err = proc.process.stderr.fileno() + buf = '' + while True: + b = os.read(err, 256) + if b: + buf += b + while buf: + r = buf.find('remote: error: unable to find ') + if r >= 0: + lf = buf.find('\n') + if lf < 0: + break + buf = buf[lf + 1:] + continue + + cr = buf.find('\r') + if cr < 0: + break + os.write(2, buf[0:cr + 1]) + buf = buf[cr + 1:] + if not b: + if buf: + os.write(2, buf) + break + return proc.Wait() == 0 + + def _Checkout(self, rev, quiet=False): + cmd = ['checkout'] + if quiet: + cmd.append('-q') + cmd.append(rev) + cmd.append('--') + if GitCommand(self, cmd).Wait() != 0: + if self._allrefs: + raise GitError('%s checkout %s ' % (self.name, rev)) + + def _ResetHard(self, rev, quiet=True): + cmd = ['reset', '--hard'] + if quiet: + cmd.append('-q') + cmd.append(rev) + if GitCommand(self, cmd).Wait() != 0: + raise GitError('%s reset --hard %s ' % (self.name, rev)) + + def _Rebase(self, upstream, onto = None): + cmd = ['rebase', '-i'] + if onto is not None: + cmd.extend(['--onto', onto]) + cmd.append(upstream) + if GitCommand(self, cmd, disable_editor=True).Wait() != 0: + raise GitError('%s rebase %s ' % (self.name, upstream)) + + def _FastForward(self, head): + cmd = ['merge', head] + if GitCommand(self, cmd).Wait() != 0: + raise GitError('%s merge %s ' % (self.name, head)) + + def _InitGitDir(self): + if not os.path.exists(self.gitdir): + os.makedirs(self.gitdir) + self.bare_git.init() + self.config.SetString('core.bare', None) + + hooks = self._gitdir_path('hooks') + for old_hook in os.listdir(hooks): + os.remove(os.path.join(hooks, old_hook)) + + # TODO(sop) install custom repo hooks + + m = self.manifest.manifestProject.config + for key in ['user.name', 'user.email']: + if m.Has(key, include_defaults = False): + self.config.SetString(key, m.GetString(key)) + + def _InitRemote(self): + if self.remote.fetchUrl: + remote = self.GetRemote(self.remote.name) + + url = self.remote.fetchUrl + while url.endswith('/'): + url = url[:-1] + url += '/%s.git' % self.name + remote.url = url + remote.review = self.remote.reviewUrl + + remote.ResetFetch() + remote.Save() + + for r in self.extraRemotes.values(): + remote = self.GetRemote(r.name) + remote.url = r.fetchUrl + remote.review = r.reviewUrl + remote.ResetFetch() + remote.Save() + + def _InitMRef(self): + if self.manifest.branch: + msg = 'manifest set to %s' % self.revision + ref = R_M + self.manifest.branch + + if IsId(self.revision): + dst = self.revision + '^0', + self.bare_git.UpdateRef(ref, dst, message = msg, detach = True) + else: + remote = self.GetRemote(self.remote.name) + dst = remote.ToLocal(self.revision) + self.bare_git.symbolic_ref('-m', msg, ref, dst) + + def _InitWorkTree(self): + dotgit = os.path.join(self.worktree, '.git') + if not os.path.exists(dotgit): + os.makedirs(dotgit) + + topdir = os.path.commonprefix([self.gitdir, dotgit]) + if topdir.endswith('/'): + topdir = topdir[:-1] + else: + topdir = os.path.dirname(topdir) + + tmpdir = dotgit + relgit = '' + while topdir != tmpdir: + relgit += '../' + tmpdir = os.path.dirname(tmpdir) + relgit += self.gitdir[len(topdir) + 1:] + + for name in ['config', + 'description', + 'hooks', + 'info', + 'logs', + 'objects', + 'packed-refs', + 'refs', + 'rr-cache', + 'svn']: + os.symlink(os.path.join(relgit, name), + os.path.join(dotgit, name)) + + rev = self.GetRemote(self.remote.name).ToLocal(self.revision) + rev = self.bare_git.rev_parse('%s^0' % rev) + + f = open(os.path.join(dotgit, HEAD), 'wb') + f.write("%s\n" % rev) + f.close() + + cmd = ['read-tree', '--reset', '-u'] + cmd.append('-v') + cmd.append('HEAD') + if GitCommand(self, cmd).Wait() != 0: + raise GitError("cannot initialize work tree") + + def _gitdir_path(self, path): + return os.path.join(self.gitdir, path) + + def _revlist(self, *args): + cmd = [] + cmd.extend(args) + cmd.append('--') + return self.work_git.rev_list(*args) + + @property + def _allrefs(self): + return self.bare_git.ListRefs() + + class _GitGetByExec(object): + def __init__(self, project, bare): + self._project = project + self._bare = bare + + def ListRefs(self, *args): + cmdv = ['for-each-ref', '--format=%(objectname) %(refname)'] + cmdv.extend(args) + p = GitCommand(self._project, + cmdv, + bare = self._bare, + capture_stdout = True, + capture_stderr = True) + r = {} + for line in p.process.stdout: + id, name = line[:-1].split(' ', 2) + r[name] = id + if p.Wait() != 0: + raise GitError('%s for-each-ref %s: %s' % ( + self._project.name, + str(args), + p.stderr)) + return r + + def LsOthers(self): + p = GitCommand(self._project, + ['ls-files', + '-z', + '--others', + '--exclude-standard'], + bare = False, + capture_stdout = True, + capture_stderr = True) + if p.Wait() == 0: + out = p.stdout + if out: + return out[:-1].split("\0") + return [] + + def DiffZ(self, name, *args): + cmd = [name] + cmd.append('-z') + cmd.extend(args) + p = GitCommand(self._project, + cmd, + bare = False, + capture_stdout = True, + capture_stderr = True) + try: + out = p.process.stdout.read() + r = {} + if out: + out = iter(out[:-1].split('\0')) + while out: + info = out.next() + path = out.next() + + class _Info(object): + def __init__(self, path, omode, nmode, oid, nid, state): + self.path = path + self.src_path = None + self.old_mode = omode + self.new_mode = nmode + self.old_id = oid + self.new_id = nid + + if len(state) == 1: + self.status = state + self.level = None + else: + self.status = state[:1] + self.level = state[1:] + while self.level.startswith('0'): + self.level = self.level[1:] + + info = info[1:].split(' ') + info =_Info(path, *info) + if info.status in ('R', 'C'): + info.src_path = info.path + info.path = out.next() + r[info.path] = info + return r + finally: + p.Wait() + + def GetHead(self): + return self.symbolic_ref(HEAD) + + def SetHead(self, ref, message=None): + cmdv = [] + if message is not None: + cmdv.extend(['-m', message]) + cmdv.append(HEAD) + cmdv.append(ref) + self.symbolic_ref(*cmdv) + + def DetachHead(self, new, message=None): + cmdv = ['--no-deref'] + if message is not None: + cmdv.extend(['-m', message]) + cmdv.append(HEAD) + cmdv.append(new) + self.update_ref(*cmdv) + + def UpdateRef(self, name, new, old=None, + message=None, + detach=False): + cmdv = [] + if message is not None: + cmdv.extend(['-m', message]) + if detach: + cmdv.append('--no-deref') + cmdv.append(name) + cmdv.append(new) + if old is not None: + cmdv.append(old) + self.update_ref(*cmdv) + + def DeleteRef(self, name, old=None): + if not old: + old = self.rev_parse(name) + self.update_ref('-d', name, old) + + def rev_list(self, *args): + cmdv = ['rev-list'] + cmdv.extend(args) + p = GitCommand(self._project, + cmdv, + bare = self._bare, + capture_stdout = True, + capture_stderr = True) + r = [] + for line in p.process.stdout: + r.append(line[:-1]) + if p.Wait() != 0: + raise GitError('%s rev-list %s: %s' % ( + self._project.name, + str(args), + p.stderr)) + return r + + def __getattr__(self, name): + name = name.replace('_', '-') + def runner(*args): + cmdv = [name] + cmdv.extend(args) + p = GitCommand(self._project, + cmdv, + bare = self._bare, + capture_stdout = True, + capture_stderr = True) + if p.Wait() != 0: + raise GitError('%s %s: %s' % ( + self._project.name, + name, + p.stderr)) + r = p.stdout + if r.endswith('\n') and r.index('\n') == len(r) - 1: + return r[:-1] + return r + return runner + + +class MetaProject(Project): + """A special project housed under .repo. + """ + def __init__(self, manifest, name, gitdir, worktree): + repodir = manifest.repodir + Project.__init__(self, + manifest = manifest, + name = name, + gitdir = gitdir, + worktree = worktree, + remote = Remote('origin'), + relpath = '.repo/%s' % name, + revision = 'refs/heads/master') + + def PreSync(self): + if self.Exists: + cb = self.CurrentBranch + if cb: + base = self.GetBranch(cb).merge + if base: + self.revision = base + + @property + def HasChanges(self): + """Has the remote received new commits not yet checked out? + """ + rev = self.GetRemote(self.remote.name).ToLocal(self.revision) + if self._revlist(not_rev(HEAD), rev): + return True + return False diff --git a/remote.py b/remote.py new file mode 100644 index 00000000..27a8f7a7 --- /dev/null +++ b/remote.py @@ -0,0 +1,21 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class Remote(object): + def __init__(self, name, fetch=None, review=None): + self.name = name + self.fetchUrl = fetch + self.reviewUrl = review + self.requiredCommits = [] diff --git a/repo b/repo new file mode 100755 index 00000000..d5f69fb2 --- /dev/null +++ b/repo @@ -0,0 +1,587 @@ +#!/bin/sh + +## repo default configuration +## +REPO_URL='git://android.kernel.org/tools/repo.git' +REPO_REV='stable' + +# Copyright (C) 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +magic='--calling-python-from-/bin/sh--' +"""exec" python2.4 -E "$0" "$@" """#$magic" +if __name__ == '__main__': + import sys + if sys.argv[-1] == '#%s' % magic: + del sys.argv[-1] +del magic + +# increment this whenever we make important changes to this script +VERSION = (1, 4) + +# increment this if the MAINTAINER_KEYS block is modified +KEYRING_VERSION = (1,0) +MAINTAINER_KEYS = """ + + Repo Maintainer +-----BEGIN PGP PUBLIC KEY BLOCK----- +Version: GnuPG v1.4.2.2 (GNU/Linux) + +mQGiBEj3ugERBACrLJh/ZPyVSKeClMuznFIrsQ+hpNnmJGw1a9GXKYKk8qHPhAZf +WKtrBqAVMNRLhL85oSlekRz98u41H5si5zcuv+IXJDF5MJYcB8f22wAy15lUqPWi +VCkk1l8qqLiuW0fo+ZkPY5qOgrvc0HW1SmdH649uNwqCbcKb6CxaTxzhOwCgj3AP +xI1WfzLqdJjsm1Nq98L0cLcD/iNsILCuw44PRds3J75YP0pze7YF/6WFMB6QSFGu +aUX1FsTTztKNXGms8i5b2l1B8JaLRWq/jOnZzyl1zrUJhkc0JgyZW5oNLGyWGhKD +Fxp5YpHuIuMImopWEMFIRQNrvlg+YVK8t3FpdI1RY0LYqha8pPzANhEYgSfoVzOb +fbfbA/4ioOrxy8ifSoga7ITyZMA+XbW8bx33WXutO9N7SPKS/AK2JpasSEVLZcON +ae5hvAEGVXKxVPDjJBmIc2cOe7kOKSi3OxLzBqrjS2rnjiP4o0ekhZIe4+ocwVOg +e0PLlH5avCqihGRhpoqDRsmpzSHzJIxtoeb+GgGEX8KkUsVAhbQpUmVwbyBNYWlu +dGFpbmVyIDxyZXBvQGFuZHJvaWQua2VybmVsLm9yZz6IYAQTEQIAIAUCSPe6AQIb +AwYLCQgHAwIEFQIIAwQWAgMBAh4BAheAAAoJEBZTDV6SD1xl1GEAn0x/OKQpy7qI +6G73NJviU0IUMtftAKCFMUhGb/0bZvQ8Rm3QCUpWHyEIu7kEDQRI97ogEBAA2wI6 +5fs9y/rMwD6dkD/vK9v4C9mOn1IL5JCPYMJBVSci+9ED4ChzYvfq7wOcj9qIvaE0 +GwCt2ar7Q56me5J+byhSb32Rqsw/r3Vo5cZMH80N4cjesGuSXOGyEWTe4HYoxnHv +gF4EKI2LK7xfTUcxMtlyn52sUpkfKsCpUhFvdmbAiJE+jCkQZr1Z8u2KphV79Ou+ +P1N5IXY/XWOlq48Qf4MWCYlJFrB07xjUjLKMPDNDnm58L5byDrP/eHysKexpbakL +xCmYyfT6DV1SWLblpd2hie0sL3YejdtuBMYMS2rI7Yxb8kGuqkz+9l1qhwJtei94 +5MaretDy/d/JH/pRYkRf7L+ke7dpzrP+aJmcz9P1e6gq4NJsWejaALVASBiioqNf +QmtqSVzF1wkR5avZkFHuYvj6V/t1RrOZTXxkSk18KFMJRBZrdHFCWbc5qrVxUB6e +N5pja0NFIUCigLBV1c6I2DwiuboMNh18VtJJh+nwWeez/RueN4ig59gRTtkcc0PR +35tX2DR8+xCCFVW/NcJ4PSePYzCuuLvp1vEDHnj41R52Fz51hgddT4rBsp0nL+5I +socSOIIezw8T9vVzMY4ArCKFAVu2IVyBcahTfBS8q5EM63mONU6UVJEozfGljiMw +xuQ7JwKcw0AUEKTKG7aBgBaTAgT8TOevpvlw91cAAwUP/jRkyVi/0WAb0qlEaq/S +ouWxX1faR+vU3b+Y2/DGjtXQMzG0qpetaTHC/AxxHpgt/dCkWI6ljYDnxgPLwG0a +Oasm94BjZc6vZwf1opFZUKsjOAAxRxNZyjUJKe4UZVuMTk6zo27Nt3LMnc0FO47v +FcOjRyquvgNOS818irVHUf12waDx8gszKxQTTtFxU5/ePB2jZmhP6oXSe4K/LG5T ++WBRPDrHiGPhCzJRzm9BP0lTnGCAj3o9W90STZa65RK7IaYpC8TB35JTBEbrrNCp +w6lzd74LnNEp5eMlKDnXzUAgAH0yzCQeMl7t33QCdYx2hRs2wtTQSjGfAiNmj/WW +Vl5Jn+2jCDnRLenKHwVRFsBX2e0BiRWt/i9Y8fjorLCXVj4z+7yW6DawdLkJorEo +p3v5ILwfC7hVx4jHSnOgZ65L9s8EQdVr1ckN9243yta7rNgwfcqb60ILMFF1BRk/ +0V7wCL+68UwwiQDvyMOQuqkysKLSDCLb7BFcyA7j6KG+5hpsREstFX2wK1yKeraz +5xGrFy8tfAaeBMIQ17gvFSp/suc9DYO0ICK2BISzq+F+ZiAKsjMYOBNdH/h0zobQ +HTHs37+/QLMomGEGKZMWi0dShU2J5mNRQu3Hhxl3hHDVbt5CeJBb26aQcQrFz69W +zE3GNvmJosh6leayjtI9P2A6iEkEGBECAAkFAkj3uiACGwwACgkQFlMNXpIPXGWp +TACbBS+Up3RpfYVfd63c1cDdlru13pQAn3NQy/SN858MkxN+zym86UBgOad2 +=CMiZ +-----END PGP PUBLIC KEY BLOCK----- +""" + +GIT = 'git' # our git command +MIN_GIT_VERSION = (1, 5, 4) # minimum supported git version +repodir = '.repo' # name of repo's private directory +S_repo = 'repo' # special repo reposiory +S_manifests = 'manifests' # special manifest repository +REPO_MAIN = S_repo + '/main.py' # main script + + +import optparse +import os +import re +import readline +import subprocess +import sys + +home_dot_repo = os.path.expanduser('~/.repoconfig') +gpg_dir = os.path.join(home_dot_repo, 'gnupg') + +extra_args = [] +init_optparse = optparse.OptionParser(usage="repo init -u url [options]") + +# Logging +group = init_optparse.add_option_group('Logging options') +group.add_option('-q', '--quiet', + dest="quiet", action="store_true", default=False, + help="be quiet") + +# Manifest +group = init_optparse.add_option_group('Manifest options') +group.add_option('-u', '--manifest-url', + dest='manifest_url', + help='manifest repository location', metavar='URL') +group.add_option('-b', '--manifest-branch', + dest='manifest_branch', + help='manifest branch or revision', metavar='REVISION') +group.add_option('-m', '--manifest-name', + dest='manifest_name', + help='initial manifest file', metavar='NAME.xml') + +# Tool +group = init_optparse.add_option_group('Version options') +group.add_option('--repo-url', + dest='repo_url', + help='repo repository location', metavar='URL') +group.add_option('--repo-branch', + dest='repo_branch', + help='repo branch or revision', metavar='REVISION') +group.add_option('--no-repo-verify', + dest='no_repo_verify', action='store_true', + help='do not verify repo source code') + + +class CloneFailure(Exception): + """Indicate the remote clone of repo itself failed. + """ + + +def _Init(args): + """Installs repo by cloning it over the network. + """ + opt, args = init_optparse.parse_args(args) + if args or not opt.manifest_url: + init_optparse.print_usage() + sys.exit(1) + + url = opt.repo_url + if not url: + url = REPO_URL + extra_args.append('--repo-url=%s' % url) + + branch = opt.repo_branch + if not branch: + branch = REPO_REV + extra_args.append('--repo-branch=%s' % branch) + + if branch.startswith('refs/heads/'): + branch = branch[len('refs/heads/'):] + if branch.startswith('refs/'): + print >>sys.stderr, "fatal: invalid branch name '%s'" % branch + raise CloneFailure() + + if not os.path.isdir(repodir): + try: + os.mkdir(repodir) + except OSError, e: + print >>sys.stderr, \ + 'fatal: cannot make %s directory: %s' % ( + repodir, e.strerror) + # Don't faise CloneFailure; that would delete the + # name. Instead exit immediately. + # + sys.exit(1) + + _CheckGitVersion() + try: + if _NeedSetupGnuPG(): + can_verify = _SetupGnuPG(opt.quiet) + else: + can_verify = True + + if not opt.quiet: + print >>sys.stderr, 'Getting repo ...' + print >>sys.stderr, ' from %s' % url + + dst = os.path.abspath(os.path.join(repodir, S_repo)) + _Clone(url, dst, opt.quiet) + + if can_verify and not opt.no_repo_verify: + rev = _Verify(dst, branch, opt.quiet) + else: + rev = 'refs/remotes/origin/%s^0' % branch + + _Checkout(dst, branch, rev, opt.quiet) + except CloneFailure: + if opt.quiet: + print >>sys.stderr, \ + 'fatal: repo init failed; run without --quiet to see why' + raise + + +def _CheckGitVersion(): + cmd = [GIT, '--version'] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + ver_str = proc.stdout.read().strip() + proc.stdout.close() + + if not ver_str.startswith('git version '): + print >>sys.stderr, 'error: "%s" unsupported' % ver_str + raise CloneFailure() + + ver_str = ver_str[len('git version '):].strip() + ver_act = tuple(map(lambda x: int(x), ver_str.split('.')[0:3])) + if ver_act < MIN_GIT_VERSION: + need = '.'.join(map(lambda x: str(x), MIN_GIT_VERSION)) + print >>sys.stderr, 'fatal: git %s or later required' % need + raise CloneFailure() + + +def _NeedSetupGnuPG(): + if not os.path.isdir(home_dot_repo): + return True + + kv = os.path.join(home_dot_repo, 'keyring-version') + if not os.path.exists(kv): + return True + + kv = open(kv).read() + if not kv: + return True + + kv = tuple(map(lambda x: int(x), kv.split('.'))) + if kv < KEYRING_VERSION: + return True + return False + + +def _SetupGnuPG(quiet): + if not os.path.isdir(home_dot_repo): + try: + os.mkdir(home_dot_repo) + except OSError, e: + print >>sys.stderr, \ + 'fatal: cannot make %s directory: %s' % ( + home_dot_repo, e.strerror) + sys.exit(1) + + if not os.path.isdir(gpg_dir): + try: + os.mkdir(gpg_dir, 0700) + except OSError, e: + print >>sys.stderr, \ + 'fatal: cannot make %s directory: %s' % ( + gpg_dir, e.strerror) + sys.exit(1) + + env = dict(os.environ) + env['GNUPGHOME'] = gpg_dir + + cmd = ['gpg', '--import'] + try: + proc = subprocess.Popen(cmd, + env = env, + stdin = subprocess.PIPE) + except OSError, e: + if not quiet: + print >>sys.stderr, 'warning: gpg (GnuPG) is not available.' + print >>sys.stderr, 'warning: Installing it is strongly encouraged.' + print >>sys.stderr + return False + + proc.stdin.write(MAINTAINER_KEYS) + proc.stdin.close() + + if proc.wait() != 0: + print >>sys.stderr, 'fatal: registering repo maintainer keys failed' + sys.exit(1) + print + + fd = open(os.path.join(home_dot_repo, 'keyring-version'), 'w') + fd.write('.'.join(map(lambda x: str(x), KEYRING_VERSION)) + '\n') + fd.close() + return True + + +def _SetConfig(local, name, value): + """Set a git configuration option to the specified value. + """ + cmd = [GIT, 'config', name, value] + if subprocess.Popen(cmd, cwd = local).wait() != 0: + raise CloneFailure() + + +def _Fetch(local, quiet, *args): + cmd = [GIT, 'fetch'] + if quiet: + cmd.append('--quiet') + err = subprocess.PIPE + else: + err = None + cmd.extend(args) + cmd.append('origin') + + proc = subprocess.Popen(cmd, cwd = local, stderr = err) + if err: + proc.stderr.read() + proc.stderr.close() + if proc.wait() != 0: + raise CloneFailure() + + +def _Clone(url, local, quiet): + """Clones a git repository to a new subdirectory of repodir + """ + try: + os.mkdir(local) + except OSError, e: + print >>sys.stderr, \ + 'fatal: cannot make %s directory: %s' \ + % (local, e.strerror) + raise CloneFailure() + + cmd = [GIT, 'init', '--quiet'] + try: + proc = subprocess.Popen(cmd, cwd = local) + except OSError, e: + print >>sys.stderr + print >>sys.stderr, "fatal: '%s' is not available" % GIT + print >>sys.stderr, 'fatal: %s' % e + print >>sys.stderr + print >>sys.stderr, 'Please make sure %s is installed'\ + ' and in your path.' % GIT + raise CloneFailure() + if proc.wait() != 0: + print >>sys.stderr, 'fatal: could not create %s' % local + raise CloneFailure() + + _SetConfig(local, 'remote.origin.url', url) + _SetConfig(local, 'remote.origin.fetch', + '+refs/heads/*:refs/remotes/origin/*') + _Fetch(local, quiet) + _Fetch(local, quiet, '--tags') + + +def _Verify(cwd, branch, quiet): + """Verify the branch has been signed by a tag. + """ + cmd = [GIT, 'describe', 'origin/%s' % branch] + proc = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd = cwd) + cur = proc.stdout.read().strip() + proc.stdout.close() + + proc.stderr.read() + proc.stderr.close() + + if proc.wait() != 0 or not cur: + print >>sys.stderr + print >>sys.stderr,\ + "fatal: branch '%s' has not been signed" \ + % branch + raise CloneFailure() + + m = re.compile(r'^(.*)-[0-9]{1,}-g[0-9a-f]{1,}$').match(cur) + if m: + cur = m.group(1) + if not quiet: + print >>sys.stderr + print >>sys.stderr, \ + "info: Ignoring branch '%s'; using tagged release '%s'" \ + % (branch, cur) + print >>sys.stderr + + env = dict(os.environ) + env['GNUPGHOME'] = gpg_dir + + cmd = [GIT, 'tag', '-v', cur] + proc = subprocess.Popen(cmd, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + cwd = cwd, + env = env) + out = proc.stdout.read() + proc.stdout.close() + + err = proc.stderr.read() + proc.stderr.close() + + if proc.wait() != 0: + print >>sys.stderr + print >>sys.stderr, out + print >>sys.stderr, err + print >>sys.stderr + raise CloneFailure() + return '%s^0' % cur + + +def _Checkout(cwd, branch, rev, quiet): + """Checkout an upstream branch into the repository and track it. + """ + cmd = [GIT, 'update-ref', 'refs/heads/default', rev] + if subprocess.Popen(cmd, cwd = cwd).wait() != 0: + raise CloneFailure() + + _SetConfig(cwd, 'branch.default.remote', 'origin') + _SetConfig(cwd, 'branch.default.merge', 'refs/heads/%s' % branch) + + cmd = [GIT, 'symbolic-ref', 'HEAD', 'refs/heads/default'] + if subprocess.Popen(cmd, cwd = cwd).wait() != 0: + raise CloneFailure() + + cmd = [GIT, 'read-tree', '--reset', '-u'] + if not quiet: + cmd.append('-v') + cmd.append('HEAD') + if subprocess.Popen(cmd, cwd = cwd).wait() != 0: + raise CloneFailure() + + +def _FindRepo(): + """Look for a repo installation, starting at the current directory. + """ + dir = os.getcwd() + repo = None + + while dir != '/' and not repo: + repo = os.path.join(dir, repodir, REPO_MAIN) + if not os.path.isfile(repo): + repo = None + dir = os.path.dirname(dir) + return (repo, os.path.join(dir, repodir)) + + +class _Options: + help = False + + +def _ParseArguments(args): + cmd = None + opt = _Options() + arg = [] + + for i in xrange(0, len(args)): + a = args[i] + if a == '-h' or a == '--help': + opt.help = True + + elif not a.startswith('-'): + cmd = a + arg = args[i + 1:] + break + return cmd, opt, arg + + +def _Usage(): + print >>sys.stderr,\ +"""usage: repo COMMAND [ARGS] + +repo is not yet installed. Use "repo init" to install it here. + +The most commonly used repo commands are: + + init Install repo in the current working directory + help Display detailed help on a command + +For access to the full online help, install repo ("repo init"). +""" + sys.exit(1) + + +def _Help(args): + if args: + if args[0] == 'init': + init_optparse.print_help() + else: + print >>sys.stderr,\ + "error: '%s' is not a bootstrap command.\n"\ + ' For access to online help, install repo ("repo init").'\ + % args[0] + else: + _Usage() + sys.exit(1) + + +def _NotInstalled(): + print >>sys.stderr,\ +'error: repo is not installed. Use "repo init" to install it here.' + sys.exit(1) + + +def _NoCommands(cmd): + print >>sys.stderr,\ +"""error: command '%s' requires repo to be installed first. + Use "repo init" to install it here.""" % cmd + sys.exit(1) + + +def _RunSelf(wrapper_path): + my_dir = os.path.dirname(wrapper_path) + my_main = os.path.join(my_dir, 'main.py') + my_git = os.path.join(my_dir, '.git') + + if os.path.isfile(my_main) and os.path.isdir(my_git): + for name in ['manifest.py', + 'project.py', + 'subcmds']: + if not os.path.exists(os.path.join(my_dir, name)): + return None, None + return my_main, my_git + return None, None + + +def _SetDefaultsTo(gitdir): + global REPO_URL + global REPO_REV + + REPO_URL = gitdir + proc = subprocess.Popen([GIT, + '--git-dir=%s' % gitdir, + 'symbolic-ref', + 'HEAD'], + stdout = subprocess.PIPE, + stderr = subprocess.PIPE) + REPO_REV = proc.stdout.read().strip() + proc.stdout.close() + + proc.stderr.read() + proc.stderr.close() + + if proc.wait() != 0: + print >>sys.stderr, 'fatal: %s has no current branch' % gitdir + sys.exit(1) + + +def main(orig_args): + main, dir = _FindRepo() + cmd, opt, args = _ParseArguments(orig_args) + + wrapper_path = os.path.abspath(__file__) + my_main, my_git = _RunSelf(wrapper_path) + + if not main: + if opt.help: + _Usage() + if cmd == 'help': + _Help(args) + if not cmd: + _NotInstalled() + if cmd == 'init': + if my_git: + _SetDefaultsTo(my_git) + try: + _Init(args) + except CloneFailure: + for root, dirs, files in os.walk(repodir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(repodir) + sys.exit(1) + main, dir = _FindRepo() + else: + _NoCommands(cmd) + + if my_main: + main = my_main + + ver_str = '.'.join(map(lambda x: str(x), VERSION)) + me = [main, + '--repo-dir=%s' % dir, + '--wrapper-version=%s' % ver_str, + '--wrapper-path=%s' % wrapper_path, + '--'] + me.extend(orig_args) + me.extend(extra_args) + try: + os.execv(main, me) + except OSError, e: + print >>sys.stderr, "fatal: unable to start %s" % main + print >>sys.stderr, "fatal: %s" % e + sys.exit(148) + + +if __name__ == '__main__': + main(sys.argv[1:]) + diff --git a/subcmds/__init__.py b/subcmds/__init__.py new file mode 100644 index 00000000..a2286e78 --- /dev/null +++ b/subcmds/__init__.py @@ -0,0 +1,49 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +all = {} + +my_dir = os.path.dirname(__file__) +for py in os.listdir(my_dir): + if py == '__init__.py': + continue + + if py.endswith('.py'): + name = py[:-3] + + clsn = name.capitalize() + while clsn.find('_') > 0: + h = clsn.index('_') + clsn = clsn[0:h] + clsn[h + 1:].capitalize() + + mod = __import__(__name__, + globals(), + locals(), + ['%s' % name]) + mod = getattr(mod, name) + try: + cmd = getattr(mod, clsn)() + except AttributeError: + raise SyntaxError, '%s/%s does not define class %s' % ( + __name__, py, clsn) + + name = name.replace('_', '-') + cmd.NAME = name + all[name] = cmd + +if 'help' in all: + all['help'].commands = all diff --git a/subcmds/compute_snapshot_check.py b/subcmds/compute_snapshot_check.py new file mode 100644 index 00000000..82db359a --- /dev/null +++ b/subcmds/compute_snapshot_check.py @@ -0,0 +1,169 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile + +from command import Command +from error import GitError, NoSuchProjectError +from git_config import IsId +from import_tar import ImportTar +from import_zip import ImportZip +from project import Project +from remote import Remote + +def _ToCommit(project, rev): + return project.bare_git.rev_parse('--verify', '%s^0' % rev) + +def _Missing(project, rev): + return project._revlist('--objects', rev, '--not', '--all') + + +class ComputeSnapshotCheck(Command): + common = False + helpSummary = "Compute the check value for a new snapshot" + helpUsage = """ +%prog -p NAME -v VERSION -s FILE [options] +""" + helpDescription = """ +%prog computes and then displays the proper check value for a +snapshot, so it can be pasted into the manifest file for a project. +""" + + def _Options(self, p): + g = p.add_option_group('Snapshot description options') + g.add_option('-p', '--project', + dest='project', metavar='NAME', + help='destination project name') + g.add_option('-v', '--version', + dest='version', metavar='VERSION', + help='upstream version/revision identifier') + g.add_option('-s', '--snapshot', + dest='snapshot', metavar='PATH', + help='local tarball path') + g.add_option('--new-project', + dest='new_project', action='store_true', + help='destinition is a new project') + g.add_option('--keep', + dest='keep_git', action='store_true', + help='keep the temporary git repository') + + g = p.add_option_group('Base revision grafting options') + g.add_option('--prior', + dest='prior', metavar='COMMIT', + help='prior revision checksum') + + g = p.add_option_group('Path mangling options') + g.add_option('--strip-prefix', + dest='strip_prefix', metavar='PREFIX', + help='remove prefix from all paths on import') + g.add_option('--insert-prefix', + dest='insert_prefix', metavar='PREFIX', + help='insert prefix before all paths on import') + + + def _Compute(self, opt): + try: + real_project = self.GetProjects([opt.project])[0] + except NoSuchProjectError: + if opt.new_project: + print >>sys.stderr, \ + "warning: project '%s' does not exist" % opt.project + else: + raise NoSuchProjectError(opt.project) + + self._tmpdir = tempfile.mkdtemp() + project = Project(manifest = self.manifest, + name = opt.project, + remote = Remote('origin'), + gitdir = os.path.join(self._tmpdir, '.git'), + worktree = self._tmpdir, + relpath = opt.project, + revision = 'refs/heads/master') + project._InitGitDir() + + url = 'file://%s' % os.path.abspath(opt.snapshot) + + imp = None + for cls in [ImportTar, ImportZip]: + if cls.CanAccept(url): + imp = cls() + break + if not imp: + print >>sys.stderr, 'error: %s unsupported' % opt.snapshot + sys.exit(1) + + imp.SetProject(project) + imp.SetVersion(opt.version) + imp.AddUrl(url) + + if opt.prior: + if opt.new_project: + if not IsId(opt.prior): + print >>sys.stderr, 'error: --prior=%s not valid' % opt.prior + sys.exit(1) + else: + try: + opt.prior = _ToCommit(real_project, opt.prior) + missing = _Missing(real_project, opt.prior) + except GitError, e: + print >>sys.stderr,\ + 'error: --prior=%s not valid\n%s' \ + % (opt.prior, e) + sys.exit(1) + if missing: + print >>sys.stderr,\ + 'error: --prior=%s is valid, but is not reachable' \ + % opt.prior + sys.exit(1) + imp.SetParent(opt.prior) + + src = opt.strip_prefix + dst = opt.insert_prefix + if src or dst: + if src is None: + src = '' + if dst is None: + dst = '' + imp.RemapPath(src, dst) + commitId = imp.Import() + + print >>sys.stderr,"%s\t%s" % (commitId, imp.version) + return project + + def Execute(self, opt, args): + if args \ + or not opt.project \ + or not opt.version \ + or not opt.snapshot: + self.Usage() + + success = False + project = None + try: + self._tmpdir = None + project = self._Compute(opt) + finally: + if project and opt.keep_git: + print 'GIT_DIR = %s' % (project.gitdir) + elif self._tmpdir: + for root, dirs, files in os.walk(self._tmpdir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self._tmpdir) + diff --git a/subcmds/diff.py b/subcmds/diff.py new file mode 100644 index 00000000..e0247140 --- /dev/null +++ b/subcmds/diff.py @@ -0,0 +1,27 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from command import PagedCommand + +class Diff(PagedCommand): + common = True + helpSummary = "Show changes between commit and working tree" + helpUsage = """ +%prog [...] +""" + + def Execute(self, opt, args): + for project in self.GetProjects(args): + project.PrintWorkTreeDiff() diff --git a/subcmds/forall.py b/subcmds/forall.py new file mode 100644 index 00000000..b22e22a1 --- /dev/null +++ b/subcmds/forall.py @@ -0,0 +1,82 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import sys +import subprocess +from command import Command + +class Forall(Command): + common = False + helpSummary = "Run a shell command in each project" + helpUsage = """ +%prog [...] -c [...] +""" + helpDescription = """ +Executes the same shell command in each project. + +Environment +----------- +pwd is the project's working directory. + +REPO_PROJECT is set to the unique name of the project. + +shell positional arguments ($1, $2, .., $#) are set to any arguments +following . + +stdin, stdout, stderr are inherited from the terminal and are +not redirected. +""" + + def _Options(self, p): + def cmd(option, opt_str, value, parser): + setattr(parser.values, option.dest, list(parser.rargs)) + while parser.rargs: + del parser.rargs[0] + p.add_option('-c', '--command', + help='Command (and arguments) to execute', + dest='command', + action='callback', + callback=cmd) + + def Execute(self, opt, args): + if not opt.command: + self.Usage() + + cmd = [opt.command[0]] + + shell = True + if re.compile(r'^[a-z0-9A-Z_/\.-]+$').match(cmd[0]): + shell = False + + if shell: + cmd.append(cmd[0]) + cmd.extend(opt.command[1:]) + + rc = 0 + for project in self.GetProjects(args): + env = dict(os.environ.iteritems()) + env['REPO_PROJECT'] = project.name + + p = subprocess.Popen(cmd, + cwd = project.worktree, + shell = shell, + env = env) + r = p.wait() + if r != 0 and r != rc: + rc = r + if rc != 0: + sys.exit(rc) diff --git a/subcmds/help.py b/subcmds/help.py new file mode 100644 index 00000000..6e0238a0 --- /dev/null +++ b/subcmds/help.py @@ -0,0 +1,147 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from formatter import AbstractFormatter, DumbWriter + +from color import Coloring +from command import PagedCommand + +class Help(PagedCommand): + common = False + helpSummary = "Display detailed help on a command" + helpUsage = """ +%prog [--all|command] +""" + helpDescription = """ +Displays detailed usage information about a command. +""" + + def _PrintAllCommands(self): + print 'usage: repo COMMAND [ARGS]' + print """ +The complete list of recognized repo commands are: +""" + commandNames = self.commands.keys() + commandNames.sort() + + maxlen = 0 + for name in commandNames: + maxlen = max(maxlen, len(name)) + fmt = ' %%-%ds %%s' % maxlen + + for name in commandNames: + command = self.commands[name] + try: + summary = command.helpSummary.strip() + except AttributeError: + summary = '' + print fmt % (name, summary) + print """ +See 'repo help ' for more information on a specific command. +""" + + def _PrintCommonCommands(self): + print 'usage: repo COMMAND [ARGS]' + print """ +The most commonly used repo commands are: +""" + commandNames = [name + for name in self.commands.keys() + if self.commands[name].common] + commandNames.sort() + + maxlen = 0 + for name in commandNames: + maxlen = max(maxlen, len(name)) + fmt = ' %%-%ds %%s' % maxlen + + for name in commandNames: + command = self.commands[name] + try: + summary = command.helpSummary.strip() + except AttributeError: + summary = '' + print fmt % (name, summary) + print """ +See 'repo help ' for more information on a specific command. +""" + + def _PrintCommandHelp(self, cmd): + class _Out(Coloring): + def __init__(self, gc): + Coloring.__init__(self, gc, 'help') + self.heading = self.printer('heading', attr='bold') + + self.wrap = AbstractFormatter(DumbWriter()) + + def _PrintSection(self, heading, bodyAttr): + try: + body = getattr(cmd, bodyAttr) + except AttributeError: + return + + self.nl() + + self.heading('%s', heading) + self.nl() + + self.heading('%s', ''.ljust(len(heading), '-')) + self.nl() + + me = 'repo %s' % cmd.NAME + body = body.strip() + body = body.replace('%prog', me) + + for para in body.split("\n\n"): + if para.startswith(' '): + self.write('%s', para) + self.nl() + self.nl() + else: + self.wrap.add_flowing_data(para) + self.wrap.end_paragraph(1) + self.wrap.end_paragraph(0) + + out = _Out(self.manifest.globalConfig) + cmd.OptionParser.print_help() + out._PrintSection('Summary', 'helpSummary') + out._PrintSection('Description', 'helpDescription') + + def _Options(self, p): + p.add_option('-a', '--all', + dest='show_all', action='store_true', + help='show the complete list of commands') + + def Execute(self, opt, args): + if len(args) == 0: + if opt.show_all: + self._PrintAllCommands() + else: + self._PrintCommonCommands() + + elif len(args) == 1: + name = args[0] + + try: + cmd = self.commands[name] + except KeyError: + print >>sys.stderr, "repo: '%s' is not a repo command." % name + sys.exit(1) + + self._PrintCommandHelp(cmd) + + else: + self._PrintCommandHelp(self) diff --git a/subcmds/init.py b/subcmds/init.py new file mode 100644 index 00000000..03f358d1 --- /dev/null +++ b/subcmds/init.py @@ -0,0 +1,193 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +from color import Coloring +from command import InteractiveCommand +from error import ManifestParseError +from remote import Remote +from git_command import git, MIN_GIT_VERSION + +class Init(InteractiveCommand): + common = True + helpSummary = "Initialize repo in the current directory" + helpUsage = """ +%prog [options] +""" + helpDescription = """ +The '%prog' command is run once to install and initialize repo. +The latest repo source code and manifest collection is downloaded +from the server and is installed in the .repo/ directory in the +current working directory. + +The optional argument can be used to specify an alternate +manifest to be used. If no manifest is specified, the manifest +default.xml will be used. +""" + + def _Options(self, p): + # Logging + g = p.add_option_group('Logging options') + g.add_option('-q', '--quiet', + dest="quiet", action="store_true", default=False, + help="be quiet") + + # Manifest + g = p.add_option_group('Manifest options') + g.add_option('-u', '--manifest-url', + dest='manifest_url', + help='manifest repository location', metavar='URL') + g.add_option('-b', '--manifest-branch', + dest='manifest_branch', + help='manifest branch or revision', metavar='REVISION') + g.add_option('-m', '--manifest-name', + dest='manifest_name', default='default.xml', + help='initial manifest file', metavar='NAME.xml') + + # Tool + g = p.add_option_group('Version options') + g.add_option('--repo-url', + dest='repo_url', + help='repo repository location', metavar='URL') + g.add_option('--repo-branch', + dest='repo_branch', + help='repo branch or revision', metavar='REVISION') + g.add_option('--no-repo-verify', + dest='no_repo_verify', action='store_true', + help='do not verify repo source code') + + def _CheckGitVersion(self): + ver_str = git.version() + if not ver_str.startswith('git version '): + print >>sys.stderr, 'error: "%s" unsupported' % ver_str + sys.exit(1) + + ver_str = ver_str[len('git version '):].strip() + ver_act = tuple(map(lambda x: int(x), ver_str.split('.')[0:3])) + if ver_act < MIN_GIT_VERSION: + need = '.'.join(map(lambda x: str(x), MIN_GIT_VERSION)) + print >>sys.stderr, 'fatal: git %s or later required' % need + sys.exit(1) + + def _SyncManifest(self, opt): + m = self.manifest.manifestProject + + if not m.Exists: + if not opt.manifest_url: + print >>sys.stderr, 'fatal: manifest url (-u) is required.' + sys.exit(1) + + if not opt.quiet: + print >>sys.stderr, 'Getting manifest ...' + print >>sys.stderr, ' from %s' % opt.manifest_url + m._InitGitDir() + + if opt.manifest_branch: + m.revision = opt.manifest_branch + else: + m.revision = 'refs/heads/master' + else: + if opt.manifest_branch: + m.revision = opt.manifest_branch + else: + m.PreSync() + + if opt.manifest_url: + r = m.GetRemote(m.remote.name) + r.url = opt.manifest_url + r.ResetFetch() + r.Save() + + m.Sync_NetworkHalf() + m.Sync_LocalHalf() + m.StartBranch('default') + + def _LinkManifest(self, name): + if not name: + print >>sys.stderr, 'fatal: manifest name (-m) is required.' + sys.exit(1) + + try: + self.manifest.Link(name) + except ManifestParseError, e: + print >>sys.stderr, "fatal: manifest '%s' not available" % name + print >>sys.stderr, 'fatal: %s' % str(e) + sys.exit(1) + + def _PromptKey(self, prompt, key, value): + mp = self.manifest.manifestProject + + sys.stdout.write('%-10s [%s]: ' % (prompt, value)) + a = sys.stdin.readline().strip() + if a != '' and a != value: + mp.config.SetString(key, a) + + def _ConfigureUser(self): + mp = self.manifest.manifestProject + + print '' + self._PromptKey('Your Name', 'user.name', mp.UserName) + self._PromptKey('Your Email', 'user.email', mp.UserEmail) + + def _HasColorSet(self, gc): + for n in ['ui', 'diff', 'status']: + if gc.Has('color.%s' % n): + return True + return False + + def _ConfigureColor(self): + gc = self.manifest.globalConfig + if self._HasColorSet(gc): + return + + class _Test(Coloring): + def __init__(self): + Coloring.__init__(self, gc, 'test color display') + self._on = True + out = _Test() + + print '' + print "Testing colorized output (for 'repo diff', 'repo status'):" + + for c in ['black','red','green','yellow','blue','magenta','cyan']: + out.write(' ') + out.printer(fg=c)(' %-6s ', c) + out.write(' ') + out.printer(fg='white', bg='black')(' %s ' % 'white') + out.nl() + + for c in ['bold','dim','ul','reverse']: + out.write(' ') + out.printer(fg='black', attr=c)(' %-6s ', c) + out.nl() + + sys.stdout.write('Enable color display in this user account (y/n)? ') + a = sys.stdin.readline().strip().lower() + if a in ('y', 'yes', 't', 'true', 'on'): + gc.SetString('color.ui', 'auto') + + def Execute(self, opt, args): + self._CheckGitVersion() + self._SyncManifest(opt) + self._LinkManifest(opt.manifest_name) + + if os.isatty(0) and os.isatty(1): + self._ConfigureUser() + self._ConfigureColor() + + print '' + print 'repo initialized in %s' % self.manifest.topdir diff --git a/subcmds/prune.py b/subcmds/prune.py new file mode 100644 index 00000000..f412bd48 --- /dev/null +++ b/subcmds/prune.py @@ -0,0 +1,59 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from color import Coloring +from command import PagedCommand + +class Prune(PagedCommand): + common = True + helpSummary = "Prune (delete) already merged topics" + helpUsage = """ +%prog [...] +""" + + def Execute(self, opt, args): + all = [] + for project in self.GetProjects(args): + all.extend(project.PruneHeads()) + + if not all: + return + + class Report(Coloring): + def __init__(self, config): + Coloring.__init__(self, config, 'status') + self.project = self.printer('header', attr='bold') + + out = Report(all[0].project.config) + out.project('Pending Branches') + out.nl() + + project = None + + for branch in all: + if project != branch.project: + project = branch.project + out.nl() + out.project('project %s/' % project.relpath) + out.nl() + + commits = branch.commits + date = branch.date + print '%s %-33s (%2d commit%s, %s)' % ( + branch.name == project.CurrentBranch and '*' or ' ', + branch.name, + len(commits), + len(commits) != 1 and 's' or ' ', + date) diff --git a/subcmds/stage.py b/subcmds/stage.py new file mode 100644 index 00000000..c451cd6d --- /dev/null +++ b/subcmds/stage.py @@ -0,0 +1,108 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from color import Coloring +from command import InteractiveCommand +from git_command import GitCommand + +class _ProjectList(Coloring): + def __init__(self, gc): + Coloring.__init__(self, gc, 'interactive') + self.prompt = self.printer('prompt', fg='blue', attr='bold') + self.header = self.printer('header', attr='bold') + self.help = self.printer('help', fg='red', attr='bold') + +class Stage(InteractiveCommand): + common = True + helpSummary = "Stage file(s) for commit" + helpUsage = """ +%prog -i [...] +""" + helpDescription = """ +The '%prog' command stages files to prepare the next commit. +""" + + def _Options(self, p): + p.add_option('-i', '--interactive', + dest='interactive', action='store_true', + help='use interactive staging') + + def Execute(self, opt, args): + if opt.interactive: + self._Interactive(opt, args) + else: + self.Usage() + + def _Interactive(self, opt, args): + all = filter(lambda x: x.IsDirty(), self.GetProjects(args)) + if not all: + print >>sys.stderr,'no projects have uncommitted modifications' + return + + out = _ProjectList(self.manifest.manifestProject.config) + while True: + out.header(' %-20s %s', 'project', 'path') + out.nl() + + for i in xrange(0, len(all)): + p = all[i] + out.write('%3d: %-20s %s', i + 1, p.name, p.relpath + '/') + out.nl() + out.nl() + + out.write('%3d: (', 0) + out.prompt('q') + out.write('uit)') + out.nl() + + out.prompt('project> ') + try: + a = sys.stdin.readline() + except KeyboardInterrupt: + out.nl() + break + if a == '': + out.nl() + break + + a = a.strip() + if a.lower() in ('q', 'quit', 'exit'): + break + if not a: + continue + + try: + a_index = int(a) + except ValueError: + a_index = None + + if a_index is not None: + if a_index == 0: + break + if 0 < a_index and a_index <= len(all): + _AddI(all[a_index - 1]) + continue + + p = filter(lambda x: x.name == a or x.relpath == a, all) + if len(p) == 1: + _AddI(p[0]) + continue + print 'Bye.' + +def _AddI(project): + p = GitCommand(project, ['add', '--interactive'], bare=False) + p.Wait() diff --git a/subcmds/start.py b/subcmds/start.py new file mode 100644 index 00000000..4eb3e476 --- /dev/null +++ b/subcmds/start.py @@ -0,0 +1,51 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from command import Command +from git_command import git + +class Start(Command): + common = True + helpSummary = "Start a new branch for development" + helpUsage = """ +%prog [...] + +This subcommand starts a new branch of development that is automatically +pulled from a remote branch. + +It is equivalent to the following git commands: + +"git branch --track m/", +or +"git checkout --track -b m/". + +All three forms set up the config entries that repo bases some of its +processing on. Use %prog or git branch or checkout with --track to ensure +the configuration data is set up properly. + +""" + + def Execute(self, opt, args): + if not args: + self.Usage() + + nb = args[0] + if not git.check_ref_format('heads/%s' % nb): + print >>sys.stderr, "error: '%s' is not a valid name" % nb + sys.exit(1) + + for project in self.GetProjects(args[1:]): + project.StartBranch(nb) diff --git a/subcmds/status.py b/subcmds/status.py new file mode 100644 index 00000000..1615b423 --- /dev/null +++ b/subcmds/status.py @@ -0,0 +1,27 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from command import PagedCommand + +class Status(PagedCommand): + common = True + helpSummary = "Show the working tree status" + helpUsage = """ +%prog [...] +""" + + def Execute(self, opt, args): + for project in self.GetProjects(args): + project.PrintWorkTreeStatus() diff --git a/subcmds/sync.py b/subcmds/sync.py new file mode 100644 index 00000000..3eb44edf --- /dev/null +++ b/subcmds/sync.py @@ -0,0 +1,150 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import subprocess +import sys + +from git_command import GIT +from command import Command +from error import RepoChangedException, GitError +from project import R_HEADS + +class Sync(Command): + common = True + helpSummary = "Update working tree to the latest revision" + helpUsage = """ +%prog [...] +""" + helpDescription = """ +The '%prog' command synchronizes local project directories +with the remote repositories specified in the manifest. If a local +project does not yet exist, it will clone a new local directory from +the remote repository and set up tracking branches as specified in +the manifest. If the local project already exists, '%prog' +will update the remote branches and rebase any new local changes +on top of the new remote changes. + +'%prog' will synchronize all projects listed at the command +line. Projects can be specified either by name, or by a relative +or absolute path to the project's local directory. If no projects +are specified, '%prog' will synchronize all projects listed in +the manifest. +""" + + def _Options(self, p): + p.add_option('--no-repo-verify', + dest='no_repo_verify', action='store_true', + help='do not verify repo source code') + + def _Fetch(self, *projects): + fetched = set() + for project in projects: + if project.Sync_NetworkHalf(): + fetched.add(project.gitdir) + else: + print >>sys.stderr, 'error: Cannot fetch %s' % project.name + sys.exit(1) + return fetched + + def Execute(self, opt, args): + rp = self.manifest.repoProject + rp.PreSync() + + mp = self.manifest.manifestProject + mp.PreSync() + + all = self.GetProjects(args, missing_ok=True) + fetched = self._Fetch(rp, mp, *all) + + if rp.HasChanges: + print >>sys.stderr, 'info: A new version of repo is available' + print >>sys.stderr, '' + if opt.no_repo_verify or _VerifyTag(rp): + if not rp.Sync_LocalHalf(): + sys.exit(1) + print >>sys.stderr, 'info: Restarting repo with latest version' + raise RepoChangedException() + else: + print >>sys.stderr, 'warning: Skipped upgrade to unverified version' + + if mp.HasChanges: + if not mp.Sync_LocalHalf(): + sys.exit(1) + + self.manifest._Unload() + all = self.GetProjects(args, missing_ok=True) + missing = [] + for project in all: + if project.gitdir not in fetched: + missing.append(project) + self._Fetch(*missing) + + for project in all: + if not project.Sync_LocalHalf(): + sys.exit(1) + + +def _VerifyTag(project): + gpg_dir = os.path.expanduser('~/.repoconfig/gnupg') + if not os.path.exists(gpg_dir): + print >>sys.stderr,\ +"""warning: GnuPG was not available during last "repo init" +warning: Cannot automatically authenticate repo.""" + return True + + remote = project.GetRemote(project.remote.name) + ref = remote.ToLocal(project.revision) + + try: + cur = project.bare_git.describe(ref) + except GitError: + cur = None + + if not cur \ + or re.compile(r'^.*-[0-9]{1,}-g[0-9a-f]{1,}$').match(cur): + rev = project.revision + if rev.startswith(R_HEADS): + rev = rev[len(R_HEADS):] + + print >>sys.stderr + print >>sys.stderr,\ + "warning: project '%s' branch '%s' is not signed" \ + % (project.name, rev) + return False + + env = dict(os.environ) + env['GIT_DIR'] = project.gitdir + env['GNUPGHOME'] = gpg_dir + + cmd = [GIT, 'tag', '-v', cur] + proc = subprocess.Popen(cmd, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + env = env) + out = proc.stdout.read() + proc.stdout.close() + + err = proc.stderr.read() + proc.stderr.close() + + if proc.wait() != 0: + print >>sys.stderr + print >>sys.stderr, out + print >>sys.stderr, err + print >>sys.stderr + return False + return True diff --git a/subcmds/upload.py b/subcmds/upload.py new file mode 100644 index 00000000..ad05050e --- /dev/null +++ b/subcmds/upload.py @@ -0,0 +1,180 @@ +# +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys + +from command import InteractiveCommand +from editor import Editor +from error import UploadError + +def _die(fmt, *args): + msg = fmt % args + print >>sys.stderr, 'error: %s' % msg + sys.exit(1) + +class Upload(InteractiveCommand): + common = True + helpSummary = "Upload changes for code review" + helpUsage=""" +%prog []... +""" + helpDescription = """ +The '%prog' command is used to send changes to the Gerrit code +review system. It searches for changes in local projects that do +not yet exist in the corresponding remote repository. If multiple +changes are found, '%prog' opens an editor to allow the +user to choose which change to upload. After a successful upload, +repo prints the URL for the change in the Gerrit code review system. + +'%prog' searches for uploadable changes in all projects listed +at the command line. Projects can be specified either by name, or +by a relative or absolute path to the project's local directory. If +no projects are specified, '%prog' will search for uploadable +changes in all projects listed in the manifest. +""" + + def _SingleBranch(self, branch): + project = branch.project + name = branch.name + date = branch.date + list = branch.commits + + print 'Upload project %s/:' % project.relpath + print ' branch %s (%2d commit%s, %s):' % ( + name, + len(list), + len(list) != 1 and 's' or '', + date) + for commit in list: + print ' %s' % commit + + sys.stdout.write('(y/n)? ') + answer = sys.stdin.readline().strip() + if answer in ('y', 'Y', 'yes', '1', 'true', 't'): + self._UploadAndReport([branch]) + else: + _die("upload aborted by user") + + def _MultipleBranches(self, pending): + projects = {} + branches = {} + + script = [] + script.append('# Uncomment the branches to upload:') + for project, avail in pending: + script.append('#') + script.append('# project %s/:' % project.relpath) + + b = {} + for branch in avail: + name = branch.name + date = branch.date + list = branch.commits + + if b: + script.append('#') + script.append('# branch %s (%2d commit%s, %s):' % ( + name, + len(list), + len(list) != 1 and 's' or '', + date)) + for commit in list: + script.append('# %s' % commit) + b[name] = branch + + projects[project.relpath] = project + branches[project.name] = b + script.append('') + + script = Editor.EditString("\n".join(script)).split("\n") + + project_re = re.compile(r'^#?\s*project\s*([^\s]+)/:$') + branch_re = re.compile(r'^\s*branch\s*([^\s(]+)\s*\(.*') + + project = None + todo = [] + + for line in script: + m = project_re.match(line) + if m: + name = m.group(1) + project = projects.get(name) + if not project: + _die('project %s not available for upload', name) + continue + + m = branch_re.match(line) + if m: + name = m.group(1) + if not project: + _die('project for branch %s not in script', name) + branch = branches[project.name].get(name) + if not branch: + _die('branch %s not in %s', name, project.relpath) + todo.append(branch) + if not todo: + _die("nothing uncommented for upload") + self._UploadAndReport(todo) + + def _UploadAndReport(self, todo): + have_errors = False + for branch in todo: + try: + branch.UploadForReview() + branch.uploaded = True + except UploadError, e: + branch.error = e + branch.uploaded = False + have_errors = True + + print >>sys.stderr, '' + print >>sys.stderr, '--------------------------------------------' + + if have_errors: + for branch in todo: + if not branch.uploaded: + print >>sys.stderr, '[FAILED] %-15s %-15s (%s)' % ( + branch.project.relpath + '/', \ + branch.name, \ + branch.error) + print >>sys.stderr, '' + + for branch in todo: + if branch.uploaded: + print >>sys.stderr, '[OK ] %-15s %s' % ( + branch.project.relpath + '/', + branch.name) + print >>sys.stderr, '%s' % branch.tip_url + print >>sys.stderr, '' + + if have_errors: + sys.exit(1) + + def Execute(self, opt, args): + project_list = self.GetProjects(args) + pending = [] + + for project in project_list: + avail = project.GetUploadableBranches() + if avail: + pending.append((project, avail)) + + if not pending: + print >>sys.stdout, "no branches ready for upload" + elif len(pending) == 1 and len(pending[0][1]) == 1: + self._SingleBranch(pending[0][1][0]) + else: + self._MultipleBranches(pending)