Separate guest/host tracking + unaligned protection (#6486)

* WIP: Separate guest/host tracking + unaligned protection

Allow memory manager to define support for single byte guest tracking

* Formatting

* Improve docs

* Properly handle cases where the address space bits are too low

* Address feedback
This commit is contained in:
riperiperi
2024-03-14 22:38:27 +00:00
committed by GitHub
parent ce607db944
commit fdd3263e31
18 changed files with 774 additions and 763 deletions

View File

@ -33,6 +33,8 @@ namespace Ryujinx.Cpu.Jit
private readonly MemoryBlock _pageTable;
private readonly ManagedPageFlags _pages;
/// <summary>
/// Page table base pointer.
/// </summary>
@ -70,6 +72,8 @@ namespace Ryujinx.Cpu.Jit
AddressSpaceSize = asSize;
_pageTable = new MemoryBlock((asSize / PageSize) * PteSize);
_pages = new ManagedPageFlags(AddressSpaceBits);
Tracking = new MemoryTracking(this, PageSize);
}
@ -89,6 +93,7 @@ namespace Ryujinx.Cpu.Jit
remainingSize -= PageSize;
}
_pages.AddMapping(oVa, size);
Tracking.Map(oVa, size);
}
@ -111,6 +116,7 @@ namespace Ryujinx.Cpu.Jit
UnmapEvent?.Invoke(va, size);
Tracking.Unmap(va, size);
_pages.RemoveMapping(va, size);
ulong remainingSize = size;
while (remainingSize != 0)
@ -148,6 +154,26 @@ namespace Ryujinx.Cpu.Jit
}
}
/// <inheritdoc/>
public T ReadGuest<T>(ulong va) where T : unmanaged
{
try
{
SignalMemoryTrackingImpl(va, (ulong)Unsafe.SizeOf<T>(), false, true);
return Read<T>(va);
}
catch (InvalidMemoryRegionException)
{
if (_invalidAccessHandler == null || !_invalidAccessHandler(va))
{
throw;
}
return default;
}
}
/// <inheritdoc/>
public override void Read(ulong va, Span<byte> data)
{
@ -183,6 +209,16 @@ namespace Ryujinx.Cpu.Jit
WriteImpl(va, data);
}
/// <inheritdoc/>
public void WriteGuest<T>(ulong va, T value) where T : unmanaged
{
Span<byte> data = MemoryMarshal.Cast<T, byte>(MemoryMarshal.CreateSpan(ref value, 1));
SignalMemoryTrackingImpl(va, (ulong)data.Length, true, true);
WriteImpl(va, data);
}
/// <inheritdoc/>
public void WriteUntracked(ulong va, ReadOnlySpan<byte> data)
{
@ -520,50 +556,57 @@ namespace Ryujinx.Cpu.Jit
}
/// <inheritdoc/>
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection)
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest)
{
AssertValidAddressAndSize(va, size);
// Protection is inverted on software pages, since the default value is 0.
protection = (~protection) & MemoryPermission.ReadAndWrite;
long tag = protection switch
if (guest)
{
MemoryPermission.None => 0L,
MemoryPermission.Write => 2L << PointerTagBit,
_ => 3L << PointerTagBit,
};
// Protection is inverted on software pages, since the default value is 0.
protection = (~protection) & MemoryPermission.ReadAndWrite;
int pages = GetPagesCount(va, (uint)size, out va);
ulong pageStart = va >> PageBits;
long invTagMask = ~(0xffffL << 48);
for (int page = 0; page < pages; page++)
{
ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
long pte;
do
long tag = protection switch
{
pte = Volatile.Read(ref pageRef);
}
while (pte != 0 && Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
MemoryPermission.None => 0L,
MemoryPermission.Write => 2L << PointerTagBit,
_ => 3L << PointerTagBit,
};
pageStart++;
int pages = GetPagesCount(va, (uint)size, out va);
ulong pageStart = va >> PageBits;
long invTagMask = ~(0xffffL << 48);
for (int page = 0; page < pages; page++)
{
ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
long pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while (pte != 0 && Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
pageStart++;
}
}
else
{
_pages.TrackingReprotect(va, size, protection);
}
}
/// <inheritdoc/>
public RegionHandle BeginTracking(ulong address, ulong size, int id)
public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
{
return Tracking.BeginTracking(address, size, id);
return Tracking.BeginTracking(address, size, id, flags);
}
/// <inheritdoc/>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id)
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
{
return Tracking.BeginGranularTracking(address, size, handles, granularity, id);
return Tracking.BeginGranularTracking(address, size, handles, granularity, id, flags);
}
/// <inheritdoc/>
@ -572,8 +615,7 @@ namespace Ryujinx.Cpu.Jit
return Tracking.BeginSmartGranularTracking(address, size, granularity, id);
}
/// <inheritdoc/>
public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false, int? exemptId = null)
private void SignalMemoryTrackingImpl(ulong va, ulong size, bool write, bool guest, bool precise = false, int? exemptId = null)
{
AssertValidAddressAndSize(va, size);
@ -583,31 +625,47 @@ namespace Ryujinx.Cpu.Jit
return;
}
// We emulate guard pages for software memory access. This makes for an easy transition to
// tracking using host guard pages in future, but also supporting platforms where this is not possible.
// If the memory tracking is coming from the guest, use the tag bits in the page table entry.
// Otherwise, use the managed page flags.
// Write tag includes read protection, since we don't have any read actions that aren't performed before write too.
long tag = (write ? 3L : 1L) << PointerTagBit;
int pages = GetPagesCount(va, (uint)size, out _);
ulong pageStart = va >> PageBits;
for (int page = 0; page < pages; page++)
if (guest)
{
ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
// We emulate guard pages for software memory access. This makes for an easy transition to
// tracking using host guard pages in future, but also supporting platforms where this is not possible.
long pte;
// Write tag includes read protection, since we don't have any read actions that aren't performed before write too.
long tag = (write ? 3L : 1L) << PointerTagBit;
pte = Volatile.Read(ref pageRef);
int pages = GetPagesCount(va, (uint)size, out _);
ulong pageStart = va >> PageBits;
if ((pte & tag) != 0)
for (int page = 0; page < pages; page++)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
break;
}
ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
pageStart++;
long pte;
pte = Volatile.Read(ref pageRef);
if ((pte & tag) != 0)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId, true);
break;
}
pageStart++;
}
}
else
{
_pages.SignalMemoryTracking(Tracking, va, size, write, exemptId);
}
}
/// <inheritdoc/>
public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false, int? exemptId = null)
{
SignalMemoryTrackingImpl(va, size, write, false, precise, exemptId);
}
private ulong PaToPte(ulong pa)