Work in progress

This commit is contained in:
Matt Wright
2013-12-19 16:12:29 -05:00
parent d95a8c9364
commit f1447b2adc
9 changed files with 141 additions and 118 deletions
+3 -2
View File
@@ -19,7 +19,7 @@ from passlib.context import CryptContext
from werkzeug.datastructures import ImmutableList
from werkzeug.local import LocalProxy
from .utils import config_value as cv, get_config, md5, url_for_security
from .utils import config_value as cv, get_config, md5, url_for_security, string_types
from .views import create_blueprint
from .forms import LoginForm, ConfirmRegisterForm, RegisterForm, \
ForgotPasswordForm, ChangePasswordForm, ResetPasswordForm, \
@@ -249,6 +249,7 @@ def _context_processor():
class RoleMixin(object):
"""Mixin for `Role` model definitions"""
def __eq__(self, other):
return (self.name == other or
self.name == getattr(other, 'name', None))
@@ -273,7 +274,7 @@ class UserMixin(BaseUserMixin):
"""Returns `True` if the user identifies with the specified role.
:param role: A role name or `Role` instance"""
if isinstance(role, basestring):
if isinstance(role, string_types):
return role in (role.name for role in self.roles)
else:
return role in self.roles
+6 -5
View File
@@ -9,7 +9,7 @@
:license: MIT, see LICENSE for more details.
"""
from .utils import get_identity_attributes
from .utils import get_identity_attributes, string_types
class Datastore(object):
@@ -68,9 +68,9 @@ class UserDatastore(object):
self.role_model = role_model
def _prepare_role_modify_args(self, user, role):
if isinstance(user, basestring):
if isinstance(user, string_types):
user = self.find_user(email=user)
if isinstance(role, basestring):
if isinstance(role, string_types):
role = self.find_role(role)
return user, role
@@ -105,6 +105,7 @@ class UserDatastore(object):
user, role = self._prepare_role_modify_args(user, role)
if role not in user.roles:
user.roles.append(role)
self.put(user)
return True
return False
@@ -161,8 +162,8 @@ class UserDatastore(object):
def create_user(self, **kwargs):
"""Creates and returns a new user from the given parameters."""
user = self.user_model(**self._prepare_create_user_args(**kwargs))
kwargs = self._prepare_create_user_args(**kwargs)
user = self.user_model(**kwargs)
return self.put(user)
def delete_user(self, user):
+7 -6
View File
@@ -10,9 +10,10 @@
"""
import inspect
import urlparse
import flask_wtf as wtf
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
from flask import request, current_app
from flask_wtf import Form as BaseForm
@@ -22,7 +23,7 @@ from flask_login import current_user
from werkzeug.local import LocalProxy
from .confirmable import requires_confirmation
from .utils import verify_and_update_password, get_message, encrypt_password, config_value
from .utils import verify_and_update_password, get_message, config_value
# Convenient reference
_datastore = LocalProxy(lambda: current_app.extensions['security'].datastore)
@@ -137,8 +138,8 @@ class NextFormMixin():
def validate_next(self, field):
if field.data:
url_next = urlparse.urlsplit(field.data)
url_base = urlparse.urlsplit(request.host_url)
url_next = urlsplit(field.data)
url_base = urlsplit(request.host_url)
if url_next.netloc and url_next.netloc != url_base.netloc:
field.data = ''
raise ValidationError(get_message('INVALID_REDIRECT')[0])
+21 -11
View File
@@ -14,6 +14,7 @@ import blinker
import functools
import hashlib
import hmac
import sys
from contextlib import contextmanager
from datetime import datetime, timedelta
@@ -37,6 +38,15 @@ _datastore = LocalProxy(lambda: _security.datastore)
_pwd_context = LocalProxy(lambda: _security.pwd_context)
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str,
text_type = str
else:
string_types = basestring,
text_type = unicode
def login_user(user, remember=None):
"""Performs the login routine.
@@ -85,16 +95,13 @@ def get_hmac(password):
:param password: The password to sign
"""
if _security.password_hash == 'plaintext':
return password
if _security.password_salt is None:
raise RuntimeError(
'The configuration value `SECURITY_PASSWORD_SALT` must '
'not be None when the value of `SECURITY_PASSWORD_HASH` is '
'set to "%s"' % _security.password_hash)
h = hmac.new(_security.password_salt, password.encode('utf-8'), hashlib.sha512)
h = hmac.new(_security.password_salt.encode('utf-8'), password.encode('utf-8'), hashlib.sha512)
return base64.b64encode(h.digest())
@@ -104,7 +111,7 @@ def verify_password(password, password_hash):
:param password: A plaintext password to verify
:param password_hash: The expected hash value of the password (usually form your database)
"""
return _pwd_context.verify(get_hmac(password), password_hash)
return _pwd_context.verify(encrypt_password(password), password_hash)
def verify_and_update_password(password, user):
@@ -114,7 +121,7 @@ def verify_and_update_password(password, user):
:param password: A plaintext password to verify
:param user: The user to verify against
"""
verified, new_password = _pwd_context.verify_and_update(get_hmac(password), user.password)
verified, new_password = _pwd_context.verify_and_update(encrypt_password(password), user.password)
if verified and new_password:
user.password = new_password
_datastore.put(user)
@@ -126,11 +133,14 @@ def encrypt_password(password):
:param password: The plaintext passwrod to encrypt
"""
return _pwd_context.encrypt(get_hmac(password))
if _security.password_hash == 'plaintext':
return password
signed = get_hmac(password)
return _pwd_context.encrypt(signed.decode('ascii'))
def md5(data):
return hashlib.md5(data).hexdigest()
return hashlib.md5(data.encode('ascii')).hexdigest()
def do_flash(message, category=None):
@@ -408,19 +418,19 @@ class CaptureSignals(object):
self._records[signal].append((args, kwargs))
def __enter__(self):
for signal, receiver in self._receivers.iteritems():
for signal, receiver in self._receivers.items():
signal.connect(receiver)
return self
def __exit__(self, type, value, traceback):
for signal, receiver in self._receivers.iteritems():
for signal, receiver in self._receivers.items():
signal.disconnect(receiver)
def signals_sent(self):
"""Return a set of the signals sent.
:rtype: list of blinker `NamedSignals`.
"""
return set([signal for signal, _ in self._records.iteritems() if self._records[signal]])
return set([signal for signal, _ in self._records.items() if self._records[signal]])
def capture_signals():
+1 -1
View File
@@ -61,7 +61,7 @@ class SecurityTest(TestCase):
return self._get(endpoint or '/logout', follow_redirects=True)
def assertIsHomePage(self, data):
self.assertIn('Home Page', data)
self.assertIn(b'Home Page', data)
def assertIn(self, member, container, msg=None):
if hasattr(TestCase, 'assertIn'):
+47 -47
View File
@@ -27,7 +27,7 @@ class ConfiguredPasswordHashSecurityTests(SecurityTest):
def test_authenticate(self):
r = self.authenticate(endpoint="/login")
self.assertIn('Home Page', r.data)
self.assertIn(b'Home Page', r.data)
class ConfiguredSecurityTests(SecurityTest):
@@ -49,16 +49,16 @@ class ConfiguredSecurityTests(SecurityTest):
def test_authenticate(self):
r = self.authenticate(endpoint="/custom_login")
self.assertIn('Post Login', r.data)
self.assertIn(b'Post Login', r.data)
def test_logout(self):
self.authenticate(endpoint="/custom_login")
r = self.logout(endpoint="/custom_logout")
self.assertIn('Post Logout', r.data)
self.assertIn(b'Post Logout', r.data)
def test_register_view(self):
r = self._get('/register')
self.assertIn('<h1>Register</h1>', r.data)
self.assertIn(b'<h1>Register</h1>', r.data)
def test_register(self):
data = dict(email='dude@lp.com',
@@ -66,7 +66,7 @@ class ConfiguredSecurityTests(SecurityTest):
password_confirm='password')
r = self._post('/register', data=data, follow_redirects=True)
self.assertIn('Post Register', r.data)
self.assertIn(b'Post Register', r.data)
def test_register_with_next_querystring_argument(self):
data = dict(email='dude@lp.com',
@@ -74,7 +74,7 @@ class ConfiguredSecurityTests(SecurityTest):
password_confirm='password')
r = self._post('/register?next=/page1', data=data, follow_redirects=True)
self.assertIn('Page 1', r.data)
self.assertIn(b'Page 1', r.data)
def test_register_json(self):
data = '{ "email": "dude@lp.com", "password": "password"}'
@@ -98,9 +98,9 @@ class ConfiguredSecurityTests(SecurityTest):
def test_default_http_auth_realm(self):
r = self._get('/http', headers={
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
})
self.assertIn('<h1>Unauthorized</h1>', r.data)
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
self.assertIn('WWW-Authenticate', r.headers)
self.assertEquals('Basic realm="Custom Realm"',
r.headers['WWW-Authenticate'])
@@ -125,7 +125,7 @@ class DefaultTemplatePathTests(SecurityTest):
def test_login_user_template(self):
r = self._get('/login')
self.assertIn('CUSTOM LOGIN USER', r.data)
self.assertIn(b'CUSTOM LOGIN USER', r.data)
class RegisterableTemplatePathTests(SecurityTest):
@@ -137,7 +137,7 @@ class RegisterableTemplatePathTests(SecurityTest):
def test_register_user_template(self):
r = self._get('/register')
self.assertIn('CUSTOM REGISTER USER', r.data)
self.assertIn(b'CUSTOM REGISTER USER', r.data)
class RecoverableTemplatePathTests(SecurityTest):
@@ -150,7 +150,7 @@ class RecoverableTemplatePathTests(SecurityTest):
def test_forgot_password_template(self):
r = self._get('/reset')
self.assertIn('CUSTOM FORGOT PASSWORD', r.data)
self.assertIn(b'CUSTOM FORGOT PASSWORD', r.data)
def test_reset_password_template(self):
with capture_reset_password_requests() as requests:
@@ -161,7 +161,7 @@ class RecoverableTemplatePathTests(SecurityTest):
r = self._get('/reset/' + t)
self.assertIn('CUSTOM RESET PASSWORD', r.data)
self.assertIn(b'CUSTOM RESET PASSWORD', r.data)
class ConfirmableTemplatePathTests(SecurityTest):
@@ -173,7 +173,7 @@ class ConfirmableTemplatePathTests(SecurityTest):
def test_send_confirmation_template(self):
r = self._get('/confirm')
self.assertIn('CUSTOM SEND CONFIRMATION', r.data)
self.assertIn(b'CUSTOM SEND CONFIRMATION', r.data)
class PasswordlessTemplatePathTests(SecurityTest):
@@ -185,7 +185,7 @@ class PasswordlessTemplatePathTests(SecurityTest):
def test_send_login_template(self):
r = self._get('/login')
self.assertIn('CUSTOM SEND LOGIN', r.data)
self.assertIn(b'CUSTOM SEND LOGIN', r.data)
class RegisterableTests(SecurityTest):
@@ -200,7 +200,7 @@ class RegisterableTests(SecurityTest):
password_confirm='password')
self._post('/register', data=data, follow_redirects=True)
r = self.authenticate('dude@lp.com')
self.assertIn('Hello dude@lp.com', r.data)
self.assertIn(b'Hello dude@lp.com', r.data)
class ConfirmableTests(SecurityTest):
@@ -215,7 +215,7 @@ class ConfirmableTests(SecurityTest):
e = 'dude@lp.com'
self.register(e)
r = self.authenticate(email=e)
self.assertIn(self.get_message('CONFIRMATION_REQUIRED'), r.data)
self.assertIn(self.get_message('CONFIRMATION_REQUIRED').encode('utf-8'), r.data)
def test_send_confirmation_of_already_confirmed_account(self):
e = 'dude@lp.com'
@@ -227,7 +227,7 @@ class ConfirmableTests(SecurityTest):
self.client.get('/confirm/' + token, follow_redirects=True)
self.logout()
r = self._post('/confirm', data=dict(email=e))
self.assertIn(self.get_message('ALREADY_CONFIRMED'), r.data)
self.assertIn(self.get_message('ALREADY_CONFIRMED').encode('utf-8'), r.data)
def test_register_sends_confirmation_email(self):
e = 'dude@lp.com'
@@ -269,7 +269,7 @@ class ConfirmableTests(SecurityTest):
self.register(e)
r = self._post('/confirm', data={'email': e})
msg = self.get_message('CONFIRMATION_REQUEST', email=e)
msg = self.get_message('CONFIRMATION_REQUEST', email=e).encode('utf-8')
self.assertIn(msg, r.data)
def test_user_deleted_before_confirmation(self):
@@ -350,7 +350,7 @@ class LoginWithoutImmediateConfirmTests(SecurityTest):
r = self.client.get('/confirm/' + token2, follow_redirects=True)
msg = self.app.config['SECURITY_MSG_EMAIL_CONFIRMED'][0]
self.assertIn(msg, r.data)
self.assertIn('Hello %s' % e2, r.data)
self.assertIn(b'Hello %s' % e2, r.data)
def test_login_unconfirmed_user_when_login_without_confirmation_is_true(self):
e = 'dude@lp.com'
@@ -377,7 +377,7 @@ class RecoverableTests(SecurityTest):
follow_redirects=True)
t = requests[0]['token']
r = self._get('/reset/' + t)
self.assertIn('<h1>Reset password</h1>', r.data)
self.assertIn(b'<h1>Reset password</h1>', r.data)
def test_forgot_post_sends_email(self):
with capture_reset_password_requests():
@@ -408,7 +408,7 @@ class RecoverableTests(SecurityTest):
r = self.logout()
r = self.authenticate('joe@lp.com', 'newpassword')
self.assertIn('Hello joe@lp.com', r.data)
self.assertIn(b'Hello joe@lp.com', r.data)
def test_reset_password_with_invalid_token(self):
r = self._post('/reset/bogus', data={
@@ -416,7 +416,7 @@ class RecoverableTests(SecurityTest):
'password_confirm': 'newpassword'
}, follow_redirects=True)
self.assertIn(self.get_message('INVALID_RESET_PASSWORD_TOKEN'), r.data)
self.assertIn(self.get_message('INVALID_RESET_PASSWORD_TOKEN').encode('utf-8'), r.data)
class ExpiredResetPasswordTest(SecurityTest):
@@ -439,7 +439,7 @@ class ExpiredResetPasswordTest(SecurityTest):
'password_confirm': 'newpassword'
}, follow_redirects=True)
self.assertIn('You did not reset your password within', r.data)
self.assertIn(b'You did not reset your password within', r.data)
class ChangePasswordTest(SecurityTest):
@@ -452,7 +452,7 @@ class ChangePasswordTest(SecurityTest):
def test_change_password(self):
self.authenticate()
r = self.client.get('/change', follow_redirects=True)
self.assertIn('Change password', r.data)
self.assertIn(b'Change password', r.data)
def test_change_password_invalid(self):
self.authenticate()
@@ -461,8 +461,8 @@ class ChangePasswordTest(SecurityTest):
'new_password': 'newpassword',
'new_password_confirm': 'newpassword'
}, follow_redirects=True)
self.assertNotIn('You successfully changed your password', r.data)
self.assertIn('Invalid password', r.data)
self.assertNotIn(b'You successfully changed your password', r.data)
self.assertIn(b'Invalid password', r.data)
def test_change_password_mismatch(self):
self.authenticate()
@@ -471,8 +471,8 @@ class ChangePasswordTest(SecurityTest):
'new_password': 'newpassword',
'new_password_confirm': 'notnewpassword'
}, follow_redirects=True)
self.assertNotIn('You successfully changed your password', r.data)
self.assertIn('Passwords do not match', r.data)
self.assertNotIn(b'You successfully changed your password', r.data)
self.assertIn(b'Passwords do not match', r.data)
def test_change_password_bad_password(self):
self.authenticate()
@@ -481,8 +481,8 @@ class ChangePasswordTest(SecurityTest):
'new_password': 'a',
'new_password_confirm': 'a'
}, follow_redirects=True)
self.assertNotIn('You successfully changed your password', r.data)
self.assertIn('Password must be at least 6 characters', r.data)
self.assertNotIn(b'You successfully changed your password', r.data)
self.assertIn(b'Password must be at least 6 characters', r.data)
def test_change_password_same_as_previous(self):
self.authenticate()
@@ -491,8 +491,8 @@ class ChangePasswordTest(SecurityTest):
'new_password': 'password',
'new_password_confirm': 'password'
}, follow_redirects=True)
self.assertNotIn('You successfully changed your password', r.data)
self.assertIn('Your new password must be different than your previous password.', r.data)
self.assertNotIn(b'You successfully changed your password', r.data)
self.assertIn(b'Your new password must be different than your previous password.', r.data)
def test_change_password_success(self):
data = {
@@ -505,8 +505,8 @@ class ChangePasswordTest(SecurityTest):
with self.app.extensions['mail'].record_messages() as outbox:
r = self._post('/change', data=data, follow_redirects=True)
self.assertIn('You successfully changed your password', r.data)
self.assertIn('Home Page', r.data)
self.assertIn(b'You successfully changed your password', r.data)
self.assertIn(b'Home Page', r.data)
self.assertEqual(len(outbox), 1)
self.assertIn("Your password has been changed", outbox[0].html)
@@ -552,7 +552,7 @@ class ChangePasswordPostViewTest(SecurityTest):
self.authenticate()
r = self._post('/change', data=data, follow_redirects=True)
self.assertIn('Profile Page', r.data)
self.assertIn(b'Profile Page', r.data)
class ChangePasswordDisabledTest(SecurityTest):
@@ -605,12 +605,12 @@ class PasswordlessTests(SecurityTest):
data = '{"email": "matt@lp.com", "password": "password"}'
r = self._post('/login', data=data, content_type='application/json')
self.assertEquals(r.status_code, 200)
self.assertNotIn('error', r.data)
self.assertNotIn(b'error', r.data)
def test_request_login_token_with_json_and_invalid_email(self):
data = '{"email": "nobody@lp.com", "password": "password"}'
r = self._post('/login', data=data, content_type='application/json')
self.assertIn('errors', r.data)
self.assertIn(b'errors', r.data)
def test_request_login_token_sends_email_and_can_login(self):
e = 'matt@lp.com'
@@ -624,8 +624,8 @@ class PasswordlessTests(SecurityTest):
self.assertEqual(len(outbox), 1)
self.assertEquals(1, len(requests))
self.assertIn('user', requests[0])
self.assertIn('login_token', requests[0])
self.assertIn(b'user', requests[0])
self.assertIn(b'login_token', requests[0])
user = requests[0]['user']
token = requests[0]['login_token']
@@ -635,11 +635,11 @@ class PasswordlessTests(SecurityTest):
self.assertIn(msg, r.data)
r = self.client.get('/login/' + token, follow_redirects=True)
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
self.assertIn(msg, r.data)
r = self.client.get('/profile')
self.assertIn('Profile Page', r.data)
self.assertIn(b'Profile Page', r.data)
def test_invalid_login_token(self):
msg = self.app.config['SECURITY_MSG_INVALID_LOGIN_TOKEN'][0]
@@ -653,16 +653,16 @@ class PasswordlessTests(SecurityTest):
token = requests[0]['login_token']
r = self.client.get('/login/' + token, follow_redirects=True)
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
self.assertIn(msg, r.data)
r = self.client.get('/login/' + token, follow_redirects=True)
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL')
msg = self.get_message('PASSWORDLESS_LOGIN_SUCCESSFUL').encode('utf-8')
self.assertNotIn(msg, r.data)
def test_send_login_with_invalid_email(self):
r = self._post('/login', data=dict(email='bogus@bogus.com'))
self.assertIn('Specified user does not exist', r.data)
self.assertIn(b'Specified user does not exist', r.data)
class ExpiredLoginTokenTests(SecurityTest):
@@ -730,9 +730,9 @@ class NoBlueprintTests(SecurityTest):
self.assertEqual(404, r.status_code)
def test_http_auth_without_blueprint(self):
auth = 'Basic ' + base64.b64encode("matt@lp.com:password")
auth = 'Basic %s' % base64.b64encode(b"matt@lp.com:password")
r = self._get('/http', headers={'Authorization': auth})
self.assertIn('HTTP Authentication', r.data)
self.assertIn(b'HTTP Authentication', r.data)
class ExtendFormsTest(SecurityTest):
@@ -846,4 +846,4 @@ class AdditionalUserIdentityAttributes(SecurityTest):
def test_authenticate(self):
r = self.authenticate(email='matt')
self.assertIn('Hello matt@lp.com', r.data)
self.assertIn(b'Hello matt@lp.com', r.data)
+47 -41
View File
@@ -4,7 +4,11 @@ from __future__ import with_statement
import base64
import simplejson as json
from cookielib import Cookie
try:
from cookielib import Cookie
except ImportError:
from http.cookiejar import Cookie
from werkzeug.utils import parse_cookie
@@ -27,35 +31,35 @@ class DefaultSecurityTests(SecurityTest):
def test_login_view(self):
r = self._get('/login')
self.assertIn('<h1>Login</h1>', r.data)
self.assertIn(b'<h1>Login</h1>', r.data)
def test_authenticate(self):
r = self.authenticate()
self.assertIn('Hello matt@lp.com', r.data)
self.assertIn(b'Hello matt@lp.com', r.data)
def test_authenticate_case_insensitive_email(self):
r = self.authenticate(email='MATT@lp.com')
self.assertIn('Hello matt@lp.com', r.data)
self.assertIn(b'Hello matt@lp.com', r.data)
def test_unprovided_username(self):
r = self.authenticate("")
self.assertIn(self.get_message('EMAIL_NOT_PROVIDED'), r.data)
self.assertIn(self.get_message('EMAIL_NOT_PROVIDED').encode('utf-8'), r.data)
def test_unprovided_password(self):
r = self.authenticate(password="")
self.assertIn(self.get_message('PASSWORD_NOT_PROVIDED'), r.data)
self.assertIn(self.get_message('PASSWORD_NOT_PROVIDED').encode('utf-8'), r.data)
def test_invalid_user(self):
r = self.authenticate(email="bogus@bogus.com")
self.assertIn(self.get_message('USER_DOES_NOT_EXIST'), r.data)
self.assertIn(self.get_message('USER_DOES_NOT_EXIST').encode('utf-8'), r.data)
def test_bad_password(self):
r = self.authenticate(password="bogus")
self.assertIn(self.get_message('INVALID_PASSWORD'), r.data)
self.assertIn(self.get_message('INVALID_PASSWORD').encode('utf-8'), r.data)
def test_inactive_user(self):
r = self.authenticate("tiya@lp.com", "password")
self.assertIn(self.get_message('DISABLED_ACCOUNT'), r.data)
self.assertIn(self.get_message('DISABLED_ACCOUNT').encode('utf-8'), r.data)
def test_logout(self):
self.authenticate()
@@ -65,17 +69,17 @@ class DefaultSecurityTests(SecurityTest):
def test_unauthorized_access(self):
self.logout()
r = self._get('/profile', follow_redirects=True)
self.assertIn('<li class="info">Please log in to access this page.</li>', r.data)
self.assertIn(b'<li class="info">Please log in to access this page.</li>', r.data)
def test_authorized_access(self):
self.authenticate()
r = self._get("/profile")
self.assertIn('profile', r.data)
self.assertIn(b'profile', r.data)
def test_valid_admin_role(self):
self.authenticate()
r = self._get("/admin")
self.assertIn('Admin Page', r.data)
self.assertIn(b'Admin Page', r.data)
def test_invalid_admin_role(self):
self.authenticate("joe@lp.com")
@@ -86,7 +90,7 @@ class DefaultSecurityTests(SecurityTest):
for user in ("matt@lp.com", "joe@lp.com"):
self.authenticate(user)
r = self._get("/admin_or_editor")
self.assertIn('Admin or Editor Page', r.data)
self.assertIn(b'Admin or Editor Page', r.data)
self.logout()
self.authenticate("jill@lp.com")
@@ -95,7 +99,7 @@ class DefaultSecurityTests(SecurityTest):
def test_unauthenticated_role_required(self):
r = self._get('/admin', follow_redirects=True)
self.assertIn(self.get_message('UNAUTHORIZED'), r.data)
self.assertIn(self.get_message('UNAUTHORIZED').encode('utf-8'), r.data)
def test_multiple_role_required(self):
for user in ("matt@lp.com", "joe@lp.com"):
@@ -106,7 +110,7 @@ class DefaultSecurityTests(SecurityTest):
self.authenticate('dave@lp.com')
r = self._get("/admin_and_editor", follow_redirects=True)
self.assertIn('Admin and Editor Page', r.data)
self.assertIn(b'Admin and Editor Page', r.data)
def test_ok_json_auth(self):
r = self.json_authenticate()
@@ -116,14 +120,14 @@ class DefaultSecurityTests(SecurityTest):
def test_invalid_json_auth(self):
r = self.json_authenticate(password='junk')
self.assertIn('"code": 400', r.data)
self.assertIn(b'"code": 400', r.data)
def test_token_auth_via_querystring_valid_token(self):
r = self.json_authenticate()
data = json.loads(r.data)
token = data['response']['user']['authentication_token']
r = self._get('/token?auth_token=' + token)
self.assertIn('Token Authentication', r.data)
self.assertIn(b'Token Authentication', r.data)
def test_token_auth_via_header_valid_token(self):
r = self.json_authenticate()
@@ -131,7 +135,7 @@ class DefaultSecurityTests(SecurityTest):
token = data['response']['user']['authentication_token']
headers = {"Authentication-Token": token}
r = self._get('/token', headers=headers)
self.assertIn('Token Authentication', r.data)
self.assertIn(b'Token Authentication', r.data)
def test_token_auth_via_querystring_invalid_token(self):
r = self._get('/token?auth_token=X')
@@ -143,61 +147,63 @@ class DefaultSecurityTests(SecurityTest):
def test_http_auth(self):
r = self._get('/http', headers={
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:password")
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:password")
})
self.assertIn('HTTP Authentication', r.data)
self.assertIn(b'HTTP Authentication', r.data)
def test_http_auth_no_authorization(self):
r = self._get('/http', headers={})
self.assertIn('<h1>Unauthorized</h1>', r.data)
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
self.assertIn('WWW-Authenticate', r.headers)
self.assertEquals('Basic realm="Login Required"',
r.headers['WWW-Authenticate'])
def test_invalid_http_auth_invalid_username(self):
r = self._get('/http', headers={
'Authorization': 'Basic ' + base64.b64encode("bogus:bogus")
'Authorization': 'Basic %s' % base64.b64encode(b"bogus:bogus")
})
self.assertIn('<h1>Unauthorized</h1>', r.data)
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
self.assertIn('WWW-Authenticate', r.headers)
self.assertEquals('Basic realm="Login Required"',
r.headers['WWW-Authenticate'])
def test_invalid_http_auth_bad_password(self):
r = self._get('/http', headers={
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
})
self.assertIn('<h1>Unauthorized</h1>', r.data)
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
self.assertIn('WWW-Authenticate', r.headers)
self.assertEquals('Basic realm="Login Required"',
r.headers['WWW-Authenticate'])
def test_custom_http_auth_realm(self):
r = self._get('/http_custom_realm', headers={
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:bogus")
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:bogus")
})
self.assertIn('<h1>Unauthorized</h1>', r.data)
self.assertIn(b'<h1>Unauthorized</h1>', r.data)
self.assertIn('WWW-Authenticate', r.headers)
self.assertEquals('Basic realm="My Realm"',
r.headers['WWW-Authenticate'])
def test_multi_auth_basic(self):
r = self._get('/multi_auth', headers={
'Authorization': 'Basic ' + base64.b64encode("joe@lp.com:password")
})
self.assertIn('Basic', r.data)
h = {
'Authorization': 'Basic %s' % base64.b64encode(b"joe@lp.com:password")
}
print(h)
r = self._get('/multi_auth', headers=h)
self.assertIn(b'Basic', r.data)
def test_multi_auth_token(self):
r = self.json_authenticate()
data = json.loads(r.data)
token = data['response']['user']['authentication_token']
r = self._get('/multi_auth?auth_token=' + token)
self.assertIn('Token', r.data)
self.assertIn(b'Token', r.data)
def test_multi_auth_session(self):
self.authenticate()
r = self._get('/multi_auth')
self.assertIn('Session', r.data)
self.assertIn(b'Session', r.data)
def test_user_deleted_during_session_reverts_to_anonymous_user(self):
self.authenticate()
@@ -208,13 +214,13 @@ class DefaultSecurityTests(SecurityTest):
self.app.security.datastore.commit()
r = self._get('/')
self.assertNotIn('Hello matt@lp.com', r.data)
self.assertNotIn(b'Hello matt@lp.com', r.data)
def test_remember_token(self):
r = self.authenticate(follow_redirects=False)
self.client.cookie_jar.clear_session_cookies()
r = self._get('/profile')
self.assertIn('profile', r.data)
self.assertIn(b'profile', r.data)
def test_token_loader_does_not_fail_with_invalid_token(self):
c = Cookie(version=0, name='remember_token', value='None', port=None,
@@ -226,7 +232,7 @@ class DefaultSecurityTests(SecurityTest):
self.client.cookie_jar.set_cookie(c)
r = self._get('/')
self.assertNotIn('BadSignature', r.data)
self.assertNotIn(b'BadSignature', r.data)
class MongoEngineSecurityTests(DefaultSecurityTests):
@@ -247,23 +253,23 @@ class DefaultDatastoreTests(SecurityTest):
def test_add_role_to_user(self):
r = self._get('/coverage/add_role_to_user')
self.assertIn('success', r.data)
self.assertIn(b'success', r.data)
def test_remove_role_from_user(self):
r = self._get('/coverage/remove_role_from_user')
self.assertIn('success', r.data)
self.assertIn(b'success', r.data)
def test_activate_user(self):
r = self._get('/coverage/activate_user')
self.assertIn('success', r.data)
self.assertIn(b'success', r.data)
def test_deactivate_user(self):
r = self._get('/coverage/deactivate_user')
self.assertIn('success', r.data)
self.assertIn(b'success', r.data)
def test_invalid_role(self):
r = self._get('/coverage/invalid_role')
self.assertIn('success', r.data)
self.assertIn(b'success', r.data)
class MongoEngineDatastoreTests(DefaultDatastoreTests):
+7 -3
View File
@@ -137,9 +137,13 @@ def create_users(count=None):
for u in users[:count]:
pw = encrypt_password(u[2])
ds.create_user(email=u[0], username=u[1], password=pw,
roles=u[3], active=u[4])
ds.commit()
roles = [ds.find_or_create_role(rn) for rn in u[3]]
ds.commit()
user = ds.create_user(email=u[0], username=u[1], password=pw, active=u[4])
ds.commit()
for role in roles:
ds.add_role_to_user(user, role)
ds.commit()
def populate_data(user_count=None):
+2 -2
View File
@@ -1,5 +1,5 @@
[tox]
envlist = py26, py27
envlist = py26, py27, py33
[testenv]
deps =
@@ -10,4 +10,4 @@ deps =
Flask-Peewee
bcrypt
commands = nosetests []
commands = nosetests -xs []