Fix CSRF functionality for LoginForm

The login form was not respecting csrf validation. I've adjusted the tests as well to always send a CSRF token along. This now requires all requests to pass a csrf token. If performing plain AJAX requests the token will have to be extracted from the form in some way. Fixes #86
This commit is contained in:
Matt Wright
2013-02-01 17:18:31 -05:00
parent b82a8d681d
commit 34b3bf9e80
4 changed files with 72 additions and 77 deletions
+6 -5
View File
@@ -45,10 +45,7 @@ def valid_user_email(form, field):
class Form(BaseForm):
def __init__(self, *args, **kwargs):
if current_app.testing:
csrf_enabled = False
else:
csrf_enabled = request.json is None
kwargs.setdefault('csrf_enabled', csrf_enabled)
self.TIME_LIMIT = None
super(Form, self).__init__(*args, **kwargs)
@@ -83,6 +80,7 @@ class NewPasswordFormMixin():
validators=[password_required,
Length(min=6, max=128)])
class PasswordConfirmFormMixin():
password_confirm = PasswordField("Retype Password",
validators=[EqualTo('password', message="Passwords do not match")])
@@ -147,6 +145,7 @@ class PasswordlessLoginForm(Form, UserEmailFormMixin):
class LoginForm(Form, NextFormMixin):
"""The default login form"""
email = TextField('Email Address')
password = PasswordField('Password')
remember = BooleanField("Remember Me")
@@ -156,7 +155,8 @@ class LoginForm(Form, NextFormMixin):
super(LoginForm, self).__init__(*args, **kwargs)
def validate(self):
super(LoginForm, self).validate()
if not super(LoginForm, self).validate():
return False
if self.email.data.strip() == '':
self.email.errors.append('Email not provided')
@@ -187,6 +187,7 @@ class ConfirmRegisterForm(Form, RegisterFormMixin,
UniqueEmailFormMixin, NewPasswordFormMixin):
pass
class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin):
pass
+24 -12
View File
@@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-
import hmac
from hashlib import sha1
from unittest import TestCase
from tests.test_app.sqlalchemy import create_app
class SecurityTest(TestCase):
APP_KWARGS = {
@@ -21,6 +26,13 @@ class SecurityTest(TestCase):
self.app = app
self.client = app.test_client()
with self.client.session_transaction() as session:
session['csrf'] = 'csrf_token'
csrf_hmac = hmac.new(self.app.config['SECRET_KEY'],
'csrf_token'.encode('utf8'), digestmod=sha1)
self.csrf_token = '##' + csrf_hmac.hexdigest()
def _create_app(self, auth_config, **kwargs):
return create_app(auth_config, **kwargs)
@@ -30,30 +42,30 @@ class SecurityTest(TestCase):
headers=headers)
def _post(self, route, data=None, content_type=None, follow_redirects=True, headers=None):
if isinstance(data, dict):
data['csrf_token'] = self.csrf_token
return self.client.post(route, data=data,
follow_redirects=follow_redirects,
content_type=content_type or 'application/x-www-form-urlencoded',
headers=headers)
def register(self, email, password='password'):
data = dict(email=email, password=password)
data = dict(email=email, password=password, csrf_token=self.csrf_token)
return self.client.post('/register', data=data, follow_redirects=True)
def authenticate(self, email="matt@lp.com", password="password", endpoint=None, **kwargs):
data = dict(email=email, password=password, remember='y')
r = self._post(endpoint or '/login', data=data, **kwargs)
return r
return self._post(endpoint or '/login', data=data, **kwargs)
def json_authenticate(self, email="matt@lp.com", password="password", endpoint=None):
data = """
{
"email": "%s",
"password": "%s"
}
"""
return self._post(endpoint or '/login',
content_type="application/json",
data=data % (email, password))
data = """{
"email": "%s",
"password": "%s",
"csrf_token": "%s"
}"""
return self._post(endpoint or '/login', content_type="application/json",
data=data % (email, password, self.csrf_token))
def logout(self, endpoint=None):
return self._get(endpoint or '/logout', follow_redirects=True)
+25 -32
View File
@@ -67,7 +67,7 @@ class ConfiguredSecurityTests(SecurityTest):
self.assertIn('Post Register', r.data)
def test_register_json(self):
data = '{ "email": "dude@lp.com", "password": "password" }'
data = '{ "email": "dude@lp.com", "password": "password", "csrf_token":"%s" }' % self.csrf_token
r = self._post('/register', data=data, content_type='application/json')
data = json.loads(r.data)
self.assertEquals(data['meta']['code'], 200)
@@ -117,7 +117,7 @@ class RegisterableTests(SecurityTest):
data = dict(email='dude@lp.com',
password='password',
password_confirm='password')
self.client.post('/register', data=data, follow_redirects=True)
self._post('/register', data=data, follow_redirects=True)
r = self.authenticate('dude@lp.com')
self.assertIn('Hello dude@lp.com', r.data)
@@ -145,7 +145,7 @@ class ConfirmableTests(SecurityTest):
self.client.get('/confirm/' + token, follow_redirects=True)
self.logout()
r = self.client.post('/confirm', data=dict(email=e))
r = self._post('/confirm', data=dict(email=e))
self.assertIn(self.get_message('ALREADY_CONFIRMED'), r.data)
def test_register_sends_confirmation_email(self):
@@ -231,7 +231,7 @@ class LoginWithoutImmediateConfirmTests(SecurityTest):
e = 'dude@lp.com'
p = 'password'
data = dict(email=e, password=p, password_confirm=p)
r = self.client.post('/register', data=data, follow_redirects=True)
r = self._post('/register', data=data, follow_redirects=True)
self.assertIn(e, r.data)
@@ -245,7 +245,7 @@ class RecoverableTests(SecurityTest):
def test_reset_view(self):
with capture_reset_password_requests() as requests:
r = self.client.post('/reset',
r = self._post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
t = requests[0]['token']
@@ -255,23 +255,23 @@ class RecoverableTests(SecurityTest):
def test_forgot_post_sends_email(self):
with capture_reset_password_requests():
with self.app.extensions['mail'].record_messages() as outbox:
self.client.post('/reset', data=dict(email='joe@lp.com'))
self._post('/reset', data=dict(email='joe@lp.com'))
self.assertEqual(len(outbox), 1)
def test_forgot_password_json(self):
r = self.client.post('/reset', data='{"email": "matt@lp.com"}',
r = self._post('/reset', data='{"email": "matt@lp.com"}',
content_type="application/json")
self.assertEquals(r.status_code, 200)
def test_forgot_password_invalid_email(self):
r = self.client.post('/reset',
r = self._post('/reset',
data=dict(email='larry@lp.com'),
follow_redirects=True)
self.assertIn("Specified user does not exist", r.data)
def test_reset_password_with_valid_token(self):
with capture_reset_password_requests() as requests:
r = self.client.post('/reset',
r = self._post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
t = requests[0]['token']
@@ -303,14 +303,13 @@ class ExpiredResetPasswordTest(SecurityTest):
def test_reset_password_with_expired_token(self):
with capture_reset_password_requests() as requests:
r = self.client.post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
r = self._post('/reset', data=dict(email='joe@lp.com'),
follow_redirects=True)
t = requests[0]['token']
time.sleep(1)
r = self.client.post('/reset/' + t, data={
r = self._post('/reset/' + t, data={
'password': 'newpassword',
'password_confirm': 'newpassword'
}, follow_redirects=True)
@@ -348,20 +347,19 @@ class PasswordlessTests(SecurityTest):
def test_login_request_for_inactive_user(self):
msg = self.app.config['SECURITY_MSG_DISABLED_ACCOUNT'][0]
r = self.client.post('/login',
data=dict(email='tiya@lp.com'),
follow_redirects=True)
r = self._post('/login', data=dict(email='tiya@lp.com'),
follow_redirects=True)
self.assertIn(msg, r.data)
def test_request_login_token_with_json_and_valid_email(self):
data = '{"email": "matt@lp.com", "password": "password"}'
r = self.client.post('/login', data=data, content_type='application/json')
data = '{"email": "matt@lp.com", "password": "password", "csrf_token":"%s"}' % self.csrf_token
r = self._post('/login', data=data, content_type='application/json')
self.assertEquals(r.status_code, 200)
self.assertNotIn('error', r.data)
def test_request_login_token_with_json_and_invalid_email(self):
data = '{"email": "nobody@lp.com", "password": "password"}'
r = self.client.post('/login', data=data, content_type='application/json')
r = self._post('/login', data=data, content_type='application/json')
self.assertIn('errors', r.data)
def test_request_login_token_sends_email_and_can_login(self):
@@ -370,9 +368,8 @@ class PasswordlessTests(SecurityTest):
with capture_passwordless_login_requests() as requests:
with self.app.extensions['mail'].record_messages() as outbox:
r = self.client.post('/login',
data=dict(email=e),
follow_redirects=True)
r = self._post('/login', data=dict(email=e),
follow_redirects=True)
self.assertEqual(len(outbox), 1)
@@ -401,9 +398,8 @@ class PasswordlessTests(SecurityTest):
def test_token_login_when_already_authenticated(self):
with capture_passwordless_login_requests() as requests:
self.client.post('/login',
data=dict(email='matt@lp.com'),
follow_redirects=True)
self._post('/login', data=dict(email='matt@lp.com'),
follow_redirects=True)
token = requests[0]['login_token']
r = self.client.get('/login/' + token, follow_redirects=True)
@@ -431,9 +427,7 @@ class ExpiredLoginTokenTests(SecurityTest):
e = 'matt@lp.com'
with capture_passwordless_login_requests() as requests:
self.client.post('/login',
data=dict(email=e),
follow_redirects=True)
self._post('/login', data=dict(email=e), follow_redirects=True)
token = requests[0]['login_token']
time.sleep(1.25)
@@ -467,7 +461,7 @@ class AsyncMailTaskTests(SecurityTest):
def send_email(msg):
self.mail_sent = True
self.client.post('/reset', data=dict(email='matt@lp.com'))
self._post('/reset', data=dict(email='matt@lp.com'))
self.assertTrue(self.mail_sent)
@@ -542,9 +536,8 @@ class RecoverableExtendFormsTest(SecurityTest):
def test_reset_password(self):
with capture_reset_password_requests() as requests:
self.client.post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
self._post('/reset', data=dict(email='joe@lp.com'),
follow_redirects=True)
token = requests[0]['token']
r = self._get('/reset/' + token)
self.assertIn("My Reset Password Submit Field", r.data)
+17 -28
View File
@@ -32,7 +32,7 @@ class RegisterableSignalsTests(SecurityTest):
self.assertIn('confirm_token', args[0])
self.assertEqual(kwargs['app'], self.app)
def test_register(self):
def test_register_without_password(self):
e = 'dude@lp.com'
with capture_signals() as mocks:
self.register(e, password='')
@@ -111,9 +111,8 @@ class RecoverableSignalsTests(SecurityTest):
def test_reset_password_request(self):
with capture_signals() as mocks:
self.client.post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
self._post('/reset', data=dict(email='joe@lp.com'),
follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set([reset_password_instructions_sent]))
user = self.app.security.datastore.find_user(email='joe@lp.com')
calls = mocks[reset_password_instructions_sent]
@@ -125,15 +124,12 @@ class RecoverableSignalsTests(SecurityTest):
def test_reset_password(self):
with capture_reset_password_requests() as requests:
self.client.post('/reset',
data=dict(email='joe@lp.com'),
follow_redirects=True)
self._post('/reset', data=dict(email='joe@lp.com'),
follow_redirects=True)
token = requests[0]['token']
with capture_signals() as mocks:
self.client.post('/reset/' + token,
data=dict(password='newpassword',
password_confirm='newpassword'),
follow_redirects=True)
data = dict(password='newpassword', password_confirm='newpassword')
self._post('/reset/' + token, data, follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set([password_reset]))
user = self.app.security.datastore.find_user(email='joe@lp.com')
calls = mocks[password_reset]
@@ -144,17 +140,14 @@ class RecoverableSignalsTests(SecurityTest):
def test_reset_password_invalid_emails(self):
with capture_signals() as mocks:
self.client.post('/reset',
data=dict(email='nobody@lp.com'),
follow_redirects=True)
self._post('/reset', data=dict(email='nobody@lp.com'),
follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set())
def test_reset_password_invalid_token(self):
with capture_signals() as mocks:
self.client.post('/reset/bogus',
data=dict(password='newpassword',
password_confirm='newpassword'),
follow_redirects=True)
data = dict(password='newpassword', password_confirm='newpassword')
self._post('/reset/bogus', data, follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set())
@@ -166,25 +159,22 @@ class PasswordlessTests(SecurityTest):
def test_login_request_for_inactive_user(self):
with capture_signals() as mocks:
self.client.post('/login',
data=dict(email='tiya@lp.com'),
follow_redirects=True)
self._post('/login', data=dict(email='tiya@lp.com'),
follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set())
def test_login_request_for_invalid_email(self):
with capture_signals() as mocks:
self.client.post('/login',
data=dict(email='nobody@lp.com'),
follow_redirects=True)
self._post('/login', data=dict(email='nobody@lp.com'),
follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set())
def test_request_login_token_sends_email_and_can_login(self):
e = 'matt@lp.com'
with capture_signals() as mocks:
self.client.post('/login',
data=dict(email=e),
follow_redirects=True)
self._post('/login', data=dict(email=e), follow_redirects=True)
self.assertEqual(mocks.signals_sent(), set([login_instructions_sent]))
user = self.app.security.datastore.find_user(email='matt@lp.com')
calls = mocks[login_instructions_sent]
@@ -193,4 +183,3 @@ class PasswordlessTests(SecurityTest):
self.assertTrue(compare_user(args[0]['user'], user))
self.assertIn('login_token', args[0])
self.assertEqual(kwargs['app'], self.app)