mirror of
https://github.com/wassname/flask-security.git
synced 2026-06-27 16:10:11 +08:00
Work in progress
This commit is contained in:
@@ -19,7 +19,7 @@ from passlib.context import CryptContext
|
||||
from werkzeug.datastructures import ImmutableList
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from .utils import config_value as cv, get_config, md5, url_for_security
|
||||
from .utils import config_value as cv, get_config, md5, url_for_security, string_types
|
||||
from .views import create_blueprint
|
||||
from .forms import LoginForm, ConfirmRegisterForm, RegisterForm, \
|
||||
ForgotPasswordForm, ChangePasswordForm, ResetPasswordForm, \
|
||||
@@ -249,6 +249,7 @@ def _context_processor():
|
||||
|
||||
class RoleMixin(object):
|
||||
"""Mixin for `Role` model definitions"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name == other or
|
||||
self.name == getattr(other, 'name', None))
|
||||
@@ -273,7 +274,7 @@ class UserMixin(BaseUserMixin):
|
||||
"""Returns `True` if the user identifies with the specified role.
|
||||
|
||||
:param role: A role name or `Role` instance"""
|
||||
if isinstance(role, basestring):
|
||||
if isinstance(role, string_types):
|
||||
return role in (role.name for role in self.roles)
|
||||
else:
|
||||
return role in self.roles
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
:license: MIT, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from .utils import get_identity_attributes
|
||||
from .utils import get_identity_attributes, string_types
|
||||
|
||||
|
||||
class Datastore(object):
|
||||
@@ -68,9 +68,9 @@ class UserDatastore(object):
|
||||
self.role_model = role_model
|
||||
|
||||
def _prepare_role_modify_args(self, user, role):
|
||||
if isinstance(user, basestring):
|
||||
if isinstance(user, string_types):
|
||||
user = self.find_user(email=user)
|
||||
if isinstance(role, basestring):
|
||||
if isinstance(role, string_types):
|
||||
role = self.find_role(role)
|
||||
return user, role
|
||||
|
||||
@@ -105,6 +105,7 @@ class UserDatastore(object):
|
||||
user, role = self._prepare_role_modify_args(user, role)
|
||||
if role not in user.roles:
|
||||
user.roles.append(role)
|
||||
self.put(user)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -161,8 +162,8 @@ class UserDatastore(object):
|
||||
|
||||
def create_user(self, **kwargs):
|
||||
"""Creates and returns a new user from the given parameters."""
|
||||
|
||||
user = self.user_model(**self._prepare_create_user_args(**kwargs))
|
||||
kwargs = self._prepare_create_user_args(**kwargs)
|
||||
user = self.user_model(**kwargs)
|
||||
return self.put(user)
|
||||
|
||||
def delete_user(self, user):
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import urlparse
|
||||
|
||||
import flask_wtf as wtf
|
||||
try:
|
||||
from urlparse import urlsplit
|
||||
except ImportError:
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from flask import request, current_app
|
||||
from flask_wtf import Form as BaseForm
|
||||
@@ -22,7 +23,7 @@ from flask_login import current_user
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from .confirmable import requires_confirmation
|
||||
from .utils import verify_and_update_password, get_message, encrypt_password, config_value
|
||||
from .utils import verify_and_update_password, get_message, config_value
|
||||
|
||||
# Convenient reference
|
||||
_datastore = LocalProxy(lambda: current_app.extensions['security'].datastore)
|
||||
@@ -137,8 +138,8 @@ class NextFormMixin():
|
||||
|
||||
def validate_next(self, field):
|
||||
if field.data:
|
||||
url_next = urlparse.urlsplit(field.data)
|
||||
url_base = urlparse.urlsplit(request.host_url)
|
||||
url_next = urlsplit(field.data)
|
||||
url_base = urlsplit(request.host_url)
|
||||
if url_next.netloc and url_next.netloc != url_base.netloc:
|
||||
field.data = ''
|
||||
raise ValidationError(get_message('INVALID_REDIRECT')[0])
|
||||
|
||||
+21
-11
@@ -14,6 +14,7 @@ import blinker
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import sys
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
@@ -37,6 +38,15 @@ _datastore = LocalProxy(lambda: _security.datastore)
|
||||
|
||||
_pwd_context = LocalProxy(lambda: _security.pwd_context)
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
string_types = str,
|
||||
text_type = str
|
||||
else:
|
||||
string_types = basestring,
|
||||
text_type = unicode
|
||||
|
||||
|
||||
def login_user(user, remember=None):
|
||||
"""Performs the login routine.
|
||||
@@ -85,16 +95,13 @@ def get_hmac(password):
|
||||
|
||||
:param password: The password to sign
|
||||
"""
|
||||
if _security.password_hash == 'plaintext':
|
||||
return password
|
||||
|
||||
if _security.password_salt is None:
|
||||
raise RuntimeError(
|
||||
'The configuration value `SECURITY_PASSWORD_SALT` must '
|
||||
'not be None when the value of `SECURITY_PASSWORD_HASH` is '
|
||||
'set to "%s"' % _security.password_hash)
|
||||
|
||||
h = hmac.new(_security.password_salt, password.encode('utf-8'), hashlib.sha512)
|
||||
h = hmac.new(_security.password_salt.encode('utf-8'), password.encode('utf-8'), hashlib.sha512)
|
||||
return base64.b64encode(h.digest())
|
||||
|
||||
|
||||
@@ -104,7 +111,7 @@ def verify_password(password, password_hash):
|
||||
:param password: A plaintext password to verify
|
||||
:param password_hash: The expected hash value of the password (usually form your database)
|
||||
"""
|
||||
return _pwd_context.verify(get_hmac(password), password_hash)
|
||||
return _pwd_context.verify(encrypt_password(password), password_hash)
|
||||
|
||||
|
||||
def verify_and_update_password(password, user):
|
||||
@@ -114,7 +121,7 @@ def verify_and_update_password(password, user):
|
||||
:param password: A plaintext password to verify
|
||||
:param user: The user to verify against
|
||||
"""
|
||||
verified, new_password = _pwd_context.verify_and_update(get_hmac(password), user.password)
|
||||
verified, new_password = _pwd_context.verify_and_update(encrypt_password(password), user.password)
|
||||
if verified and new_password:
|
||||
user.password = new_password
|
||||
_datastore.put(user)
|
||||
@@ -126,11 +133,14 @@ def encrypt_password(password):
|
||||
|
||||
:param password: The plaintext passwrod to encrypt
|
||||
"""
|
||||
return _pwd_context.encrypt(get_hmac(password))
|
||||
if _security.password_hash == 'plaintext':
|
||||
return password
|
||||
signed = get_hmac(password)
|
||||
return _pwd_context.encrypt(signed.decode('ascii'))
|
||||
|
||||
|
||||
def md5(data):
|
||||
return hashlib.md5(data).hexdigest()
|
||||
return hashlib.md5(data.encode('ascii')).hexdigest()
|
||||
|
||||
|
||||
def do_flash(message, category=None):
|
||||
@@ -408,19 +418,19 @@ class CaptureSignals(object):
|
||||
self._records[signal].append((args, kwargs))
|
||||
|
||||
def __enter__(self):
|
||||
for signal, receiver in self._receivers.iteritems():
|
||||
for signal, receiver in self._receivers.items():
|
||||
signal.connect(receiver)
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
for signal, receiver in self._receivers.iteritems():
|
||||
for signal, receiver in self._receivers.items():
|
||||
signal.disconnect(receiver)
|
||||
|
||||
def signals_sent(self):
|
||||
"""Return a set of the signals sent.
|
||||
:rtype: list of blinker `NamedSignals`.
|
||||
"""
|
||||
return set([signal for signal, _ in self._records.iteritems() if self._records[signal]])
|
||||
return set([signal for signal, _ in self._records.items() if self._records[signal]])
|
||||
|
||||
|
||||
def capture_signals():
|
||||
|
||||
+1
-1
@@ -61,7 +61,7 @@ class SecurityTest(TestCase):
|
||||
return self._get(endpoint or '/logout', follow_redirects=True)
|
||||
|
||||
def assertIsHomePage(self, data):
|
||||
self.assertIn('Home Page', data)
|
||||
self.assertIn(b'Home Page', data)
|
||||
|
||||
def assertIn(self, member, container, msg=None):
|
||||
if hasattr(TestCase, 'assertIn'):
|
||||
|
||||
+47
-47
@@ -27,7 +27,7 @@ class ConfiguredPasswordHashSecurityTests(SecurityTest):
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate(endpoint="/login")
|
||||
self.assertIn('Home Page', r.data)
|
||||
self.assertIn(b'Home Page', r.data)
|
||||
|
||||
|
||||
class ConfiguredSecurityTests(SecurityTest):
|
||||
@@ -49,16 +49,16 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate(endpoint="/custom_login")
|
||||
self.assertIn('Post Login', r.data)
|
||||
self.assertIn(b'Post Login', r.data)
|
||||
|
||||
def test_logout(self):
|
||||
self.authenticate(endpoint="/custom_login")
|
||||
r = self.logout(endpoint="/custom_logout")
|
||||
self.assertIn('Post Logout', r.data)
|
||||
self.assertIn(b'Post Logout', r.data)
|
||||
|
||||
def test_register_view(self):
|
||||
r = self._get('/register')
|
||||
self.assertIn('<h1>Register</h1>', r.data)
|
||||
self.assertIn(b'<h1>Register</h1>', r.data)
|
||||
|
||||
def test_register(self):
|
||||
data = dict(email='dude@lp.com',
|
||||
@@ -66,7 +66,7 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
password_confirm='password')
|
||||
|
||||
r = self._post('/register', data=data, follow_redirects=True)
|
||||
self.assertIn('Post Register', r.data)
|
||||
self.assertIn(b'Post Register', r.data)
|
||||
|
||||
def test_register_with_next_querystring_argument(self):
|
||||
data = dict(email='dude@lp.com',
|
||||
@@ -74,7 +74,7 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
password_confirm='password')
|
||||
|
||||
r = self._post('/register?next=/page1', data=data, follow_redirects=True)
|
||||
self.assertIn('Page 1', r.data)
|
||||
self.assertIn(b'Page 1', r.data)
|
||||
|
||||
def test_register_json(self):
|
||||
data = '{ "email": "dude@lp.com", "password": "password"}'
|
||||
@@ -98,9 +98,9 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
|
||||
def test_default_http_auth_realm(self):
|
||||
r = self._get('/http', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
|
||||
})
|
||||
self.assertIn('<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn('WWW-Authenticate', r.headers)
|
||||
self.assertEquals('Basic realm="Custom Realm"',
|
||||
r.headers['WWW-Authenticate'])
|
||||
@@ -125,7 +125,7 @@ class DefaultTemplatePathTests(SecurityTest):
|
||||
def test_login_user_template(self):
|
||||
r = self._get('/login')
|
||||
|
||||
self.assertIn('CUSTOM LOGIN USER', r.data)
|
||||
self.assertIn(b'CUSTOM LOGIN USER', r.data)
|
||||
|
||||
|
||||
class RegisterableTemplatePathTests(SecurityTest):
|
||||
@@ -137,7 +137,7 @@ class RegisterableTemplatePathTests(SecurityTest):
|
||||
def test_register_user_template(self):
|
||||
r = self._get('/register')
|
||||
|
||||
self.assertIn('CUSTOM REGISTER USER', r.data)
|
||||
self.assertIn(b'CUSTOM REGISTER USER', r.data)
|
||||
|
||||
|
||||
class RecoverableTemplatePathTests(SecurityTest):
|
||||
@@ -150,7 +150,7 @@ class RecoverableTemplatePathTests(SecurityTest):
|
||||
def test_forgot_password_template(self):
|
||||
r = self._get('/reset')
|
||||
|
||||
self.assertIn('CUSTOM FORGOT PASSWORD', r.data)
|
||||
self.assertIn(b'CUSTOM FORGOT PASSWORD', r.data)
|
||||
|
||||
def test_reset_password_template(self):
|
||||
with capture_reset_password_requests() as requests:
|
||||
@@ -161,7 +161,7 @@ class RecoverableTemplatePathTests(SecurityTest):
|
||||
|
||||
r = self._get('/reset/' + t)
|
||||
|
||||
self.assertIn('CUSTOM RESET PASSWORD', r.data)
|
||||
self.assertIn(b'CUSTOM RESET PASSWORD', r.data)
|
||||
|
||||
|
||||
class ConfirmableTemplatePathTests(SecurityTest):
|
||||
@@ -173,7 +173,7 @@ class ConfirmableTemplatePathTests(SecurityTest):
|
||||
def test_send_confirmation_template(self):
|
||||
r = self._get('/confirm')
|
||||
|
||||
self.assertIn('CUSTOM SEND CONFIRMATION', r.data)
|
||||
self.assertIn(b'CUSTOM SEND CONFIRMATION', r.data)
|
||||
|
||||
|
||||
class PasswordlessTemplatePathTests(SecurityTest):
|
||||
@@ -185,7 +185,7 @@ class PasswordlessTemplatePathTests(SecurityTest):
|
||||
def test_send_login_template(self):
|
||||
r = self._get('/login')
|
||||
|
||||
self.assertIn('CUSTOM SEND LOGIN', r.data)
|
||||
self.assertIn(b'CUSTOM SEND LOGIN', r.data)
|
||||
|
||||
|
||||
class RegisterableTests(SecurityTest):
|
||||
@@ -200,7 +200,7 @@ class RegisterableTests(SecurityTest):
|
||||
password_confirm='password')
|
||||
self._post('/register', data=data, follow_redirects=True)
|
||||
r = self.authenticate('dude@lp.com')
|
||||
self.assertIn('Hello dude@lp.com', r.data)
|
||||
self.assertIn(b'Hello dude@lp.com', r.data)
|
||||
|
||||
|
||||
class ConfirmableTests(SecurityTest):
|
||||
@@ -215,7 +215,7 @@ class ConfirmableTests(SecurityTest):
|
||||
e = 'dude@lp.com'
|
||||
self.register(e)
|
||||
r = self.authenticate(email=e)
|
||||
self.assertIn(self.get_message('CONFIRMATION_REQUIRED'), r.data)
|
||||
self.assertIn(self.get_message('CONFIRMATION_REQUIRED').encode('utf-8'), r.data)
|
||||
|
||||
def test_send_confirmation_of_already_confirmed_account(self):
|
||||
e = 'dude@lp.com'
|
||||
@@ -227,7 +227,7 @@ class ConfirmableTests(SecurityTest):
|
||||
self.client.get('/confirm/' + token, follow_redirects=True)
|
||||
self.logout()
|
||||
r = self._post('/confirm', data=dict(email=e))
|
||||
self.assertIn(self.get_message('ALREADY_CONFIRMED'), r.data)
|
||||
self.assertIn(self.get_message('ALREADY_CONFIRMED').encode('utf-8'), r.data)
|
||||
|
||||
def test_register_sends_confirmation_email(self):
|
||||
e = 'dude@lp.com'
|
||||
@@ -269,7 +269,7 @@ class ConfirmableTests(SecurityTest):
|
||||
self.register(e)
|
||||
r = self._post('/confirm', data={'email': e})
|
||||
|
||||
msg = self.get_message('CONFIRMATION_REQUEST', email=e)
|
||||
msg = self.get_message('CONFIRMATION_REQUEST', email=e).encode('utf-8')
|
||||
self.assertIn(msg, r.data)
|
||||
|
||||
def test_user_deleted_before_confirmation(self):
|
||||
@@ -350,7 +350,7 @@ class LoginWithoutImmediateConfirmTests(SecurityTest):
|
||||
r = self.client.get('/confirm/' + token2, follow_redirects=True)
|
||||
msg = self.app.config['SECURITY_MSG_EMAIL_CONFIRMED'][0]
|
||||
self.assertIn(msg, r.data)
|
||||
self.assertIn('Hello %s' % e2, r.data)
|
||||
self.assertIn(b'Hello %s' % e2, r.data)
|
||||
|
||||
def test_login_unconfirmed_user_when_login_without_confirmation_is_true(self):
|
||||
e = 'dude@lp.com'
|
||||
@@ -377,7 +377,7 @@ class RecoverableTests(SecurityTest):
|
||||
follow_redirects=True)
|
||||
t = requests[0]['token']
|
||||
r = self._get('/reset/' + t)
|
||||
self.assertIn('<h1>Reset password</h1>', r.data)
|
||||
self.assertIn(b'<h1>Reset password</h1>', r.data)
|
||||
|
||||
def test_forgot_post_sends_email(self):
|
||||
with capture_reset_password_requests():
|
||||
@@ -408,7 +408,7 @@ class RecoverableTests(SecurityTest):
|
||||
|
||||
r = self.logout()
|
||||
r = self.authenticate('joe@lp.com', 'newpassword')
|
||||
self.assertIn('Hello joe@lp.com', r.data)
|
||||
self.assertIn(b'Hello joe@lp.com', r.data)
|
||||
|
||||
def test_reset_password_with_invalid_token(self):
|
||||
r = self._post('/reset/bogus', data={
|
||||
@@ -416,7 +416,7 @@ class RecoverableTests(SecurityTest):
|
||||
'password_confirm': 'newpassword'
|
||||
}, follow_redirects=True)
|
||||
|
||||
self.assertIn(self.get_message('INVALID_RESET_PASSWORD_TOKEN'), r.data)
|
||||
self.assertIn(self.get_message('INVALID_RESET_PASSWORD_TOKEN').encode('utf-8'), r.data)
|
||||
|
||||
|
||||
class ExpiredResetPasswordTest(SecurityTest):
|
||||
@@ -439,7 +439,7 @@ class ExpiredResetPasswordTest(SecurityTest):
|
||||
'password_confirm': 'newpassword'
|
||||
}, follow_redirects=True)
|
||||
|
||||
self.assertIn('You did not reset your password within', r.data)
|
||||
self.assertIn(b'You did not reset your password within', r.data)
|
||||
|
||||
|
||||
class ChangePasswordTest(SecurityTest):
|
||||
@@ -452,7 +452,7 @@ class ChangePasswordTest(SecurityTest):
|
||||
def test_change_password(self):
|
||||
self.authenticate()
|
||||
r = self.client.get('/change', follow_redirects=True)
|
||||
self.assertIn('Change password', r.data)
|
||||
self.assertIn(b'Change password', r.data)
|
||||
|
||||
def test_change_password_invalid(self):
|
||||
self.authenticate()
|
||||
@@ -461,8 +461,8 @@ class ChangePasswordTest(SecurityTest):
|
||||
'new_password': 'newpassword',
|
||||
'new_password_confirm': 'newpassword'
|
||||
}, follow_redirects=True)
|
||||
self.assertNotIn('You successfully changed your password', r.data)
|
||||
self.assertIn('Invalid password', r.data)
|
||||
self.assertNotIn(b'You successfully changed your password', r.data)
|
||||
self.assertIn(b'Invalid password', r.data)
|
||||
|
||||
def test_change_password_mismatch(self):
|
||||
self.authenticate()
|
||||
@@ -471,8 +471,8 @@ class ChangePasswordTest(SecurityTest):
|
||||
'new_password': 'newpassword',
|
||||
'new_password_confirm': 'notnewpassword'
|
||||
}, follow_redirects=True)
|
||||
self.assertNotIn('You successfully changed your password', r.data)
|
||||
self.assertIn('Passwords do not match', r.data)
|
||||
self.assertNotIn(b'You successfully changed your password', r.data)
|
||||
self.assertIn(b'Passwords do not match', r.data)
|
||||
|
||||
def test_change_password_bad_password(self):
|
||||
self.authenticate()
|
||||
@@ -481,8 +481,8 @@ class ChangePasswordTest(SecurityTest):
|
||||
'new_password': 'a',
|
||||
'new_password_confirm': 'a'
|
||||
}, follow_redirects=True)
|
||||
self.assertNotIn('You successfully changed your password', r.data)
|
||||
self.assertIn('Password must be at least 6 characters', r.data)
|
||||
self.assertNotIn(b'You successfully changed your password', r.data)
|
||||
self.assertIn(b'Password must be at least 6 characters', r.data)
|
||||
|
||||
def test_change_password_same_as_previous(self):
|
||||
self.authenticate()
|
||||
@@ -491,8 +491,8 @@ class ChangePasswordTest(SecurityTest):
|
||||
'new_password': 'password',
|
||||
'new_password_confirm': 'password'
|
||||
}, follow_redirects=True)
|
||||
self.assertNotIn('You successfully changed your password', r.data)
|
||||
self.assertIn('Your new password must be different than your previous password.', r.data)
|
||||
self.assertNotIn(b'You successfully changed your password', r.data)
|
||||
self.assertIn(b'Your new password must be different than your previous password.', r.data)
|
||||
|
||||
def test_change_password_success(self):
|
||||
data = {
|
||||
@@ -505,8 +505,8 @@ class ChangePasswordTest(SecurityTest):
|
||||
with self.app.extensions['mail'].record_messages() as outbox:
|
||||
r = self._post('/change', data=data, follow_redirects=True)
|
||||
|
||||
self.assertIn('You successfully changed your password', r.data)
|
||||
self.assertIn('Home Page', r.data)
|
||||
self.assertIn(b'You successfully changed your password', r.data)
|
||||
self.assertIn(b'Home Page', r.data)
|
||||
|
||||
self.assertEqual(len(outbox), 1)
|
||||
self.assertIn("Your password has been changed", outbox[0].html)
|
||||
@@ -552,7 +552,7 @@ class ChangePasswordPostViewTest(SecurityTest):
|
||||
self.authenticate()
|
||||
r = self._post('/change', data=data, follow_redirects=True)
|
||||
|
||||
self.assertIn('Profile Page', r.data)
|
||||
self.assertIn(b'Profile Page', r.data)
|
||||
|
||||
|
||||
class ChangePasswordDisabledTest(SecurityTest):
|
||||
@@ -605,12 +605,12 @@ class PasswordlessTests(SecurityTest):
|
||||
data = '{"email": "matt@lp.com", "password": "password"}'
|
||||
r = self._post('/login', data=data, content_type='application/json')
|
||||
self.assertEquals(r.status_code, 200)
|
||||
self.assertNotIn('error', r.data)
|
||||
self.assertNotIn(b'error', r.data)
|
||||
|
||||
def test_request_login_token_with_json_and_invalid_email(self):
|
||||
data = '{"email": "nobody@lp.com", "password": "password"}'
|
||||
r = self._post('/login', data=data, content_type='application/json')
|
||||
self.assertIn('errors', r.data)
|
||||
self.assertIn(b'errors', r.data)
|
||||
|
||||
def test_request_login_token_sends_email_and_can_login(self):
|
||||
e = 'matt@lp.com'
|
||||
@@ -624,8 +624,8 @@ class PasswordlessTests(SecurityTest):
|
||||
self.assertEqual(len(outbox), 1)
|
||||
|
||||
self.assertEquals(1, len(requests))
|
||||
self.assertIn('user', requests[0])
|
||||
self.assertIn('login_token', requests[0])
|
||||
self.assertIn(b'user', requests[0])
|
||||
self.assertIn(b'login_token', requests[0])
|
||||
|
||||
user = requests[0]['user']
|
||||
token = requests[0]['login_token']
|
||||
@@ -635,11 +635,11 @@ class PasswordlessTests(SecurityTest):
|
||||
self.assertIn(msg, r.data)
|
||||
|
||||
r = self.client.get('/login/' + token, follow_redirects=True)
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
|
||||
self.assertIn(msg, r.data)
|
||||
|
||||
r = self.client.get('/profile')
|
||||
self.assertIn('Profile Page', r.data)
|
||||
self.assertIn(b'Profile Page', r.data)
|
||||
|
||||
def test_invalid_login_token(self):
|
||||
msg = self.app.config['SECURITY_MSG_INVALID_LOGIN_TOKEN'][0]
|
||||
@@ -653,16 +653,16 @@ class PasswordlessTests(SecurityTest):
|
||||
token = requests[0]['login_token']
|
||||
|
||||
r = self.client.get('/login/' + token, follow_redirects=True)
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
|
||||
self.assertIn(msg, r.data)
|
||||
|
||||
r = self.client.get('/login/' + token, follow_redirects=True)
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
|
||||
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
|
||||
self.assertNotIn(msg, r.data)
|
||||
|
||||
def test_send_login_with_invalid_email(self):
|
||||
r = self._post('/login', data=dict(email='bogus@bogus.com'))
|
||||
self.assertIn('Specified user does not exist', r.data)
|
||||
self.assertIn(b'Specified user does not exist', r.data)
|
||||
|
||||
|
||||
class ExpiredLoginTokenTests(SecurityTest):
|
||||
@@ -730,9 +730,9 @@ class NoBlueprintTests(SecurityTest):
|
||||
self.assertEqual(404, r.status_code)
|
||||
|
||||
def test_http_auth_without_blueprint(self):
|
||||
auth = 'Basic ' + base64.b64encode("matt@lp.com:password")
|
||||
auth = 'Basic %s' % base64.b64encode(b"matt@lp.com:password")
|
||||
r = self._get('/http', headers={'Authorization': auth})
|
||||
self.assertIn('HTTP Authentication', r.data)
|
||||
self.assertIn(b'HTTP Authentication', r.data)
|
||||
|
||||
|
||||
class ExtendFormsTest(SecurityTest):
|
||||
@@ -846,4 +846,4 @@ class AdditionalUserIdentityAttributes(SecurityTest):
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate(email='matt')
|
||||
self.assertIn('Hello matt@lp.com', r.data)
|
||||
self.assertIn(b'Hello matt@lp.com', r.data)
|
||||
|
||||
+47
-41
@@ -4,7 +4,11 @@ from __future__ import with_statement
|
||||
|
||||
import base64
|
||||
import simplejson as json
|
||||
from cookielib import Cookie
|
||||
|
||||
try:
|
||||
from cookielib import Cookie
|
||||
except ImportError:
|
||||
from http.cookiejar import Cookie
|
||||
|
||||
from werkzeug.utils import parse_cookie
|
||||
|
||||
@@ -27,35 +31,35 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
def test_login_view(self):
|
||||
r = self._get('/login')
|
||||
self.assertIn('<h1>Login</h1>', r.data)
|
||||
self.assertIn(b'<h1>Login</h1>', r.data)
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate()
|
||||
self.assertIn('Hello matt@lp.com', r.data)
|
||||
self.assertIn(b'Hello matt@lp.com', r.data)
|
||||
|
||||
def test_authenticate_case_insensitive_email(self):
|
||||
r = self.authenticate(email='MATT@lp.com')
|
||||
self.assertIn('Hello matt@lp.com', r.data)
|
||||
self.assertIn(b'Hello matt@lp.com', r.data)
|
||||
|
||||
def test_unprovided_username(self):
|
||||
r = self.authenticate("")
|
||||
self.assertIn(self.get_message('EMAIL_NOT_PROVIDED'), r.data)
|
||||
self.assertIn(self.get_message('EMAIL_NOT_PROVIDED').encode('utf-8'), r.data)
|
||||
|
||||
def test_unprovided_password(self):
|
||||
r = self.authenticate(password="")
|
||||
self.assertIn(self.get_message('PASSWORD_NOT_PROVIDED'), r.data)
|
||||
self.assertIn(self.get_message('PASSWORD_NOT_PROVIDED').encode('utf-8'), r.data)
|
||||
|
||||
def test_invalid_user(self):
|
||||
r = self.authenticate(email="bogus@bogus.com")
|
||||
self.assertIn(self.get_message('USER_DOES_NOT_EXIST'), r.data)
|
||||
self.assertIn(self.get_message('USER_DOES_NOT_EXIST').encode('utf-8'), r.data)
|
||||
|
||||
def test_bad_password(self):
|
||||
r = self.authenticate(password="bogus")
|
||||
self.assertIn(self.get_message('INVALID_PASSWORD'), r.data)
|
||||
self.assertIn(self.get_message('INVALID_PASSWORD').encode('utf-8'), r.data)
|
||||
|
||||
def test_inactive_user(self):
|
||||
r = self.authenticate("tiya@lp.com", "password")
|
||||
self.assertIn(self.get_message('DISABLED_ACCOUNT'), r.data)
|
||||
self.assertIn(self.get_message('DISABLED_ACCOUNT').encode('utf-8'), r.data)
|
||||
|
||||
def test_logout(self):
|
||||
self.authenticate()
|
||||
@@ -65,17 +69,17 @@ class DefaultSecurityTests(SecurityTest):
|
||||
def test_unauthorized_access(self):
|
||||
self.logout()
|
||||
r = self._get('/profile', follow_redirects=True)
|
||||
self.assertIn('<li class="info">Please log in to access this page.</li>', r.data)
|
||||
self.assertIn(b'<li class="info">Please log in to access this page.</li>', r.data)
|
||||
|
||||
def test_authorized_access(self):
|
||||
self.authenticate()
|
||||
r = self._get("/profile")
|
||||
self.assertIn('profile', r.data)
|
||||
self.assertIn(b'profile', r.data)
|
||||
|
||||
def test_valid_admin_role(self):
|
||||
self.authenticate()
|
||||
r = self._get("/admin")
|
||||
self.assertIn('Admin Page', r.data)
|
||||
self.assertIn(b'Admin Page', r.data)
|
||||
|
||||
def test_invalid_admin_role(self):
|
||||
self.authenticate("joe@lp.com")
|
||||
@@ -86,7 +90,7 @@ class DefaultSecurityTests(SecurityTest):
|
||||
for user in ("matt@lp.com", "joe@lp.com"):
|
||||
self.authenticate(user)
|
||||
r = self._get("/admin_or_editor")
|
||||
self.assertIn('Admin or Editor Page', r.data)
|
||||
self.assertIn(b'Admin or Editor Page', r.data)
|
||||
self.logout()
|
||||
|
||||
self.authenticate("jill@lp.com")
|
||||
@@ -95,7 +99,7 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
def test_unauthenticated_role_required(self):
|
||||
r = self._get('/admin', follow_redirects=True)
|
||||
self.assertIn(self.get_message('UNAUTHORIZED'), r.data)
|
||||
self.assertIn(self.get_message('UNAUTHORIZED').encode('utf-8'), r.data)
|
||||
|
||||
def test_multiple_role_required(self):
|
||||
for user in ("matt@lp.com", "joe@lp.com"):
|
||||
@@ -106,7 +110,7 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
self.authenticate('dave@lp.com')
|
||||
r = self._get("/admin_and_editor", follow_redirects=True)
|
||||
self.assertIn('Admin and Editor Page', r.data)
|
||||
self.assertIn(b'Admin and Editor Page', r.data)
|
||||
|
||||
def test_ok_json_auth(self):
|
||||
r = self.json_authenticate()
|
||||
@@ -116,14 +120,14 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
def test_invalid_json_auth(self):
|
||||
r = self.json_authenticate(password='junk')
|
||||
self.assertIn('"code": 400', r.data)
|
||||
self.assertIn(b'"code": 400', r.data)
|
||||
|
||||
def test_token_auth_via_querystring_valid_token(self):
|
||||
r = self.json_authenticate()
|
||||
data = json.loads(r.data)
|
||||
token = data['response']['user']['authentication_token']
|
||||
r = self._get('/token?auth_token=' + token)
|
||||
self.assertIn('Token Authentication', r.data)
|
||||
self.assertIn(b'Token Authentication', r.data)
|
||||
|
||||
def test_token_auth_via_header_valid_token(self):
|
||||
r = self.json_authenticate()
|
||||
@@ -131,7 +135,7 @@ class DefaultSecurityTests(SecurityTest):
|
||||
token = data['response']['user']['authentication_token']
|
||||
headers = {"Authentication-Token": token}
|
||||
r = self._get('/token', headers=headers)
|
||||
self.assertIn('Token Authentication', r.data)
|
||||
self.assertIn(b'Token Authentication', r.data)
|
||||
|
||||
def test_token_auth_via_querystring_invalid_token(self):
|
||||
r = self._get('/token?auth_token=X')
|
||||
@@ -143,61 +147,63 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
def test_http_auth(self):
|
||||
r = self._get('/http', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:password")
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:password")
|
||||
})
|
||||
self.assertIn('HTTP Authentication', r.data)
|
||||
self.assertIn(b'HTTP Authentication', r.data)
|
||||
|
||||
def test_http_auth_no_authorization(self):
|
||||
r = self._get('/http', headers={})
|
||||
self.assertIn('<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn('WWW-Authenticate', r.headers)
|
||||
self.assertEquals('Basic realm="Login Required"',
|
||||
r.headers['WWW-Authenticate'])
|
||||
|
||||
def test_invalid_http_auth_invalid_username(self):
|
||||
r = self._get('/http', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("bogus:bogus")
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"bogus:bogus")
|
||||
})
|
||||
self.assertIn('<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn('WWW-Authenticate', r.headers)
|
||||
self.assertEquals('Basic realm="Login Required"',
|
||||
r.headers['WWW-Authenticate'])
|
||||
|
||||
def test_invalid_http_auth_bad_password(self):
|
||||
r = self._get('/http', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
|
||||
})
|
||||
self.assertIn('<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn('WWW-Authenticate', r.headers)
|
||||
self.assertEquals('Basic realm="Login Required"',
|
||||
r.headers['WWW-Authenticate'])
|
||||
|
||||
def test_custom_http_auth_realm(self):
|
||||
r = self._get('/http_custom_realm', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
|
||||
})
|
||||
self.assertIn('<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
|
||||
self.assertIn('WWW-Authenticate', r.headers)
|
||||
self.assertEquals('Basic realm="My Realm"',
|
||||
r.headers['WWW-Authenticate'])
|
||||
|
||||
def test_multi_auth_basic(self):
|
||||
r = self._get('/multi_auth', headers={
|
||||
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:password")
|
||||
})
|
||||
self.assertIn('Basic', r.data)
|
||||
h = {
|
||||
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:password")
|
||||
}
|
||||
print(h)
|
||||
r = self._get('/multi_auth', headers=h)
|
||||
self.assertIn(b'Basic', r.data)
|
||||
|
||||
def test_multi_auth_token(self):
|
||||
r = self.json_authenticate()
|
||||
data = json.loads(r.data)
|
||||
token = data['response']['user']['authentication_token']
|
||||
r = self._get('/multi_auth?auth_token=' + token)
|
||||
self.assertIn('Token', r.data)
|
||||
self.assertIn(b'Token', r.data)
|
||||
|
||||
def test_multi_auth_session(self):
|
||||
self.authenticate()
|
||||
r = self._get('/multi_auth')
|
||||
self.assertIn('Session', r.data)
|
||||
self.assertIn(b'Session', r.data)
|
||||
|
||||
def test_user_deleted_during_session_reverts_to_anonymous_user(self):
|
||||
self.authenticate()
|
||||
@@ -208,13 +214,13 @@ class DefaultSecurityTests(SecurityTest):
|
||||
self.app.security.datastore.commit()
|
||||
|
||||
r = self._get('/')
|
||||
self.assertNotIn('Hello matt@lp.com', r.data)
|
||||
self.assertNotIn(b'Hello matt@lp.com', r.data)
|
||||
|
||||
def test_remember_token(self):
|
||||
r = self.authenticate(follow_redirects=False)
|
||||
self.client.cookie_jar.clear_session_cookies()
|
||||
r = self._get('/profile')
|
||||
self.assertIn('profile', r.data)
|
||||
self.assertIn(b'profile', r.data)
|
||||
|
||||
def test_token_loader_does_not_fail_with_invalid_token(self):
|
||||
c = Cookie(version=0, name='remember_token', value='None', port=None,
|
||||
@@ -226,7 +232,7 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
self.client.cookie_jar.set_cookie(c)
|
||||
r = self._get('/')
|
||||
self.assertNotIn('BadSignature', r.data)
|
||||
self.assertNotIn(b'BadSignature', r.data)
|
||||
|
||||
|
||||
class MongoEngineSecurityTests(DefaultSecurityTests):
|
||||
@@ -247,23 +253,23 @@ class DefaultDatastoreTests(SecurityTest):
|
||||
|
||||
def test_add_role_to_user(self):
|
||||
r = self._get('/coverage/add_role_to_user')
|
||||
self.assertIn('success', r.data)
|
||||
self.assertIn(b'success', r.data)
|
||||
|
||||
def test_remove_role_from_user(self):
|
||||
r = self._get('/coverage/remove_role_from_user')
|
||||
self.assertIn('success', r.data)
|
||||
self.assertIn(b'success', r.data)
|
||||
|
||||
def test_activate_user(self):
|
||||
r = self._get('/coverage/activate_user')
|
||||
self.assertIn('success', r.data)
|
||||
self.assertIn(b'success', r.data)
|
||||
|
||||
def test_deactivate_user(self):
|
||||
r = self._get('/coverage/deactivate_user')
|
||||
self.assertIn('success', r.data)
|
||||
self.assertIn(b'success', r.data)
|
||||
|
||||
def test_invalid_role(self):
|
||||
r = self._get('/coverage/invalid_role')
|
||||
self.assertIn('success', r.data)
|
||||
self.assertIn(b'success', r.data)
|
||||
|
||||
|
||||
class MongoEngineDatastoreTests(DefaultDatastoreTests):
|
||||
|
||||
@@ -137,9 +137,13 @@ def create_users(count=None):
|
||||
|
||||
for u in users[:count]:
|
||||
pw = encrypt_password(u[2])
|
||||
ds.create_user(email=u[0], username=u[1], password=pw,
|
||||
roles=u[3], active=u[4])
|
||||
ds.commit()
|
||||
roles = [ds.find_or_create_role(rn) for rn in u[3]]
|
||||
ds.commit()
|
||||
user = ds.create_user(email=u[0], username=u[1], password=pw, active=u[4])
|
||||
ds.commit()
|
||||
for role in roles:
|
||||
ds.add_role_to_user(user, role)
|
||||
ds.commit()
|
||||
|
||||
|
||||
def populate_data(user_count=None):
|
||||
|
||||
Reference in New Issue
Block a user