From 34b3bf9e8060c73b78fdfeef273e00efa3feb207 Mon Sep 17 00:00:00 2001 From: Matt Wright Date: Fri, 1 Feb 2013 17:18:31 -0500 Subject: [PATCH] Fix CSRF functionality for LoginForm The login form was not respecting csrf validation. I've adjusted the tests as well to always send a CSRF token along. This now requires all requests to pass a csrf token. If performing plain AJAX requests the token will have to be extracted from the form in some way. Fixes #86 --- flask_security/forms.py | 11 ++++---- tests/__init__.py | 36 ++++++++++++++++--------- tests/configured_tests.py | 57 +++++++++++++++++---------------------- tests/signals_tests.py | 45 ++++++++++++------------------- 4 files changed, 72 insertions(+), 77 deletions(-) diff --git a/flask_security/forms.py b/flask_security/forms.py index 64ae022..27dd675 100644 --- a/flask_security/forms.py +++ b/flask_security/forms.py @@ -45,10 +45,7 @@ def valid_user_email(form, field): class Form(BaseForm): def __init__(self, *args, **kwargs): if current_app.testing: - csrf_enabled = False - else: - csrf_enabled = request.json is None - kwargs.setdefault('csrf_enabled', csrf_enabled) + self.TIME_LIMIT = None super(Form, self).__init__(*args, **kwargs) @@ -83,6 +80,7 @@ class NewPasswordFormMixin(): validators=[password_required, Length(min=6, max=128)]) + class PasswordConfirmFormMixin(): password_confirm = PasswordField("Retype Password", validators=[EqualTo('password', message="Passwords do not match")]) @@ -147,6 +145,7 @@ class PasswordlessLoginForm(Form, UserEmailFormMixin): class LoginForm(Form, NextFormMixin): """The default login form""" + email = TextField('Email Address') password = PasswordField('Password') remember = BooleanField("Remember Me") @@ -156,7 +155,8 @@ class LoginForm(Form, NextFormMixin): super(LoginForm, self).__init__(*args, **kwargs) def validate(self): - super(LoginForm, self).validate() + if not super(LoginForm, self).validate(): + return False if self.email.data.strip() == '': self.email.errors.append('Email not provided') @@ -187,6 +187,7 @@ class ConfirmRegisterForm(Form, RegisterFormMixin, UniqueEmailFormMixin, NewPasswordFormMixin): pass + class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin): pass diff --git a/tests/__init__.py b/tests/__init__.py index 53e0d9b..35ceb79 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,8 +1,13 @@ # -*- coding: utf-8 -*- +import hmac + +from hashlib import sha1 from unittest import TestCase + from tests.test_app.sqlalchemy import create_app + class SecurityTest(TestCase): APP_KWARGS = { @@ -21,6 +26,13 @@ class SecurityTest(TestCase): self.app = app self.client = app.test_client() + with self.client.session_transaction() as session: + session['csrf'] = 'csrf_token' + + csrf_hmac = hmac.new(self.app.config['SECRET_KEY'], + 'csrf_token'.encode('utf8'), digestmod=sha1) + self.csrf_token = '##' + csrf_hmac.hexdigest() + def _create_app(self, auth_config, **kwargs): return create_app(auth_config, **kwargs) @@ -30,30 +42,30 @@ class SecurityTest(TestCase): headers=headers) def _post(self, route, data=None, content_type=None, follow_redirects=True, headers=None): + if isinstance(data, dict): + data['csrf_token'] = self.csrf_token + return self.client.post(route, data=data, follow_redirects=follow_redirects, content_type=content_type or 'application/x-www-form-urlencoded', headers=headers) def register(self, email, password='password'): - data = dict(email=email, password=password) + data = dict(email=email, password=password, csrf_token=self.csrf_token) return self.client.post('/register', data=data, follow_redirects=True) def authenticate(self, email="matt@lp.com", password="password", endpoint=None, **kwargs): data = dict(email=email, password=password, remember='y') - r = self._post(endpoint or '/login', data=data, **kwargs) - return r + return self._post(endpoint or '/login', data=data, **kwargs) def json_authenticate(self, email="matt@lp.com", password="password", endpoint=None): - data = """ -{ - "email": "%s", - "password": "%s" -} -""" - return self._post(endpoint or '/login', - content_type="application/json", - data=data % (email, password)) + data = """{ + "email": "%s", + "password": "%s", + "csrf_token": "%s" + }""" + return self._post(endpoint or '/login', content_type="application/json", + data=data % (email, password, self.csrf_token)) def logout(self, endpoint=None): return self._get(endpoint or '/logout', follow_redirects=True) diff --git a/tests/configured_tests.py b/tests/configured_tests.py index f1e0d7d..5e9517a 100644 --- a/tests/configured_tests.py +++ b/tests/configured_tests.py @@ -67,7 +67,7 @@ class ConfiguredSecurityTests(SecurityTest): self.assertIn('Post Register', r.data) def test_register_json(self): - data = '{ "email": "dude@lp.com", "password": "password" }' + data = '{ "email": "dude@lp.com", "password": "password", "csrf_token":"%s" }' % self.csrf_token r = self._post('/register', data=data, content_type='application/json') data = json.loads(r.data) self.assertEquals(data['meta']['code'], 200) @@ -117,7 +117,7 @@ class RegisterableTests(SecurityTest): data = dict(email='dude@lp.com', password='password', password_confirm='password') - self.client.post('/register', data=data, follow_redirects=True) + self._post('/register', data=data, follow_redirects=True) r = self.authenticate('dude@lp.com') self.assertIn('Hello dude@lp.com', r.data) @@ -145,7 +145,7 @@ class ConfirmableTests(SecurityTest): self.client.get('/confirm/' + token, follow_redirects=True) self.logout() - r = self.client.post('/confirm', data=dict(email=e)) + r = self._post('/confirm', data=dict(email=e)) self.assertIn(self.get_message('ALREADY_CONFIRMED'), r.data) def test_register_sends_confirmation_email(self): @@ -231,7 +231,7 @@ class LoginWithoutImmediateConfirmTests(SecurityTest): e = 'dude@lp.com' p = 'password' data = dict(email=e, password=p, password_confirm=p) - r = self.client.post('/register', data=data, follow_redirects=True) + r = self._post('/register', data=data, follow_redirects=True) self.assertIn(e, r.data) @@ -245,7 +245,7 @@ class RecoverableTests(SecurityTest): def test_reset_view(self): with capture_reset_password_requests() as requests: - r = self.client.post('/reset', + r = self._post('/reset', data=dict(email='joe@lp.com'), follow_redirects=True) t = requests[0]['token'] @@ -255,23 +255,23 @@ class RecoverableTests(SecurityTest): def test_forgot_post_sends_email(self): with capture_reset_password_requests(): with self.app.extensions['mail'].record_messages() as outbox: - self.client.post('/reset', data=dict(email='joe@lp.com')) + self._post('/reset', data=dict(email='joe@lp.com')) self.assertEqual(len(outbox), 1) def test_forgot_password_json(self): - r = self.client.post('/reset', data='{"email": "matt@lp.com"}', + r = self._post('/reset', data='{"email": "matt@lp.com"}', content_type="application/json") self.assertEquals(r.status_code, 200) def test_forgot_password_invalid_email(self): - r = self.client.post('/reset', + r = self._post('/reset', data=dict(email='larry@lp.com'), follow_redirects=True) self.assertIn("Specified user does not exist", r.data) def test_reset_password_with_valid_token(self): with capture_reset_password_requests() as requests: - r = self.client.post('/reset', + r = self._post('/reset', data=dict(email='joe@lp.com'), follow_redirects=True) t = requests[0]['token'] @@ -303,14 +303,13 @@ class ExpiredResetPasswordTest(SecurityTest): def test_reset_password_with_expired_token(self): with capture_reset_password_requests() as requests: - r = self.client.post('/reset', - data=dict(email='joe@lp.com'), - follow_redirects=True) + r = self._post('/reset', data=dict(email='joe@lp.com'), + follow_redirects=True) t = requests[0]['token'] time.sleep(1) - r = self.client.post('/reset/' + t, data={ + r = self._post('/reset/' + t, data={ 'password': 'newpassword', 'password_confirm': 'newpassword' }, follow_redirects=True) @@ -348,20 +347,19 @@ class PasswordlessTests(SecurityTest): def test_login_request_for_inactive_user(self): msg = self.app.config['SECURITY_MSG_DISABLED_ACCOUNT'][0] - r = self.client.post('/login', - data=dict(email='tiya@lp.com'), - follow_redirects=True) + r = self._post('/login', data=dict(email='tiya@lp.com'), + follow_redirects=True) self.assertIn(msg, r.data) def test_request_login_token_with_json_and_valid_email(self): - data = '{"email": "matt@lp.com", "password": "password"}' - r = self.client.post('/login', data=data, content_type='application/json') + data = '{"email": "matt@lp.com", "password": "password", "csrf_token":"%s"}' % self.csrf_token + r = self._post('/login', data=data, content_type='application/json') self.assertEquals(r.status_code, 200) self.assertNotIn('error', r.data) def test_request_login_token_with_json_and_invalid_email(self): data = '{"email": "nobody@lp.com", "password": "password"}' - r = self.client.post('/login', data=data, content_type='application/json') + r = self._post('/login', data=data, content_type='application/json') self.assertIn('errors', r.data) def test_request_login_token_sends_email_and_can_login(self): @@ -370,9 +368,8 @@ class PasswordlessTests(SecurityTest): with capture_passwordless_login_requests() as requests: with self.app.extensions['mail'].record_messages() as outbox: - r = self.client.post('/login', - data=dict(email=e), - follow_redirects=True) + r = self._post('/login', data=dict(email=e), + follow_redirects=True) self.assertEqual(len(outbox), 1) @@ -401,9 +398,8 @@ class PasswordlessTests(SecurityTest): def test_token_login_when_already_authenticated(self): with capture_passwordless_login_requests() as requests: - self.client.post('/login', - data=dict(email='matt@lp.com'), - follow_redirects=True) + self._post('/login', data=dict(email='matt@lp.com'), + follow_redirects=True) token = requests[0]['login_token'] r = self.client.get('/login/' + token, follow_redirects=True) @@ -431,9 +427,7 @@ class ExpiredLoginTokenTests(SecurityTest): e = 'matt@lp.com' with capture_passwordless_login_requests() as requests: - self.client.post('/login', - data=dict(email=e), - follow_redirects=True) + self._post('/login', data=dict(email=e), follow_redirects=True) token = requests[0]['login_token'] time.sleep(1.25) @@ -467,7 +461,7 @@ class AsyncMailTaskTests(SecurityTest): def send_email(msg): self.mail_sent = True - self.client.post('/reset', data=dict(email='matt@lp.com')) + self._post('/reset', data=dict(email='matt@lp.com')) self.assertTrue(self.mail_sent) @@ -542,9 +536,8 @@ class RecoverableExtendFormsTest(SecurityTest): def test_reset_password(self): with capture_reset_password_requests() as requests: - self.client.post('/reset', - data=dict(email='joe@lp.com'), - follow_redirects=True) + self._post('/reset', data=dict(email='joe@lp.com'), + follow_redirects=True) token = requests[0]['token'] r = self._get('/reset/' + token) self.assertIn("My Reset Password Submit Field", r.data) diff --git a/tests/signals_tests.py b/tests/signals_tests.py index 51fde0f..35f067d 100644 --- a/tests/signals_tests.py +++ b/tests/signals_tests.py @@ -32,7 +32,7 @@ class RegisterableSignalsTests(SecurityTest): self.assertIn('confirm_token', args[0]) self.assertEqual(kwargs['app'], self.app) - def test_register(self): + def test_register_without_password(self): e = 'dude@lp.com' with capture_signals() as mocks: self.register(e, password='') @@ -111,9 +111,8 @@ class RecoverableSignalsTests(SecurityTest): def test_reset_password_request(self): with capture_signals() as mocks: - self.client.post('/reset', - data=dict(email='joe@lp.com'), - follow_redirects=True) + self._post('/reset', data=dict(email='joe@lp.com'), + follow_redirects=True) self.assertEqual(mocks.signals_sent(), set([reset_password_instructions_sent])) user = self.app.security.datastore.find_user(email='joe@lp.com') calls = mocks[reset_password_instructions_sent] @@ -125,15 +124,12 @@ class RecoverableSignalsTests(SecurityTest): def test_reset_password(self): with capture_reset_password_requests() as requests: - self.client.post('/reset', - data=dict(email='joe@lp.com'), - follow_redirects=True) + self._post('/reset', data=dict(email='joe@lp.com'), + follow_redirects=True) token = requests[0]['token'] with capture_signals() as mocks: - self.client.post('/reset/' + token, - data=dict(password='newpassword', - password_confirm='newpassword'), - follow_redirects=True) + data = dict(password='newpassword', password_confirm='newpassword') + self._post('/reset/' + token, data, follow_redirects=True) self.assertEqual(mocks.signals_sent(), set([password_reset])) user = self.app.security.datastore.find_user(email='joe@lp.com') calls = mocks[password_reset] @@ -144,17 +140,14 @@ class RecoverableSignalsTests(SecurityTest): def test_reset_password_invalid_emails(self): with capture_signals() as mocks: - self.client.post('/reset', - data=dict(email='nobody@lp.com'), - follow_redirects=True) + self._post('/reset', data=dict(email='nobody@lp.com'), + follow_redirects=True) self.assertEqual(mocks.signals_sent(), set()) def test_reset_password_invalid_token(self): with capture_signals() as mocks: - self.client.post('/reset/bogus', - data=dict(password='newpassword', - password_confirm='newpassword'), - follow_redirects=True) + data = dict(password='newpassword', password_confirm='newpassword') + self._post('/reset/bogus', data, follow_redirects=True) self.assertEqual(mocks.signals_sent(), set()) @@ -166,25 +159,22 @@ class PasswordlessTests(SecurityTest): def test_login_request_for_inactive_user(self): with capture_signals() as mocks: - self.client.post('/login', - data=dict(email='tiya@lp.com'), - follow_redirects=True) + self._post('/login', data=dict(email='tiya@lp.com'), + follow_redirects=True) self.assertEqual(mocks.signals_sent(), set()) def test_login_request_for_invalid_email(self): with capture_signals() as mocks: - self.client.post('/login', - data=dict(email='nobody@lp.com'), - follow_redirects=True) + self._post('/login', data=dict(email='nobody@lp.com'), + follow_redirects=True) self.assertEqual(mocks.signals_sent(), set()) def test_request_login_token_sends_email_and_can_login(self): e = 'matt@lp.com' with capture_signals() as mocks: - self.client.post('/login', - data=dict(email=e), - follow_redirects=True) + self._post('/login', data=dict(email=e), follow_redirects=True) + self.assertEqual(mocks.signals_sent(), set([login_instructions_sent])) user = self.app.security.datastore.find_user(email='matt@lp.com') calls = mocks[login_instructions_sent] @@ -193,4 +183,3 @@ class PasswordlessTests(SecurityTest): self.assertTrue(compare_user(args[0]['user'], user)) self.assertIn('login_token', args[0]) self.assertEqual(kwargs['app'], self.app) -