-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
ext_http_server.py
286 lines (243 loc) · 10 KB
/
ext_http_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#!/usr/bin/env python
"""A small set of improvements upon the Simple and BaseHTTPServers."""
import SocketServer
import base64
import os
import socket
import ssl
import sys
import threading
import time
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from SimpleHTTPServer import SimpleHTTPRequestHandler
from optparse import OptionParser
from warnings import warn
__version__ = '0.2'
#
# Helpers
#
class RateLimitWriter(object):
"""A class that rate limits writing to associated file streams
This method only supports threading and not forking (multiprocessing).
"""
INTERVAL_LEN = .125
block_size = 16384
lock = threading.Lock()
block_start = None
block_sent = 0
@classmethod
def bytes_to_write(cls, desired):
"""Determine how many bytes to write and sleep when over the limit."""
to_send = 0
while not to_send:
cls.lock.acquire()
now = time.time()
if not cls.block_start:
# First data of block, send it all
cls.block_start = now
to_send = min(desired, cls.block_size)
cls.block_sent = to_send
elif cls.block_sent < cls.block_size:
# Haven't sent a complete block, send remainder
to_send = min(desired, cls.block_size - cls.block_sent)
cls.block_sent += to_send
else:
# A complete block has been sent, sleep if necessary
sleep_time = cls.INTERVAL_LEN - (now - cls.block_start)
if sleep_time > 0:
time.sleep(sleep_time)
cls.block_start = cls.block_sent = None
cls.block_sent = 0
cls.lock.release()
return to_send
@classmethod
def set_rate_limit(cls, limit):
"""Set the rate limit in kilobytes per second."""
cls.block_size = int(1024 * limit * cls.INTERVAL_LEN)
def __init__(self, to_wrap):
"""Store the output stream we are wrapping."""
self.wrapped = to_wrap
def __getattr__(self, attr):
"""Redirect all function calls through the wrapped output stream."""
return getattr(self.wrapped, attr)
def write(self, message):
"""Perform a throttled write to the wrapped output stream."""
while message:
to_send = RateLimitWriter.bytes_to_write(len(message))
self.wrapped.write(message[:to_send])
message = message[to_send:]
#
# HTTPServer extensions
#
class SecureHTTPServer(HTTPServer, object):
"""A HTTP Server object that support HTTPS"""
def __init__(self, address, handler, cert_file):
"""Support TLS/SSL by wrapping the socket."""
super(SecureHTTPServer, self).__init__(address, handler)
self.socket = ssl.wrap_socket(self.socket, certfile=cert_file)
#
# BaseHTTPRequestHandler extensions
#
class AuthHandler(BaseHTTPRequestHandler, object):
"""A handler that supports basic HTTP authentication/authorization"""
message = 'Authentication required.'
realm = 'Something'
users = set()
@classmethod
def add_user(cls, username, password):
"""Add a set of credentials."""
cls.users.add(base64.b64encode('{0}:{1}'.format(username, password)))
def handle_auth(self, head=False):
"""Output the authentication headers if the user is not valid."""
auth = self.headers.getheader('Authorization')
if auth:
try:
_, encoded = auth.split(' ', 1)
except ValueError:
encoded = None
# Verify the user
if encoded in AuthHandler.users:
return True
# Send authentication header information
self.send_response(401)
self.send_header('WWW-Authenticate',
'Basic realm="{0}"'.format(AuthHandler.realm))
self.send_header('Content-Type', 'text/html')
self.send_header('Content-Length', len(AuthHandler.message))
self.end_headers()
if not head:
self.wfile.write(AuthHandler.message)
return False
def do_GET(self):
"""Call the parent's do_GET function if the user is authorized."""
if self.handle_auth():
super(AuthHandler, self).do_GET()
def do_HEAD(self):
"""Call the parent's do_HEAD function if the user is authorized."""
if self.handle_auth(head=True):
super(AuthHandler, self).do_HEAD()
class RangeHandler(SimpleHTTPRequestHandler, object):
"""A handler that supports HTTP requests with the Range header
The Range header allows for the resume download functionality.
"""
def copyfile(self, source, outputfile):
"""Copy only the ranged part of the file when appropriate."""
if self.is_ranged:
source.seek(self.range_begin)
super(RangeHandler, self).copyfile(source, outputfile)
def do_GET(self):
"""Set is_ranged flag if a valid Range header is sent."""
self.handle_range()
super(RangeHandler, self).do_GET()
def do_HEAD(self):
"""Set is_ranged flag if a valid Range header is sent."""
self.handle_range()
super(RangeHandler, self).do_HEAD()
def handle_range(self):
"""Parse the Range header if it exists."""
self.is_ranged = False
if 'range' in self.headers:
try:
range_unit, other = self.headers['range'].split('=', 1)
if range_unit == 'bytes':
if ',' in other: # Handle only a single range
warn('Multiple ranges are not supported.')
return
begin, end = other.split('-', 1)
if end:
warn('Shortened ranges are not supported.')
return
self.range_begin = int(begin) if begin else 0
self.range_end = None
self.is_ranged = True
except ValueError:
pass
def send_header(self, key, value):
"""Modify Content-Length and add Content-Range when ranged."""
if key == 'Content-Length' and self.is_ranged:
length = int(value)
if self.range_end is None:
end = length - 1
else:
end = min(self.range_end, length - 1)
value = str(1 + end - self.range_begin)
self.send_header('Content-Range', 'bytes {0}-{1}/{2}'
.format(self.range_begin, end, length))
super(RangeHandler, self).send_header(key, value)
def send_response(self, status, *args, **kwargs):
"""Send 206 status for ranged responses."""
if self.is_ranged and status == 200:
status = 206
super(RangeHandler, self).send_response(status, *args, **kwargs)
def setup(self):
"""Set HTTP/1.1 as Range is supported only on HTTP/1.1."""
super(RangeHandler, self).setup()
self.protocol_version = 'HTTP/1.1'
self.is_ranged = False
class RateLimitHandler(BaseHTTPRequestHandler, object):
"""A hander that supports rate limiting from server to client.
This handler will not properly rate limit if a ForkinMixIn is used in the
HTTPServer object. However, it works great in combination with the
ThreadingMixIn.
"""
def handle(self):
"""Setup rate limiting on the outgoing connection."""
self.wfile = RateLimitWriter(self.wfile)
super(RateLimitHandler, self).handle()
#
# Combined classes for use with the main functionality
#
class MyHandler(AuthHandler, RangeHandler, RateLimitHandler):
"""A handler that supports auth, download resuming, and throttling."""
class MyServer(SocketServer.ThreadingMixIn, SecureHTTPServer):
"""A threaded SecureHTTPServer with basic error filtering"""
def handle_error(self, request, client_address):
"""Disable tracebacks on connection close errors."""
exc_type, exc_value, _ = sys.exc_info()
if exc_type is socket.error and exc_value[0] == 32:
print('{0} closed connection'.format(client_address))
elif exc_type is ssl.SSLError and exc_value.errno == 1:
print('{0} SSL Error: bad write retry'.format(client_address))
else:
super(MyServer, self).handle_error(request, client_address)
def main():
"""Run a secure threaded server with auth resume and rate limit support."""
parser = OptionParser(version='%prog {0}'.format(__version__))
parser.add_option('-p', '--port', type='int', default='8000')
parser.add_option('-c', '--cert', help='The TLS/SSL certificate file')
parser.add_option('-d', '--directory', help='The directory to serve')
parser.add_option('-r', '--ratelimit', help='The ratelimit in KBps',
type='int', default=128)
parser.add_option('-a', '--add-auth', help='Add user:password combination',
action='append')
options, _ = parser.parse_args()
# Configure Services
if not options.add_auth:
parser.error('At least one user must be added via --add-auth')
for auth in options.add_auth:
try:
username, password = auth.split(':', 1)
except ValueError:
parser.error('{0!r} is not a valid username:password'.format(auth))
AuthHandler.add_user(username, password)
RateLimitWriter.set_rate_limit(options.ratelimit)
# Verify cert file
if not options.cert:
parser.error('--cert must be provided')
cert_path = os.path.abspath(options.cert)
if not os.path.isfile(cert_path):
parser.error('Invalid cert file')
# Change into serving directory
if options.directory:
try:
os.chdir(options.directory)
except OSError:
parser.error('Invalid --directory')
server = MyServer(('', options.port), MyHandler, cert_path)
print('Server listening on port %d' % options.port)
try:
server.serve_forever()
except KeyboardInterrupt:
print('\nGoodbye')
if __name__ == '__main__':
sys.exit(main())