mirror of
https://github.com/wassname/flask-security.git
synced 2026-06-29 16:30:04 +08:00
Refactor views a bit to keep things cleaner and fix up tests
This commit is contained in:
+49
-66
@@ -16,12 +16,13 @@ from importlib import import_module
|
||||
from flask import current_app, Blueprint, flash, redirect, request, \
|
||||
session, url_for
|
||||
from flask.ext.login import AnonymousUser as AnonymousUserBase, \
|
||||
UserMixin as BaseUserMixin, LoginManager, login_required, login_user, \
|
||||
logout_user, current_user, login_url
|
||||
from flask.ext.principal import Identity, Principal, RoleNeed, UserNeed, \
|
||||
Permission, AnonymousIdentity, identity_changed, identity_loaded
|
||||
UserMixin as BaseUserMixin, LoginManager, login_required, \
|
||||
current_user, login_url
|
||||
from flask.ext.principal import Principal, RoleNeed, UserNeed, \
|
||||
Permission, identity_loaded
|
||||
from flask.ext.wtf import Form, TextField, PasswordField, SubmitField, \
|
||||
HiddenField, Required, BooleanField
|
||||
from flask.ext.security import views
|
||||
from passlib.context import CryptContext
|
||||
from werkzeug.datastructures import ImmutableList
|
||||
|
||||
@@ -36,9 +37,11 @@ _default_config = {
|
||||
'SECURITY_LOGIN_FORM': 'flask.ext.security::LoginForm',
|
||||
'SECURITY_AUTH_URL': '/auth',
|
||||
'SECURITY_LOGOUT_URL': '/logout',
|
||||
'SECURITY_RESET_URL': '/reset',
|
||||
'SECURITY_LOGIN_VIEW': '/login',
|
||||
'SECURITY_POST_LOGIN_VIEW': '/',
|
||||
'SECURITY_POST_LOGOUT_VIEW': '/',
|
||||
'SECURITY_RESET_PASSWORD_WITHIN': 10
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +90,7 @@ class RoleCreationError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def roles_required(*args):
|
||||
def roles_required(*roles):
|
||||
"""View decorator which specifies that a user must have all the specified
|
||||
roles. Example::
|
||||
|
||||
@@ -101,7 +104,6 @@ def roles_required(*args):
|
||||
|
||||
:param args: The required roles.
|
||||
"""
|
||||
roles = args
|
||||
perm = Permission(*[RoleNeed(role) for role in roles])
|
||||
|
||||
def wrapper(fn):
|
||||
@@ -121,7 +123,7 @@ def roles_required(*args):
|
||||
return wrapper
|
||||
|
||||
|
||||
def roles_accepted(*args):
|
||||
def roles_accepted(*roles):
|
||||
"""View decorator which specifies that a user must have at least one of the
|
||||
specified roles. Example::
|
||||
|
||||
@@ -135,7 +137,6 @@ def roles_accepted(*args):
|
||||
|
||||
:param args: The possible roles.
|
||||
"""
|
||||
roles = args
|
||||
perms = [Permission(RoleNeed(role)) for role in roles]
|
||||
|
||||
def wrapper(fn):
|
||||
@@ -149,8 +150,9 @@ def roles_accepted(*args):
|
||||
if perm.can():
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
current_app.logger.debug('Identity does not provide at least one '
|
||||
'role: %s' % [r for r in roles])
|
||||
current_app.logger.debug('Current user does not provide a required '
|
||||
'role. Accepted: %s Provided: %s' % ([r for r in roles],
|
||||
[r.name for r in current_user.roles]))
|
||||
|
||||
_do_flash('You do not have permission to view this resource',
|
||||
'error')
|
||||
@@ -203,6 +205,24 @@ class AnonymousUser(AnonymousUserBase):
|
||||
return False
|
||||
|
||||
|
||||
def load_user(user_id):
|
||||
try:
|
||||
return current_app.security.datastore.with_id(user_id)
|
||||
except Exception, e:
|
||||
current_app.logger.error('Error getting user: %s' % e)
|
||||
return None
|
||||
|
||||
|
||||
def on_identity_loaded(sender, identity):
|
||||
if hasattr(current_user, 'id'):
|
||||
identity.provides.add(UserNeed(current_user.id))
|
||||
|
||||
for role in current_user.roles:
|
||||
identity.provides.add(RoleNeed(role.name))
|
||||
|
||||
identity.user = current_user
|
||||
|
||||
|
||||
class Security(object):
|
||||
"""The :class:`Security` class initializes the Flask-Security extension.
|
||||
|
||||
@@ -212,7 +232,7 @@ class Security(object):
|
||||
def __init__(self, app=None, datastore=None):
|
||||
self.init_app(app, datastore)
|
||||
|
||||
def init_app(self, app, datastore):
|
||||
def init_app(self, app, datastore, recoverable=False):
|
||||
"""Initializes the Flask-Security extension for the specified
|
||||
application and datastore implentation.
|
||||
|
||||
@@ -231,6 +251,7 @@ class Security(object):
|
||||
login_manager.setup_app(app)
|
||||
|
||||
Provider = _get_class_from_string(app, 'AUTH_PROVIDER')
|
||||
Form = _get_class_from_string(app, 'LOGIN_FORM')
|
||||
pw_hash = _config_value(app, 'PASSWORD_HASH')
|
||||
|
||||
self.login_manager = login_manager
|
||||
@@ -238,67 +259,31 @@ class Security(object):
|
||||
self.auth_provider = Provider(Form)
|
||||
self.principal = Principal(app)
|
||||
self.datastore = datastore
|
||||
self.form_class = _get_class_from_string(app, 'LOGIN_FORM')
|
||||
self.form_class = Form
|
||||
self.auth_url = _config_value(app, 'AUTH_URL')
|
||||
self.logout_url = _config_value(app, 'LOGOUT_URL')
|
||||
self.reset_url = _config_value(app, 'RESET_URL')
|
||||
self.post_login_view = _config_value(app, 'POST_LOGIN_VIEW')
|
||||
self.post_logout_view = _config_value(app, 'POST_LOGOUT_VIEW')
|
||||
self.reset_password_within = _config_value(app, 'RESET_PASSWORD_WITHIN')
|
||||
|
||||
@identity_loaded.connect_via(app)
|
||||
def on_identity_loaded(sender, identity):
|
||||
if hasattr(current_user, 'id'):
|
||||
identity.provides.add(UserNeed(current_user.id))
|
||||
identity_loaded.connect_via(app)(on_identity_loaded)
|
||||
|
||||
for role in current_user.roles:
|
||||
identity.provides.add(RoleNeed(role.name))
|
||||
|
||||
identity.user = current_user
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
try:
|
||||
return app.security.datastore.with_id(user_id)
|
||||
except Exception, e:
|
||||
app.logger.error('Error getting user: %s' % e)
|
||||
return None
|
||||
login_manager.user_loader(load_user)
|
||||
|
||||
bp = Blueprint('auth', __name__)
|
||||
|
||||
@bp.route(self.auth_url, methods=['POST'], endpoint='authenticate')
|
||||
def authenticate():
|
||||
try:
|
||||
form = current_app.security.form_class()
|
||||
user = current_app.security.auth_provider.authenticate(form)
|
||||
bp.route(self.auth_url,
|
||||
methods=['POST'],
|
||||
endpoint='authenticate')(views.authenticate)
|
||||
|
||||
if login_user(user, remember=form.remember.data):
|
||||
redirect_url = _get_post_login_redirect()
|
||||
identity_changed.send(app, identity=Identity(user.id))
|
||||
app.logger.debug('User %s logged in. Redirecting to: '
|
||||
'%s' % (user, redirect_url))
|
||||
return redirect(redirect_url)
|
||||
bp.route(self.logout_url,
|
||||
endpoint='logout')(login_required(views.logout))
|
||||
|
||||
raise BadCredentialsError('Inactive user')
|
||||
|
||||
except BadCredentialsError, e:
|
||||
message = '%s' % e
|
||||
_do_flash(message, 'error')
|
||||
redirect_url = request.referrer or login_manager.login_view
|
||||
app.logger.error('Unsuccessful authentication attempt: %s. '
|
||||
'Redirect to: %s' % (message, redirect_url))
|
||||
return redirect(redirect_url)
|
||||
|
||||
@bp.route(self.logout_url, endpoint='logout')
|
||||
@login_required
|
||||
def logout():
|
||||
for value in ('identity.name', 'identity.auth_type'):
|
||||
session.pop(value, None)
|
||||
|
||||
identity_changed.send(app, identity=AnonymousIdentity())
|
||||
logout_user()
|
||||
|
||||
redirect_url = _find_redirect('SECURITY_POST_LOGOUT_VIEW')
|
||||
app.logger.debug('User logged out. Redirect to: %s' % redirect_url)
|
||||
return redirect(redirect_url)
|
||||
if recoverable:
|
||||
bp.route(self.reset_url,
|
||||
methods=['POST'],
|
||||
endpoint='reset')(views.reset)
|
||||
|
||||
app.register_blueprint(bp, url_prefix=_config_value(app, 'URL_PREFIX'))
|
||||
app.security = self
|
||||
@@ -361,11 +346,9 @@ class AuthenticationProvider(object):
|
||||
try:
|
||||
user = current_app.security.datastore.find_user(user_identifier)
|
||||
except AttributeError, e:
|
||||
self.auth_error("Could not find user service: %s" % e)
|
||||
self.auth_error("Could not find user datastore: %s" % e)
|
||||
except UserNotFoundError, e:
|
||||
raise BadCredentialsError("Specified user does not exist")
|
||||
except AttributeError, e:
|
||||
self.auth_error('Invalid user service: %s' % e)
|
||||
except Exception, e:
|
||||
self.auth_error('Unexpected authentication error: %s' % e)
|
||||
|
||||
@@ -424,5 +407,5 @@ def _find_redirect(key):
|
||||
return result
|
||||
|
||||
|
||||
def _config_value(app, key):
|
||||
return app.config['SECURITY_' + key.upper()]
|
||||
def _config_value(app, key, default=None):
|
||||
return app.config.get('SECURITY_' + key.upper(), default)
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from flask import current_app, redirect, request, session
|
||||
from flask.ext.login import login_user, logout_user
|
||||
from flask.ext.principal import Identity, AnonymousIdentity, identity_changed
|
||||
from flask.ext import security
|
||||
|
||||
|
||||
def authenticate():
|
||||
try:
|
||||
form = current_app.security.form_class()
|
||||
user = current_app.security.auth_provider.authenticate(form)
|
||||
|
||||
if login_user(user, remember=form.remember.data):
|
||||
redirect_url = security._get_post_login_redirect()
|
||||
identity_changed.send(current_app._get_current_object(),
|
||||
identity=Identity(user.id))
|
||||
current_app.logger.debug('User %s logged in. Redirecting to: '
|
||||
'%s' % (user, redirect_url))
|
||||
return redirect(redirect_url)
|
||||
|
||||
raise security.BadCredentialsError('Inactive user')
|
||||
|
||||
except security.BadCredentialsError, e:
|
||||
message = '%s' % e
|
||||
security._do_flash(message, 'error')
|
||||
redirect_url = request.referrer or \
|
||||
current_app.security.login_manager.login_view
|
||||
current_app.logger.error('Unsuccessful authentication attempt: %s. '
|
||||
'Redirect to: %s' % (message, redirect_url))
|
||||
return redirect(redirect_url)
|
||||
|
||||
|
||||
def logout():
|
||||
for value in ('identity.name', 'identity.auth_type'):
|
||||
session.pop(value, None)
|
||||
|
||||
identity_changed.send(current_app._get_current_object(),
|
||||
identity=AnonymousIdentity())
|
||||
logout_user()
|
||||
|
||||
redirect_url = security._find_redirect('SECURITY_POST_LOGOUT_VIEW')
|
||||
current_app.logger.debug('User logged out. Redirect to: %s' % redirect_url)
|
||||
return redirect(redirect_url)
|
||||
|
||||
|
||||
def reset():
|
||||
# user = something
|
||||
# if reset_password_period_valid_for_user(user):
|
||||
# user.reset_password_sent_at = datetime.utcnow()
|
||||
# user.reset_password_token = token
|
||||
# current_app.security.datastore._save_model(user)
|
||||
pass
|
||||
+15
-15
@@ -40,55 +40,55 @@ class DefaultSecurityTests(SecurityTest):
|
||||
|
||||
def test_login_view(self):
|
||||
r = self._get('/login')
|
||||
assert 'Login Page' in r.data
|
||||
self.assertIn('Login Page', r.data)
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate("matt", "password")
|
||||
assert 'Home Page' in r.data
|
||||
self.assertIn('Home Page', r.data)
|
||||
|
||||
def test_unprovided_username(self):
|
||||
r = self.authenticate("", "password")
|
||||
assert "Username not provided" in r.data
|
||||
self.assertIn("Username not provided", r.data)
|
||||
|
||||
def test_unprovided_password(self):
|
||||
r = self.authenticate("matt", "")
|
||||
assert "Password not provided" in r.data
|
||||
self.assertIn("Password not provided", r.data)
|
||||
|
||||
def test_invalid_user(self):
|
||||
r = self.authenticate("bogus", "password")
|
||||
assert "Specified user does not exist" in r.data
|
||||
self.assertIn("Specified user does not exist", r.data)
|
||||
|
||||
def test_bad_password(self):
|
||||
r = self.authenticate("matt", "bogus")
|
||||
assert "Password does not match" in r.data
|
||||
self.assertIn("Password does not match", r.data)
|
||||
|
||||
def test_inactive_user(self):
|
||||
r = self.authenticate("tiya", "password")
|
||||
assert "Inactive user" in r.data
|
||||
self.assertIn("Inactive user", r.data)
|
||||
|
||||
def test_logout(self):
|
||||
self.authenticate("matt", "password")
|
||||
r = self.logout()
|
||||
assert 'Home Page' in r.data
|
||||
self.assertIn('Home Page', r.data)
|
||||
|
||||
def test_unauthorized_access(self):
|
||||
r = self._get('/profile', follow_redirects=True)
|
||||
assert 'Please log in to access this page' in r.data
|
||||
self.assertIn('Please log in to access this page', r.data)
|
||||
|
||||
def test_authorized_access(self):
|
||||
self.authenticate("matt", "password")
|
||||
r = self._get("/profile")
|
||||
assert 'profile' in r.data
|
||||
self.assertIn('profile', r.data)
|
||||
|
||||
def test_valid_admin_role(self):
|
||||
self.authenticate("matt", "password")
|
||||
r = self._get("/admin")
|
||||
assert 'Admin Page' in r.data
|
||||
self.assertIn('Admin Page', r.data)
|
||||
|
||||
def test_invalid_admin_role(self):
|
||||
self.authenticate("joe", "password")
|
||||
r = self._get("/admin", follow_redirects=True)
|
||||
assert 'Home Page' in r.data
|
||||
self.assertIn('Home Page', r.data)
|
||||
|
||||
def test_roles_accepted(self):
|
||||
for user in ("matt", "joe"):
|
||||
@@ -120,16 +120,16 @@ class ConfiguredSecurityTests(SecurityTest):
|
||||
|
||||
def test_login_view(self):
|
||||
r = self._get('/custom_login')
|
||||
assert "Custom Login Page" in r.data
|
||||
self.assertIn("Custom Login Page", r.data)
|
||||
|
||||
def test_authenticate(self):
|
||||
r = self.authenticate("matt", "password", endpoint="/custom_auth")
|
||||
assert 'Post Login' in r.data
|
||||
self.assertIn('Post Login', r.data)
|
||||
|
||||
def test_logout(self):
|
||||
self.authenticate("matt", "password", endpoint="/custom_auth")
|
||||
r = self.logout(endpoint="/custom_logout")
|
||||
assert 'Post Logout' in r.data
|
||||
self.assertIn('Post Logout', r.data)
|
||||
|
||||
|
||||
class MongoEngineSecurityTests(DefaultSecurityTests):
|
||||
|
||||
Reference in New Issue
Block a user