更新(EF Core 5.0):由于元数据接口更改和跳过导航的引入,更新了部分代码。注意代码不处理多对多跳过导航属性。
// process navigations
foreach (var navEntry in dbEntry.Navigations)
{
if (navEntry.Metadata is not INavigation navigation) continue; // skip navigation pproperty
if (!visited.Add(navigation.ForeignKey)) continue; // already processed
await navEntry.LoadAsync();
if (!navigation.IsCollection)
{
// reference type navigation property
var refValue = navigation.GetGetter().GetClrValue(entity);
navEntry.CurrentValue = refValue == null ? null :
await context.UpdateGraphAsync(navEntry.CurrentValue, refValue, visited);
}
else
{
// collection type navigation property
var accessor = navigation.GetCollectionAccessor();
var items = (IEnumerable<object>)accessor.GetOrCreate(entity, false);
var dbItems = (IEnumerable<object>)accessor.GetOrCreate(dbEntity, false);
var itemType = navigation.TargetEntityType;
var keyProperties = itemType.FindPrimaryKey().Properties
.Select((p, i) => (Index: i, Getter: p.GetGetter(), Comparer: p.GetKeyValueComparer()))
.ToList();
var keyValues = new object[keyProperties.Count];
void GetKeyValues(object sourceItem)
{
foreach (var p in keyProperties)
keyValues[p.Index] = p.Getter.GetClrValue(sourceItem);
}
object FindItem(IEnumerable<object> targetCollection, object sourceItem)
{
GetKeyValues(sourceItem);
foreach (var targetItem in targetCollection)
{
bool keyMatch = true;
foreach (var p in keyProperties)
{
(var keyA, var keyB) = (p.Getter.GetClrValue(targetItem), keyValues[p.Index]);
keyMatch = p.Comparer?.Equals(keyA, keyB) ?? object.Equals(keyA, keyB);
if (!keyMatch) break;
}
if (keyMatch) return targetItem;
}
return null;
}
// Remove db items missing in the current list
foreach (var dbItem in dbItems.ToList())
if (FindItem(items, dbItem) == null) accessor.Remove(dbEntity, dbItem);
// Add current items missing in the db list, update others
var existingItems = dbItems.ToList();
foreach (var item in items)
{
var dbItem = FindItem(existingItems, item);
if (dbItem == null)
accessor.Add(dbEntity, item, false);
await context.UpdateGraphAsync(dbItem, item, visited);
}
}
}
Update:
评论中还提出了一些其他问题。如何处理引用导航属性以及如果相关实体未实现此类通用接口该怎么办,以及编译器在使用此类通用方法签名时无法推断通用类型参数。
经过一番思考后,我得出的结论是,根本不需要基类/接口(甚至通用实体类型),因为 EF Core 元数据包含使用 PK 所需的所有信息(由Find
/ FindAsync
例如方法和更改跟踪器)。
以下是仅使用 EF Core 元数据信息/服务递归应用断开连接的实体图修改的方法。如果需要,可以对其进行修改以接收“排除”过滤器,以防应跳过某些实体/集合:
public static class EntityGraphUpdateHelper
{
public static async ValueTask<object> UpdateGraphAsync(this DbContext context, object entity) =>
await context.UpdateGraphAsync(await context.FindEntityAsync(entity), entity, new HashSet<IForeignKey>());
private static async ValueTask<object> UpdateGraphAsync(this DbContext context, object dbEntity, object entity, HashSet<IForeignKey> visited)
{
bool isNew = dbEntity == null;
if (isNew) dbEntity = entity;
var dbEntry = context.Entry(dbEntity);
if (isNew)
dbEntry.State = EntityState.Added;
else
{
// ensure is attached (tracked)
if (dbEntry.State == EntityState.Detached)
dbEntry.State = EntityState.Unchanged;
// update primitive values
dbEntry.CurrentValues.SetValues(entity);
}
// process navigations
foreach (var navEntry in dbEntry.Navigations)
{
if (!visited.Add(navEntry.Metadata.ForeignKey)) continue; // already processed
await navEntry.LoadAsync();
if (!navEntry.Metadata.IsCollection())
{
// reference type navigation property
var refValue = navEntry.Metadata.GetGetter().GetClrValue(entity);
navEntry.CurrentValue = refValue == null ? null :
await context.UpdateGraphAsync(navEntry.CurrentValue, refValue, visited);
}
else
{
// collection type navigation property
var accessor = navEntry.Metadata.GetCollectionAccessor();
var items = (IEnumerable<object>)accessor.GetOrCreate(entity, false);
var dbItems = (IEnumerable<object>)accessor.GetOrCreate(dbEntity, false);
var itemType = navEntry.Metadata.GetTargetType();
var keyProperties = itemType.FindPrimaryKey().Properties
.Select((p, i) => (Index: i, Getter: p.GetGetter(), Comparer: p.GetKeyValueComparer()))
.ToList();
var keyValues = new object[keyProperties.Count];
void GetKeyValues(object sourceItem)
{
foreach (var p in keyProperties)
keyValues[p.Index] = p.Getter.GetClrValue(sourceItem);
}
object FindItem(IEnumerable<object> targetCollection, object sourceItem)
{
GetKeyValues(sourceItem);
foreach (var targetItem in targetCollection)
{
bool keyMatch = true;
foreach (var p in keyProperties)
{
(var keyA, var keyB) = (p.Getter.GetClrValue(targetItem), keyValues[p.Index]);
keyMatch = p.Comparer?.Equals(keyA, keyB) ?? object.Equals(keyA, keyB);
if (!keyMatch) break;
}
if (keyMatch) return targetItem;
}
return null;
}
// Remove db items missing in the current list
foreach (var dbItem in dbItems.ToList())
if (FindItem(items, dbItem) == null) accessor.Remove(dbEntity, dbItem);
// Add current items missing in the db list, update others
var existingItems = dbItems.ToList();
foreach (var item in items)
{
var dbItem = FindItem(existingItems, item);
if (dbItem == null)
accessor.Add(dbEntity, item, false);
await context.UpdateGraphAsync(dbItem, item, visited);
}
}
}
return dbEntity;
}
public static ValueTask<object> FindEntityAsync(this DbContext context, object entity)
{
var entityType = context.Model.FindRuntimeEntityType(entity.GetType());
var keyProperties = entityType.FindPrimaryKey().Properties;
var keyValues = new object[keyProperties.Count];
for (int i = 0; i < keyValues.Length; i++)
keyValues[i] = keyProperties[i].GetGetter().GetClrValue(entity);
return context.FindAsync(entityType.ClrType, keyValues);
}
}
请注意,与 EF Core 方法类似,SaveChangesAsync
call 不是上述方法的一部分,应该在后面单独调用。
原来的:
处理实现此类的实体的集合generic接口需要稍微不同的方法,因为没有非通用的用于提取的基类/接口Id
.
一种可能的解决方案是将集合处理代码移至单独的generic方法并动态或通过反射调用它。
例如(使用VS来确定必要的using
s):
public static class EntityUpdateHelper
{
public static async Task UpdateEntityAsync<TEntity, TId>(this DbContext context, TEntity entity, params Expression<Func<TEntity, object>>[] navigations)
where TEntity : class, IEntity<TId>
{
var dbEntity = await context.FindAsync<TEntity>(entity.Id);
var dbEntry = context.Entry(dbEntity);
dbEntry.CurrentValues.SetValues(entity);
foreach (var property in navigations)
{
var propertyName = property.GetPropertyAccess().Name;
var dbItemsEntry = dbEntry.Collection(propertyName);
var dbItems = dbItemsEntry.CurrentValue;
var items = dbItemsEntry.Metadata.GetGetter().GetClrValue(entity);
// Determine TEntity and TId, and call UpdateCollection<TEntity, TId>
// via reflection
var itemType = dbItemsEntry.Metadata.GetTargetType().ClrType;
var idType = itemType.GetInterfaces()
.Single(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEntity<>))
.GetGenericArguments().Single();
var updateMethod = typeof(EntityUpdateHelper).GetMethod(nameof(UpdateCollection))
.MakeGenericMethod(itemType, idType);
updateMethod.Invoke(null, new[] { dbItems, items });
}
await context.SaveChangesAsync();
}
public static void UpdateCollection<TEntity, TId>(this DbContext context, ICollection<TEntity> dbItems, ICollection<TEntity> items)
where TEntity : class, IEntity<TId>
{
var dbItemsMap = dbItems.ToDictionary(e => e.Id);
foreach (var item in items)
{
if (!dbItemsMap.TryGetValue(item.Id, out var oldItem))
dbItems.Add(item);
else
{
context.Entry(oldItem).CurrentValues.SetValues(item);
dbItemsMap.Remove(item.Id);
}
}
foreach (var oldItem in dbItemsMap.Values)
dbItems.Remove(oldItem);
}
}
并从调用它Customer
存储库:
return await _context.UpdateEntityAsync(entity, e => e.Addresses);
在通用存储库(无导航参数)和实现该接口的所有子集合实体的情况下,简单地迭代dbEntry.Collections
财产,例如
//foreach (var property in navigations)
foreach (var dbItemsEntry in dbEntry.Collections)
{
//var propertyName = property.GetPropertyAccess().Name;
//var dbItemsEntry = dbEntry.Collection(propertyName);
var dbItems = dbItemsEntry.CurrentValue;
var items = dbItemsEntry.Metadata.GetGetter().GetClrValue(entity);
// ...
}