X Tutup
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions telegram/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
from functools import wraps
from inspect import getargspec
from threading import Thread, BoundedSemaphore, Lock, Event
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread
from re import match
from time import sleep

Expand All @@ -33,7 +33,8 @@
logging.getLogger(__name__).addHandler(H)

semaphore = None
running_async = 0
async_threads = set()
""":type: set[Thread]"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does pycharm recognize """asyc_threads (set[Thread]): """? This format would be preferrable, since it's a google style docstring.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in order not to leave this "in the air"
it was agreed that we can use the Sphinx docstring comment style for this kind of case (i.e. "standalone" variable)
this was the most reliable form we found and the most friendly to pycharm

async_lock = Lock()


Expand All @@ -58,23 +59,21 @@ def pooled(*pargs, **kwargs):
"""
A wrapper to run a thread in a thread pool
"""
global running_async, async_lock
result = func(*pargs, **kwargs)
semaphore.release()
with async_lock:
running_async -= 1
async_threads.remove(current_thread())
return result

@wraps(func)
def async_func(*pargs, **kwargs):
"""
A wrapper to run a function in a thread
"""
global running_async, async_lock
thread = Thread(target=pooled, args=pargs, kwargs=kwargs)
semaphore.acquire()
with async_lock:
running_async += 1
async_threads.add(thread)
thread.start()
return thread

Expand Down
86 changes: 50 additions & 36 deletions telegram/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(self,
self.is_idle = False
self.httpd = None
self.__lock = Lock()
self.__threads = []
""":type: list[Thread]"""

def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2):
"""
Expand Down Expand Up @@ -120,9 +122,10 @@ def _init_thread(self, target, name, *args, **kwargs):
thr = Thread(target=self._thread_wrapper, name=name,
args=(target,) + args, kwargs=kwargs)
thr.start()
self.__threads.append(thr)

def _thread_wrapper(self, target, *args, **kwargs):
thr_name = current_thread()
thr_name = current_thread().name
self.logger.debug('{0} - started'.format(thr_name))
try:
target(*args, **kwargs)
Expand Down Expand Up @@ -160,20 +163,10 @@ def start_webhook(self,
if not self.running:
self.running = True

# Create Thread objects
dispatcher_thread = Thread(target=self.dispatcher.start,
name="dispatcher")
updater_thread = Thread(target=self._start_webhook,
name="updater",
args=(listen,
port,
url_path,
cert,
key))

# Start threads
dispatcher_thread.start()
updater_thread.start()
# Create & start threads
self._init_thread(self.dispatcher.start, "dispatcher"),
self._init_thread(self._start_webhook, "updater", listen,
port, url_path, cert, key)

# Return the update queue so the main thread can insert updates
return self.update_queue
Expand Down Expand Up @@ -221,8 +214,6 @@ def _start_polling(self, poll_interval, timeout, network_delay):

sleep(cur_interval)

self.logger.info('Updater thread stopped')

@staticmethod
def _increase_poll_interval(current_interval):
# increase waiting times on subsequent errors up to 30secs
Expand Down Expand Up @@ -266,7 +257,6 @@ def _start_webhook(self, listen, port, url_path, cert, key):
raise TelegramError('SSL Certificate invalid')

self.httpd.serve_forever(poll_interval=1)
self.logger.info('Updater thread stopped')

def stop(self):
"""
Expand All @@ -276,25 +266,49 @@ def stop(self):
self.job_queue.stop()
with self.__lock:
if self.running:
self.running = False
self.logger.info('Stopping Updater and Dispatcher...')
self.logger.debug('This might take a long time if you set a '
'high value as polling timeout.')

if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None

self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()
while dispatcher.running_async > 0:
sleep(1)

self.logger.debug("Dispatcher stopped.")

self.running = False

self._stop_httpd()
self._stop_dispatcher()
self._join_threads()
# async threads must be join()ed only after the dispatcher
# thread was joined, otherwise we can still have new async
# threads dispatched
self._join_async_threads()

def _stop_httpd(self):
if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None

def _stop_dispatcher(self):
self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()

def _join_async_threads(self):
with dispatcher.async_lock:
threads = list(dispatcher.async_threads)
total = len(threads)
for i, thr in enumerate(threads):
self.logger.info(
'Waiting for async thread {0}/{1} to end'.format(i, total))
thr.join()
self.logger.debug(
'async thread {0}/{1} has ended'.format(i, total))

def _join_threads(self):
for thr in self.__threads:
self.logger.info(
'Waiting for {0} thread to end'.format(thr.name))
thr.join()
self.logger.debug('{0} thread has ended'.format(thr.name))
self.__threads = []

def signal_handler(self, signum, frame):
self.is_idle = False
Expand Down
44 changes: 34 additions & 10 deletions telegram/utils/webhookhandler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from telegram import Update, NullHandler
from future.utils import bytes_to_native_str as n
from future.utils import bytes_to_native_str
from threading import Lock
import json
try:
Expand All @@ -14,6 +14,13 @@
logging.getLogger(__name__).addHandler(H)


class _InvalidPost(Exception):

def __init__(self, http_code):
self.http_code = http_code
super(_InvalidPost, self).__init__()


class WebhookServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, server_address, RequestHandlerClass, update_queue,
webhook_path):
Expand Down Expand Up @@ -63,12 +70,15 @@ def do_GET(self):

def do_POST(self):
self.logger.debug("Webhook triggered")
if self.path == self.server.webhook_path and \
'content-type' in self.headers and \
'content-length' in self.headers and \
self.headers['content-type'] == 'application/json':
json_string = \
n(self.rfile.read(int(self.headers['content-length'])))
try:
self._validate_post()
clen = self._get_content_len()
except _InvalidPost as e:
self.send_error(e.http_code)
self.end_headers()
else:
buf = self.rfile.read(clen)
json_string = bytes_to_native_str(buf)

self.send_response(200)
self.end_headers()
Expand All @@ -80,6 +90,20 @@ def do_POST(self):
update.update_id)
self.server.update_queue.put(update)

else:
self.send_error(403)
self.end_headers()
def _validate_post(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but this is not called, is it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your eyes are working properly :)
i forgot to call self._validate_post() instead of the if statement in do_POST()

if not (self.path == self.server.webhook_path and
'content-type' in self.headers and
self.headers['content-type'] == 'application/json'):
raise _InvalidPost(403)

def _get_content_len(self):
clen = self.headers.get('content-length')
if clen is None:
raise _InvalidPost(411)
try:
clen = int(clen)
except ValueError:
raise _InvalidPost(403)
if clen < 0:
raise _InvalidPost(403)
return clen
Loading
X Tutup