diff --git a/docs/api.rst b/docs/api.rst index efad2ef..4b86c60 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -52,6 +52,45 @@ Datastores Signals ------- +See the `Flask documentation on signals`_ for information on how to use these +signals in your code. + See the documentation for the signals provided by the Flask-Login and -Flask-Principal extensions. Flask-Security does not provide any additional -signals. \ No newline at end of file +Flask-Principal extensions. In addition to those signals, Flask-Security +sends the following signals. + +.. data:: user_registered + + Sent when a user registers on the site. It is passed a dict with + the `user` and `confirm_token`, the user being logged in and the + (if so configured) the confirmation token issued. + +.. data:: user_confirmed + + Sent when a user is confirmed. It is passed `user`, which is the + user being confirmed. + +.. data:: confirm_instructions_sent + + Sent when a user requests confirmation instructions. It is passed + the `user`. + +.. data:: login_instructions_sent + + Sent when passwordless login is used and user logs in. It is passed + a dict with the `user` and `login_token`, the user being logged in + and the (if so configured) the login token issued. + +.. data:: password_reset + + Sent when a user completes a password. It is passed the `user`. + +.. data:: reset_password_instructions_sent + + Sent when a user requests a password reset. It is passed a dict + with the `user` and `token`, the user being logged in and + the (if so configured) the reset token issued. + +All signals are also passed a `app` keyword argument, which is the current application. + +.. _Flask documentation on signals: http://flask.pocoo.org/docs/signals/ diff --git a/flask_security/utils.py b/flask_security/utils.py index 7bc1902..03becf6 100644 --- a/flask_security/utils.py +++ b/flask_security/utils.py @@ -10,6 +10,8 @@ """ import base64 +import blinker +import functools import hashlib import hmac from contextlib import contextmanager @@ -23,9 +25,9 @@ from flask.ext.principal import Identity, AnonymousIdentity, identity_changed from itsdangerous import BadSignature, SignatureExpired from werkzeug.local import LocalProxy -from .signals import user_registered, reset_password_instructions_sent, \ - login_instructions_sent - +from .signals import user_registered, user_confirmed, \ + confirm_instructions_sent, login_instructions_sent, \ + password_reset, reset_password_instructions_sent # Convenient references _security = LocalProxy(lambda: current_app.extensions['security']) @@ -311,3 +313,57 @@ def capture_reset_password_requests(reset_password_sent_at=None): yield reset_requests finally: reset_password_instructions_sent.disconnect(_on) + + +class CaptureSignals(object): + """Testing utility for capturing blinker signals. + + Context manager which mocks out selected signals and registers which are `sent` on and what + arguments were sent. Instantiate with a list of blinker `NamedSignals` to patch. Each signal + has it's `send` mocked out. + """ + def __init__(self, signals): + """Patch all given signals and make them available as attributes. + + :param signals: list of signals + """ + self._records = {} + self._receivers = {} + for signal in signals: + self._records[signal] = [] + self._receivers[signal] = functools.partial(self._record, signal) + + def __getitem__(self, signal): + """All captured signals are available via `ctxt[signal]`. + """ + if isinstance(signal, blinker.base.NamedSignal): + return self._records[signal] + else: + super(CaptureSignals, self).__setitem__(signal) + + def _record(self, signal, *args, **kwargs): + self._records[signal].append((args, kwargs)) + + def __enter__(self): + for signal, receiver in self._receivers.iteritems(): + signal.connect(receiver) + return self + + def __exit__(self, type, value, traceback): + for signal, receiver in self._receivers.iteritems(): + signal.disconnect(receiver) + + def signals_sent(self): + """Return a set of the signals sent. + :rtype: list of blinker `NamedSignals`. + """ + return set([signal for signal, _ in self._records.iteritems() if self._records[signal]]) + + +def capture_signals(): + """Factory method that creates a `CaptureSignals` with all the flask_security signals.""" + return CaptureSignals([user_registered, user_confirmed, + confirm_instructions_sent, login_instructions_sent, + password_reset, reset_password_instructions_sent]) + + diff --git a/tests/signals_tests.py b/tests/signals_tests.py new file mode 100644 index 0000000..38fcf4b --- /dev/null +++ b/tests/signals_tests.py @@ -0,0 +1,198 @@ +from __future__ import with_statement + +from mock import (patch, call) + +from flask_security.utils import (capture_registrations, capture_reset_password_requests, capture_signals) +from flask_security.signals import (user_registered, user_confirmed, + confirm_instructions_sent, login_instructions_sent, + password_reset, reset_password_instructions_sent) +from tests import SecurityTest + + +def compare_user(a, b): + """Helper to compare two users.""" + return a.id == b.id and a.email == b.email and a.password == b.password + + +class RegisterableSignalsTests(SecurityTest): + + AUTH_CONFIG = { + 'SECURITY_CONFIRMABLE': True, + 'SECURITY_REGISTERABLE': True, + } + + def test_register(self): + e = 'dude@lp.com' + with capture_signals() as mocks: + self.register(e) + user = self.app.security.datastore.find_user(email='dude@lp.com') + self.assertEqual(mocks.signals_sent(), set([user_registered])) + calls = mocks[user_registered] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertTrue(compare_user(args[0]['user'], user)) + self.assertIn('confirm_token', args[0]) + self.assertEqual(kwargs['app'], self.app) + + def test_register(self): + e = 'dude@lp.com' + with capture_signals() as mocks: + self.register(e, password='') + self.assertEqual(mocks.signals_sent(), set()) + + +class ConfirmableSignalsTests(SecurityTest): + + AUTH_CONFIG = { + 'SECURITY_CONFIRMABLE': True, + 'SECURITY_REGISTERABLE': True, + } + + def test_confirm(self): + e = 'dude@lp.com' + with capture_registrations() as registrations: + self.register(e) + token = registrations[0]['confirm_token'] + with capture_signals() as mocks: + self.client.get('/confirm/' + token, follow_redirects=True) + user = self.app.security.datastore.find_user(email='dude@lp.com') + self.assertTrue(mocks.signals_sent(), set([user_confirmed])) + calls = mocks[user_confirmed] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertEqual(args[0].id, user.id) + self.assertEqual(kwargs['app'], self.app) + + def test_confirm_bad_token(self): + e = 'dude@lp.com' + with capture_registrations() as registrations: + self.register(e) + token = registrations[0]['confirm_token'] + with capture_signals() as mocks: + self.client.get('/confirm/bogus', follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set()) + + def test_confirm_twice(self): + e = 'dude@lp.com' + with capture_registrations() as registrations: + self.register(e) + token = registrations[0]['confirm_token'] + self.client.get('/confirm/' + token, follow_redirects=True) + self.logout() + with capture_signals() as mocks: + self.client.get('/confirm/' + token, follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set([user_confirmed])) + # TODO: is that the desired behaviour? + + def test_resend_confirmation(self): + e = 'dude@lp.com' + self.register(e) + with capture_signals() as mocks: + self._post('/confirm', data={'email': e}) + user = self.app.security.datastore.find_user(email='dude@lp.com') + self.assertEqual(mocks.signals_sent(), set([confirm_instructions_sent])) + calls = mocks[confirm_instructions_sent] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertTrue(compare_user(args[0], user)) + self.assertEqual(kwargs['app'], self.app) + + def test_send_confirmation_bad_email(self): + with capture_signals() as mocks: + self._post('/confirm', data=dict(email='bogus@bogus.com')) + self.assertEqual(mocks.signals_sent(), set()) + + +class RecoverableSignalsTests(SecurityTest): + + AUTH_CONFIG = { + 'SECURITY_RECOVERABLE': True, + 'SECURITY_RESET_PASSWORD_ERROR_VIEW': '/', + 'SECURITY_POST_FORGOT_VIEW': '/' + } + + def test_reset_password_request(self): + with capture_signals() as mocks: + self.client.post('/reset', + data=dict(email='joe@lp.com'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set([reset_password_instructions_sent])) + user = self.app.security.datastore.find_user(email='joe@lp.com') + calls = mocks[reset_password_instructions_sent] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertTrue(compare_user(args[0]['user'], user)) + self.assertIn('token', args[0]) + self.assertEqual(kwargs['app'], self.app) + + def test_reset_password(self): + with capture_reset_password_requests() as requests: + self.client.post('/reset', + data=dict(email='joe@lp.com'), + follow_redirects=True) + token = requests[0]['token'] + with capture_signals() as mocks: + self.client.post('/reset/' + token, + data=dict(password='newpassword', + password_confirm='newpassword'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set([password_reset])) + user = self.app.security.datastore.find_user(email='joe@lp.com') + calls = mocks[password_reset] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertTrue(compare_user(args[0], user)) + self.assertEqual(kwargs['app'], self.app) + + def test_reset_password_invalid_emails(self): + with capture_signals() as mocks: + self.client.post('/reset', + data=dict(email='nobody@lp.com'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set()) + + def test_reset_password_invalid_token(self): + with capture_signals() as mocks: + self.client.post('/reset/bogus', + data=dict(password='newpassword', + password_confirm='newpassword'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set()) + + +class PasswordlessTests(SecurityTest): + + AUTH_CONFIG = { + 'SECURITY_PASSWORDLESS': True + } + + def test_login_request_for_inactive_user(self): + with capture_signals() as mocks: + self.client.post('/login', + data=dict(email='tiya@lp.com'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set()) + + def test_login_request_for_invalid_email(self): + with capture_signals() as mocks: + self.client.post('/login', + data=dict(email='nobody@lp.com'), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set()) + + def test_request_login_token_sends_email_and_can_login(self): + e = 'matt@lp.com' + + with capture_signals() as mocks: + self.client.post('/login', + data=dict(email=e), + follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set([login_instructions_sent])) + user = self.app.security.datastore.find_user(email='matt@lp.com') + calls = mocks[login_instructions_sent] + self.assertEqual(len(calls), 1) + args, kwargs = calls[0] + self.assertTrue(compare_user(args[0]['user'], user)) + self.assertIn('login_token', args[0]) + self.assertEqual(kwargs['app'], self.app) +