// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Collections; using System.Collections.Specialized; using System.Runtime.InteropServices; using Microsoft.Win32.SafeHandles; namespace System.DirectoryServices.ActiveDirectory { public class ForestTrustRelationshipInformation : TrustRelationshipInformation { private TopLevelNameCollection _topLevelNames = new TopLevelNameCollection(); private StringCollection _excludedNames = new StringCollection(); private ForestTrustDomainInfoCollection _domainInfo = new ForestTrustDomainInfoCollection(); private ArrayList _binaryData = new ArrayList(); private ArrayList _binaryRecordType = new ArrayList(); private Hashtable _excludedNameTime = new Hashtable(); private ArrayList _binaryDataTime = new ArrayList(); internal bool retrieved; internal ForestTrustRelationshipInformation(DirectoryContext context, string source, Interop.Netapi32.DS_DOMAIN_TRUSTS unmanagedTrust, TrustType type) { string? tmpDNSName = null; string? tmpNetBIOSName = null; // security context this.context = context; // source this.source = source; // target if (unmanagedTrust.DnsDomainName != (IntPtr)0) tmpDNSName = Marshal.PtrToStringUni(unmanagedTrust.DnsDomainName); if (unmanagedTrust.NetbiosDomainName != (IntPtr)0) tmpNetBIOSName = Marshal.PtrToStringUni(unmanagedTrust.NetbiosDomainName); this.target = tmpDNSName ?? tmpNetBIOSName; // direction if ((unmanagedTrust.Flags & Interop.Netapi32.DS_DOMAINTRUST_FLAG.DS_DOMAIN_DIRECT_OUTBOUND) != 0 && (unmanagedTrust.Flags & Interop.Netapi32.DS_DOMAINTRUST_FLAG.DS_DOMAIN_DIRECT_INBOUND) != 0) direction = TrustDirection.Bidirectional; else if ((unmanagedTrust.Flags & Interop.Netapi32.DS_DOMAINTRUST_FLAG.DS_DOMAIN_DIRECT_OUTBOUND) != 0) direction = TrustDirection.Outbound; else if ((unmanagedTrust.Flags & Interop.Netapi32.DS_DOMAINTRUST_FLAG.DS_DOMAIN_DIRECT_INBOUND) != 0) direction = TrustDirection.Inbound; // type this.type = type; } public TopLevelNameCollection TopLevelNames { get { if (!retrieved) GetForestTrustInfoHelper(); return _topLevelNames; } } public StringCollection ExcludedTopLevelNames { get { if (!retrieved) GetForestTrustInfoHelper(); return _excludedNames; } } public ForestTrustDomainInfoCollection TrustedDomainInformation { get { if (!retrieved) GetForestTrustInfoHelper(); return _domainInfo; } } public unsafe void Save() { int count = 0; IntPtr records = (IntPtr)0; int currentCount = 0; IntPtr tmpPtr = (IntPtr)0; IntPtr forestInfo = (IntPtr)0; SafeLsaPolicyHandle? handle = null; IntPtr collisionInfo = (IntPtr)0; ArrayList ptrList = new ArrayList(); ArrayList sidList = new ArrayList(); bool impersonated = false; IntPtr target = (IntPtr)0; string? serverName = null; IntPtr fileTime = (IntPtr)0; // first get the count of all the records int toplevelNamesCount = TopLevelNames.Count; int excludedNamesCount = ExcludedTopLevelNames.Count; int trustedDomainCount = TrustedDomainInformation.Count; int binaryDataCount = _binaryData.Count; checked { count += toplevelNamesCount; count += excludedNamesCount; count += trustedDomainCount; count += binaryDataCount; // allocate the memory for all the records records = Marshal.AllocHGlobal(count * IntPtr.Size); } try { try { IntPtr ptr = 0; fileTime = Marshal.AllocHGlobal(Marshal.SizeOf<FileTime>()); Interop.Kernel32.GetSystemTimeAsFileTime(fileTime); // set the time FileTime currentTime = new FileTime(); Marshal.PtrToStructure(fileTime, currentTime); for (int i = 0; i < toplevelNamesCount; i++) { // now begin to construct top leve name record LSA_FOREST_TRUST_RECORD record = new LSA_FOREST_TRUST_RECORD(); record.Flags = (int)_topLevelNames[i].Status; record.ForestTrustType = LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustTopLevelName; TopLevelName TLN = _topLevelNames[i]; record.Time = TLN.time; ptr = Marshal.StringToHGlobalUni(TLN.Name); ptrList.Add(ptr); Interop.NtDll.RtlInitUnicodeString(out record.TopLevelName, ptr); tmpPtr = Marshal.AllocHGlobal(Marshal.SizeOf<LSA_FOREST_TRUST_RECORD>()); ptrList.Add(tmpPtr); Marshal.StructureToPtr(record, tmpPtr, false); Marshal.WriteIntPtr(records, IntPtr.Size * currentCount, tmpPtr); currentCount++; } for (int i = 0; i < excludedNamesCount; i++) { // now begin to construct excluded top leve name record LSA_FOREST_TRUST_RECORD record = new LSA_FOREST_TRUST_RECORD(); record.Flags = 0; record.ForestTrustType = LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustTopLevelNameEx; if (_excludedNameTime.Contains(_excludedNames[i]!)) { record.Time = (LARGE_INTEGER)_excludedNameTime[i]!; } else { record.Time = new LARGE_INTEGER(); record.Time.lowPart = currentTime.lower; record.Time.highPart = currentTime.higher; } ptr = Marshal.StringToHGlobalUni(_excludedNames[i]); ptrList.Add(ptr); Interop.NtDll.RtlInitUnicodeString(out record.TopLevelName, ptr); tmpPtr = Marshal.AllocHGlobal(Marshal.SizeOf<LSA_FOREST_TRUST_RECORD>()); ptrList.Add(tmpPtr); Marshal.StructureToPtr(record, tmpPtr, false); Marshal.WriteIntPtr(records, IntPtr.Size * currentCount, tmpPtr); currentCount++; } for (int i = 0; i < trustedDomainCount; i++) { // now begin to construct domain info record LSA_FOREST_TRUST_RECORD record = new LSA_FOREST_TRUST_RECORD(); record.Flags = (int)_domainInfo[i].Status; record.ForestTrustType = LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustDomainInfo; ForestTrustDomainInformation tmp = _domainInfo[i]; record.Time = tmp.time; void* pSid = null; global::Interop.BOOL result = global::Interop.Advapi32.ConvertStringSidToSid(tmp.DomainSid, out pSid); if (result == global::Interop.BOOL.FALSE) { throw ExceptionHelper.GetExceptionFromErrorCode(Marshal.GetLastPInvokeError()); } record.DomainInfo.sid = (IntPtr)pSid; sidList.Add((IntPtr)pSid); record.DomainInfo.DNSNameBuffer = Marshal.StringToHGlobalUni(tmp.DnsName); ptrList.Add(record.DomainInfo.DNSNameBuffer); record.DomainInfo.DNSNameLength = (short)(tmp.DnsName == null ? 0 : tmp.DnsName.Length * 2); // sizeof(WCHAR) record.DomainInfo.DNSNameMaximumLength = (short)(tmp.DnsName == null ? 0 : tmp.DnsName.Length * 2); record.DomainInfo.NetBIOSNameBuffer = Marshal.StringToHGlobalUni(tmp.NetBiosName); ptrList.Add(record.DomainInfo.NetBIOSNameBuffer); record.DomainInfo.NetBIOSNameLength = (short)(tmp.NetBiosName == null ? 0 : tmp.NetBiosName.Length * 2); record.DomainInfo.NetBIOSNameMaximumLength = (short)(tmp.NetBiosName == null ? 0 : tmp.NetBiosName.Length * 2); tmpPtr = Marshal.AllocHGlobal(Marshal.SizeOf<LSA_FOREST_TRUST_RECORD>()); ptrList.Add(tmpPtr); Marshal.StructureToPtr(record, tmpPtr, false); Marshal.WriteIntPtr(records, IntPtr.Size * currentCount, tmpPtr); currentCount++; } for (int i = 0; i < binaryDataCount; i++) { LSA_FOREST_TRUST_RECORD record = new LSA_FOREST_TRUST_RECORD(); record.Flags = 0; record.Time = (LARGE_INTEGER)_binaryDataTime[i]!; record.Data.Length = ((byte[])_binaryData[i]!).Length; record.ForestTrustType = (LSA_FOREST_TRUST_RECORD_TYPE)_binaryRecordType[i]!; if (record.Data.Length == 0) { record.Data.Buffer = (IntPtr)0; } else { record.Data.Buffer = Marshal.AllocHGlobal(record.Data.Length); ptrList.Add(record.Data.Buffer); Marshal.Copy((byte[])_binaryData[i]!, 0, record.Data.Buffer, record.Data.Length); } tmpPtr = Marshal.AllocHGlobal(Marshal.SizeOf<LSA_FOREST_TRUST_RECORD>()); ptrList.Add(tmpPtr); Marshal.StructureToPtr(record, tmpPtr, false); Marshal.WriteIntPtr(records, IntPtr.Size * currentCount, tmpPtr); currentCount++; } // finally construct the LSA_FOREST_TRUST_INFORMATION LSA_FOREST_TRUST_INFORMATION trustInformation = new LSA_FOREST_TRUST_INFORMATION(); trustInformation.RecordCount = count; trustInformation.Entries = records; forestInfo = Marshal.AllocHGlobal(Marshal.SizeOf<LSA_FOREST_TRUST_INFORMATION>()); Marshal.StructureToPtr(trustInformation, forestInfo, false); // get policy server name serverName = Utils.GetPolicyServerName(context, true, true, SourceName); // do impersonation first impersonated = Utils.Impersonate(context); // get the policy handle handle = Utils.GetPolicyHandle(serverName); // get the target name global::Interop.UNICODE_STRING trustedDomainName; target = Marshal.StringToHGlobalUni(TargetName); Interop.NtDll.RtlInitUnicodeString(out trustedDomainName, target); // call the unmanaged function uint error = Interop.Advapi32.LsaSetForestTrustInformation(handle, trustedDomainName, forestInfo, true, out collisionInfo); if (error != 0) { throw ExceptionHelper.GetExceptionFromErrorCode((int)global::Interop.Advapi32.LsaNtStatusToWinError(error), serverName); } // there is collision, throw proper exception so user can deal with it if (collisionInfo != (IntPtr)0) { throw ExceptionHelper.CreateForestTrustCollisionException(collisionInfo); } // commit the changes error = Interop.Advapi32.LsaSetForestTrustInformation(handle, trustedDomainName, forestInfo, false, out collisionInfo); if (error != 0) { throw ExceptionHelper.GetExceptionFromErrorCode((int)error, serverName); } // now next time property is invoked, we need to go to the server retrieved = false; } finally { if (impersonated) Utils.Revert(); // release the memory for (int i = 0; i < ptrList.Count; i++) { Marshal.FreeHGlobal((IntPtr)ptrList[i]!); } for (int i = 0; i < sidList.Count; i++) { global::Interop.Kernel32.LocalFree((void*)(nint)sidList[i]!); } if (records != (IntPtr)0) { Marshal.FreeHGlobal(records); } if (forestInfo != (IntPtr)0) { Marshal.FreeHGlobal(forestInfo); } if (collisionInfo != (IntPtr)0) global::Interop.Advapi32.LsaFreeMemory(collisionInfo); if (target != (IntPtr)0) Marshal.FreeHGlobal(target); if (fileTime != (IntPtr)0) Marshal.FreeHGlobal(fileTime); } } catch { throw; } } private unsafe void GetForestTrustInfoHelper() { IntPtr forestTrustInfo = (IntPtr)0; SafeLsaPolicyHandle? handle = null; bool impersonated = false; IntPtr targetPtr = (IntPtr)0; string? serverName = null; TopLevelNameCollection tmpTLNs = new TopLevelNameCollection(); StringCollection tmpExcludedTLNs = new StringCollection(); ForestTrustDomainInfoCollection tmpDomainInformation = new ForestTrustDomainInfoCollection(); // internal members ArrayList tmpBinaryData = new ArrayList(); Hashtable tmpExcludedNameTime = new Hashtable(); ArrayList tmpBinaryDataTime = new ArrayList(); ArrayList tmpBinaryRecordType = new ArrayList(); try { try { // get the target name global::Interop.UNICODE_STRING tmpName; targetPtr = Marshal.StringToHGlobalUni(TargetName); Interop.NtDll.RtlInitUnicodeString(out tmpName, targetPtr); serverName = Utils.GetPolicyServerName(context, true, false, source); // do impersonation impersonated = Utils.Impersonate(context); // get the policy handle handle = Utils.GetPolicyHandle(serverName); uint result = Interop.Advapi32.LsaQueryForestTrustInformation(handle, tmpName, ref forestTrustInfo); // check the result if (result != 0) { uint win32Error = global::Interop.Advapi32.LsaNtStatusToWinError(result); if (win32Error != 0) { throw ExceptionHelper.GetExceptionFromErrorCode((int)win32Error, serverName); } } try { if (forestTrustInfo != (IntPtr)0) { LSA_FOREST_TRUST_INFORMATION trustInfo = new LSA_FOREST_TRUST_INFORMATION(); Marshal.PtrToStructure(forestTrustInfo, trustInfo); int count = trustInfo.RecordCount; IntPtr addr = (IntPtr)0; for (int i = 0; i < count; i++) { addr = Marshal.ReadIntPtr(trustInfo.Entries, i * IntPtr.Size); LSA_FOREST_TRUST_RECORD record = new LSA_FOREST_TRUST_RECORD(); Marshal.PtrToStructure(addr, record); if (record.ForestTrustType == LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustTopLevelName) { IntPtr myPtr = IntPtr.Add(addr, 16); record.TopLevelName = *(global::Interop.UNICODE_STRING*)myPtr; TopLevelName TLN = new TopLevelName(record.Flags, record.TopLevelName, record.Time); tmpTLNs.Add(TLN); } else if (record.ForestTrustType == LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustTopLevelNameEx) { // get the excluded TLN and put it in our collection IntPtr myPtr = IntPtr.Add(addr, 16); record.TopLevelName = *(global::Interop.UNICODE_STRING*)myPtr; string excludedName = Marshal.PtrToStringUni(record.TopLevelName.Buffer, record.TopLevelName.Length / 2); tmpExcludedTLNs.Add(excludedName); tmpExcludedNameTime.Add(excludedName, record.Time); } else if (record.ForestTrustType == LSA_FOREST_TRUST_RECORD_TYPE.ForestTrustDomainInfo) { IntPtr myPtr = IntPtr.Add(addr, 16); record.DomainInfo = *(LSA_FOREST_TRUST_DOMAIN_INFO*)myPtr; ForestTrustDomainInformation dom = new ForestTrustDomainInformation(record.Flags, record.DomainInfo!, record.Time); tmpDomainInformation.Add(dom); } else { IntPtr myPtr = IntPtr.Add(addr, 16); record.Data = *(LSA_FOREST_TRUST_BINARY_DATA*)myPtr; int length = record.Data.Length; byte[] byteArray = new byte[length]; if ((record.Data.Buffer != (IntPtr)0) && (length != 0)) { Marshal.Copy(record.Data.Buffer, byteArray, 0, length); } tmpBinaryData.Add(byteArray); tmpBinaryDataTime.Add(record.Time); tmpBinaryRecordType.Add((int)record.ForestTrustType); } } } } finally { global::Interop.Advapi32.LsaFreeMemory(forestTrustInfo); } _topLevelNames = tmpTLNs; _excludedNames = tmpExcludedTLNs; _domainInfo = tmpDomainInformation; _binaryData = tmpBinaryData; _excludedNameTime = tmpExcludedNameTime; _binaryDataTime = tmpBinaryDataTime; _binaryRecordType = tmpBinaryRecordType; // mark it as retrieved retrieved = true; } finally { if (impersonated) Utils.Revert(); if (targetPtr != (IntPtr)0) { Marshal.FreeHGlobal(targetPtr); } } } catch { throw; } } } } |