IPC refactor part 3+4: New server HIPC message processor (#4188)

* IPC refactor part 3 + 4: New server HIPC message processor with source generator based serialization

* Make types match on calls to AlignUp/AlignDown

* Formatting

* Address some PR feedback

* Move BitfieldExtensions to Ryujinx.Common.Utilities and consolidate implementations

* Rename Reader/Writer to SpanReader/SpanWriter and move to Ryujinx.Common.Memory

* Implement EventType

* Address more PR feedback

* Log request processing errors since they are not normal

* Rename waitable to multiwait and add missing lock

* PR feedback

* Ac_K PR feedback
This commit is contained in:
gdkchan
2023-01-04 19:15:45 -03:00
committed by GitHub
parent c6a139a6e7
commit 08831eecf7
213 changed files with 9762 additions and 1010 deletions

View File

@ -24,10 +24,10 @@ namespace Ryujinx.Horizon.Generators
IncreaseIndentation();
}
public void LeaveScope()
public void LeaveScope(string suffix = "")
{
DecreaseIndentation();
AppendLine("}");
AppendLine($"}}{suffix}");
}
public void IncreaseIndentation()

View File

@ -0,0 +1,18 @@
namespace Ryujinx.Horizon.Generators.Hipc
{
enum CommandArgType : byte
{
Invalid,
Buffer,
InArgument,
InCopyHandle,
InMoveHandle,
InObject,
OutArgument,
OutCopyHandle,
OutMoveHandle,
OutObject,
ProcessId
}
}

View File

@ -0,0 +1,17 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
namespace Ryujinx.Horizon.Generators.Hipc
{
class CommandInterface
{
public ClassDeclarationSyntax ClassDeclarationSyntax { get; }
public List<MethodDeclarationSyntax> CommandImplementations { get; }
public CommandInterface(ClassDeclarationSyntax classDeclarationSyntax)
{
ClassDeclarationSyntax = classDeclarationSyntax;
CommandImplementations = new List<MethodDeclarationSyntax>();
}
}
}

View File

@ -0,0 +1,749 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
namespace Ryujinx.Horizon.Generators.Hipc
{
[Generator]
class HipcGenerator : ISourceGenerator
{
private const string ArgVariablePrefix = "arg";
private const string ResultVariableName = "result";
private const string IsBufferMapAliasVariableName = "isBufferMapAlias";
private const string InObjectsVariableName = "inObjects";
private const string OutObjectsVariableName = "outObjects";
private const string ResponseVariableName = "response";
private const string OutRawDataVariableName = "outRawData";
private const string TypeSystemReadOnlySpan = "System.ReadOnlySpan";
private const string TypeSystemSpan = "System.Span";
private const string TypeStructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute";
public const string CommandAttributeName = "CmifCommandAttribute";
private const string TypeResult = "Ryujinx.Horizon.Common.Result";
private const string TypeBufferAttribute = "Ryujinx.Horizon.Sdk.Sf.BufferAttribute";
private const string TypeCopyHandleAttribute = "Ryujinx.Horizon.Sdk.Sf.CopyHandleAttribute";
private const string TypeMoveHandleAttribute = "Ryujinx.Horizon.Sdk.Sf.MoveHandleAttribute";
private const string TypeClientProcessIdAttribute = "Ryujinx.Horizon.Sdk.Sf.ClientProcessIdAttribute";
private const string TypeCommandAttribute = "Ryujinx.Horizon.Sdk.Sf." + CommandAttributeName;
private const string TypeIServiceObject = "Ryujinx.Horizon.Sdk.Sf.IServiceObject";
private enum Modifier
{
None,
Ref,
Out,
In
}
private struct OutParameter
{
public readonly string Name;
public readonly string TypeName;
public readonly int Index;
public readonly CommandArgType Type;
public OutParameter(string name, string typeName, int index, CommandArgType type)
{
Name = name;
TypeName = typeName;
Index = index;
Type = type;
}
}
public void Execute(GeneratorExecutionContext context)
{
HipcSyntaxReceiver syntaxReceiver = (HipcSyntaxReceiver)context.SyntaxReceiver;
foreach (var commandInterface in syntaxReceiver.CommandInterfaces)
{
if (!NeedsIServiceObjectImplementation(context.Compilation, commandInterface.ClassDeclarationSyntax))
{
continue;
}
CodeGenerator generator = new CodeGenerator();
string className = commandInterface.ClassDeclarationSyntax.Identifier.ToString();
generator.AppendLine("using Ryujinx.Horizon.Common;");
generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf;");
generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf.Cmif;");
generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf.Hipc;");
generator.AppendLine("using System;");
generator.AppendLine("using System.Collections.Generic;");
generator.AppendLine("using System.Runtime.CompilerServices;");
generator.AppendLine("using System.Runtime.InteropServices;");
generator.AppendLine();
generator.EnterScope($"namespace {GetNamespaceName(commandInterface.ClassDeclarationSyntax)}");
generator.EnterScope($"partial class {className}");
GenerateMethodTable(generator, context.Compilation, commandInterface);
foreach (var method in commandInterface.CommandImplementations)
{
generator.AppendLine();
GenerateMethod(generator, context.Compilation, method);
}
generator.LeaveScope();
generator.LeaveScope();
context.AddSource($"{className}.g.cs", generator.ToString());
}
}
private static string GetNamespaceName(SyntaxNode syntaxNode)
{
while (syntaxNode != null && !(syntaxNode is NamespaceDeclarationSyntax))
{
syntaxNode = syntaxNode.Parent;
}
if (syntaxNode == null)
{
return string.Empty;
}
return ((NamespaceDeclarationSyntax)syntaxNode).Name.ToString();
}
private static void GenerateMethodTable(CodeGenerator generator, Compilation compilation, CommandInterface commandInterface)
{
generator.EnterScope($"public IReadOnlyDictionary<int, CommandHandler> GetCommandHandlers()");
generator.EnterScope($"return new Dictionary<int, CommandHandler>()");
foreach (var method in commandInterface.CommandImplementations)
{
foreach (var commandId in GetAttributeAguments(compilation, method, TypeCommandAttribute, 0))
{
string[] args = new string[method.ParameterList.Parameters.Count];
int index = 0;
foreach (var parameter in method.ParameterList.Parameters)
{
string canonicalTypeName = GetCanonicalTypeNameWithGenericArguments(compilation, parameter.Type);
CommandArgType argType = GetCommandArgType(compilation, parameter);
string arg;
if (argType == CommandArgType.Buffer)
{
string bufferFlags = GetFirstAttributeAgument(compilation, parameter, TypeBufferAttribute, 0);
string bufferFixedSize = GetFirstAttributeAgument(compilation, parameter, TypeBufferAttribute, 1);
if (bufferFixedSize != null)
{
arg = $"new CommandArg({bufferFlags}, {bufferFixedSize})";
}
else
{
arg = $"new CommandArg({bufferFlags})";
}
}
else if (argType == CommandArgType.InArgument || argType == CommandArgType.OutArgument)
{
string alignment = GetTypeAlignmentExpression(compilation, parameter.Type);
arg = $"new CommandArg(CommandArgType.{argType}, Unsafe.SizeOf<{canonicalTypeName}>(), {alignment})";
}
else
{
arg = $"new CommandArg(CommandArgType.{argType})";
}
args[index++] = arg;
}
generator.AppendLine($"{{ {commandId}, new CommandHandler({method.Identifier.Text}, {string.Join(", ", args)}) }},");
}
}
generator.LeaveScope(";");
generator.LeaveScope();
}
private static IEnumerable<string> GetAttributeAguments(Compilation compilation, SyntaxNode syntaxNode, string attributeName, int argIndex)
{
ISymbol symbol = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetDeclaredSymbol(syntaxNode);
foreach (var attribute in symbol.GetAttributes())
{
if (attribute.AttributeClass.ToDisplayString() == attributeName && (uint)argIndex < (uint)attribute.ConstructorArguments.Length)
{
yield return attribute.ConstructorArguments[argIndex].ToCSharpString();
}
}
}
private static string GetFirstAttributeAgument(Compilation compilation, SyntaxNode syntaxNode, string attributeName, int argIndex)
{
return GetAttributeAguments(compilation, syntaxNode, attributeName, argIndex).FirstOrDefault();
}
private static void GenerateMethod(CodeGenerator generator, Compilation compilation, MethodDeclarationSyntax method)
{
int inObjectsCount = 0;
int outObjectsCount = 0;
int buffersCount = 0;
foreach (var parameter in method.ParameterList.Parameters)
{
if (IsObject(compilation, parameter))
{
if (IsIn(parameter))
{
inObjectsCount++;
}
else
{
outObjectsCount++;
}
}
else if (IsBuffer(compilation, parameter))
{
buffersCount++;
}
}
generator.EnterScope($"private Result {method.Identifier.Text}(" +
"ref ServiceDispatchContext context, " +
"HipcCommandProcessor processor, " +
"ServerMessageRuntimeMetadata runtimeMetadata, " +
"ReadOnlySpan<byte> inRawData, " +
"ref Span<CmifOutHeader> outHeader)");
bool returnsResult = method.ReturnType != null && GetCanonicalTypeName(compilation, method.ReturnType) == TypeResult;
if (returnsResult || buffersCount != 0 || inObjectsCount != 0)
{
generator.AppendLine($"Result {ResultVariableName};");
if (buffersCount != 0)
{
generator.AppendLine($"bool[] {IsBufferMapAliasVariableName} = new bool[{method.ParameterList.Parameters.Count}];");
generator.AppendLine();
generator.AppendLine($"{ResultVariableName} = processor.ProcessBuffers(ref context, {IsBufferMapAliasVariableName}, runtimeMetadata);");
generator.EnterScope($"if ({ResultVariableName}.IsFailure)");
generator.AppendLine($"return {ResultVariableName};");
generator.LeaveScope();
}
generator.AppendLine();
}
List<OutParameter> outParameters = new List<OutParameter>();
string[] args = new string[method.ParameterList.Parameters.Count];
if (inObjectsCount != 0)
{
generator.AppendLine($"var {InObjectsVariableName} = new IServiceObject[{inObjectsCount}];");
generator.AppendLine();
generator.AppendLine($"{ResultVariableName} = processor.GetInObjects(context.Processor, {InObjectsVariableName});");
generator.EnterScope($"if ({ResultVariableName}.IsFailure)");
generator.AppendLine($"return {ResultVariableName};");
generator.LeaveScope();
generator.AppendLine();
}
if (outObjectsCount != 0)
{
generator.AppendLine($"var {OutObjectsVariableName} = new IServiceObject[{outObjectsCount}];");
}
int index = 0;
int inCopyHandleIndex = 0;
int inMoveHandleIndex = 0;
int inObjectIndex = 0;
foreach (var parameter in method.ParameterList.Parameters)
{
string name = parameter.Identifier.Text;
string argName = GetPrefixedArgName(name);
string canonicalTypeName = GetCanonicalTypeNameWithGenericArguments(compilation, parameter.Type);
CommandArgType argType = GetCommandArgType(compilation, parameter);
Modifier modifier = GetModifier(parameter);
bool isNonSpanBuffer = false;
if (modifier == Modifier.Out)
{
if (IsNonSpanOutBuffer(compilation, parameter))
{
generator.AppendLine($"using var {argName} = CommandSerialization.GetWritableRegion(processor.GetBufferRange({index}));");
argName = $"out {GenerateSpanCastElement0(canonicalTypeName, $"{argName}.Memory.Span")}";
}
else
{
outParameters.Add(new OutParameter(argName, canonicalTypeName, index, argType));
argName = $"out {canonicalTypeName} {argName}";
}
}
else
{
string value = $"default({canonicalTypeName})";
switch (argType)
{
case CommandArgType.InArgument:
value = $"CommandSerialization.DeserializeArg<{canonicalTypeName}>(inRawData, processor.GetInArgOffset({index}))";
break;
case CommandArgType.InCopyHandle:
value = $"CommandSerialization.DeserializeCopyHandle(ref context, {inCopyHandleIndex++})";
break;
case CommandArgType.InMoveHandle:
value = $"CommandSerialization.DeserializeMoveHandle(ref context, {inMoveHandleIndex++})";
break;
case CommandArgType.ProcessId:
value = "CommandSerialization.DeserializeClientProcessId(ref context)";
break;
case CommandArgType.InObject:
value = $"{InObjectsVariableName}[{inObjectIndex++}]";
break;
case CommandArgType.Buffer:
if (IsReadOnlySpan(compilation, parameter))
{
string spanGenericTypeName = GetCanonicalTypeNameOfGenericArgument(compilation, parameter.Type, 0);
value = GenerateSpanCast(spanGenericTypeName, $"CommandSerialization.GetReadOnlySpan(processor.GetBufferRange({index}))");
}
else if (IsSpan(compilation, parameter))
{
value = $"CommandSerialization.GetWritableRegion(processor.GetBufferRange({index}))";
}
else
{
value = $"CommandSerialization.GetRef<{canonicalTypeName}>(processor.GetBufferRange({index}))";
isNonSpanBuffer = true;
}
break;
}
if (IsSpan(compilation, parameter))
{
generator.AppendLine($"using var {argName} = {value};");
string spanGenericTypeName = GetCanonicalTypeNameOfGenericArgument(compilation, parameter.Type, 0);
argName = GenerateSpanCast(spanGenericTypeName, $"{argName}.Memory.Span"); ;
}
else if (isNonSpanBuffer)
{
generator.AppendLine($"ref var {argName} = ref {value};");
}
else if (argType == CommandArgType.InObject)
{
generator.EnterScope($"if (!({value} is {canonicalTypeName} {argName}))");
generator.AppendLine("return SfResult.InvalidInObject;");
generator.LeaveScope();
}
else
{
generator.AppendLine($"var {argName} = {value};");
}
}
if (modifier == Modifier.Ref)
{
argName = $"ref {argName}";
}
else if (modifier == Modifier.In)
{
argName = $"in {argName}";
}
args[index++] = argName;
}
if (args.Length - outParameters.Count > 0)
{
generator.AppendLine();
}
if (returnsResult)
{
generator.AppendLine($"{ResultVariableName} = {method.Identifier.Text}({string.Join(", ", args)});");
generator.AppendLine();
generator.AppendLine($"Span<byte> {OutRawDataVariableName};");
generator.AppendLine();
generator.EnterScope($"if ({ResultVariableName}.IsFailure)");
generator.AppendLine($"context.Processor.PrepareForErrorReply(ref context, out {OutRawDataVariableName}, runtimeMetadata);");
generator.AppendLine($"CommandHandler.GetCmifOutHeaderPointer(ref outHeader, ref {OutRawDataVariableName});");
generator.AppendLine($"return {ResultVariableName};");
generator.LeaveScope();
}
else
{
generator.AppendLine($"{method.Identifier.Text}({string.Join(", ", args)});");
generator.AppendLine();
generator.AppendLine($"Span<byte> {OutRawDataVariableName};");
}
generator.AppendLine();
generator.AppendLine($"var {ResponseVariableName} = context.Processor.PrepareForReply(ref context, out {OutRawDataVariableName}, runtimeMetadata);");
generator.AppendLine($"CommandHandler.GetCmifOutHeaderPointer(ref outHeader, ref {OutRawDataVariableName});");
generator.AppendLine();
generator.EnterScope($"if ({OutRawDataVariableName}.Length < processor.OutRawDataSize)");
generator.AppendLine("return SfResult.InvalidOutRawSize;");
generator.LeaveScope();
if (outParameters.Count != 0)
{
generator.AppendLine();
int outCopyHandleIndex = 0;
int outMoveHandleIndex = outObjectsCount;
int outObjectIndex = 0;
for (int outIndex = 0; outIndex < outParameters.Count; outIndex++)
{
OutParameter outParameter = outParameters[outIndex];
switch (outParameter.Type)
{
case CommandArgType.OutArgument:
generator.AppendLine($"CommandSerialization.SerializeArg<{outParameter.TypeName}>({OutRawDataVariableName}, processor.GetOutArgOffset({outParameter.Index}), {outParameter.Name});");
break;
case CommandArgType.OutCopyHandle:
generator.AppendLine($"CommandSerialization.SerializeCopyHandle({ResponseVariableName}, {outCopyHandleIndex++}, {outParameter.Name});");
break;
case CommandArgType.OutMoveHandle:
generator.AppendLine($"CommandSerialization.SerializeMoveHandle({ResponseVariableName}, {outMoveHandleIndex++}, {outParameter.Name});");
break;
case CommandArgType.OutObject:
generator.AppendLine($"{OutObjectsVariableName}[{outObjectIndex++}] = {outParameter.Name};");
break;
}
}
}
generator.AppendLine();
if (outObjectsCount != 0 || buffersCount != 0)
{
if (outObjectsCount != 0)
{
generator.AppendLine($"processor.SetOutObjects(ref context, {ResponseVariableName}, {OutObjectsVariableName});");
}
if (buffersCount != 0)
{
generator.AppendLine($"processor.SetOutBuffers({ResponseVariableName}, {IsBufferMapAliasVariableName});");
}
generator.AppendLine();
}
generator.AppendLine("return Result.Success;");
generator.LeaveScope();
}
private static string GetPrefixedArgName(string name)
{
return ArgVariablePrefix + name[0].ToString().ToUpperInvariant() + name.Substring(1);
}
private static string GetCanonicalTypeNameOfGenericArgument(Compilation compilation, SyntaxNode syntaxNode, int argIndex)
{
if (syntaxNode is GenericNameSyntax genericNameSyntax)
{
if ((uint)argIndex < (uint)genericNameSyntax.TypeArgumentList.Arguments.Count)
{
return GetCanonicalTypeNameWithGenericArguments(compilation, genericNameSyntax.TypeArgumentList.Arguments[argIndex]);
}
}
return GetCanonicalTypeName(compilation, syntaxNode);
}
private static string GetCanonicalTypeNameWithGenericArguments(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
return typeInfo.Type.ToDisplayString();
}
private static string GetCanonicalTypeName(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
string typeName = typeInfo.Type.ToDisplayString();
int genericArgsStartIndex = typeName.IndexOf('<');
if (genericArgsStartIndex >= 0)
{
return typeName.Substring(0, genericArgsStartIndex);
}
return typeName;
}
private static SpecialType GetSpecialTypeName(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
return typeInfo.Type.SpecialType;
}
private static string GetTypeAlignmentExpression(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
// Since there's no way to get the alignment for a arbitrary type here, let's assume that all
// "special" types are primitive types aligned to their own length.
// Otherwise, assume that the type is a custom struct, that either defines an explicit alignment
// or has an alignment of 1 which is the lowest possible value.
if (typeInfo.Type.SpecialType == SpecialType.None)
{
string pack = GetTypeFirstNamedAttributeAgument(compilation, syntaxNode, TypeStructLayoutAttribute, "Pack");
return pack ?? "1";
}
else
{
return $"Unsafe.SizeOf<{typeInfo.Type.ToDisplayString()}>()";
}
}
private static string GetTypeFirstNamedAttributeAgument(Compilation compilation, SyntaxNode syntaxNode, string attributeName, string argName)
{
ISymbol symbol = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode).Type;
foreach (var attribute in symbol.GetAttributes())
{
if (attribute.AttributeClass.ToDisplayString() == attributeName)
{
foreach (var kv in attribute.NamedArguments)
{
if (kv.Key == argName)
{
return kv.Value.ToCSharpString();
}
}
}
}
return null;
}
private static CommandArgType GetCommandArgType(Compilation compilation, ParameterSyntax parameter)
{
CommandArgType type = CommandArgType.Invalid;
if (IsIn(parameter))
{
if (IsArgument(compilation, parameter))
{
type = CommandArgType.InArgument;
}
else if (IsBuffer(compilation, parameter))
{
type = CommandArgType.Buffer;
}
else if (IsCopyHandle(compilation, parameter))
{
type = CommandArgType.InCopyHandle;
}
else if (IsMoveHandle(compilation, parameter))
{
type = CommandArgType.InMoveHandle;
}
else if (IsObject(compilation, parameter))
{
type = CommandArgType.InObject;
}
else if (IsProcessId(compilation, parameter))
{
type = CommandArgType.ProcessId;
}
}
else if (IsOut(parameter))
{
if (IsArgument(compilation, parameter))
{
type = CommandArgType.OutArgument;
}
else if (IsNonSpanOutBuffer(compilation, parameter))
{
type = CommandArgType.Buffer;
}
else if (IsCopyHandle(compilation, parameter))
{
type = CommandArgType.OutCopyHandle;
}
else if (IsMoveHandle(compilation, parameter))
{
type = CommandArgType.OutMoveHandle;
}
else if (IsObject(compilation, parameter))
{
type = CommandArgType.OutObject;
}
}
return type;
}
private static bool IsArgument(Compilation compilation,ParameterSyntax parameter)
{
return !IsBuffer(compilation, parameter) &&
!IsHandle(compilation, parameter) &&
!IsObject(compilation, parameter) &&
!IsProcessId(compilation, parameter) &&
IsUnmanagedType(compilation, parameter.Type);
}
private static bool IsBuffer(Compilation compilation, ParameterSyntax parameter)
{
return HasAttribute(compilation, parameter, TypeBufferAttribute) &&
IsValidTypeForBuffer(compilation, parameter);
}
private static bool IsNonSpanOutBuffer(Compilation compilation, ParameterSyntax parameter)
{
return HasAttribute(compilation, parameter, TypeBufferAttribute) &&
IsUnmanagedType(compilation, parameter.Type);
}
private static bool IsValidTypeForBuffer(Compilation compilation, ParameterSyntax parameter)
{
return IsReadOnlySpan(compilation, parameter) ||
IsSpan(compilation, parameter) ||
IsUnmanagedType(compilation, parameter.Type);
}
private static bool IsUnmanagedType(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
return typeInfo.Type.IsUnmanagedType;
}
private static bool IsReadOnlySpan(Compilation compilation, ParameterSyntax parameter)
{
return GetCanonicalTypeName(compilation, parameter.Type) == TypeSystemReadOnlySpan;
}
private static bool IsSpan(Compilation compilation, ParameterSyntax parameter)
{
return GetCanonicalTypeName(compilation, parameter.Type) == TypeSystemSpan;
}
private static bool IsHandle(Compilation compilation, ParameterSyntax parameter)
{
return IsCopyHandle(compilation, parameter) || IsMoveHandle(compilation, parameter);
}
private static bool IsCopyHandle(Compilation compilation, ParameterSyntax parameter)
{
return HasAttribute(compilation, parameter, TypeCopyHandleAttribute) &&
GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_Int32;
}
private static bool IsMoveHandle(Compilation compilation, ParameterSyntax parameter)
{
return HasAttribute(compilation, parameter, TypeMoveHandleAttribute) &&
GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_Int32;
}
private static bool IsObject(Compilation compilation, ParameterSyntax parameter)
{
SyntaxNode syntaxNode = parameter.Type;
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
return typeInfo.Type.ToDisplayString() == TypeIServiceObject ||
typeInfo.Type.AllInterfaces.Any(x => x.ToDisplayString() == TypeIServiceObject);
}
private static bool IsProcessId(Compilation compilation, ParameterSyntax parameter)
{
return HasAttribute(compilation, parameter, TypeClientProcessIdAttribute) &&
GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_UInt64;
}
private static bool IsIn(ParameterSyntax parameter)
{
return !IsOut(parameter);
}
private static bool IsOut(ParameterSyntax parameter)
{
return parameter.Modifiers.Any(SyntaxKind.OutKeyword);
}
private static Modifier GetModifier(ParameterSyntax parameter)
{
foreach (SyntaxToken syntaxToken in parameter.Modifiers)
{
if (syntaxToken.IsKind(SyntaxKind.RefKeyword))
{
return Modifier.Ref;
}
else if (syntaxToken.IsKind(SyntaxKind.OutKeyword))
{
return Modifier.Out;
}
else if (syntaxToken.IsKind(SyntaxKind.InKeyword))
{
return Modifier.In;
}
}
return Modifier.None;
}
private static string GenerateSpanCastElement0(string targetType, string input)
{
return $"{GenerateSpanCast(targetType, input)}[0]";
}
private static string GenerateSpanCast(string targetType, string input)
{
return $"MemoryMarshal.Cast<byte, {targetType}>({input})";
}
private static bool HasAttribute(Compilation compilation, ParameterSyntax parameterSyntax, string fullAttributeName)
{
foreach (var attributeList in parameterSyntax.AttributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (GetCanonicalTypeName(compilation, attribute) == fullAttributeName)
{
return true;
}
}
}
return false;
}
private static bool NeedsIServiceObjectImplementation(Compilation compilation, ClassDeclarationSyntax classDeclarationSyntax)
{
ITypeSymbol type = compilation.GetSemanticModel(classDeclarationSyntax.SyntaxTree).GetDeclaredSymbol(classDeclarationSyntax);
var serviceObjectInterface = type.AllInterfaces.FirstOrDefault(x => x.ToDisplayString() == TypeIServiceObject);
var interfaceMember = serviceObjectInterface?.GetMembers().FirstOrDefault(x => x.Name == "GetCommandHandlers");
// Return true only if the class implements IServiceObject but does not actually implement the method
// that the interface defines, since this is the only case we want to handle, if the method already exists
// we have nothing to do.
return serviceObjectInterface != null && type.FindImplementationForInterfaceMember(interfaceMember) == null;
}
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new HipcSyntaxReceiver());
}
}
}

View File

@ -0,0 +1,58 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
namespace Ryujinx.Horizon.Generators.Hipc
{
class HipcSyntaxReceiver : ISyntaxReceiver
{
public List<CommandInterface> CommandInterfaces { get; }
public HipcSyntaxReceiver()
{
CommandInterfaces = new List<CommandInterface>();
}
public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is ClassDeclarationSyntax classDeclaration)
{
if (!classDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword) || classDeclaration.BaseList == null)
{
return;
}
CommandInterface commandInterface = new CommandInterface(classDeclaration);
foreach (var memberDeclaration in classDeclaration.Members)
{
if (memberDeclaration is MethodDeclarationSyntax methodDeclaration)
{
VisitMethod(commandInterface, methodDeclaration);
}
}
CommandInterfaces.Add(commandInterface);
}
}
private void VisitMethod(CommandInterface commandInterface, MethodDeclarationSyntax methodDeclaration)
{
string attributeName = HipcGenerator.CommandAttributeName.Replace("Attribute", string.Empty);
if (methodDeclaration.AttributeLists.Count != 0)
{
foreach (var attributeList in methodDeclaration.AttributeLists)
{
if (attributeList.Attributes.Any(x => x.Name.ToString().Contains(attributeName)))
{
commandInterface.CommandImplementations.Add(methodDeclaration);
break;
}
}
}
}
}
}

View File

@ -1,524 +0,0 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Diagnostics;
namespace Ryujinx.Horizon.Generators.Kernel
{
[Generator]
class SyscallGenerator : ISourceGenerator
{
private const string ClassNamespace = "Ryujinx.HLE.HOS.Kernel.SupervisorCall";
private const string ClassName = "SyscallDispatch";
private const string A32Suffix = "32";
private const string A64Suffix = "64";
private const string ResultVariableName = "result";
private const string ArgVariablePrefix = "arg";
private const string ResultCheckHelperName = "LogResultAsTrace";
private const string TypeSystemBoolean = "System.Boolean";
private const string TypeSystemInt32 = "System.Int32";
private const string TypeSystemInt64 = "System.Int64";
private const string TypeSystemUInt32 = "System.UInt32";
private const string TypeSystemUInt64 = "System.UInt64";
private const string NamespaceKernel = "Ryujinx.HLE.HOS.Kernel";
private const string TypeSvcAttribute = NamespaceKernel + ".SupervisorCall.SvcAttribute";
private const string TypePointerSizedAttribute = NamespaceKernel + ".SupervisorCall.PointerSizedAttribute";
private const string TypeKernelResultName = "KernelResult";
private const string TypeKernelResult = NamespaceKernel + ".Common." + TypeKernelResultName;
private const string TypeExecutionContext = "IExecutionContext";
private static readonly string[] _expectedResults = new string[]
{
$"{TypeKernelResultName}.Success",
$"{TypeKernelResultName}.TimedOut",
$"{TypeKernelResultName}.Cancelled",
$"{TypeKernelResultName}.PortRemoteClosed",
$"{TypeKernelResultName}.InvalidState"
};
private readonly struct OutParameter
{
public readonly string Identifier;
public readonly bool NeedsSplit;
public OutParameter(string identifier, bool needsSplit = false)
{
Identifier = identifier;
NeedsSplit = needsSplit;
}
}
private struct RegisterAllocatorA32
{
private uint _useSet;
private int _linearIndex;
public int AllocateSingle()
{
return Allocate();
}
public (int, int) AllocatePair()
{
_linearIndex += _linearIndex & 1;
return (Allocate(), Allocate());
}
private int Allocate()
{
int regIndex;
if (_linearIndex < 4)
{
regIndex = _linearIndex++;
}
else
{
regIndex = -1;
for (int i = 0; i < 32; i++)
{
if ((_useSet & (1 << i)) == 0)
{
regIndex = i;
break;
}
}
Debug.Assert(regIndex != -1);
}
_useSet |= 1u << regIndex;
return regIndex;
}
public void AdvanceLinearIndex()
{
_linearIndex++;
}
}
private readonly struct SyscallIdAndName : IComparable<SyscallIdAndName>
{
public readonly int Id;
public readonly string Name;
public SyscallIdAndName(int id, string name)
{
Id = id;
Name = name;
}
public int CompareTo(SyscallIdAndName other)
{
return Id.CompareTo(other.Id);
}
}
public void Execute(GeneratorExecutionContext context)
{
SyscallSyntaxReceiver syntaxReceiver = (SyscallSyntaxReceiver)context.SyntaxReceiver;
CodeGenerator generator = new CodeGenerator();
generator.AppendLine("using Ryujinx.Common.Logging;");
generator.AppendLine("using Ryujinx.Cpu;");
generator.AppendLine($"using {NamespaceKernel}.Common;");
generator.AppendLine($"using {NamespaceKernel}.Memory;");
generator.AppendLine($"using {NamespaceKernel}.Process;");
generator.AppendLine($"using {NamespaceKernel}.Threading;");
generator.AppendLine("using System;");
generator.AppendLine();
generator.EnterScope($"namespace {ClassNamespace}");
generator.EnterScope($"static class {ClassName}");
GenerateResultCheckHelper(generator);
generator.AppendLine();
List<SyscallIdAndName> syscalls = new List<SyscallIdAndName>();
foreach (var method in syntaxReceiver.SvcImplementations)
{
GenerateMethod32(generator, context.Compilation, method);
GenerateMethod64(generator, context.Compilation, method);
foreach (var attributeList in method.AttributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (GetCanonicalTypeName(context.Compilation, attribute) != TypeSvcAttribute)
{
continue;
}
foreach (var attributeArg in attribute.ArgumentList.Arguments)
{
if (attributeArg.Expression.Kind() == SyntaxKind.NumericLiteralExpression)
{
LiteralExpressionSyntax numericLiteral = (LiteralExpressionSyntax)attributeArg.Expression;
syscalls.Add(new SyscallIdAndName((int)numericLiteral.Token.Value, method.Identifier.Text));
}
}
}
}
}
syscalls.Sort();
GenerateDispatch(generator, syscalls, A32Suffix);
generator.AppendLine();
GenerateDispatch(generator, syscalls, A64Suffix);
generator.LeaveScope();
generator.LeaveScope();
context.AddSource($"{ClassName}.g.cs", generator.ToString());
}
private static void GenerateResultCheckHelper(CodeGenerator generator)
{
generator.EnterScope($"private static bool {ResultCheckHelperName}({TypeKernelResultName} {ResultVariableName})");
string[] expectedChecks = new string[_expectedResults.Length];
for (int i = 0; i < expectedChecks.Length; i++)
{
expectedChecks[i] = $"{ResultVariableName} == {_expectedResults[i]}";
}
string checks = string.Join(" || ", expectedChecks);
generator.AppendLine($"return {checks};");
generator.LeaveScope();
}
private static void GenerateMethod32(CodeGenerator generator, Compilation compilation, MethodDeclarationSyntax method)
{
generator.EnterScope($"private static void {method.Identifier.Text}{A32Suffix}(Syscall syscall, {TypeExecutionContext} context)");
string[] args = new string[method.ParameterList.Parameters.Count];
int index = 0;
RegisterAllocatorA32 regAlloc = new RegisterAllocatorA32();
List<OutParameter> outParameters = new List<OutParameter>();
List<string> logInArgs = new List<string>();
List<string> logOutArgs = new List<string>();
foreach (var methodParameter in method.ParameterList.Parameters)
{
string name = methodParameter.Identifier.Text;
string argName = GetPrefixedArgName(name);
string typeName = methodParameter.Type.ToString();
string canonicalTypeName = GetCanonicalTypeName(compilation, methodParameter.Type);
if (methodParameter.Modifiers.Any(SyntaxKind.OutKeyword))
{
bool needsSplit = Is64BitInteger(canonicalTypeName) && !IsPointerSized(compilation, methodParameter);
outParameters.Add(new OutParameter(argName, needsSplit));
logOutArgs.Add($"{name}: {GetFormattedLogValue(argName, canonicalTypeName)}");
argName = $"out {typeName} {argName}";
regAlloc.AdvanceLinearIndex();
}
else
{
if (Is64BitInteger(canonicalTypeName))
{
if (IsPointerSized(compilation, methodParameter))
{
int registerIndex = regAlloc.AllocateSingle();
generator.AppendLine($"var {argName} = (uint)context.GetX({registerIndex});");
}
else
{
(int registerIndex, int registerIndex2) = regAlloc.AllocatePair();
string valueLow = $"(ulong)(uint)context.GetX({registerIndex})";
string valueHigh = $"(ulong)(uint)context.GetX({registerIndex2})";
string value = $"{valueLow} | ({valueHigh} << 32)";
generator.AppendLine($"var {argName} = ({typeName})({value});");
}
}
else
{
int registerIndex = regAlloc.AllocateSingle();
string value = GenerateCastFromUInt64($"context.GetX({registerIndex})", canonicalTypeName, typeName);
generator.AppendLine($"var {argName} = {value};");
}
logInArgs.Add($"{name}: {GetFormattedLogValue(argName, canonicalTypeName)}");
}
args[index++] = argName;
}
GenerateLogPrintBeforeCall(generator, method.Identifier.Text, logInArgs);
string returnTypeName = method.ReturnType.ToString();
string argsList = string.Join(", ", args);
int returnRegisterIndex = 0;
string result = null;
string canonicalReturnTypeName = null;
if (returnTypeName != "void")
{
generator.AppendLine($"var {ResultVariableName} = syscall.{method.Identifier.Text}({argsList});");
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (uint){ResultVariableName});");
canonicalReturnTypeName = GetCanonicalTypeName(compilation, method.ReturnType);
if (Is64BitInteger(canonicalReturnTypeName))
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (uint)({ResultVariableName} >> 32));");
}
result = GetFormattedLogValue(ResultVariableName, canonicalReturnTypeName);
}
else
{
generator.AppendLine($"syscall.{method.Identifier.Text}({argsList});");
}
foreach (OutParameter outParameter in outParameters)
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (uint){outParameter.Identifier});");
if (outParameter.NeedsSplit)
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (uint)({outParameter.Identifier} >> 32));");
}
}
while (returnRegisterIndex < 4)
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, 0);");
}
GenerateLogPrintAfterCall(generator, method.Identifier.Text, logOutArgs, result, canonicalReturnTypeName);
generator.LeaveScope();
generator.AppendLine();
}
private static void GenerateMethod64(CodeGenerator generator, Compilation compilation, MethodDeclarationSyntax method)
{
generator.EnterScope($"private static void {method.Identifier.Text}{A64Suffix}(Syscall syscall, {TypeExecutionContext} context)");
string[] args = new string[method.ParameterList.Parameters.Count];
int registerIndex = 0;
int index = 0;
List<OutParameter> outParameters = new List<OutParameter>();
List<string> logInArgs = new List<string>();
List<string> logOutArgs = new List<string>();
foreach (var methodParameter in method.ParameterList.Parameters)
{
string name = methodParameter.Identifier.Text;
string argName = GetPrefixedArgName(name);
string typeName = methodParameter.Type.ToString();
string canonicalTypeName = GetCanonicalTypeName(compilation, methodParameter.Type);
if (methodParameter.Modifiers.Any(SyntaxKind.OutKeyword))
{
outParameters.Add(new OutParameter(argName));
logOutArgs.Add($"{name}: {GetFormattedLogValue(argName, canonicalTypeName)}");
argName = $"out {typeName} {argName}";
registerIndex++;
}
else
{
string value = GenerateCastFromUInt64($"context.GetX({registerIndex++})", canonicalTypeName, typeName);
generator.AppendLine($"var {argName} = {value};");
logInArgs.Add($"{name}: {GetFormattedLogValue(argName, canonicalTypeName)}");
}
args[index++] = argName;
}
GenerateLogPrintBeforeCall(generator, method.Identifier.Text, logInArgs);
string argsList = string.Join(", ", args);
int returnRegisterIndex = 0;
string result = null;
string canonicalReturnTypeName = null;
if (method.ReturnType.ToString() != "void")
{
generator.AppendLine($"var {ResultVariableName} = syscall.{method.Identifier.Text}({argsList});");
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (ulong){ResultVariableName});");
canonicalReturnTypeName = GetCanonicalTypeName(compilation, method.ReturnType);
result = GetFormattedLogValue(ResultVariableName, canonicalReturnTypeName);
}
else
{
generator.AppendLine($"syscall.{method.Identifier.Text}({argsList});");
}
foreach (OutParameter outParameter in outParameters)
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, (ulong){outParameter.Identifier});");
}
while (returnRegisterIndex < 8)
{
generator.AppendLine($"context.SetX({returnRegisterIndex++}, 0);");
}
GenerateLogPrintAfterCall(generator, method.Identifier.Text, logOutArgs, result, canonicalReturnTypeName);
generator.LeaveScope();
generator.AppendLine();
}
private static string GetFormattedLogValue(string value, string canonicalTypeName)
{
if (Is32BitInteger(canonicalTypeName))
{
return $"0x{{{value}:X8}}";
}
else if (Is64BitInteger(canonicalTypeName))
{
return $"0x{{{value}:X16}}";
}
return $"{{{value}}}";
}
private static string GetPrefixedArgName(string name)
{
return ArgVariablePrefix + name[0].ToString().ToUpperInvariant() + name.Substring(1);
}
private static string GetCanonicalTypeName(Compilation compilation, SyntaxNode syntaxNode)
{
TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode);
if (typeInfo.Type.ContainingNamespace == null)
{
return typeInfo.Type.Name;
}
return $"{typeInfo.Type.ContainingNamespace.ToDisplayString()}.{typeInfo.Type.Name}";
}
private static void GenerateLogPrintBeforeCall(CodeGenerator generator, string methodName, List<string> argList)
{
string log = $"{methodName}({string.Join(", ", argList)})";
GenerateLogPrint(generator, "Trace", "KernelSvc", log);
}
private static void GenerateLogPrintAfterCall(
CodeGenerator generator,
string methodName,
List<string> argList,
string result,
string canonicalResultTypeName)
{
string log = $"{methodName}({string.Join(", ", argList)})";
if (result != null)
{
log += $" = {result}";
}
if (canonicalResultTypeName == TypeKernelResult)
{
generator.EnterScope($"if ({ResultCheckHelperName}({ResultVariableName}))");
GenerateLogPrint(generator, "Trace", "KernelSvc", log);
generator.LeaveScope();
generator.EnterScope("else");
GenerateLogPrint(generator, "Warning", "KernelSvc", log);
generator.LeaveScope();
}
else
{
GenerateLogPrint(generator, "Trace", "KernelSvc", log);
}
}
private static void GenerateLogPrint(CodeGenerator generator, string logLevel, string logClass, string log)
{
generator.AppendLine($"Logger.{logLevel}?.PrintMsg(LogClass.{logClass}, $\"{log}\");");
}
private static void GenerateDispatch(CodeGenerator generator, List<SyscallIdAndName> syscalls, string suffix)
{
generator.EnterScope($"public static void Dispatch{suffix}(Syscall syscall, {TypeExecutionContext} context, int id)");
generator.EnterScope("switch (id)");
foreach (var syscall in syscalls)
{
generator.AppendLine($"case {syscall.Id}:");
generator.IncreaseIndentation();
generator.AppendLine($"{syscall.Name}{suffix}(syscall, context);");
generator.AppendLine("break;");
generator.DecreaseIndentation();
}
generator.AppendLine($"default:");
generator.IncreaseIndentation();
generator.AppendLine("throw new NotImplementedException($\"SVC 0x{id:X4} is not implemented.\");");
generator.DecreaseIndentation();
generator.LeaveScope();
generator.LeaveScope();
}
private static bool Is32BitInteger(string canonicalTypeName)
{
return canonicalTypeName == TypeSystemInt32 || canonicalTypeName == TypeSystemUInt32;
}
private static bool Is64BitInteger(string canonicalTypeName)
{
return canonicalTypeName == TypeSystemInt64 || canonicalTypeName == TypeSystemUInt64;
}
private static string GenerateCastFromUInt64(string value, string canonicalTargetTypeName, string targetTypeName)
{
if (canonicalTargetTypeName == TypeSystemBoolean)
{
return $"({value} & 1) != 0";
}
return $"({targetTypeName}){value}";
}
private static bool IsPointerSized(Compilation compilation, ParameterSyntax parameterSyntax)
{
foreach (var attributeList in parameterSyntax.AttributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (GetCanonicalTypeName(compilation, attribute) == TypePointerSizedAttribute)
{
return true;
}
}
}
return false;
}
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new SyscallSyntaxReceiver());
}
}
}

View File

@ -1,54 +0,0 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
namespace Ryujinx.Horizon.Generators.Kernel
{
class SyscallSyntaxReceiver : ISyntaxReceiver
{
public List<MethodDeclarationSyntax> SvcImplementations { get; }
public SyscallSyntaxReceiver()
{
SvcImplementations = new List<MethodDeclarationSyntax>();
}
public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is ClassDeclarationSyntax classDeclaration && classDeclaration.AttributeLists.Count != 0)
{
foreach (var attributeList in classDeclaration.AttributeLists)
{
if (attributeList.Attributes.Any(x => x.Name.GetText().ToString() == "SvcImpl"))
{
foreach (var memberDeclaration in classDeclaration.Members)
{
if (memberDeclaration is MethodDeclarationSyntax methodDeclaration)
{
VisitMethod(methodDeclaration);
}
}
break;
}
}
}
}
private void VisitMethod(MethodDeclarationSyntax methodDeclaration)
{
if (methodDeclaration.AttributeLists.Count != 0)
{
foreach (var attributeList in methodDeclaration.AttributeLists)
{
if (attributeList.Attributes.Any(x => x.Name.GetText().ToString() == "Svc"))
{
SvcImplementations.Add(methodDeclaration);
break;
}
}
}
}
}
}