Path: blob/trunk/dotnet/src/webdriver/DevTools/WebSocketConnection.cs
2885 views
// <copyright file="WebSocketConnection.cs" company="Selenium Committers"> // Licensed to the Software Freedom Conservancy (SFC) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The SFC licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // </copyright> using System; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; namespace OpenQA.Selenium.DevTools; /// <summary> /// Represents a connection to a WebDriver Bidi remote end. /// </summary> public class WebSocketConnection { private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(10); private readonly CancellationTokenSource clientTokenSource = new CancellationTokenSource(); private readonly TimeSpan startupTimeout; private readonly TimeSpan shutdownTimeout; private Task? dataReceiveTask; private ClientWebSocket client = new ClientWebSocket(); private readonly SemaphoreSlim sendMethodSemaphore = new SemaphoreSlim(1, 1); /// <summary> /// Initializes a new instance of the <see cref="WebSocketConnection" /> class. /// </summary> public WebSocketConnection() : this(DefaultTimeout) { } /// <summary> /// Initializes a new instance of the <see cref="WebSocketConnection" /> class with a given startup timeout. /// </summary> /// <param name="startupTimeout">The timeout before throwing an error when starting up the connection.</param> public WebSocketConnection(TimeSpan startupTimeout) : this(startupTimeout, DefaultTimeout) { } /// <summary> /// Initializes a new instance of the <see cref="WebSocketConnection" /> class with a given startup and shutdown timeout. /// </summary> /// <param name="startupTimeout">The timeout before throwing an error when starting up the connection.</param> /// <param name="shutdownTimeout">The timeout before throwing an error when shutting down the connection.</param> public WebSocketConnection(TimeSpan startupTimeout, TimeSpan shutdownTimeout) { this.startupTimeout = startupTimeout; this.shutdownTimeout = shutdownTimeout; } /// <summary> /// Occurs when data is received from this connection. /// </summary> public event EventHandler<WebSocketConnectionDataReceivedEventArgs>? DataReceived; /// <summary> /// Occurs when a log message is emitted from this connection. /// </summary> public event EventHandler<DevToolsSessionLogMessageEventArgs>? LogMessage; /// <summary> /// Gets a value indicating whether this connection is active. /// </summary> public bool IsActive { get; private set; } = false; /// <summary> /// Gets the buffer size for communication used by this connection. /// </summary> public int BufferSize { get; } = 4096; /// <summary> /// Asynchronously starts communication with the remote end of this connection. /// </summary> /// <param name="url">The URL used to connect to the remote end.</param> /// <returns>The task object representing the asynchronous operation.</returns> /// <exception cref="TimeoutException">Thrown when the connection is not established within the startup timeout.</exception> public virtual async Task Start(string url) { if (url is null) { throw new ArgumentNullException(nameof(url)); } this.Log($"Opening connection to URL {url}", DevToolsSessionLogLevel.Trace); bool connected = false; DateTime timeout = DateTime.Now.Add(this.startupTimeout); while (!connected && DateTime.Now <= timeout) { try { await this.client.ConnectAsync(new Uri(url), this.clientTokenSource.Token).ConfigureAwait(false); connected = true; } catch (WebSocketException) { // If the server-side socket is not yet ready, it leaves the client socket in a closed state, // which sees the object as disposed, so we must create a new one to try again await Task.Delay(TimeSpan.FromMilliseconds(500)).ConfigureAwait(false); this.client = new ClientWebSocket(); } } if (!connected) { throw new TimeoutException($"Could not connect to browser within {this.startupTimeout.TotalSeconds} seconds"); } this.dataReceiveTask = Task.Run(async () => await this.ReceiveData()); this.IsActive = true; this.Log($"Connection opened", DevToolsSessionLogLevel.Trace); } /// <summary> /// Asynchronously stops communication with the remote end of this connection. /// </summary> /// <returns>The task object representing the asynchronous operation.</returns> public virtual async Task Stop() { this.Log($"Closing connection", DevToolsSessionLogLevel.Trace); if (this.client.State != WebSocketState.Open) { this.Log($"Socket already closed (Socket state: {this.client.State})"); } else { await this.CloseClientWebSocket().ConfigureAwait(false); } // Whether we closed the socket or timed out, we cancel the token causing ReceiveAsync to abort the socket. // The finally block at the end of the processing loop will dispose of the ClientWebSocket object. this.clientTokenSource.Cancel(); if (this.dataReceiveTask != null) { await this.dataReceiveTask.ConfigureAwait(false); } this.client.Dispose(); } /// <summary> /// Asynchronously sends data to the remote end of this connection. /// </summary> /// <param name="data">The data to be sent to the remote end of this connection.</param> /// <returns>The task object representing the asynchronous operation.</returns> public virtual async Task SendData(string data) { if (data is null) { throw new ArgumentNullException(nameof(data)); } ArraySegment<byte> messageBuffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(data)); this.Log($"SEND >>> {data}"); await sendMethodSemaphore.WaitAsync().ConfigureAwait(false); try { await this.client.SendAsync(messageBuffer, WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None).ConfigureAwait(false); } finally { sendMethodSemaphore.Release(); } } /// <summary> /// Asynchronously closes the client WebSocket. /// </summary> /// <returns>The task object representing the asynchronous operation.</returns> protected virtual async Task CloseClientWebSocket() { // Close the socket first, because ReceiveAsync leaves an invalid socket (state = aborted) when the token is cancelled CancellationTokenSource timeout = new CancellationTokenSource(this.shutdownTimeout); try { // After this, the socket state which change to CloseSent await this.client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Closing", timeout.Token).ConfigureAwait(false); // Now we wait for the server response, which will close the socket while (this.client.State != WebSocketState.Closed && this.client.State != WebSocketState.Aborted && !timeout.Token.IsCancellationRequested) { // The loop may be too tight for the cancellation token to get triggered, so add a small delay await Task.Delay(TimeSpan.FromMilliseconds(10)).ConfigureAwait(false); } this.Log($"Client state is {this.client.State}", DevToolsSessionLogLevel.Trace); } catch (OperationCanceledException) { // An OperationCanceledException is normal upon task/token cancellation, so disregard it } catch (WebSocketException e) { this.Log($"Unexpected error during attempt at close: {e.Message}", DevToolsSessionLogLevel.Error); } } /// <summary> /// Raises the DataReceived event. /// </summary> /// <param name="e">The event args used when raising the event.</param> protected virtual void OnDataReceived(WebSocketConnectionDataReceivedEventArgs e) { if (this.DataReceived != null) { this.DataReceived(this, e); } } /// <summary> /// Raises the LogMessage event. /// </summary> /// <param name="e">The event args used when raising the event.</param> protected virtual void OnLogMessage(DevToolsSessionLogMessageEventArgs e) { if (this.LogMessage != null) { this.LogMessage(this, e); } } private async Task ReceiveData() { CancellationToken cancellationToken = this.clientTokenSource.Token; try { StringBuilder messageBuilder = new StringBuilder(); ArraySegment<byte> buffer = WebSocket.CreateClientBuffer(this.BufferSize, this.BufferSize); while (this.client.State != WebSocketState.Closed && !cancellationToken.IsCancellationRequested) { WebSocketReceiveResult receiveResult = await this.client.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false); // If the token is cancelled while ReceiveAsync is blocking, the socket state changes to aborted and it can't be used if (!cancellationToken.IsCancellationRequested) { // The server is notifying us that the connection will close, and we did // not initiate the close; send acknowledgement if (receiveResult.MessageType == WebSocketMessageType.Close && this.client.State != WebSocketState.Closed && this.client.State != WebSocketState.CloseSent) { this.Log($"Acknowledging Close frame received from server (client state: {this.client.State})", DevToolsSessionLogLevel.Trace); await this.client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Acknowledge Close frame", CancellationToken.None).ConfigureAwait(false); } // Display text or binary data if (this.client.State == WebSocketState.Open && receiveResult.MessageType != WebSocketMessageType.Close) { messageBuilder.Append(Encoding.UTF8.GetString(buffer.Array!, 0, receiveResult.Count)); if (receiveResult.EndOfMessage) { string message = messageBuilder.ToString(); messageBuilder = new StringBuilder(); if (message.Length > 0) { this.Log($"RECV <<< {message}"); this.OnDataReceived(new WebSocketConnectionDataReceivedEventArgs(message)); } } } } } this.Log($"Ending processing loop in state {this.client.State}", DevToolsSessionLogLevel.Trace); } catch (OperationCanceledException) { // An OperationCanceledException is normal upon task/token cancellation, so disregard it } catch (WebSocketException e) { this.Log($"Unexpected error during receive of data: {e.Message}", DevToolsSessionLogLevel.Error); } finally { this.IsActive = false; } } private void Log(string message) { this.Log(message, DevToolsSessionLogLevel.Trace); } private void Log(string message, DevToolsSessionLogLevel level) { this.OnLogMessage(new DevToolsSessionLogMessageEventArgs(level, "[{0}] {1}", "Connection", message)); } }