mirror of
https://github.com/wassname/flask-security.git
synced 2026-06-27 16:10:11 +08:00
Fix CSRF functionality for LoginForm
The login form was not respecting csrf validation. I've adjusted the tests as well to always send a CSRF token along. This now requires all requests to pass a csrf token. If performing plain AJAX requests the token will have to be extracted from the form in some way. Fixes #86
This commit is contained in:
@@ -45,10 +45,7 @@ def valid_user_email(form, field):
|
||||
class Form(BaseForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if current_app.testing:
|
||||
csrf_enabled = False
|
||||
else:
|
||||
csrf_enabled = request.json is None
|
||||
kwargs.setdefault('csrf_enabled', csrf_enabled)
|
||||
self.TIME_LIMIT = None
|
||||
super(Form, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -83,6 +80,7 @@ class NewPasswordFormMixin():
|
||||
validators=[password_required,
|
||||
Length(min=6, max=128)])
|
||||
|
||||
|
||||
class PasswordConfirmFormMixin():
|
||||
password_confirm = PasswordField("Retype Password",
|
||||
validators=[EqualTo('password', message="Passwords do not match")])
|
||||
@@ -147,6 +145,7 @@ class PasswordlessLoginForm(Form, UserEmailFormMixin):
|
||||
|
||||
class LoginForm(Form, NextFormMixin):
|
||||
"""The default login form"""
|
||||
|
||||
email = TextField('Email Address')
|
||||
password = PasswordField('Password')
|
||||
remember = BooleanField("Remember Me")
|
||||
@@ -156,7 +155,8 @@ class LoginForm(Form, NextFormMixin):
|
||||
super(LoginForm, self).__init__(*args, **kwargs)
|
||||
|
||||
def validate(self):
|
||||
super(LoginForm, self).validate()
|
||||
if not super(LoginForm, self).validate():
|
||||
return False
|
||||
|
||||
if self.email.data.strip() == '':
|
||||
self.email.errors.append('Email not provided')
|
||||
@@ -187,6 +187,7 @@ class ConfirmRegisterForm(Form, RegisterFormMixin,
|
||||
UniqueEmailFormMixin, NewPasswordFormMixin):
|
||||
pass
|
||||
|
||||
|
||||
class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin):
|
||||
pass
|
||||
|
||||
|
||||
+24
-12
@@ -1,8 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import hmac
|
||||
|
||||
from hashlib import sha1
|
||||
from unittest import TestCase
|
||||
|
||||
from tests.test_app.sqlalchemy import create_app
|
||||
|
||||
|
||||
class SecurityTest(TestCase):
|
||||
|
||||
APP_KWARGS = {
|
||||
@@ -21,6 +26,13 @@ class SecurityTest(TestCase):
|
||||
self.app = app
|
||||
self.client = app.test_client()
|
||||
|
||||
with self.client.session_transaction() as session:
|
||||
session['csrf'] = 'csrf_token'
|
||||
|
||||
csrf_hmac = hmac.new(self.app.config['SECRET_KEY'],
|
||||
'csrf_token'.encode('utf8'), digestmod=sha1)
|
||||
self.csrf_token = '##' + csrf_hmac.hexdigest()
|
||||
|
||||
def _create_app(self, auth_config, **kwargs):
|
||||
return create_app(auth_config, **kwargs)
|
||||
|
||||
@@ -30,30 +42,30 @@ class SecurityTest(TestCase):
|
||||
headers=headers)
|
||||
|
||||
def _post(self, route, data=None, content_type=None, follow_redirects=True, headers=None):
|
||||
if isinstance(data, dict):
|
||||
data['csrf_token'] = self.csrf_token
|
||||
|
||||
return self.client.post(route, data=data,
|
||||
follow_redirects=follow_redirects,
|
||||
content_type=content_type or 'application/x-www-form-urlencoded',
|
||||
headers=headers)
|
||||
|
||||
def register(self, email, password='password'):
|
||||
data = dict(email=email, password=password)
|
||||
data = dict(email=email, password=password, csrf_token=self.csrf_token)
|
||||
return self.client.post('/register', data=data, follow_redirects=True)
|
||||
|
||||
def authenticate(self, email="matt@lp.com", password="password", endpoint=None, **kwargs):
|
||||
data = dict(email=email, password=password, remember='y')
|
||||
r = self._post(endpoint or '/login', data=data, **kwargs)
|
||||
return r
|
||||
return self._post(endpoint or '/login', data=data, **kwargs)
|
||||
|
||||
def json_authenticate(self, email="matt@lp.com", password="password", endpoint=None):
|
||||
data = """
|
||||
{
|
||||
"email": "%s",
|
||||
"password": "%s"
|
||||
}
|
||||
"""
|
||||
return self._post(endpoint or '/login',
|
||||
content_type="application/json",
|
||||
data=data % (email, password))
|
||||
data = """{
|
||||
"email": "%s",
|
||||
"password": "%s",
|
||||
"csrf_token": "%s"
|
||||
}"""
|
||||
return self._post(endpoint or '/login', content_type="application/json",
|
||||
data=data % (email, password, self.csrf_token))
|
||||
|
||||
def logout(self, endpoint=None):
|
||||
return self._get(endpoint or '/logout', follow_redirects=True)
|
||||
|
||||
+25
-32
@@ -67,7 +67,7 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
self.assertIn('Post Register', r.data)
|
||||
|
||||
def test_register_json(self):
|
||||
data = '{ "email": "dude@lp.com", "password": "password" }'
|
||||
data = '{ "email": "dude@lp.com", "password": "password", "csrf_token":"%s" }' % self.csrf_token
|
||||
r = self._post('/register', data=data, content_type='application/json')
|
||||
data = json.loads(r.data)
|
||||
self.assertEquals(data['meta']['code'], 200)
|
||||
@@ -117,7 +117,7 @@ class RegisterableTests(SecurityTest):
|
||||
data = dict(email='dude@lp.com',
|
||||
password='password',
|
||||
password_confirm='password')
|
||||
self.client.post('/register', data=data, follow_redirects=True)
|
||||
self._post('/register', data=data, follow_redirects=True)
|
||||
r = self.authenticate('dude@lp.com')
|
||||
self.assertIn('Hello dude@lp.com', r.data)
|
||||
|
||||
@@ -145,7 +145,7 @@ class ConfirmableTests(SecurityTest):
|
||||
|
||||
self.client.get('/confirm/' + token, follow_redirects=True)
|
||||
self.logout()
|
||||
r = self.client.post('/confirm', data=dict(email=e))
|
||||
r = self._post('/confirm', data=dict(email=e))
|
||||
self.assertIn(self.get_message('ALREADY_CONFIRMED'), r.data)
|
||||
|
||||
def test_register_sends_confirmation_email(self):
|
||||
@@ -231,7 +231,7 @@ class LoginWithoutImmediateConfirmTests(SecurityTest):
|
||||
e = 'dude@lp.com'
|
||||
p = 'password'
|
||||
data = dict(email=e, password=p, password_confirm=p)
|
||||
r = self.client.post('/register', data=data, follow_redirects=True)
|
||||
r = self._post('/register', data=data, follow_redirects=True)
|
||||
self.assertIn(e, r.data)
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ class RecoverableTests(SecurityTest):
|
||||
|
||||
def test_reset_view(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
r = self.client.post('/reset',
|
||||
r = self._post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
t = requests[0]['token']
|
||||
@@ -255,23 +255,23 @@ class RecoverableTests(SecurityTest):
|
||||
def test_forgot_post_sends_email(self):
|
||||
with capture_reset_password_requests():
|
||||
with self.app.extensions['mail'].record_messages() as outbox:
|
||||
self.client.post('/reset', data=dict(email='joe@lp.com'))
|
||||
self._post('/reset', data=dict(email='joe@lp.com'))
|
||||
self.assertEqual(len(outbox), 1)
|
||||
|
||||
def test_forgot_password_json(self):
|
||||
r = self.client.post('/reset', data='{"email": "matt@lp.com"}',
|
||||
r = self._post('/reset', data='{"email": "matt@lp.com"}',
|
||||
content_type="application/json")
|
||||
self.assertEquals(r.status_code, 200)
|
||||
|
||||
def test_forgot_password_invalid_email(self):
|
||||
r = self.client.post('/reset',
|
||||
r = self._post('/reset',
|
||||
data=dict(email='larry@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertIn("Specified user does not exist", r.data)
|
||||
|
||||
def test_reset_password_with_valid_token(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
r = self.client.post('/reset',
|
||||
r = self._post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
t = requests[0]['token']
|
||||
@@ -303,14 +303,13 @@ class ExpiredResetPasswordTest(SecurityTest):
|
||||
|
||||
def test_reset_password_with_expired_token(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
r = self.client.post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
r = self._post('/reset', data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
t = requests[0]['token']
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
r = self.client.post('/reset/' + t, data={
|
||||
r = self._post('/reset/' + t, data={
|
||||
'password': 'newpassword',
|
||||
'password_confirm': 'newpassword'
|
||||
}, follow_redirects=True)
|
||||
@@ -348,20 +347,19 @@ class PasswordlessTests(SecurityTest):
|
||||
|
||||
def test_login_request_for_inactive_user(self):
|
||||
msg = self.app.config['SECURITY_MSG_DISABLED_ACCOUNT'][0]
|
||||
r = self.client.post('/login',
|
||||
data=dict(email='tiya@lp.com'),
|
||||
follow_redirects=True)
|
||||
r = self._post('/login', data=dict(email='tiya@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertIn(msg, r.data)
|
||||
|
||||
def test_request_login_token_with_json_and_valid_email(self):
|
||||
data = '{"email": "matt@lp.com", "password": "password"}'
|
||||
r = self.client.post('/login', data=data, content_type='application/json')
|
||||
data = '{"email": "matt@lp.com", "password": "password", "csrf_token":"%s"}' % self.csrf_token
|
||||
r = self._post('/login', data=data, content_type='application/json')
|
||||
self.assertEquals(r.status_code, 200)
|
||||
self.assertNotIn('error', r.data)
|
||||
|
||||
def test_request_login_token_with_json_and_invalid_email(self):
|
||||
data = '{"email": "nobody@lp.com", "password": "password"}'
|
||||
r = self.client.post('/login', data=data, content_type='application/json')
|
||||
r = self._post('/login', data=data, content_type='application/json')
|
||||
self.assertIn('errors', r.data)
|
||||
|
||||
def test_request_login_token_sends_email_and_can_login(self):
|
||||
@@ -370,9 +368,8 @@ class PasswordlessTests(SecurityTest):
|
||||
|
||||
with capture_passwordless_login_requests() as requests:
|
||||
with self.app.extensions['mail'].record_messages() as outbox:
|
||||
r = self.client.post('/login',
|
||||
data=dict(email=e),
|
||||
follow_redirects=True)
|
||||
r = self._post('/login', data=dict(email=e),
|
||||
follow_redirects=True)
|
||||
|
||||
self.assertEqual(len(outbox), 1)
|
||||
|
||||
@@ -401,9 +398,8 @@ class PasswordlessTests(SecurityTest):
|
||||
|
||||
def test_token_login_when_already_authenticated(self):
|
||||
with capture_passwordless_login_requests() as requests:
|
||||
self.client.post('/login',
|
||||
data=dict(email='matt@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/login', data=dict(email='matt@lp.com'),
|
||||
follow_redirects=True)
|
||||
token = requests[0]['login_token']
|
||||
|
||||
r = self.client.get('/login/' + token, follow_redirects=True)
|
||||
@@ -431,9 +427,7 @@ class ExpiredLoginTokenTests(SecurityTest):
|
||||
e = 'matt@lp.com'
|
||||
|
||||
with capture_passwordless_login_requests() as requests:
|
||||
self.client.post('/login',
|
||||
data=dict(email=e),
|
||||
follow_redirects=True)
|
||||
self._post('/login', data=dict(email=e), follow_redirects=True)
|
||||
token = requests[0]['login_token']
|
||||
|
||||
time.sleep(1.25)
|
||||
@@ -467,7 +461,7 @@ class AsyncMailTaskTests(SecurityTest):
|
||||
def send_email(msg):
|
||||
self.mail_sent = True
|
||||
|
||||
self.client.post('/reset', data=dict(email='matt@lp.com'))
|
||||
self._post('/reset', data=dict(email='matt@lp.com'))
|
||||
self.assertTrue(self.mail_sent)
|
||||
|
||||
|
||||
@@ -542,9 +536,8 @@ class RecoverableExtendFormsTest(SecurityTest):
|
||||
|
||||
def test_reset_password(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
self.client.post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/reset', data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
token = requests[0]['token']
|
||||
r = self._get('/reset/' + token)
|
||||
self.assertIn("My Reset Password Submit Field", r.data)
|
||||
|
||||
+17
-28
@@ -32,7 +32,7 @@ class RegisterableSignalsTests(SecurityTest):
|
||||
self.assertIn('confirm_token', args[0])
|
||||
self.assertEqual(kwargs['app'], self.app)
|
||||
|
||||
def test_register(self):
|
||||
def test_register_without_password(self):
|
||||
e = 'dude@lp.com'
|
||||
with capture_signals() as mocks:
|
||||
self.register(e, password='')
|
||||
@@ -111,9 +111,8 @@ class RecoverableSignalsTests(SecurityTest):
|
||||
|
||||
def test_reset_password_request(self):
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/reset', data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set([reset_password_instructions_sent]))
|
||||
user = self.app.security.datastore.find_user(email='joe@lp.com')
|
||||
calls = mocks[reset_password_instructions_sent]
|
||||
@@ -125,15 +124,12 @@ class RecoverableSignalsTests(SecurityTest):
|
||||
|
||||
def test_reset_password(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
self.client.post('/reset',
|
||||
data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/reset', data=dict(email='joe@lp.com'),
|
||||
follow_redirects=True)
|
||||
token = requests[0]['token']
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/reset/' + token,
|
||||
data=dict(password='newpassword',
|
||||
password_confirm='newpassword'),
|
||||
follow_redirects=True)
|
||||
data = dict(password='newpassword', password_confirm='newpassword')
|
||||
self._post('/reset/' + token, data, follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set([password_reset]))
|
||||
user = self.app.security.datastore.find_user(email='joe@lp.com')
|
||||
calls = mocks[password_reset]
|
||||
@@ -144,17 +140,14 @@ class RecoverableSignalsTests(SecurityTest):
|
||||
|
||||
def test_reset_password_invalid_emails(self):
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/reset',
|
||||
data=dict(email='nobody@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/reset', data=dict(email='nobody@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set())
|
||||
|
||||
def test_reset_password_invalid_token(self):
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/reset/bogus',
|
||||
data=dict(password='newpassword',
|
||||
password_confirm='newpassword'),
|
||||
follow_redirects=True)
|
||||
data = dict(password='newpassword', password_confirm='newpassword')
|
||||
self._post('/reset/bogus', data, follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set())
|
||||
|
||||
|
||||
@@ -166,25 +159,22 @@ class PasswordlessTests(SecurityTest):
|
||||
|
||||
def test_login_request_for_inactive_user(self):
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/login',
|
||||
data=dict(email='tiya@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/login', data=dict(email='tiya@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set())
|
||||
|
||||
def test_login_request_for_invalid_email(self):
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/login',
|
||||
data=dict(email='nobody@lp.com'),
|
||||
follow_redirects=True)
|
||||
self._post('/login', data=dict(email='nobody@lp.com'),
|
||||
follow_redirects=True)
|
||||
self.assertEqual(mocks.signals_sent(), set())
|
||||
|
||||
def test_request_login_token_sends_email_and_can_login(self):
|
||||
e = 'matt@lp.com'
|
||||
|
||||
with capture_signals() as mocks:
|
||||
self.client.post('/login',
|
||||
data=dict(email=e),
|
||||
follow_redirects=True)
|
||||
self._post('/login', data=dict(email=e), follow_redirects=True)
|
||||
|
||||
self.assertEqual(mocks.signals_sent(), set([login_instructions_sent]))
|
||||
user = self.app.security.datastore.find_user(email='matt@lp.com')
|
||||
calls = mocks[login_instructions_sent]
|
||||
@@ -193,4 +183,3 @@ class PasswordlessTests(SecurityTest):
|
||||
self.assertTrue(compare_user(args[0]['user'], user))
|
||||
self.assertIn('login_token', args[0])
|
||||
self.assertEqual(kwargs['app'], self.app)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user