c# - Entity Framework - 存储库检查(大文本)

标签 c# .net asp.net entity-framework repository

我正在使用 Entity Framework 在 C#/ASP.NET 中创建一个完整的存储库,但目前我担心我忽略了诸如处理我的 ObjectContexts 之类的东西。在以下几行代码中,您将看到我的完整存储库(至少是你们理解我的问题所需要的),我希望有人能够仔细查看并告诉我是否犯了一些错误。

这个项目对我来说非常重要,但我是存储库/EF 模型的新手。

Global.asax

public class Global : System.Web.HttpApplication
{
    private WebObjectContextStorage _storage;

    public override void Init()
    {
        base.Init();
        _storage = new WebObjectContextStorage(this);
    }

    protected void Application_Start(object sender, EventArgs e)
    {

    }

    protected void Session_Start(object sender, EventArgs e)
    {

    }

    protected void Application_BeginRequest(object sender, EventArgs e)
    {
        ObjectContextInitializer.Instance().InitializeObjectContextOnce(() =>
        {
            ObjectContextManager.InitStorage(_storage);
        });
    }

    protected void Application_EndRequest(object sender, EventArgs e)
    {

    }

    protected void Application_AuthenticateRequest(object sender, EventArgs e)
    {

    }

    protected void Application_Error(object sender, EventArgs e)
    {

    }

    protected void Session_End(object sender, EventArgs e)
    {

    }

    protected void Application_End(object sender, EventArgs e)
    {

    }
}

对象上下文管理器

public static class ObjectContextManager
{
    public static void InitStorage(IObjectContextStorage storage)
    {
        if (storage == null) 
        {
            throw new ArgumentNullException("storage");
        }
        if ((Storage != null) && (Storage != storage))
        {
            throw new ApplicationException("A storage mechanism has already been configured for this application");
        }            
        Storage = storage;
    }

    /// <summary>
    /// The default connection string name used if only one database is being communicated with.
    /// </summary>
    public static readonly string DefaultConnectionStringName = "TraceConnection";        

    /// <summary>
    /// Used to get the current object context session if you're communicating with a single database.
    /// When communicating with multiple databases, invoke <see cref="CurrentFor()" /> instead.
    /// </summary>
    public static ObjectContext Current
    {
        get
        {
            return CurrentFor(DefaultConnectionStringName);
        }
    }

    /// <summary>
    /// Used to get the current ObjectContext associated with a key; i.e., the key 
    /// associated with an object context for a specific database.
    /// 
    /// If you're only communicating with one database, you should call <see cref="Current" /> instead,
    /// although you're certainly welcome to call this if you have the key available.
    /// </summary>
    public static ObjectContext CurrentFor(string key)
    {
        if (string.IsNullOrEmpty(key))
        {
            throw new ArgumentNullException("key");
        }

        if (Storage == null)
        {
            throw new ApplicationException("An IObjectContextStorage has not been initialized");
        }

        ObjectContext context = null;
        lock (_syncLock)
        {
            context = Storage.GetObjectContextForKey(key);

            if (context == null)
            {
                context = ObjectContextFactory.GetTraceContext(key);
                Storage.SetObjectContextForKey(key, context);
            }
        }

        return context;
    }

    /// <summary>
    /// This method is used by application-specific object context storage implementations
    /// and unit tests. Its job is to walk thru existing cached object context(s) and Close() each one.
    /// </summary>
    public static void CloseAllObjectContexts()
    {
        foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
        {
            if (ctx.Connection.State == System.Data.ConnectionState.Open)
                ctx.Connection.Close();
        }
    }      

    /// <summary>
    /// An application-specific implementation of IObjectContextStorage must be setup either thru
    /// <see cref="InitStorage" /> or one of the <see cref="Init" /> overloads. 
    /// </summary>
    private static IObjectContextStorage Storage { get; set; }

    private static object _syncLock = new object();
}

ObjectContextInitializer

public class ObjectContextInitializer
{
    private static readonly object syncLock = new object();
    private static ObjectContextInitializer instance;

    protected ObjectContextInitializer() { }

    private bool isInitialized = false;

    public static ObjectContextInitializer Instance()
    {
        if (instance == null)
        {
            lock (syncLock)
            {
                if (instance == null)
                {
                    instance = new ObjectContextInitializer();
                }
            }
        }

        return instance;
    }

    /// <summary>
    /// This is the method which should be given the call to intialize the ObjectContext; e.g.,
    /// ObjectContextInitializer.Instance().InitializeObjectContextOnce(() => InitializeObjectContext());
    /// where InitializeObjectContext() is a method which calls ObjectContextManager.Init()
    /// </summary>
    /// <param name="initMethod"></param>
    public void InitializeObjectContextOnce(Action initMethod)
    {
        lock (syncLock)
        {
            if (!isInitialized)
            {
                initMethod();
                isInitialized = true;
            }
        }
    }

}

ObjectContextFactory

public static class ObjectContextFactory
{
    /// <summary>
    /// Gets the TraceContext
    /// </summary>
    /// <param name="connectionString">Connection string to use for database queries</param>
    /// <returns>The TraceContext</returns>
    public static TraceContext GetTraceContext(string configName)
    {
        string connectionString = ConfigurationManager.ConnectionStrings[configName].ConnectionString;
        return new TraceContext(connectionString);
    }
}

WebObjectContextStorage

public class WebObjectContextStorage : IObjectContextStorage
{   
    public WebObjectContextStorage(HttpApplication app)
    { 
        app.EndRequest += (sender, args) =>
                              {
                                  ObjectContextManager.CloseAllObjectContexts();
                                  HttpContext.Current.Items.Remove(HttpContextObjectContextStorageKey);
                              };
    }        

    public ObjectContext GetObjectContextForKey(string key)
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        return storage.GetObjectContextForKey(key);
    }

    public void SetObjectContextForKey(string factoryKey, ObjectContext session)
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        storage.SetObjectContextForKey(factoryKey, session);
    }

    public IEnumerable<ObjectContext> GetAllObjectContexts()
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        return storage.GetAllObjectContexts();
    }

    private ObjectContextStorage GetObjectContextStorage()
    {
        HttpContext context = HttpContext.Current;
        ObjectContextStorage storage = context.Items[HttpContextObjectContextStorageKey] as ObjectContextStorage;
        if (storage == null)
        {
            storage = new ObjectContextStorage();
            context.Items[HttpContextObjectContextStorageKey] = storage;
        }
        return storage;
    }       

    private static readonly string HttpContextObjectContextStorageKey = "HttpContextObjectContextStorageKey";       
}

ObjectContextStorage

public class ObjectContextStorage : IObjectContextStorage
{
    private Dictionary<string, ObjectContext> storage = new Dictionary<string, ObjectContext>();

    /// <summary>
    /// Initializes a new instance of the <see cref="SimpleObjectContextStorage"/> class.
    /// </summary>
    public ObjectContextStorage() { }

    /// <summary>
    /// Returns the object context associated with the specified key or
    /// null if the specified key is not found.
    /// </summary>
    /// <param name="key">The key.</param>
    /// <returns></returns>
    public ObjectContext GetObjectContextForKey(string key)
    {
        ObjectContext context;
        if (!this.storage.TryGetValue(key, out context))
            return null;
        return context;
    }


    /// <summary>
    /// Stores the object context into a dictionary using the specified key.
    /// If an object context already exists by the specified key, 
    /// it gets overwritten by the new object context passed in.
    /// </summary>
    /// <param name="key">The key.</param>
    /// <param name="objectContext">The object context.</param>
    public void SetObjectContextForKey(string key, ObjectContext objectContext)
    {           
        this.storage.Add(key, objectContext);           
    }

    /// <summary>
    /// Returns all the values of the internal dictionary of object contexts.
    /// </summary>
    /// <returns></returns>
    public IEnumerable<ObjectContext> GetAllObjectContexts()
    {
        return this.storage.Values;
    }
}

通用存储库

public class GenericRepository : IRepository
{
    private readonly string _connectionStringName;
    private ObjectContext _objectContext;
    private readonly PluralizationService _pluralizer = PluralizationService.CreateService(CultureInfo.GetCultureInfo("en"));
    private bool _usePlurazation;

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository&lt;TEntity&gt;"/> class.
    /// </summary>
    public GenericRepository()
        : this(string.Empty, false)
    {
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository&lt;TEntity&gt;"/> class.
    /// </summary>
    /// <param name="connectionStringName">Name of the connection string.</param>
    public GenericRepository(string connectionStringName, bool usePlurazation)
    {
        this._connectionStringName = connectionStringName;
        this._usePlurazation = usePlurazation;
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository"/> class.
    /// </summary>
    /// <param name="objectContext">The object context.</param>
    public GenericRepository(ObjectContext objectContext, bool usePlurazation)
    {
        if (objectContext == null)
            throw new ArgumentNullException("objectContext");
        this._objectContext = objectContext;
        this._usePlurazation = usePlurazation;
    }

    public TEntity GetByKey<TEntity>(object keyValue) where TEntity : class
    {
        EntityKey key = GetEntityKey<TEntity>(keyValue);

        object originalItem;
        if (ObjectContext.TryGetObjectByKey(key, out originalItem))
        {
            return (TEntity)originalItem;
        }
        return default(TEntity);
    }

    public IQueryable<TEntity> GetQuery<TEntity>() where TEntity : class
    {
        var entityName = GetEntityName<TEntity>();
        return ObjectContext.CreateQuery<TEntity>(entityName).OfType<TEntity>();
    }

    public IQueryable<TEntity> GetQuery<TEntity>(Expression<Func<TEntity, bool>> predicate) where TEntity : class
    {
        return GetQuery<TEntity>().Where(predicate);
    }

    public IQueryable<TEntity> GetQuery<TEntity>(ISpecification<TEntity> specification) where TEntity : class
    {
        return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>());
    }

    public IEnumerable<TEntity> Get<TEntity>(Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return GetQuery<TEntity>().OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return GetQuery<TEntity>().OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public IEnumerable<TEntity> Get<TEntity>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return GetQuery<TEntity>().Where(predicate).OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return GetQuery<TEntity>().Where(predicate).OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public IEnumerable<TEntity> Get<TEntity>(ISpecification<TEntity> specification, Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>()).OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>()).OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public TEntity Single<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().SingleOrDefault<TEntity>(criteria);
    }

    public TEntity Single<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntityFrom(GetQuery<TEntity>());
    }

    public TEntity First<TEntity>(Expression<Func<TEntity, bool>> predicate) where TEntity : class
    {
        return GetQuery<TEntity>().FirstOrDefault(predicate);
    }

    public TEntity First<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>()).FirstOrDefault();
    }

    public void Add<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }
        ObjectContext.AddObject(GetEntityName<TEntity>(), entity);
    }

    public void Attach<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }

        ObjectContext.AttachTo(GetEntityName<TEntity>(), entity);
    }

    public void Delete<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }
        ObjectContext.DeleteObject(entity);
    }

    public void Delete<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        IEnumerable<TEntity> records = Find<TEntity>(criteria);

        foreach (TEntity record in records)
        {
            Delete<TEntity>(record);
        }
    }

    public void Delete<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        IEnumerable<TEntity> records = Find<TEntity>(criteria);
        foreach (TEntity record in records)
        {
            Delete<TEntity>(record);
        }
    }

    public IEnumerable<TEntity> GetAll<TEntity>() where TEntity : class
    {
        return GetQuery<TEntity>().AsEnumerable();
    }

    public void Update<TEntity>(TEntity entity) where TEntity : class
    {
        var fqen = GetEntityName<TEntity>();

        object originalItem;
        EntityKey key = ObjectContext.CreateEntityKey(fqen, entity);
        if (ObjectContext.TryGetObjectByKey(key, out originalItem))
        {
            ObjectContext.ApplyCurrentValues(key.EntitySetName, entity);
        }
    }

    public IEnumerable<TEntity> Find<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Where(criteria);
    }

    public TEntity FindOne<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Where(criteria).FirstOrDefault();
    }

    public TEntity FindOne<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntityFrom(GetQuery<TEntity>());
    }

    public IEnumerable<TEntity> Find<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>());
    }

    public int Count<TEntity>() where TEntity : class
    {
        return GetQuery<TEntity>().Count();
    }

    public int Count<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Count(criteria);
    }

    public int Count<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>()).Count();
    }

    public IUnitOfWork UnitOfWork
    {
        get
        {
            if (unitOfWork == null)
            {
                unitOfWork = new UnitOfWork(this.ObjectContext);
            }
            return unitOfWork;
        }
    }

    private ObjectContext ObjectContext
    {
        get
        {
            if (this._objectContext == null)
            {
                if (string.IsNullOrEmpty(this._connectionStringName))
                {
                    this._objectContext = ObjectContextManager.Current;
                }
                else
                {
                    this._objectContext = ObjectContextManager.CurrentFor(this._connectionStringName);
                }
            }
            return this._objectContext;
        }
    }

    private EntityKey GetEntityKey<TEntity>(object keyValue) where TEntity : class
    {
        var entitySetName = GetEntityName<TEntity>();
        var objectSet = ObjectContext.CreateObjectSet<TEntity>();
        var keyPropertyName = objectSet.EntitySet.ElementType.KeyMembers[0].ToString();
        var entityKey = new EntityKey(entitySetName, new[] { new EntityKeyMember(keyPropertyName, keyValue) });
        return entityKey;
    }

    private string GetEntityName<TEntity>() where TEntity : class
    {
        // WARNING! : Exceptions for inheritance


        if (_usePlurazation)
        {
             return string.Format("{0}.{1}", ObjectContext.DefaultContainerName, _pluralizer.Pluralize(typeof(TEntity).Name));

        }
        else
        {
             return string.Format("{0}.{1}", ObjectContext.DefaultContainerName, typeof(TEntity).Name);

        }
    }

    private IUnitOfWork unitOfWork;
}

我知道通读代码需要一些时间,但是如果有人查看它并提供有关如何做得更好或我不处理对象的提示,这将对我有所帮助。

另外我有一个小问题:“我想在这个存储库之上放置一个业务层,这将使 global.asax 之类的东西保持不变,我猜但需要静态类(对吗?)像 BookProvider 这样给我所有数据关于我的书实体?

提前致谢!

最佳答案

我能给出的唯一具体评论是关于处理上下文:

foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
{
    if (ctx.Connection.State == System.Data.ConnectionState.Open)
        ctx.Connection.Close();
}
ObjectContext工具IDisposable ,所以我认为标准的方式是:
foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
    ctx.Dispose();

据我所知ObjectContext.Dispose()只是关闭连接,所以它和你正在做的一样。但我认为这是一个内部实现细节,可能会在 EF 版本之间发生变化。

您的通用存储库是一个,因为有很多这种类型。在查看方法时,我想到了几点:
  • 既然你暴露了IQueryablepublic IQueryable<TEntity> GetQuery<TEntity>(...)为什么您需要大多数其他方法,例如 Single , First , Count , 等等。? (为什么不是 Any 等?)你从你的 IQueryable 得到这一切.
  • 您的 Update方法仅适用于标量属性。但这是通用存储库的常见问题。以通用方式更新实体没有简单的解决方案或根本没有解决方案。
  • 您想通过使用存储库模式达到什么目标?如果您考虑到内存中数据存储的单元可测试性,则无法公开 IQueryable因为 LINQ to Entities 和 LINQ to Objects 不一样。要测试您的 IQueryables 是否工作,您需要集成测试以及您的应用程序应该在生产中使用的真实数据库。但是如果你不暴露IQueryable您的存储库需要许多特定于业务的方法,这些方法将结果作为 POCO、POCO 集合或投影/选定属性的 DTO 返回,并隐藏内部查询规范,以便您可以使用内存数据模拟这些方法来测试您的业务逻辑。但这是通用存储库不再足够的地方。 (例如,您将如何编写一个 LINQ Join,其中在只有一种实体类型作为泛型参数的存储库中涉及多个实体/对象集?)

  • 如果你问十个人他们的存储库是如何构建的,你会得到十个不同的答案。没有一个是真正错误的或最好的,因为这取决于您将使用此存储库构建的应用程序。我相信没有人能告诉你你的仓库真正值多少钱。当您开始编写应用程序时,您将在实践中看到它。对于某些应用程序,它可能是过度架构的(我认为这是最危险的,因为管理和控制无意义的架构是昂贵的,并且浪费了您编写真正的应用程序内容所浪费的时间)。对于其他需求,您可能需要扩展存储库。例如:
  • 您如何处理实体导航属性的显式加载或查询(在 EF 4.1 中使用 CreateSourceQueryDbEntityEntry.Collection/Reference)?如果您的应用程序从不需要显式加载:很好。如果有必要,您需要扩展您的 repo 。
  • 你如何控制急切加载?有时,您可能只想要一个父实体。有时你想Includ children1 集合,有时还有 child2 引用。
  • 如何手动设置实体的实体状态?也许你永远不需要。但是在下一个应用程序中它可能会很有帮助。
  • 如何手动从上下文中分离实体?
  • 您如何控制加载行为(通过上下文跟踪或不跟踪实体)?
  • 您如何手动控制延迟加载行为和更改跟踪代理的创建?
  • 如何手动创建实体代理?在某些情况下,当您使用延迟加载或更改跟踪代理时,您可能需要它。
  • 如何在不构建结果集合的情况下将实体加载到上下文中?又一种存储库方法,也许……也许不是。谁事先知道您的应用程序逻辑需要什么。

  • 等等等等...

    关于c# - Entity Framework - 存储库检查(大文本),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/7075107/

    相关文章:

    c# - 为什么自定义类型属性不能在泛型中继承?

    .net - 如何通过反射检索字符串并将其升序连接

    .net - 更好地理解 Orchard 对 'Shape' 的想法

    javascript - 如何使用 Google 图表 API 在 X 轴上使用日期?

    javascript - 解析json时出错

    c# - 在对话框中输出字符串

    c# - 我怎样才能完全禁用 DataGridView 上的制表符但保留选择行的能力?

    c# - 路由模板分隔符 '/'不能连续出现-属性路由问题

    c# - 锯齿状数组类型属性

    c# - Datetime.ParseExact() 不起作用