honk_rpc/
honk_rpc.rs

1// standard
2use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::io::{Cursor, ErrorKind};
5#[cfg(test)]
6use std::net::{SocketAddr, TcpListener, TcpStream};
7use std::option::Option;
8
9// extern crates
10use bson::doc;
11use bson::document::ValueAccessError;
12
13use crate::byte_counter::ByteCounter;
14
15/// Represents various error codes that can be present in a Honk-RPC `error_section`
16#[derive(Debug, Eq, PartialEq)]
17pub enum ErrorCode {
18    /// Failure to parse a received BSON document.
19    BsonParseFailed,
20    /// Received message document was too big; the default maximum message size
21    /// is 4096 bytes, but can be adjusted.
22    MessageTooBig,
23    /// Received message document missing required fields.
24    MessageParseFailed,
25    /// Received message contained version the receiver cannot handle.
26    MessageVersionIncompatible,
27    /// Section in received message contains unknown id.
28    SectionIdUnknown,
29    /// Section in received message missing required field, or provided
30    /// field is wrong datatype.
31    SectionParseFailed,
32    /// Provided request cookie is already in use.
33    RequestCookieInvalid,
34    /// Provided request namespace does not exist.
35    RequestNamespaceInvalid,
36    /// Provided request function does not exist within the provided namespace.
37    RequestFunctionInvalid,
38    /// Provided request version does not exist.
39    RequestVersionInvalid,
40    /// Provided response cookie is not recognized.
41    ResponseCookieInvalid,
42    /// Provided response state is not valid.
43    ResponseStateInvalid,
44    /// Represents an application-specific runtime error with a specific error code.
45    Runtime(i32),
46    /// Represents an unknown error with a specific error code.
47    Unknown(i32),
48}
49
50/// The error type for the `Session` type.
51#[derive(thiserror::Error, Debug)]
52pub enum Error {
53    /// Failed to read data from read stream due to `std::io::Error`
54    #[error("failed to read data from read stream")]
55    ReaderReadFailed(#[source] std::io::Error),
56
57    /// Bson documents need to be at least 4 bytes long
58    #[error("received invalid bson document size header value of {0}, must be at least 4")]
59    BsonDocumentSizeTooSmall(i32),
60
61    /// Received Bson document header is larger than Session supports
62    #[error("received invalid bson document size header value of {0}, must be less than {1}")]
63    BsonDocumentSizeTooLarge(i32, i32),
64
65    /// Too much time has elapsed without receiving a message
66    #[error("waited longer than {} seconds for read", .0.as_secs_f32())]
67    MessageReadTimedOut(std::time::Duration),
68
69    /// Failed to parse bson message
70    #[error("failed to parse bson Message document")]
71    BsonDocumentParseFailed(#[source] bson::de::Error),
72
73    /// Failed to convert bson document to Honk-RPC message
74    #[error("failed to convert bson document to Message")]
75    MessageConversionFailed(#[source] crate::honk_rpc::ErrorCode),
76
77    /// Failed to serialise bson document
78    #[error("failed to serialize bson document")]
79    BsonWriteFailed(#[source] bson::ser::Error),
80
81    /// Failed to write data to write stream due to `std::io::Error`
82    #[error("failed to write data to write stream")]
83    WriterWriteFailed(#[source] std::io::Error),
84
85    /// Failed to flush data to write stream due to `std::io::Error`
86    #[error("failed to flush message to write stream")]
87    WriterFlushFailed(#[source] std::io::Error),
88
89    /// Received a Honk-RPC `error_section` without an associated request cookie
90    #[error("recieved error section without cookie")]
91    UnknownErrorSectionReceived(#[source] crate::honk_rpc::ErrorCode),
92
93    /// Attempted to define invalid maximum message size
94    #[error(
95        "tried to set invalid max message size; must be >=5 bytes and <= i32::MAX (2147483647)"
96    )]
97    InvalidMaxMesageSize(),
98
99    /// Attempted to send a Honk-RPC `section` that is too large to fit in a message
100    #[error("queued message section is too large to write; calculated size is {0} but must be less than {1}")]
101    SectionTooLarge(usize, usize),
102}
103
104impl From<i32> for ErrorCode {
105    fn from(value: i32) -> ErrorCode {
106        match value {
107            -1i32 => ErrorCode::BsonParseFailed,
108            -2i32 => ErrorCode::MessageTooBig,
109            -3i32 => ErrorCode::MessageParseFailed,
110            -4i32 => ErrorCode::MessageVersionIncompatible,
111            -5i32 => ErrorCode::SectionIdUnknown,
112            -6i32 => ErrorCode::SectionParseFailed,
113            -7i32 => ErrorCode::RequestCookieInvalid,
114            -8i32 => ErrorCode::RequestNamespaceInvalid,
115            -9i32 => ErrorCode::RequestFunctionInvalid,
116            -10i32 => ErrorCode::RequestVersionInvalid,
117            -11i32 => ErrorCode::ResponseCookieInvalid,
118            -12i32 => ErrorCode::ResponseStateInvalid,
119            value => {
120                if value > 0 {
121                    ErrorCode::Runtime(value)
122                } else {
123                    ErrorCode::Unknown(value)
124                }
125            }
126        }
127    }
128}
129
130impl From<ErrorCode> for i32 {
131    fn from(err: ErrorCode) -> Self {
132        match err {
133            ErrorCode::BsonParseFailed => -1i32,
134            ErrorCode::MessageTooBig => -2i32,
135            ErrorCode::MessageParseFailed => -3i32,
136            ErrorCode::MessageVersionIncompatible => -4i32,
137            ErrorCode::SectionIdUnknown => -5i32,
138            ErrorCode::SectionParseFailed => -6i32,
139            ErrorCode::RequestCookieInvalid => -7i32,
140            ErrorCode::RequestNamespaceInvalid => -8i32,
141            ErrorCode::RequestFunctionInvalid => -9i32,
142            ErrorCode::RequestVersionInvalid => -10i32,
143            ErrorCode::ResponseCookieInvalid => -11i32,
144            ErrorCode::ResponseStateInvalid => -12i32,
145            ErrorCode::Runtime(val) => val,
146            ErrorCode::Unknown(val) => val,
147        }
148    }
149}
150
151impl std::fmt::Display for ErrorCode {
152    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
153        match self {
154            ErrorCode::BsonParseFailed => write!(f, "ProtocolError: failed to parse BSON object"),
155            ErrorCode::MessageTooBig => write!(f, "ProtocolError: received document too large"),
156            ErrorCode::MessageParseFailed => {
157                write!(f, "ProtocolError: received message has invalid schema")
158            }
159            ErrorCode::MessageVersionIncompatible => write!(
160                f,
161                "ProtocolError: received message has incompatible version"
162            ),
163            ErrorCode::SectionIdUnknown => write!(
164                f,
165                "ProtocolError: received message contains section of unknown type"
166            ),
167            ErrorCode::SectionParseFailed => write!(
168                f,
169                "ProtocolError: received message contains section with invalid schema"
170            ),
171            ErrorCode::RequestCookieInvalid => {
172                write!(f, "ProtocolError: request cookie already in use")
173            }
174            ErrorCode::RequestNamespaceInvalid => write!(
175                f,
176                "ProtocolError: request function does not exist in requested namespace"
177            ),
178            ErrorCode::RequestFunctionInvalid => {
179                write!(f, "ProtocolError: request function does not exist")
180            }
181            ErrorCode::RequestVersionInvalid => {
182                write!(f, "ProtocolError: request function version does not exist")
183            }
184            ErrorCode::ResponseCookieInvalid => {
185                write!(f, "ProtocolError: response cookie is not recognized")
186            }
187            ErrorCode::ResponseStateInvalid => write!(f, "ProtocolError: response state not valid"),
188            ErrorCode::Runtime(code) => write!(f, "RuntimeError: runtime error {}", code),
189            ErrorCode::Unknown(code) => write!(f, "UnknownError: unknown error code {}", code),
190        }
191    }
192}
193
194impl std::error::Error for ErrorCode {}
195
196// Honk-RPC semver is packed into an i32
197const fn semver_to_i32(major: u8, minor: u8, patch: u8) -> i32 {
198    let major = major as i32;
199    let minor = minor as i32;
200    let patch = patch as i32;
201    (major << 16) | (minor << 8) | patch
202}
203
204const fn i32_to_semver(ver: i32) -> Option<(u8, u8, u8)> {
205    if ver >= 0 && ver <= 0xffffff {
206        let major = (ver & 0xff0000) >> 16;
207        let minor = (ver & 0xff00) >> 8;
208        let patch = ver & 0xff;
209        Some((major as u8, minor as u8, patch as u8))
210    } else {
211        None
212    }
213}
214
215// Honk-RPC version 0.1.0
216const HONK_RPC_VERSION: i32 = semver_to_i32(0, 1, 0);
217
218struct Message {
219    honk_rpc: i32,
220    sections: Vec<Section>,
221}
222
223impl TryFrom<bson::document::Document> for Message {
224    type Error = ErrorCode;
225
226    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
227        let mut value = value;
228
229        // verify version
230        let honk_rpc = match value.get_i32("honk_rpc") {
231            Ok(HONK_RPC_VERSION) => HONK_RPC_VERSION,
232            Ok(honk_rpc) => {
233                return if let Some(_version) = i32_to_semver(honk_rpc) {
234                    // some other semver we cannot handle
235                    Err(ErrorCode::MessageVersionIncompatible)
236                } else {
237                    // an invalid semver
238                    Err(ErrorCode::MessageParseFailed)
239                };
240            }
241            Err(_err) => return Err(ErrorCode::MessageParseFailed),
242        };
243
244        if let Ok(sections) = value.get_array_mut("sections") {
245            // messages must have at least one section
246            if sections.is_empty() {
247                return Err(ErrorCode::MessageParseFailed);
248            }
249
250            let mut message = Message {
251                honk_rpc,
252                sections: Default::default(),
253            };
254
255            for section in sections.iter_mut() {
256                if let bson::Bson::Document(section) = std::mem::take(section) {
257                    message.sections.push(Section::try_from(section)?);
258                } else {
259                    return Err(ErrorCode::SectionParseFailed);
260                }
261            }
262            Ok(message)
263        } else {
264            Err(ErrorCode::MessageParseFailed)
265        }
266    }
267}
268
269impl From<Message> for bson::document::Document {
270    fn from(value: Message) -> bson::document::Document {
271        let mut value = value;
272        let mut message = bson::document::Document::new();
273        message.insert("honk_rpc", value.honk_rpc);
274
275        let mut sections = bson::Array::new();
276        for section in value.sections.drain(0..) {
277            sections.push(bson::Bson::Document(bson::document::Document::from(
278                section,
279            )));
280        }
281        message.insert("sections", sections);
282
283        message
284    }
285}
286
287/// A type alias for the cookie used to track client requests.
288pub type RequestCookie = i64;
289
290const ERROR_SECTION_ID: i32 = 0i32;
291const REQUEST_SECTION_ID: i32 = 1i32;
292const RESPONSE_SECTION_ID: i32 = 2i32;
293
294enum Section {
295    Error(ErrorSection),
296    Request(RequestSection),
297    Response(ResponseSection),
298}
299
300struct ErrorSection {
301    cookie: Option<RequestCookie>,
302    code: ErrorCode,
303    message: Option<String>,
304    data: Option<bson::Bson>,
305}
306
307struct RequestSection {
308    cookie: Option<RequestCookie>,
309    namespace: String,
310    function: String,
311    version: i32,
312    arguments: bson::document::Document,
313}
314
315#[repr(i32)]
316#[derive(Debug, PartialEq)]
317enum RequestState {
318    Pending = 0i32,
319    Complete = 1i32,
320}
321
322struct ResponseSection {
323    cookie: RequestCookie,
324    state: RequestState,
325    result: Option<bson::Bson>,
326}
327
328impl TryFrom<bson::document::Document> for Section {
329    type Error = ErrorCode;
330
331    fn try_from(
332        value: bson::document::Document,
333    ) -> Result<Self, <Self as TryFrom<bson::document::Document>>::Error> {
334        match value.get_i32("id") {
335            Ok(ERROR_SECTION_ID) => Ok(Section::Error(ErrorSection::try_from(value)?)),
336            Ok(REQUEST_SECTION_ID) => Ok(Section::Request(RequestSection::try_from(value)?)),
337            Ok(RESPONSE_SECTION_ID) => Ok(Section::Response(ResponseSection::try_from(value)?)),
338            Ok(_) => Err(ErrorCode::SectionIdUnknown),
339            Err(_) => Err(ErrorCode::SectionParseFailed),
340        }
341    }
342}
343
344impl From<Section> for bson::document::Document {
345    fn from(value: Section) -> bson::document::Document {
346        match value {
347            Section::Error(section) => bson::document::Document::from(section),
348            Section::Request(section) => bson::document::Document::from(section),
349            Section::Response(section) => bson::document::Document::from(section),
350        }
351    }
352}
353
354impl TryFrom<bson::document::Document> for ErrorSection {
355    type Error = ErrorCode;
356
357    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
358        let mut value = value;
359
360        let cookie = match value.get_i64("cookie") {
361            Ok(cookie) => Some(cookie),
362            Err(ValueAccessError::NotPresent) => None,
363            Err(_) => return Err(ErrorCode::SectionParseFailed),
364        };
365
366        let code = match value.get_i32("code") {
367            Ok(code) => ErrorCode::from(code),
368            Err(_) => return Err(ErrorCode::SectionParseFailed),
369        };
370
371        let message = match value.get_str("message") {
372            Ok(message) => Some(message.to_string()),
373            Err(ValueAccessError::NotPresent) => None,
374            Err(_) => return Err(ErrorCode::SectionParseFailed),
375        };
376
377        let data = value.get_mut("data").map(std::mem::take);
378
379        Ok(ErrorSection {
380            cookie,
381            code,
382            message,
383            data,
384        })
385    }
386}
387
388impl From<ErrorSection> for bson::document::Document {
389    fn from(value: ErrorSection) -> bson::document::Document {
390        let mut error_section = bson::document::Document::new();
391        error_section.insert("id", ERROR_SECTION_ID);
392
393        if let Some(cookie) = value.cookie {
394            error_section.insert("cookie", cookie);
395        }
396
397        error_section.insert("code", Into::<i32>::into(value.code));
398
399        if let Some(message) = value.message {
400            error_section.insert("message", message);
401        }
402
403        if let Some(data) = value.data {
404            error_section.insert("data", data);
405        }
406
407        error_section
408    }
409}
410
411impl TryFrom<bson::document::Document> for RequestSection {
412    type Error = ErrorCode;
413
414    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
415        let mut value = value;
416
417        let cookie = match value.get_i64("cookie") {
418            Ok(cookie) => Some(cookie),
419            Err(ValueAccessError::NotPresent) => None,
420            Err(_) => return Err(ErrorCode::SectionParseFailed),
421        };
422
423        let namespace = match value.get_str("namespace") {
424            Ok(namespace) => namespace.to_string(),
425            Err(ValueAccessError::NotPresent) => String::default(),
426            Err(_) => return Err(ErrorCode::SectionParseFailed),
427        };
428
429        let function = match value.get_str("function") {
430            Ok(function) => {
431                if function.is_empty() {
432                    return Err(ErrorCode::RequestFunctionInvalid);
433                } else {
434                    function.to_string()
435                }
436            }
437            Err(_) => return Err(ErrorCode::SectionParseFailed),
438        };
439
440        let version = match value.get_i32("version") {
441            Ok(version) => version,
442            Err(ValueAccessError::NotPresent) => 0i32,
443            Err(_) => return Err(ErrorCode::SectionParseFailed),
444        };
445
446        let arguments = match value.get_document_mut("arguments") {
447            Ok(arguments) => std::mem::take(arguments),
448            Err(ValueAccessError::NotPresent) => bson::document::Document::new(),
449            Err(_) => return Err(ErrorCode::SectionParseFailed),
450        };
451
452        Ok(RequestSection {
453            cookie,
454            namespace,
455            function,
456            version,
457            arguments,
458        })
459    }
460}
461
462impl From<RequestSection> for bson::document::Document {
463    fn from(value: RequestSection) -> bson::document::Document {
464        let mut request_section = bson::document::Document::new();
465        request_section.insert("id", REQUEST_SECTION_ID);
466
467        if let Some(cookie) = value.cookie {
468            request_section.insert("cookie", cookie);
469        }
470
471        if !value.namespace.is_empty() {
472            request_section.insert("namespace", value.namespace);
473        }
474
475        request_section.insert("function", value.function);
476
477        if value.version != 0i32 {
478            request_section.insert("version", value.version);
479        }
480
481        request_section.insert("arguments", value.arguments);
482
483        request_section
484    }
485}
486
487impl TryFrom<bson::document::Document> for ResponseSection {
488    type Error = ErrorCode;
489
490    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
491        let mut value = value;
492        let cookie = match value.get_i64("cookie") {
493            Ok(cookie) => cookie,
494            Err(_) => return Err(ErrorCode::SectionParseFailed),
495        };
496
497        let state = match value.get_i32("state") {
498            Ok(0i32) => RequestState::Pending,
499            Ok(1i32) => RequestState::Complete,
500            Ok(_) => return Err(ErrorCode::ResponseStateInvalid),
501            Err(_) => return Err(ErrorCode::SectionParseFailed),
502        };
503
504        let result = value.get_mut("result").map(std::mem::take);
505
506        // if pending there should be no result
507        if state == RequestState::Pending && result.is_some() {
508            return Err(ErrorCode::SectionParseFailed);
509        }
510
511        Ok(ResponseSection {
512            cookie,
513            state,
514            result,
515        })
516    }
517}
518
519impl From<ResponseSection> for bson::document::Document {
520    fn from(value: ResponseSection) -> bson::document::Document {
521        let mut response_section = bson::document::Document::new();
522        response_section.insert("id", RESPONSE_SECTION_ID);
523
524        response_section.insert("cookie", value.cookie);
525        response_section.insert("state", value.state as i32);
526
527        if let Some(result) = value.result {
528            response_section.insert("result", result);
529        }
530
531        response_section
532    }
533}
534
535/// The `ApiSet` trait represents a set of APIs that can be remotely invoked by a connecting Honk-RPC client.
536/// # Example
537/// This exampe `ApiSet` implements two methods, `example::println()` and `example::async_println()`. The
538/// `println()` method immediatley prints, whereas `async_println()` queues request and
539/// prints the messagge at a later date via `update()`
540///
541/// ```rust
542/// # use honk_rpc::honk_rpc::*;
543/// # use std::collections::VecDeque;
544///
545/// const RUNTIME_ERROR_INVALID_ARG: ErrorCode = ErrorCode::Runtime(1i32);
546///
547/// struct PrintlnApiSet {
548///     // queued print reuests
549///     async_println_work: Vec<(Option<RequestCookie>, String)>,
550///     // successful async requests
551///     async_println_cookies: VecDeque<RequestCookie>,
552/// }
553///
554/// impl PrintlnApiSet {
555///   // prints message immediately
556///   fn println_0(
557///       &mut self,
558///       mut args: bson::document::Document,
559///   ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
560///     if let Some(bson::Bson::String(val)) = args.get_mut("val") {
561///         println!("example::echo_0(val): '{}'", val);
562///         Some(Ok(Some(bson::Bson::String(std::mem::take(val)))))
563///     } else {
564///         Some(Err(RUNTIME_ERROR_INVALID_ARG))
565///     }
566///   }
567///
568///   // queues message up for printing later
569///   fn async_println_0(
570///       &mut self,
571///       request_cookie: Option<RequestCookie>,
572///       mut args: bson::document::Document,
573///   ) -> Option<Result<Option<bson::Bson>, ErrorCode>>{
574///     if let Some(bson::Bson::String(val)) = args.get_mut("val") {
575///         self.async_println_work.push((request_cookie, std::mem::take(val)));
576///         None
577///     } else {
578///         Some(Err(RUNTIME_ERROR_INVALID_ARG))
579///     }
580///   }
581/// }
582///
583/// impl ApiSet for PrintlnApiSet {
584///     fn namespace(&self) -> &str {
585///         "example"
586///     }
587///
588///     // handles and routes requests for `println` and `async_println`
589///     fn exec_function(
590///         &mut self,
591///         name: &str,
592///         version: i32,
593///         args: bson::document::Document,
594///         request_cookie: Option<RequestCookie>,
595///     ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
596///         match (name, version) {
597///             ("println", 0) => self.println_0(args),
598///             ("async_println", 0) => self.async_println_0(request_cookie, args),
599///             (name, version) => {
600///                 println!("received {{ name: '{}', version: {} }}", name, version);
601///                 Some(Err(ErrorCode::RequestFunctionInvalid))
602///             }
603///         }
604///     }
605///
606///     // handles queued `async_println` requests
607///     fn update(&mut self) {
608///         for ((cookie, val)) in self.async_println_work.drain(..) {
609///             println!("{}", val);
610///             if let Some(cookie) = cookie {
611///                 self.async_println_cookies.push_back(cookie);
612///             }
613///         }
614///     }
615///
616///     // finally return queued async results
617///     fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
618///         if let Some(cookie) = self.async_println_cookies.pop_front() {
619///             Some((cookie, Ok(None)))
620///         } else {
621///             None
622///         }
623///     }
624/// }
625///```
626pub trait ApiSet {
627    /// Returns the namespace of this `ApiSet`.
628    fn namespace(&self) -> &str;
629
630    /// Schedules the execution of the requested remote procedure call. Calls to this
631    /// function map directly to a received Honk-RPC request. Each request has the
632    /// following parameters:
633    /// - `name`: The name of the function to execute.
634    /// - `version`: The version of the function to execute.
635    /// - `args`: The arguments to pass to the function.
636    /// - `request_cookie`: An optional cookie to track the request.
637    ///
638    /// This function handles both synchronous and asynchronous requests. The possible
639    /// return values for each are:
640    /// - Synchronous requests may execute and signal success by returning `Some(Ok(..))`.
641    /// - Synchronous requests may execute and signal failure by returning `Some(Err(..))`.
642    /// - Asynchronous requests must defer execution by returning `None`.
643    fn exec_function(
644        &mut self,
645        name: &str,
646        version: i32,
647        args: bson::document::Document,
648        request_cookie: Option<RequestCookie>,
649    ) -> Option<Result<Option<bson::Bson>, ErrorCode>>;
650
651    /// Updates any internal state required to make forward progress on any requested
652    /// remote procedure calls. Implementation of this method is optional and not needed
653    /// if the implementor does not have any async functions. If left unimplemented, this
654    /// function is a no-op.
655    fn update(&mut self) {}
656
657    /// Returns the result of any in-flight asynchronous requests.
658    /// - Asynchronous requests may signal success by returning `Some((cookie, Ok(..)))`
659    /// - Asynchronous requests may signal failure by returning `Some((cookie, Err(..)))`
660    /// - returns None if no asynchronous results are available
661    ///
662    /// This method is optional and not needed if the implementor does not have any async
663    /// functions, in which case the default implementation will return `None`.
664    fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
665        None
666    }
667}
668
669/// Represents the response to a client request.
670pub enum Response {
671    /// A pending response, indicating that the request is still being processed.
672    Pending {
673        /// The cookie associated with the request.
674        cookie: RequestCookie,
675    },
676    /// A successful response, containing the result of the request.
677    Success {
678        /// The cookie associated with the request.
679        cookie: RequestCookie,
680        /// The result of the request.
681        result: Option<bson::Bson>,
682    },
683    /// An error response, containing the error code.
684    Error {
685        /// The cookie associated with the request.
686        cookie: RequestCookie,
687        /// The error code indicating the type of error that occurred.
688        error_code: ErrorCode,
689    },
690}
691
692// 4 kilobytes per specification
693/// The default maximum allowed Honk-RPC message (4096 bytes)
694pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024;
695/// The default maximum allowed duration between Honk-RPC (60 seconds)
696pub const DEFAULT_MAX_WAIT_TIME: std::time::Duration = std::time::Duration::from_secs(60);
697
698// Base Message Bson Format
699// document size             4 (sizeof i32 )
700const HEADER_SIZE: usize = 4usize;
701// "honk_rpc" : i32          1 (0x10) + 8 (strlen "honk_rpc") + 1 (null) + 4 (sizeof i32)
702const HONK_RPC_SIZE: usize = 14usize;
703// "sections" : {"0": Null}  1 (0x04) + 8 (strlen "sections") + 1 (null) + 4 (sizeof i32) + 1 (0x0a) + 1 (strlen "0") + 1 (null) + 1 (0x00)
704const SECTIONS_SIZE: usize = 18usize;
705// footer                    1 (0x00)
706const FOOTER_SIZE: usize = 1usize;
707
708// The honk-rpc message overhead before the content of a single section is added
709const MIN_MESSAGE_SIZE: usize = HEADER_SIZE + HONK_RPC_SIZE + SECTIONS_SIZE + FOOTER_SIZE;
710
711/// Computes the overhead of the Honk-RPC message type. This method in conjunction with
712/// the other `get_*_section_size(..)` functions can be used to compute the size of a
713/// Honk-RPC message with exactly one section.
714pub fn get_message_overhead() -> Result<usize, Error> {
715    // construct an example empty message; the size of a real message with
716    // one section can be calculated as the sizeof(message) + sizeof(section)
717    let message = doc! {
718        "honk_rpc" : HONK_RPC_VERSION,
719        "sections" : [
720            bson::Bson::Null
721        ]
722    };
723
724    let mut counter: ByteCounter = Default::default();
725    message
726        .to_writer(&mut counter)
727        .map_err(Error::BsonWriteFailed)?;
728
729    Ok(counter.bytes())
730}
731
732/// Computes the required size of a Honk-RPC error section in bytes.
733///
734/// Returns the size of the BSON-encoded error section. If BSON encoding fails,
735/// an `Error::BsonWriteFailed` is returned.
736pub fn get_error_section_size(
737    cookie: Option<RequestCookie>,
738    message: Option<String>,
739    data: Option<bson::Bson>,
740) -> Result<usize, Error> {
741    let mut error_section = doc! {
742        "id": ERROR_SECTION_ID,
743        "code": Into::<i32>::into(ErrorCode::Unknown(0)),
744    };
745
746    if let Some(cookie) = cookie {
747        error_section.insert("cookie", bson::Bson::Int64(cookie));
748    }
749
750    if let Some(message) = message {
751        error_section.insert("message", bson::Bson::String(message));
752    }
753
754    if let Some(data) = data {
755        error_section.insert("data", data);
756    }
757
758    let mut counter: ByteCounter = Default::default();
759    error_section
760        .to_writer(&mut counter)
761        .map_err(Error::BsonWriteFailed)?;
762
763    Ok(counter.bytes())
764}
765
766/// Computes the required size of a Honk-RPC requests section in bytes.
767///
768/// Returns the size of the BSON-encoded request section. If BSON encoding fails,
769/// an `Error::BsonWriteFailed` is returned.
770pub fn get_request_section_size(
771    cookie: Option<RequestCookie>,
772    namespace: Option<String>,
773    function: String,
774    version: Option<i32>,
775    arguments: Option<bson::Document>,
776) -> Result<usize, Error> {
777    let mut request_section = doc! {
778        "id": REQUEST_SECTION_ID,
779        "function": bson::Bson::String(function),
780    };
781
782    if let Some(cookie) = cookie {
783        request_section.insert("cookie", bson::Bson::Int64(cookie));
784    }
785
786    if let Some(namespace) = namespace {
787        request_section.insert("namespace", bson::Bson::String(namespace));
788    }
789
790    if let Some(version) = version {
791        request_section.insert("version", bson::Bson::Int32(version));
792    }
793
794    if let Some(arguments) = arguments {
795        request_section.insert("arguments", arguments);
796    }
797
798    let mut counter: ByteCounter = Default::default();
799    request_section
800        .to_writer(&mut counter)
801        .map_err(Error::BsonWriteFailed)?;
802
803    Ok(counter.bytes())
804}
805
806/// Computes the required size of a Honk-RPC response section in bytes.
807///
808/// Returns the size of the BSON-encoded response section. If BSON encoding fails,
809/// an `Error::BsonWriteFailed` is returned.
810pub fn get_response_section_size(result: Option<bson::Bson>) -> Result<usize, Error> {
811    let mut response_section = doc! {
812        "id": RESPONSE_SECTION_ID,
813        "cookie": bson::Bson::Int64(0),
814        "state": bson::Bson::Int32(0),
815    };
816
817    if let Some(result) = result {
818        response_section.insert("result", result);
819    }
820
821    let mut counter: ByteCounter = Default::default();
822    response_section
823        .to_writer(&mut counter)
824        .map_err(Error::BsonWriteFailed)?;
825
826    Ok(counter.bytes())
827}
828
829/// The object that handles the communication between two endpoints  using the
830/// Honk-RPC protocol. Provides methods for setting and getting configuration
831/// parameters, reading and processing message documents, and handling API
832/// requests and responses.
833pub struct Session<RW> {
834    // read-write stream
835    stream: RW,
836    // we write outgoing data to an intermediate buffer to handle writer blocking
837    message_write_buffer: VecDeque<u8>,
838
839    // message read data
840
841    // remaining number of bytes to read for current message
842    // if None, no message read is in progress
843    remaining_byte_count: Option<usize>,
844    // data we've read but not yet a full Message object
845    message_read_buffer: Vec<u8>,
846    // received sections to be handled
847    pending_sections: VecDeque<Section>,
848    // remote client's inbound remote procedure calls to local server
849    inbound_requests: Vec<RequestSection>,
850    // remote server's responses to local client's remote procedure calls
851    inbound_responses: VecDeque<Response>,
852
853    // message write data
854
855    // we serialize outgoing messages to this buffer first to verify size limitations
856    message_serialization_buffer: VecDeque<u8>,
857    // the next request cookie to use when making a remote prodedure call
858    next_cookie: RequestCookie,
859    // sections to be sent to the remote server
860    outbound_sections: Vec<bson::Document>,
861
862    // the maximum size of a message we've agreed to allow in the session
863    max_message_size: usize,
864    // the maximum amount of time the session is willing to wait to receive a message
865    // before terminating the session
866    max_wait_time: std::time::Duration,
867    // last time a new message read began
868    read_timestamp: std::time::Instant,
869}
870
871#[allow(dead_code)]
872impl<RW> Session<RW>
873where
874    RW: std::io::Read + std::io::Write + Send,
875{
876    /// Sets the maximum message size this `Session` is willing to read from from the underlying `RW`. Attempted reads  will abort if the next bson document's `i32` size field is greater than the `max_message_size` defined in this function.
877    pub fn set_max_message_size(&mut self, max_message_size: i32) -> Result<(), Error> {
878        if max_message_size < MIN_MESSAGE_SIZE as i32 {
879            // base size of a honk-rpc mssage
880            Err(Error::InvalidMaxMesageSize())
881        } else {
882            self.max_message_size = max_message_size as usize;
883            Ok(())
884        }
885    }
886
887    /// Gets the maximum allowed message size this `Session` is willing to read from the underlying `RW`. The default value is 4096 bytes.
888    pub fn get_max_message_size(&self) -> usize {
889        self.max_message_size
890    }
891
892    /// Sets the maximum amount of time this `Session` is willing to wait for a new Honk-RPC message on the underlying `RW`. `Session` updates will fil after `max_wait_time` has elapsed without receiving any new Honk-RPC message documents.
893    pub fn set_max_wait_time(&mut self, max_wait_time: std::time::Duration) {
894        self.max_wait_time = max_wait_time;
895    }
896
897    /// Gets the maximum amount this `Session` is willing to wait for a new Honk-RPC message. The default value is 60 seconds.
898    pub fn get_max_wait_time(&self) -> std::time::Duration {
899        self.max_wait_time
900    }
901
902    /// Creates a new `Session` using the given `stream`.
903    pub fn new(stream: RW) -> Self {
904        let mut message_write_buffer: VecDeque<u8> = Default::default();
905        message_write_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
906
907        let mut message_serialization_buffer: VecDeque<u8> = Default::default();
908        message_serialization_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
909
910        Session {
911            stream,
912            message_write_buffer,
913            remaining_byte_count: None,
914            message_read_buffer: Default::default(),
915            pending_sections: Default::default(),
916            inbound_requests: Default::default(),
917            inbound_responses: Default::default(),
918            message_serialization_buffer,
919            next_cookie: Default::default(),
920            outbound_sections: Default::default(),
921            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
922            max_wait_time: DEFAULT_MAX_WAIT_TIME,
923            read_timestamp: std::time::Instant::now(),
924        }
925    }
926
927    /// Consumes the `Session` and returns the underlying stream.
928    pub fn into_stream(self) -> RW {
929        self.stream
930    }
931
932    // read a block of bytes from the undelrying stream
933    fn stream_read(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
934        match self.stream.read(buffer) {
935            Err(err) => {
936                if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut {
937                    // abort if we've gone too long without a new message
938                    if std::time::Instant::now().duration_since(self.read_timestamp)
939                        > self.max_wait_time
940                    {
941                        Err(Error::MessageReadTimedOut(self.max_wait_time))
942                    } else {
943                        Ok(0)
944                    }
945                } else {
946                    Err(Error::ReaderReadFailed(err))
947                }
948            }
949            Ok(0) => Err(Error::ReaderReadFailed(std::io::Error::from(
950                ErrorKind::UnexpectedEof,
951            ))),
952            Ok(count) => {
953                // update read_timestamp
954                self.read_timestamp = std::time::Instant::now();
955                Ok(count)
956            }
957        }
958    }
959
960    // read the next block of bytes as a bson document size header
961    fn read_message_size(&mut self) -> Result<(), Error> {
962        match self.remaining_byte_count {
963            // we've already read the size header
964            Some(_remaining) => Ok(()),
965            // still need to read the size header
966            None => {
967                // may have been partially read already so ensure it's the right size
968                assert!(self.message_read_buffer.len() < std::mem::size_of::<i32>());
969                let bytes_needed = std::mem::size_of::<i32>() - self.message_read_buffer.len();
970                // ensure we have enough space for an entire int32
971                let mut buffer = [0u8; std::mem::size_of::<i32>()];
972                // but shrink view down to number of bytes remaining
973                let buffer = &mut buffer[0..bytes_needed];
974                match self.stream_read(buffer) {
975                    Err(err) => Err(err),
976                    Ok(0) => Ok(()),
977                    Ok(count) => {
978                        #[cfg(test)]
979                        println!("<<< read {} bytes for message header", count);
980                        self.message_read_buffer
981                            .extend_from_slice(&buffer[0..count]);
982
983                        // all bytes required for i32 message size have been read
984                        if self.message_read_buffer.len() == std::mem::size_of::<i32>() {
985                            let size = &self.message_read_buffer.as_slice();
986                            let size: i32 = (size[0] as i32)
987                                | (size[1] as i32) << 8
988                                | (size[2] as i32) << 16
989                                | (size[3] as i32) << 24;
990                            // size should be at least larger than the bytes required for size header
991                            if size <= std::mem::size_of::<i32>() as i32 {
992                                return Err(Error::BsonDocumentSizeTooSmall(size));
993                            }
994                            // convert to usize type now that we know it's not negative
995                            if size as usize > self.max_message_size {
996                                return Err(Error::BsonDocumentSizeTooLarge(
997                                    size,
998                                    self.max_message_size as i32,
999                                ));
1000                            }
1001
1002                            // deduct size of i32 header and save
1003                            let size = size as usize - std::mem::size_of::<i32>();
1004
1005                            self.remaining_byte_count = Some(size);
1006                        }
1007                        Ok(())
1008                    }
1009                }
1010            }
1011        }
1012    }
1013
1014    // read the remainder of a bson message
1015    fn read_message(&mut self) -> Result<Option<Message>, Error> {
1016        // update remaining bytes to read for message
1017        self.read_message_size()?;
1018        // read the message bytes
1019        if let Some(remaining) = self.remaining_byte_count {
1020            #[cfg(test)]
1021            println!("--- message requires {} more bytes", remaining);
1022
1023            let mut buffer = vec![0u8; remaining];
1024            match self.stream_read(&mut buffer) {
1025                Err(err) => Err(err),
1026                Ok(0) => Ok(None),
1027                Ok(count) => {
1028                    #[cfg(test)]
1029                    println!("<<< read {} bytes", count);
1030                    // append read bytes
1031                    self.message_read_buffer
1032                        .extend_from_slice(&buffer[0..count]);
1033                    if remaining == count {
1034                        self.remaining_byte_count = None;
1035
1036                        let mut cursor = Cursor::new(std::mem::take(&mut self.message_read_buffer));
1037                        let bson = bson::document::Document::from_reader(&mut cursor)
1038                            .map_err(Error::BsonDocumentParseFailed)?;
1039
1040                        // take back our allocated vec and clear it
1041                        self.message_read_buffer = cursor.into_inner();
1042                        self.message_read_buffer.clear();
1043
1044                        #[cfg(test)]
1045                        println!("<<< read message: {}", bson);
1046
1047                        Ok(Some(
1048                            Message::try_from(bson).map_err(Error::MessageConversionFailed)?,
1049                        ))
1050                    } else {
1051                        // update the remaining byte count
1052                        self.remaining_byte_count = Some(remaining - count);
1053                        Ok(None)
1054                    }
1055                }
1056            }
1057        } else {
1058            Ok(None)
1059        }
1060    }
1061
1062    // read and save of available sections
1063    fn read_sections(&mut self) -> Result<(), Error> {
1064        loop {
1065            match self.read_message() {
1066                Ok(Some(mut message)) => {
1067                    self.pending_sections.extend(message.sections.drain(..));
1068                }
1069                Ok(None) => return Ok(()),
1070                Err(err) => {
1071                    match err {
1072                        // in the event of timeouts and IO errors we finish any remaining work
1073                        Error::MessageReadTimedOut(_) | Error::ReaderReadFailed(_) => {
1074                            // ensure no pending items to handle
1075                            if self.pending_sections.is_empty() && self.inbound_responses.is_empty()
1076                            {
1077                                return Err(err);
1078                            }
1079                            return Ok(());
1080                        }
1081                        // all other errors we terminate
1082                        _ => return Err(err),
1083                    }
1084                }
1085            }
1086        }
1087    }
1088
1089    // route read sections to client and server buffers
1090    fn process_sections(&mut self) -> Result<(), Error> {
1091        while let Some(section) = self.pending_sections.pop_front() {
1092            match section {
1093                Section::Error(error) => {
1094                    if let Some(cookie) = error.cookie {
1095                        // error in response to a request
1096                        self.inbound_responses.push_back(Response::Error {
1097                            cookie,
1098                            error_code: error.code,
1099                        });
1100                    } else {
1101                        return Err(Error::UnknownErrorSectionReceived(error.code));
1102                    }
1103                }
1104                Section::Request(request) => {
1105                    // request to route to our apisets
1106                    self.inbound_requests.push(request);
1107                }
1108                Section::Response(response) => {
1109                    // response to our client
1110
1111                    match (response.cookie, response.state, response.result) {
1112                        (cookie, RequestState::Complete, result) => {
1113                            self.inbound_responses
1114                                .push_back(Response::Success { cookie, result });
1115                        }
1116                        (cookie, RequestState::Pending, _) => {
1117                            self.inbound_responses
1118                                .push_back(Response::Pending { cookie });
1119                        }
1120                    }
1121                }
1122            }
1123        }
1124
1125        Ok(())
1126    }
1127
1128    // queue outbound section for packaging into a Honk-RPC message
1129    fn push_outbound_section(&mut self, section: Section) -> Result<(), Error> {
1130        let max_section_size = self.max_message_size - MIN_MESSAGE_SIZE;
1131
1132        let mut counter: ByteCounter = Default::default();
1133        let section: bson::Document = section.into();
1134        section
1135            .to_writer(&mut counter)
1136            .map_err(Error::BsonWriteFailed)?;
1137        let section_size = counter.bytes();
1138
1139        if section_size <= max_section_size {
1140            self.outbound_sections.push(section);
1141            Ok(())
1142        } else {
1143            Err(Error::SectionTooLarge(section_size, max_section_size))
1144        }
1145    }
1146
1147    // package outbound sections into a message, and serialize message to the message_write_buffer
1148    fn serialize_messages(&mut self) -> Result<(), Error> {
1149        // if no pending sections there is nothing to do
1150        if self.outbound_sections.is_empty() {
1151            return Ok(());
1152        }
1153
1154        // build message and convert to bson to send
1155        let message = Message {
1156            honk_rpc: HONK_RPC_VERSION,
1157            sections: Default::default(),
1158        };
1159        let mut message = bson::document::Document::from(message);
1160        message.insert("sections", std::mem::take(&mut self.outbound_sections));
1161        self.serialize_messages_impl(message)
1162    }
1163
1164    // pack sections into messages and serialise them to buffer
1165    fn serialize_messages_impl(
1166        &mut self,
1167        mut message: bson::document::Document,
1168    ) -> Result<(), Error> {
1169        self.message_serialization_buffer.clear();
1170        message
1171            .to_writer(&mut self.message_serialization_buffer)
1172            .map_err(Error::BsonWriteFailed)?;
1173
1174        if self.message_serialization_buffer.len() > self.max_message_size {
1175            // if we can't split a message anymore then we have a problem
1176            let sections = message.get_array_mut("sections").unwrap();
1177            assert!(sections.len() > 1);
1178
1179            let right = doc! {
1180                "honk_rpc" : HONK_RPC_VERSION,
1181                "sections" : sections.split_off(sections.len() / 2),
1182            };
1183            let left = message;
1184
1185            self.serialize_messages_impl(left)?;
1186            self.serialize_messages_impl(right)?;
1187        } else {
1188            #[cfg(test)]
1189            println!(">>> write message: {:?}", message);
1190            // copy the serialized message into the pending write buffer
1191            self.message_write_buffer
1192                .append(&mut self.message_serialization_buffer);
1193        }
1194
1195        Ok(())
1196    }
1197
1198    // write data to stream and remove from write buffer
1199    fn write_pending_data(&mut self) -> Result<(), Error> {
1200        let bytes_written = self.write_pending_data_impl()?;
1201        self.stream.flush().map_err(Error::WriterWriteFailed)?;
1202        // removes the written bytes
1203        self.message_write_buffer.drain(0..bytes_written);
1204        // and shuffles the data so it is contiguous
1205        self.message_write_buffer.make_contiguous();
1206
1207        Ok(())
1208    }
1209
1210    fn write_pending_data_impl(&mut self) -> Result<usize, Error> {
1211        // write pending data
1212        let (mut pending_data, empty): (&[u8], &[u8]) = self.message_write_buffer.as_slices();
1213        assert!(empty.is_empty());
1214        let pending_bytes: usize = pending_data.len();
1215        let mut bytes_written: usize = 0usize;
1216
1217        while bytes_written != pending_bytes {
1218            match self.stream.write(pending_data) {
1219                Err(err) => {
1220                    let kind = err.kind();
1221                    if kind == ErrorKind::WouldBlock || kind == ErrorKind::TimedOut {
1222                        // no *additional* bytes written so return bytes written so far
1223                        return Ok(bytes_written);
1224                    } else {
1225                        return Err(Error::WriterWriteFailed(err));
1226                    }
1227                }
1228                Ok(count) => {
1229                    bytes_written += count;
1230                    #[cfg(test)]
1231                    println!(">>> sent {} of {} bytes", bytes_written, pending_bytes);
1232                    pending_data = &pending_data[count..];
1233                }
1234            }
1235        }
1236
1237        Ok(bytes_written)
1238    }
1239
1240    /// Read and process Honk-RPC message documents from connected peer, handle any new incoming Honk-RPC requests, update any in-progress async requests and write pending reponses, errors and requests to peer. This function must be called regularly for the `Session` to make forward progress.
1241    pub fn update(&mut self, apisets: Option<&mut [&mut dyn ApiSet]>) -> Result<(), Error> {
1242        // read sections from remote
1243        self.read_sections()?;
1244        // route sections to buffers
1245        self.process_sections()?;
1246
1247        // handle incoming api calls
1248        let apisets = apisets.unwrap_or(&mut []);
1249        self.handle_requests(apisets)?;
1250
1251        // serialize pending responses
1252        self.serialize_messages()?;
1253
1254        // write pendng data to writer
1255        self.write_pending_data()?;
1256
1257        Ok(())
1258    }
1259
1260    // apisets : a slice of mutable ApiSet references sorted by their namespaces
1261    fn handle_requests(&mut self, apisets: &mut [&mut dyn ApiSet]) -> Result<(), Error> {
1262        // first handle all of our inbound requests
1263        let mut inbound_requests = std::mem::take(&mut self.inbound_requests);
1264        for mut request in inbound_requests.drain(..) {
1265            if let Ok(idx) =
1266                apisets.binary_search_by(|probe| probe.namespace().cmp(&request.namespace))
1267            {
1268                let apiset = match apisets.get_mut(idx) {
1269                    Some(apiset) => apiset,
1270                    None => unreachable!(),
1271                };
1272                match apiset.exec_function(
1273                    &request.function,
1274                    request.version,
1275                    std::mem::take(&mut request.arguments),
1276                    request.cookie,
1277                ) {
1278                    // func found, invoked and succeeded
1279                    Some(Ok(result)) => {
1280                        if let Some(cookie) = request.cookie {
1281                            self.push_outbound_section(Section::Response(ResponseSection {
1282                                cookie,
1283                                state: RequestState::Complete,
1284                                result,
1285                            }))?;
1286                        }
1287                    }
1288                    // func found, invoked and failed
1289                    Some(Err(error_code)) => {
1290                        self.push_outbound_section(Section::Error(ErrorSection {
1291                            cookie: request.cookie,
1292                            code: error_code,
1293                            message: None,
1294                            data: None,
1295                        }))?;
1296                    }
1297                    // func found, called, and result is pending
1298                    None => {
1299                        if let Some(cookie) = request.cookie {
1300                            self.push_outbound_section(Section::Response(ResponseSection {
1301                                cookie,
1302                                state: RequestState::Pending,
1303                                result: None,
1304                            }))?;
1305                        }
1306                    }
1307                }
1308            } else {
1309                // invalid namespace
1310                self.push_outbound_section(Section::Error(ErrorSection {
1311                    cookie: request.cookie,
1312                    code: ErrorCode::RequestNamespaceInvalid,
1313                    message: None,
1314                    data: None,
1315                }))?;
1316            }
1317        }
1318
1319        // next send out async responses from apisets
1320        for apiset in apisets.iter_mut() {
1321            // allow apiset to do any required repetitive work
1322            apiset.update();
1323            // put pending results in our message
1324            while let Some((cookie, result)) = apiset.next_result() {
1325                match (cookie, result) {
1326                    // function completed successfully
1327                    (cookie, Ok(result)) => {
1328                        self.push_outbound_section(Section::Response(ResponseSection {
1329                            cookie,
1330                            state: RequestState::Complete,
1331                            result,
1332                        }))?;
1333                    }
1334                    // function completed with failure
1335                    (cookie, Err(error_code)) => {
1336                        self.push_outbound_section(Section::Error(ErrorSection {
1337                            cookie: Some(cookie),
1338                            code: error_code,
1339                            message: None,
1340                            data: None,
1341                        }))?;
1342                    }
1343                }
1344            }
1345        }
1346        Ok(())
1347    }
1348
1349    /// Performs a client call to a remote function. Returns a `RequestCookie` to associate this client call with a future `Response`.
1350    pub fn client_call(
1351        &mut self,
1352        namespace: &str,
1353        function: &str,
1354        version: i32,
1355        arguments: bson::document::Document,
1356    ) -> Result<RequestCookie, Error> {
1357        // always make sure we have a new cookie
1358        let cookie = self.next_cookie;
1359        self.next_cookie += 1;
1360
1361        // add request to outgoing buffer
1362        self.push_outbound_section(Section::Request(RequestSection {
1363            cookie: Some(cookie),
1364            namespace: namespace.to_string(),
1365            function: function.to_string(),
1366            version,
1367            arguments,
1368        }))?;
1369
1370        Ok(cookie)
1371    }
1372
1373    /// Drains all `Response` objects resulting from prevoius invocations of `Session::client_call()`
1374    pub fn client_drain_responses(&mut self) -> std::collections::vec_deque::Drain<Response> {
1375        self.inbound_responses.drain(..)
1376    }
1377
1378    /// Retrieves the next `Response` object from previous invocations of `Session::client_call()`
1379    pub fn client_next_response(&mut self) -> Option<Response> {
1380        self.inbound_responses.pop_front()
1381    }
1382}
1383
1384#[test]
1385fn test_honk_client_read_write() -> anyhow::Result<()> {
1386    let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1387    let listener = TcpListener::bind(socket_addr)?;
1388    let socket_addr = listener.local_addr()?;
1389
1390    let stream1 = TcpStream::connect(socket_addr)?;
1391    stream1.set_nonblocking(true)?;
1392    let (stream2, _socket_addr) = listener.accept()?;
1393    stream2.set_nonblocking(true)?;
1394
1395    let mut alice = Session::new(stream1);
1396    let mut pat = Session::new(stream2);
1397
1398    println!("--- pat reads message, but none has been sent");
1399
1400    // no message sent yet
1401    assert!(pat.read_message()?.is_none());
1402
1403    println!("--- alice sends no message, but no pending sections so no message sent");
1404
1405    // send an empty message
1406    alice.serialize_messages()?;
1407    alice.write_pending_data()?;
1408
1409    println!("--- pat reads message, but none has been sent");
1410
1411    // ensure no mesage as actually sent
1412    match pat.read_message() {
1413        Ok(Some(msg)) => panic!(
1414            "message should not have been sent: {}",
1415            bson::document::Document::from(msg)
1416        ),
1417        Ok(None) => {}
1418        Err(err) => panic!("{:?}", err),
1419    }
1420
1421    println!("--- pat sends an error message");
1422
1423    const CUSTOM_ERROR: &str = "Custom Error!";
1424
1425    pat.push_outbound_section(Section::Error(ErrorSection {
1426        cookie: Some(42069),
1427        code: ErrorCode::Runtime(1),
1428        message: Some(CUSTOM_ERROR.to_string()),
1429        data: None,
1430    }))?;
1431
1432    pat.serialize_messages()?;
1433    pat.write_pending_data()?;
1434
1435    println!("--- alice reads and verifies message");
1436
1437    // wait for alice to receive message
1438    let mut alice_read_message: bool = false;
1439    while !alice_read_message {
1440        // println!("reading...");
1441        if let Some(mut msg) = alice.read_message()? {
1442            assert_eq!(msg.sections.len(), 1);
1443            match msg.sections.pop() {
1444                Some(Section::Error(section)) => {
1445                    match (section.cookie, section.code, section.message) {
1446                        (Some(42069), ErrorCode::Runtime(1), Some(message)) => {
1447                            assert_eq!(message, CUSTOM_ERROR);
1448                            alice_read_message = true;
1449                        }
1450                        (cookie, code, message) => panic!(
1451                            "unexpected error section: cookie: {:?}, code: {:?}, message: {:?}",
1452                            cookie, code, message
1453                        ),
1454                    };
1455                }
1456                Some(_) => panic!("was expecting an Error section"),
1457                None => panic!("we should have a message"),
1458            }
1459        }
1460    }
1461
1462    println!("--- alice sends multi-section message");
1463
1464    alice.push_outbound_section(Section::Error(ErrorSection {
1465        cookie: Some(42069),
1466        code: ErrorCode::Runtime(2),
1467        message: Some(CUSTOM_ERROR.to_string()),
1468        data: None,
1469    }))?;
1470    alice.push_outbound_section(Section::Request(RequestSection {
1471        cookie: None,
1472        namespace: "std".to_string(),
1473        function: "print".to_string(),
1474        version: 0,
1475        arguments: doc! {"message": "hello!"},
1476    }))?;
1477    alice.push_outbound_section(Section::Response(ResponseSection {
1478        cookie: 123456,
1479        state: RequestState::Pending,
1480        result: None,
1481    }))?;
1482
1483    // send a multi-section mesage
1484    alice.serialize_messages()?;
1485    alice.write_pending_data()?;
1486
1487    println!("--- pat reads and verifies multi-section message");
1488
1489    // read sections sent to pat
1490    let mut pat_read_message: bool = false;
1491    while !pat_read_message {
1492        if let Some(msg) = pat.read_message()? {
1493            assert_eq!(msg.sections.len(), 3);
1494            for section in msg.sections.iter() {
1495                match section {
1496                    Section::Error(section) => {
1497                        assert_eq!(section.cookie, Some(42069));
1498                        assert_eq!(section.code, ErrorCode::Runtime(2));
1499                        assert_eq!(section.message, Some(CUSTOM_ERROR.to_string()));
1500                        assert_eq!(section.data, None);
1501                    }
1502                    Section::Request(section) => {
1503                        assert_eq!(section.cookie, None);
1504                        assert_eq!(section.namespace, "std");
1505                        assert_eq!(section.function, "print");
1506                        assert_eq!(section.version, 0i32);
1507                    }
1508                    Section::Response(section) => {
1509                        assert_eq!(section.cookie, 123456);
1510                        assert_eq!(section.state, RequestState::Pending);
1511                        assert_eq!(section.result, None);
1512                    }
1513                }
1514            }
1515            pat_read_message = true;
1516        }
1517    }
1518
1519    Ok(())
1520}
1521
1522#[cfg(test)]
1523struct TestApiSet {
1524    call_count: usize,
1525}
1526
1527#[cfg(test)]
1528impl ApiSet for TestApiSet {
1529    fn namespace(&self) -> &str {
1530        "namespace"
1531    }
1532
1533    fn exec_function(
1534        &mut self,
1535        name: &str,
1536        version: i32,
1537        _args: bson::document::Document,
1538        _request_section: Option<RequestCookie>,
1539    ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
1540        match (name, version) {
1541            ("function", 0) => {
1542                println!("--- namespace::function_0() called");
1543                self.call_count += 1;
1544            }
1545            _ => (),
1546        }
1547        Some(Ok(None))
1548    }
1549}
1550
1551#[test]
1552fn test_honk_timeout() -> anyhow::Result<()> {
1553    let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1554    let listener = TcpListener::bind(socket_addr)?;
1555    let socket_addr = listener.local_addr()?;
1556
1557    let alice_stream = TcpStream::connect(socket_addr)?;
1558    alice_stream.set_nonblocking(true)?;
1559    alice_stream.set_nodelay(true)?;
1560    println!("--- alice peer_addr: {}", alice_stream.peer_addr()?);
1561    let (pat_stream, _socket_addr) = listener.accept()?;
1562    pat_stream.set_nonblocking(true)?;
1563    pat_stream.set_nodelay(true)?;
1564
1565    let mut alice = Session::new(alice_stream);
1566    let mut alice_apiset = TestApiSet { call_count: 0usize };
1567    let mut pat = Session::new(pat_stream);
1568
1569    let start = std::time::Instant::now();
1570
1571    println!(
1572        "--- {:?} alice set max_wait_time to 3 seconds",
1573        std::time::Instant::now().duration_since(start)
1574    );
1575    alice.update(None)?;
1576    alice.set_max_wait_time(std::time::Duration::from_secs(3));
1577    alice.update(None)?;
1578
1579    // a read will happen so time should reset
1580    println!(
1581        "--- {:?} sleep 2 seconds",
1582        std::time::Instant::now().duration_since(start)
1583    );
1584    std::thread::sleep(std::time::Duration::from_secs(2));
1585
1586    println!(
1587        "--- {:?} pat calls namespace::function_0()",
1588        std::time::Instant::now().duration_since(start)
1589    );
1590    pat.client_call("namespace", "function", 0, doc! {})?;
1591    while alice_apiset.call_count != 1 {
1592        pat.update(None)?;
1593        alice.update(Some(&mut [&mut alice_apiset]))?;
1594    }
1595
1596    // a read will happen so time should reset
1597    println!(
1598        "--- {:?} sleep 2 seconds",
1599        std::time::Instant::now().duration_since(start)
1600    );
1601    std::thread::sleep(std::time::Duration::from_secs(2));
1602    pat.update(None)?;
1603    alice.update(None)?;
1604
1605    println!(
1606        "--- {:?} pat calls namespace::function_0()",
1607        std::time::Instant::now().duration_since(start)
1608    );
1609    pat.client_call("namespace", "function", 0, doc! {})?;
1610    while alice_apiset.call_count != 2 {
1611        pat.update(None)?;
1612        alice.update(Some(&mut [&mut alice_apiset]))?;
1613    }
1614
1615    // on reads occur so alice should timeout
1616    println!(
1617        "--- {:?} sleep 4 seconds",
1618        std::time::Instant::now().duration_since(start)
1619    );
1620    std::thread::sleep(std::time::Duration::from_secs(4));
1621
1622    println!(
1623        "--- {:?} pat+alice update",
1624        std::time::Instant::now().duration_since(start)
1625    );
1626    pat.update(None)?;
1627    match alice.update(None) {
1628        Ok(()) => panic!("should have timed out"),
1629        Err(Error::MessageReadTimedOut(duration)) => {
1630            println!("--- expected time out after {:?}", duration)
1631        }
1632        Err(err) => panic!("unexpected error: {:?}", err),
1633    }
1634    Ok(())
1635}