honk_rpc/
honk_rpc.rs

1// standard
2use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::io::{Cursor, ErrorKind};
5#[cfg(test)]
6use std::net::{SocketAddr, TcpListener, TcpStream};
7use std::option::Option;
8
9// extern crates
10use bson::doc;
11use bson::document::ValueAccessError;
12
13use crate::byte_counter::ByteCounter;
14
15/// Represents various error codes that can be present in a Honk-RPC `error_section`
16#[derive(Debug, Eq, PartialEq)]
17pub enum ErrorCode {
18    /// Failure to parse a received BSON document.
19    BsonParseFailed,
20    /// Received message document was too big; the default maximum message size
21    /// is 4096 bytes, but can be adjusted.
22    MessageTooBig,
23    /// Received message document missing required fields.
24    MessageParseFailed,
25    /// Received message contained version the receiver cannot handle.
26    MessageVersionIncompatible,
27    /// Section in received message contains unknown id.
28    SectionIdUnknown,
29    /// Section in received message missing required field, or provided
30    /// field is wrong datatype.
31    SectionParseFailed,
32    /// Provided request cookie is already in use.
33    RequestCookieInvalid,
34    /// Provided request namespace does not exist.
35    RequestNamespaceInvalid,
36    /// Provided request function does not exist within the provided namespace.
37    RequestFunctionInvalid,
38    /// Provided request version does not exist.
39    RequestVersionInvalid,
40    /// Provided response cookie is not recognized.
41    ResponseCookieInvalid,
42    /// Provided response state is not valid.
43    ResponseStateInvalid,
44    /// Represents an application-specific runtime error with a specific error code.
45    Runtime(i32),
46    /// Represents an unknown error with a specific error code.
47    Unknown(i32),
48}
49
50/// The error type for the `Session` type.
51#[derive(thiserror::Error, Debug)]
52pub enum Error {
53    /// Failed to read data from read stream due to `std::io::Error`
54    #[error("failed to read data from read stream")]
55    ReaderReadFailed(#[source] std::io::Error),
56
57    /// Bson documents need to be at least 4 bytes long
58    #[error("received invalid bson document size header value of {0}, must be at least 4")]
59    BsonDocumentSizeTooSmall(i32),
60
61    /// Received Bson document header is larger than Session supports
62    #[error("received invalid bson document size header value of {0}, must be less than {1}")]
63    BsonDocumentSizeTooLarge(i32, i32),
64
65    /// Too much time has elapsed without receiving a message
66    #[error("waited longer than {} seconds for read", .0.as_secs_f32())]
67    MessageReadTimedOut(std::time::Duration),
68
69    /// Failed to parse bson message
70    #[error("failed to parse bson Message document")]
71    BsonDocumentParseFailed(#[source] bson::de::Error),
72
73    /// Failed to convert bson document to Honk-RPC message
74    #[error("failed to convert bson document to Message")]
75    MessageConversionFailed(#[source] crate::honk_rpc::ErrorCode),
76
77    /// Failed to serialise bson document
78    #[error("failed to serialize bson document")]
79    BsonWriteFailed(#[source] bson::ser::Error),
80
81    /// Failed to write data to write stream due to `std::io::Error`
82    #[error("failed to write data to write stream")]
83    WriterWriteFailed(#[source] std::io::Error),
84
85    /// Failed to flush data to write stream due to `std::io::Error`
86    #[error("failed to flush message to write stream")]
87    WriterFlushFailed(#[source] std::io::Error),
88
89    /// Received a Honk-RPC `error_section` without an associated request cookie
90    #[error("recieved error section without cookie")]
91    UnknownErrorSectionReceived(#[source] crate::honk_rpc::ErrorCode),
92
93    /// Attempted to define invalid maximum message size
94    #[error(
95        "tried to set invalid max message size; must be >=5 bytes and <= i32::MAX (2147483647)"
96    )]
97    InvalidMaxMesageSize(),
98
99    /// Attempted to send a Honk-RPC `section` that is too large to fit in a message
100    #[error("queued message section is too large to write; calculated size is {0} but must be less than {1}")]
101    SectionTooLarge(usize, usize),
102}
103
104impl From<i32> for ErrorCode {
105    fn from(value: i32) -> ErrorCode {
106        match value {
107            -1i32 => ErrorCode::BsonParseFailed,
108            -2i32 => ErrorCode::MessageTooBig,
109            -3i32 => ErrorCode::MessageParseFailed,
110            -4i32 => ErrorCode::MessageVersionIncompatible,
111            -5i32 => ErrorCode::SectionIdUnknown,
112            -6i32 => ErrorCode::SectionParseFailed,
113            -7i32 => ErrorCode::RequestCookieInvalid,
114            -8i32 => ErrorCode::RequestNamespaceInvalid,
115            -9i32 => ErrorCode::RequestFunctionInvalid,
116            -10i32 => ErrorCode::RequestVersionInvalid,
117            -11i32 => ErrorCode::ResponseCookieInvalid,
118            -12i32 => ErrorCode::ResponseStateInvalid,
119            value => {
120                if value > 0 {
121                    ErrorCode::Runtime(value)
122                } else {
123                    ErrorCode::Unknown(value)
124                }
125            }
126        }
127    }
128}
129
130impl From<ErrorCode> for i32 {
131    fn from(err: ErrorCode) -> Self {
132        match err {
133            ErrorCode::BsonParseFailed => -1i32,
134            ErrorCode::MessageTooBig => -2i32,
135            ErrorCode::MessageParseFailed => -3i32,
136            ErrorCode::MessageVersionIncompatible => -4i32,
137            ErrorCode::SectionIdUnknown => -5i32,
138            ErrorCode::SectionParseFailed => -6i32,
139            ErrorCode::RequestCookieInvalid => -7i32,
140            ErrorCode::RequestNamespaceInvalid => -8i32,
141            ErrorCode::RequestFunctionInvalid => -9i32,
142            ErrorCode::RequestVersionInvalid => -10i32,
143            ErrorCode::ResponseCookieInvalid => -11i32,
144            ErrorCode::ResponseStateInvalid => -12i32,
145            ErrorCode::Runtime(val) => val,
146            ErrorCode::Unknown(val) => val,
147        }
148    }
149}
150
151impl std::fmt::Display for ErrorCode {
152    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
153        match self {
154            ErrorCode::BsonParseFailed => write!(f, "ProtocolError: failed to parse BSON object"),
155            ErrorCode::MessageTooBig => write!(f, "ProtocolError: received document too large"),
156            ErrorCode::MessageParseFailed => {
157                write!(f, "ProtocolError: received message has invalid schema")
158            }
159            ErrorCode::MessageVersionIncompatible => write!(
160                f,
161                "ProtocolError: received message has incompatible version"
162            ),
163            ErrorCode::SectionIdUnknown => write!(
164                f,
165                "ProtocolError: received message contains section of unknown type"
166            ),
167            ErrorCode::SectionParseFailed => write!(
168                f,
169                "ProtocolError: received message contains section with invalid schema"
170            ),
171            ErrorCode::RequestCookieInvalid => {
172                write!(f, "ProtocolError: request cookie already in use")
173            }
174            ErrorCode::RequestNamespaceInvalid => write!(
175                f,
176                "ProtocolError: request function does not exist in requested namespace"
177            ),
178            ErrorCode::RequestFunctionInvalid => {
179                write!(f, "ProtocolError: request function does not exist")
180            }
181            ErrorCode::RequestVersionInvalid => {
182                write!(f, "ProtocolError: request function version does not exist")
183            }
184            ErrorCode::ResponseCookieInvalid => {
185                write!(f, "ProtocolError: response cookie is not recognized")
186            }
187            ErrorCode::ResponseStateInvalid => write!(f, "ProtocolError: response state not valid"),
188            ErrorCode::Runtime(code) => write!(f, "RuntimeError: runtime error {}", code),
189            ErrorCode::Unknown(code) => write!(f, "UnknownError: unknown error code {}", code),
190        }
191    }
192}
193
194impl std::error::Error for ErrorCode {}
195
196// Honk-RPC semver is packed into an i32
197const fn semver_to_i32(major: u8, minor: u8, patch: u8) -> i32 {
198    let major = major as i32;
199    let minor = minor as i32;
200    let patch = patch as i32;
201    (major << 16) | (minor << 8) | patch
202}
203
204const fn i32_to_semver(ver: i32) -> Option<(u8, u8, u8)> {
205    if ver >= 0 && ver <= 0xffffff {
206        let major = (ver & 0xff0000) >> 16;
207        let minor = (ver & 0xff00) >> 8;
208        let patch = ver & 0xff;
209        Some((major as u8, minor as u8, patch as u8))
210    } else {
211        None
212    }
213}
214
215// Honk-RPC version 0.1.0
216const HONK_RPC_VERSION: i32 = semver_to_i32(0, 1, 0);
217
218struct Message {
219    honk_rpc: i32,
220    sections: Vec<Section>,
221}
222
223impl TryFrom<bson::document::Document> for Message {
224    type Error = ErrorCode;
225
226    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
227        let mut value = value;
228
229        // verify version
230        let honk_rpc = match value.get_i32("honk_rpc") {
231            Ok(HONK_RPC_VERSION) => HONK_RPC_VERSION,
232            Ok(honk_rpc) => {
233                return if let Some(_version) = i32_to_semver(honk_rpc) {
234                    // some other semver we cannot handle
235                    Err(ErrorCode::MessageVersionIncompatible)
236                } else {
237                    // an invalid semver
238                    Err(ErrorCode::MessageParseFailed)
239                };
240            }
241            Err(_err) => return Err(ErrorCode::MessageParseFailed),
242        };
243
244        if let Ok(sections) = value.get_array_mut("sections") {
245            // messages must have at least one section
246            if sections.is_empty() {
247                return Err(ErrorCode::MessageParseFailed);
248            }
249
250            let mut message = Message {
251                honk_rpc,
252                sections: Default::default(),
253            };
254
255            for section in sections.iter_mut() {
256                if let bson::Bson::Document(section) = std::mem::take(section) {
257                    message.sections.push(Section::try_from(section)?);
258                } else {
259                    return Err(ErrorCode::SectionParseFailed);
260                }
261            }
262            Ok(message)
263        } else {
264            Err(ErrorCode::MessageParseFailed)
265        }
266    }
267}
268
269impl From<Message> for bson::document::Document {
270    fn from(value: Message) -> bson::document::Document {
271        let mut value = value;
272        let mut message = bson::document::Document::new();
273        message.insert("honk_rpc", value.honk_rpc);
274
275        let mut sections = bson::Array::new();
276        for section in value.sections.drain(0..) {
277            sections.push(bson::Bson::Document(bson::document::Document::from(
278                section,
279            )));
280        }
281        message.insert("sections", sections);
282
283        message
284    }
285}
286
287/// A type alias for the cookie used to track client requests.
288pub type RequestCookie = i64;
289
290const ERROR_SECTION_ID: i32 = 0i32;
291const REQUEST_SECTION_ID: i32 = 1i32;
292const RESPONSE_SECTION_ID: i32 = 2i32;
293
294enum Section {
295    Error(ErrorSection),
296    Request(RequestSection),
297    Response(ResponseSection),
298}
299
300struct ErrorSection {
301    cookie: Option<RequestCookie>,
302    code: ErrorCode,
303    message: Option<String>,
304    data: Option<bson::Bson>,
305}
306
307struct RequestSection {
308    cookie: Option<RequestCookie>,
309    namespace: String,
310    function: String,
311    version: i32,
312    arguments: bson::document::Document,
313}
314
315#[repr(i32)]
316#[derive(Debug, PartialEq)]
317enum RequestState {
318    Pending = 0i32,
319    Complete = 1i32,
320}
321
322struct ResponseSection {
323    cookie: RequestCookie,
324    state: RequestState,
325    result: Option<bson::Bson>,
326}
327
328impl TryFrom<bson::document::Document> for Section {
329    type Error = ErrorCode;
330
331    fn try_from(
332        value: bson::document::Document,
333    ) -> Result<Self, <Self as TryFrom<bson::document::Document>>::Error> {
334        match value.get_i32("id") {
335            Ok(ERROR_SECTION_ID) => Ok(Section::Error(ErrorSection::try_from(value)?)),
336            Ok(REQUEST_SECTION_ID) => Ok(Section::Request(RequestSection::try_from(value)?)),
337            Ok(RESPONSE_SECTION_ID) => Ok(Section::Response(ResponseSection::try_from(value)?)),
338            Ok(_) => Err(ErrorCode::SectionIdUnknown),
339            Err(_) => Err(ErrorCode::SectionParseFailed),
340        }
341    }
342}
343
344impl From<Section> for bson::document::Document {
345    fn from(value: Section) -> bson::document::Document {
346        match value {
347            Section::Error(section) => bson::document::Document::from(section),
348            Section::Request(section) => bson::document::Document::from(section),
349            Section::Response(section) => bson::document::Document::from(section),
350        }
351    }
352}
353
354impl TryFrom<bson::document::Document> for ErrorSection {
355    type Error = ErrorCode;
356
357    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
358        let mut value = value;
359
360        let cookie = match value.get_i64("cookie") {
361            Ok(cookie) => Some(cookie),
362            Err(ValueAccessError::NotPresent) => None,
363            Err(_) => return Err(ErrorCode::SectionParseFailed),
364        };
365
366        let code = match value.get_i32("code") {
367            Ok(code) => ErrorCode::from(code),
368            Err(_) => return Err(ErrorCode::SectionParseFailed),
369        };
370
371        let message = match value.get_str("message") {
372            Ok(message) => Some(message.to_string()),
373            Err(ValueAccessError::NotPresent) => None,
374            Err(_) => return Err(ErrorCode::SectionParseFailed),
375        };
376
377        let data = value.get_mut("data").map(std::mem::take);
378
379        Ok(ErrorSection {
380            cookie,
381            code,
382            message,
383            data,
384        })
385    }
386}
387
388impl From<ErrorSection> for bson::document::Document {
389    fn from(value: ErrorSection) -> bson::document::Document {
390        let mut error_section = bson::document::Document::new();
391        error_section.insert("id", ERROR_SECTION_ID);
392
393        if let Some(cookie) = value.cookie {
394            error_section.insert("cookie", cookie);
395        }
396
397        error_section.insert("code", Into::<i32>::into(value.code));
398
399        if let Some(message) = value.message {
400            error_section.insert("message", message);
401        }
402
403        if let Some(data) = value.data {
404            error_section.insert("data", data);
405        }
406
407        error_section
408    }
409}
410
411impl TryFrom<bson::document::Document> for RequestSection {
412    type Error = ErrorCode;
413
414    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
415        let mut value = value;
416
417        let cookie = match value.get_i64("cookie") {
418            Ok(cookie) => Some(cookie),
419            Err(ValueAccessError::NotPresent) => None,
420            Err(_) => return Err(ErrorCode::SectionParseFailed),
421        };
422
423        let namespace = match value.get_str("namespace") {
424            Ok(namespace) => namespace.to_string(),
425            Err(ValueAccessError::NotPresent) => String::default(),
426            Err(_) => return Err(ErrorCode::SectionParseFailed),
427        };
428
429        let function = match value.get_str("function") {
430            Ok(function) => {
431                if function.is_empty() {
432                    return Err(ErrorCode::RequestFunctionInvalid);
433                } else {
434                    function.to_string()
435                }
436            }
437            Err(_) => return Err(ErrorCode::SectionParseFailed),
438        };
439
440        let version = match value.get_i32("version") {
441            Ok(version) => version,
442            Err(ValueAccessError::NotPresent) => 0i32,
443            Err(_) => return Err(ErrorCode::SectionParseFailed),
444        };
445
446        let arguments = match value.get_document_mut("arguments") {
447            Ok(arguments) => std::mem::take(arguments),
448            Err(ValueAccessError::NotPresent) => bson::document::Document::new(),
449            Err(_) => return Err(ErrorCode::SectionParseFailed),
450        };
451
452        Ok(RequestSection {
453            cookie,
454            namespace,
455            function,
456            version,
457            arguments,
458        })
459    }
460}
461
462impl From<RequestSection> for bson::document::Document {
463    fn from(value: RequestSection) -> bson::document::Document {
464        let mut request_section = bson::document::Document::new();
465        request_section.insert("id", REQUEST_SECTION_ID);
466
467        if let Some(cookie) = value.cookie {
468            request_section.insert("cookie", cookie);
469        }
470
471        if !value.namespace.is_empty() {
472            request_section.insert("namespace", value.namespace);
473        }
474
475        request_section.insert("function", value.function);
476
477        if value.version != 0i32 {
478            request_section.insert("version", value.version);
479        }
480
481        request_section.insert("arguments", value.arguments);
482
483        request_section
484    }
485}
486
487impl TryFrom<bson::document::Document> for ResponseSection {
488    type Error = ErrorCode;
489
490    fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
491        let mut value = value;
492        let cookie = match value.get_i64("cookie") {
493            Ok(cookie) => cookie,
494            Err(_) => return Err(ErrorCode::SectionParseFailed),
495        };
496
497        let state = match value.get_i32("state") {
498            Ok(0i32) => RequestState::Pending,
499            Ok(1i32) => RequestState::Complete,
500            Ok(_) => return Err(ErrorCode::ResponseStateInvalid),
501            Err(_) => return Err(ErrorCode::SectionParseFailed),
502        };
503
504        let result = value.get_mut("result").map(std::mem::take);
505
506        // if pending there should be no result
507        if state == RequestState::Pending && result.is_some() {
508            return Err(ErrorCode::SectionParseFailed);
509        }
510
511        Ok(ResponseSection {
512            cookie,
513            state,
514            result,
515        })
516    }
517}
518
519impl From<ResponseSection> for bson::document::Document {
520    fn from(value: ResponseSection) -> bson::document::Document {
521        let mut response_section = bson::document::Document::new();
522        response_section.insert("id", RESPONSE_SECTION_ID);
523
524        response_section.insert("cookie", value.cookie);
525        response_section.insert("state", value.state as i32);
526
527        if let Some(result) = value.result {
528            response_section.insert("result", result);
529        }
530
531        response_section
532    }
533}
534
535/// The `ApiSet` trait represents a set of APIs that can be remotely invoked by a connecting Honk-RPC client.
536/// # Example
537/// This exampe `ApiSet` implements two methods, `example::println()` and `example::async_println()`. The
538/// `println()` method immediatley prints, whereas `async_println()` queues request and
539/// prints the messagge at a later date via `update()`
540///
541/// ```rust
542/// # use honk_rpc::honk_rpc::*;
543/// # use std::collections::VecDeque;
544///
545/// const RUNTIME_ERROR_INVALID_ARG: ErrorCode = ErrorCode::Runtime(1i32);
546///
547/// struct PrintlnApiSet {
548///     // queued print reuests
549///     async_println_work: Vec<(Option<RequestCookie>, String)>,
550///     // successful async requests
551///     async_println_cookies: VecDeque<RequestCookie>,
552/// }
553///
554/// impl PrintlnApiSet {
555///   // prints message immediately
556///   fn println_0(
557///       &mut self,
558///       mut args: bson::document::Document,
559///   ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
560///     if let Some(bson::Bson::String(val)) = args.get_mut("val") {
561///         println!("example::echo_0(val): '{}'", val);
562///         Some(Ok(Some(bson::Bson::String(std::mem::take(val)))))
563///     } else {
564///         Some(Err(RUNTIME_ERROR_INVALID_ARG))
565///     }
566///   }
567///
568///   // queues message up for printing later
569///   fn async_println_0(
570///       &mut self,
571///       request_cookie: Option<RequestCookie>,
572///       mut args: bson::document::Document,
573///   ) -> Option<Result<Option<bson::Bson>, ErrorCode>>{
574///     if let Some(bson::Bson::String(val)) = args.get_mut("val") {
575///         self.async_println_work.push((request_cookie, std::mem::take(val)));
576///         None
577///     } else {
578///         Some(Err(RUNTIME_ERROR_INVALID_ARG))
579///     }
580///   }
581/// }
582///
583/// impl ApiSet for PrintlnApiSet {
584///     fn namespace(&self) -> &str {
585///         "example"
586///     }
587///
588///     // handles and routes requests for `println` and `async_println`
589///     fn exec_function(
590///         &mut self,
591///         name: &str,
592///         version: i32,
593///         args: bson::document::Document,
594///         request_cookie: Option<RequestCookie>,
595///     ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
596///         match (name, version) {
597///             ("println", 0) => self.println_0(args),
598///             ("async_println", 0) => self.async_println_0(request_cookie, args),
599///             (name, version) => {
600///                 println!("received {{ name: '{}', version: {} }}", name, version);
601///                 Some(Err(ErrorCode::RequestFunctionInvalid))
602///             }
603///         }
604///     }
605///
606///     // handles queued `async_println` requests
607///     fn update(&mut self) {
608///         for ((cookie, val)) in self.async_println_work.drain(..) {
609///             println!("{}", val);
610///             if let Some(cookie) = cookie {
611///                 self.async_println_cookies.push_back(cookie);
612///             }
613///         }
614///     }
615///
616///     // finally return queued async results
617///     fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
618///         if let Some(cookie) = self.async_println_cookies.pop_front() {
619///             Some((cookie, Ok(None)))
620///         } else {
621///             None
622///         }
623///     }
624/// }
625///```
626
627pub trait ApiSet {
628    /// Returns the namespace of this `ApiSet`.
629    fn namespace(&self) -> &str;
630
631    /// Schedules the execution of the requested remote procedure call. Calls to this
632    /// function map directly to a received Honk-RPC request. Each request has the
633    /// following parameters:
634    /// - `name`: The name of the function to execute.
635    /// - `version`: The version of the function to execute.
636    /// - `args`: The arguments to pass to the function.
637    /// - `request_cookie`: An optional cookie to track the request.
638    ///
639    /// This function handles both synchronous and asynchronous requests. The possible
640    /// return values for each are:
641    /// - Synchronous requests may execute and signal success by returning `Some(Ok(..))`.
642    /// - Synchronous requests may execute and signal failure by returning `Some(Err(..))`.
643    /// - Asynchronous requests must defer execution by returning `None`.
644    fn exec_function(
645        &mut self,
646        name: &str,
647        version: i32,
648        args: bson::document::Document,
649        request_cookie: Option<RequestCookie>,
650    ) -> Option<Result<Option<bson::Bson>, ErrorCode>>;
651
652    /// Updates any internal state required to make forward progress on any requested
653    /// remote procedure calls. Implementation of this method is optional and not needed
654    /// if the implementor does not have any async functions. If left unimplemented, this
655    /// function is a no-op.
656    fn update(&mut self) {}
657
658    /// Returns the result of any in-flight asynchronous requests.
659    /// - Asynchronous requests may signal success by returning `Some((cookie, Ok(..)))`
660    /// - Asynchronous requests may signal failure by returning `Some((cookie, Err(..)))`
661    /// - returns None if no asynchronous results are available
662    ///
663    /// This method is optional and not needed if the implementor does not have any async
664    /// functions, in which case the default implementation will return `None`.
665    fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
666        None
667    }
668}
669
670/// Represents the response to a client request.
671pub enum Response {
672    /// A pending response, indicating that the request is still being processed.
673    Pending {
674        /// The cookie associated with the request.
675        cookie: RequestCookie,
676    },
677    /// A successful response, containing the result of the request.
678    Success {
679        /// The cookie associated with the request.
680        cookie: RequestCookie,
681        /// The result of the request.
682        result: Option<bson::Bson>,
683    },
684    /// An error response, containing the error code.
685    Error {
686        /// The cookie associated with the request.
687        cookie: RequestCookie,
688        /// The error code indicating the type of error that occurred.
689        error_code: ErrorCode,
690    },
691}
692
693// 4 kilobytes per specification
694/// The default maximum allowed Honk-RPC message (4096 bytes)
695pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024;
696/// The default maximum allowed duration between Honk-RPC (60 seconds)
697pub const DEFAULT_MAX_WAIT_TIME: std::time::Duration = std::time::Duration::from_secs(60);
698
699// Base Message Bson Format
700// document size             4 (sizeof i32 )
701const HEADER_SIZE: usize = 4usize;
702// "honk_rpc" : i32          1 (0x10) + 8 (strlen "honk_rpc") + 1 (null) + 4 (sizeof i32)
703const HONK_RPC_SIZE: usize = 14usize;
704// "sections" : {"0": Null}  1 (0x04) + 8 (strlen "sections") + 1 (null) + 4 (sizeof i32) + 1 (0x0a) + 1 (strlen "0") + 1 (null) + 1 (0x00)
705const SECTIONS_SIZE: usize = 18usize;
706// footer                    1 (0x00)
707const FOOTER_SIZE: usize = 1usize;
708
709// The honk-rpc message overhead before the content of a single section is added
710const MIN_MESSAGE_SIZE: usize = HEADER_SIZE + HONK_RPC_SIZE + SECTIONS_SIZE + FOOTER_SIZE;
711
712/// Computes the overhead of the Honk-RPC message type. This method in conjunction with
713/// the other `get_*_section_size(..)` functions can be used to compute the size of a
714/// Honk-RPC message with exactly one section.
715pub fn get_message_overhead() -> Result<usize, Error> {
716    // construct an example empty message; the size of a real message with
717    // one section can be calculated as the sizeof(message) + sizeof(section)
718    let message = doc! {
719        "honk_rpc" : HONK_RPC_VERSION,
720        "sections" : [
721            bson::Bson::Null
722        ]
723    };
724
725    let mut counter: ByteCounter = Default::default();
726    message
727        .to_writer(&mut counter)
728        .map_err(Error::BsonWriteFailed)?;
729
730    Ok(counter.bytes())
731}
732
733/// Computes the required size of a Honk-RPC error section in bytes.
734///
735/// Returns the size of the BSON-encoded error section. If BSON encoding fails,
736/// an `Error::BsonWriteFailed` is returned.
737pub fn get_error_section_size(
738    cookie: Option<RequestCookie>,
739    message: Option<String>,
740    data: Option<bson::Bson>,
741) -> Result<usize, Error> {
742    let mut error_section = doc! {
743        "id": ERROR_SECTION_ID,
744        "code": Into::<i32>::into(ErrorCode::Unknown(0)),
745    };
746
747    if let Some(cookie) = cookie {
748        error_section.insert("cookie", bson::Bson::Int64(cookie));
749    }
750
751    if let Some(message) = message {
752        error_section.insert("message", bson::Bson::String(message));
753    }
754
755    if let Some(data) = data {
756        error_section.insert("data", data);
757    }
758
759    let mut counter: ByteCounter = Default::default();
760    error_section
761        .to_writer(&mut counter)
762        .map_err(Error::BsonWriteFailed)?;
763
764    Ok(counter.bytes())
765}
766
767/// Computes the required size of a Honk-RPC requests section in bytes.
768///
769/// Returns the size of the BSON-encoded request section. If BSON encoding fails,
770/// an `Error::BsonWriteFailed` is returned.
771pub fn get_request_section_size(
772    cookie: Option<RequestCookie>,
773    namespace: Option<String>,
774    function: String,
775    version: Option<i32>,
776    arguments: Option<bson::Document>,
777) -> Result<usize, Error> {
778    let mut request_section = doc! {
779        "id": REQUEST_SECTION_ID,
780        "function": bson::Bson::String(function),
781    };
782
783    if let Some(cookie) = cookie {
784        request_section.insert("cookie", bson::Bson::Int64(cookie));
785    }
786
787    if let Some(namespace) = namespace {
788        request_section.insert("namespace", bson::Bson::String(namespace));
789    }
790
791    if let Some(version) = version {
792        request_section.insert("version", bson::Bson::Int32(version));
793    }
794
795    if let Some(arguments) = arguments {
796        request_section.insert("arguments", arguments);
797    }
798
799    let mut counter: ByteCounter = Default::default();
800    request_section
801        .to_writer(&mut counter)
802        .map_err(Error::BsonWriteFailed)?;
803
804    Ok(counter.bytes())
805}
806
807/// Computes the required size of a Honk-RPC response section in bytes.
808///
809/// Returns the size of the BSON-encoded response section. If BSON encoding fails,
810/// an `Error::BsonWriteFailed` is returned.
811pub fn get_response_section_size(result: Option<bson::Bson>) -> Result<usize, Error> {
812    let mut response_section = doc! {
813        "id": RESPONSE_SECTION_ID,
814        "cookie": bson::Bson::Int64(0),
815        "state": bson::Bson::Int32(0),
816    };
817
818    if let Some(result) = result {
819        response_section.insert("result", result);
820    }
821
822    let mut counter: ByteCounter = Default::default();
823    response_section
824        .to_writer(&mut counter)
825        .map_err(Error::BsonWriteFailed)?;
826
827    Ok(counter.bytes())
828}
829
830/// The object that handles the communication between two endpoints  using the
831/// Honk-RPC protocol. Provides methods for setting and getting configuration
832/// parameters, reading and processing message documents, and handling API
833/// requests and responses.
834pub struct Session<RW> {
835    // read-write stream
836    stream: RW,
837    // we write outgoing data to an intermediate buffer to handle writer blocking
838    message_write_buffer: VecDeque<u8>,
839
840    // message read data
841
842    // remaining number of bytes to read for current message
843    // if None, no message read is in progress
844    remaining_byte_count: Option<usize>,
845    // data we've read but not yet a full Message object
846    message_read_buffer: Vec<u8>,
847    // received sections to be handled
848    pending_sections: VecDeque<Section>,
849    // remote client's inbound remote procedure calls to local server
850    inbound_requests: Vec<RequestSection>,
851    // remote server's responses to local client's remote procedure calls
852    inbound_responses: VecDeque<Response>,
853
854    // message write data
855
856    // we serialize outgoing messages to this buffer first to verify size limitations
857    message_serialization_buffer: VecDeque<u8>,
858    // the next request cookie to use when making a remote prodedure call
859    next_cookie: RequestCookie,
860    // sections to be sent to the remote server
861    outbound_sections: Vec<bson::Document>,
862
863    // the maximum size of a message we've agreed to allow in the session
864    max_message_size: usize,
865    // the maximum amount of time the session is willing to wait to receive a message
866    // before terminating the session
867    max_wait_time: std::time::Duration,
868    // last time a new message read began
869    read_timestamp: std::time::Instant,
870}
871
872#[allow(dead_code)]
873impl<RW> Session<RW>
874where
875    RW: std::io::Read + std::io::Write + Send,
876{
877    /// Sets the maximum message size this `Session` is willing to read from from the underlying `RW`. Attempted reads  will abort if the next bson document's `i32` size field is greater than the `max_message_size` defined in this function.
878    pub fn set_max_message_size(&mut self, max_message_size: i32) -> Result<(), Error> {
879        if max_message_size < MIN_MESSAGE_SIZE as i32 {
880            // base size of a honk-rpc mssage
881            Err(Error::InvalidMaxMesageSize())
882        } else {
883            self.max_message_size = max_message_size as usize;
884            Ok(())
885        }
886    }
887
888    /// Gets the maximum allowed message size this `Session` is willing to read from the underlying `RW`. The default value is 4096 bytes.
889    pub fn get_max_message_size(&self) -> usize {
890        self.max_message_size
891    }
892
893    /// Sets the maximum amount of time this `Session` is willing to wait for a new Honk-RPC message on the underlying `RW`. `Session` updates will fil after `max_wait_time` has elapsed without receiving any new Honk-RPC message documents.
894    pub fn set_max_wait_time(&mut self, max_wait_time: std::time::Duration) {
895        self.max_wait_time = max_wait_time;
896    }
897
898    /// Gets the maximum amount this `Session` is willing to wait for a new Honk-RPC message. The default value is 60 seconds.
899    pub fn get_max_wait_time(&self) -> std::time::Duration {
900        self.max_wait_time
901    }
902
903    /// Creates a new `Session` using the given `stream`.
904    pub fn new(stream: RW) -> Self {
905        let mut message_write_buffer: VecDeque<u8> = Default::default();
906        message_write_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
907
908        let mut message_serialization_buffer: VecDeque<u8> = Default::default();
909        message_serialization_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
910
911        Session {
912            stream,
913            message_write_buffer,
914            remaining_byte_count: None,
915            message_read_buffer: Default::default(),
916            pending_sections: Default::default(),
917            inbound_requests: Default::default(),
918            inbound_responses: Default::default(),
919            message_serialization_buffer,
920            next_cookie: Default::default(),
921            outbound_sections: Default::default(),
922            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
923            max_wait_time: DEFAULT_MAX_WAIT_TIME,
924            read_timestamp: std::time::Instant::now(),
925        }
926    }
927
928    /// Consumes the `Session` and returns the underlying stream.
929    pub fn into_stream(self) -> RW {
930        self.stream
931    }
932
933    // read a block of bytes from the undelrying stream
934    fn stream_read(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
935        match self.stream.read(buffer) {
936            Err(err) => {
937                if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut {
938                    // abort if we've gone too long without a new message
939                    if std::time::Instant::now().duration_since(self.read_timestamp)
940                        > self.max_wait_time
941                    {
942                        Err(Error::MessageReadTimedOut(self.max_wait_time))
943                    } else {
944                        Ok(0)
945                    }
946                } else {
947                    Err(Error::ReaderReadFailed(err))
948                }
949            }
950            Ok(0) => Err(Error::ReaderReadFailed(std::io::Error::from(
951                ErrorKind::UnexpectedEof,
952            ))),
953            Ok(count) => {
954                // update read_timestamp
955                self.read_timestamp = std::time::Instant::now();
956                Ok(count)
957            }
958        }
959    }
960
961    // read the next block of bytes as a bson document size header
962    fn read_message_size(&mut self) -> Result<(), Error> {
963        match self.remaining_byte_count {
964            // we've already read the size header
965            Some(_remaining) => Ok(()),
966            // still need to read the size header
967            None => {
968                // may have been partially read already so ensure it's the right size
969                assert!(self.message_read_buffer.len() < std::mem::size_of::<i32>());
970                let bytes_needed = std::mem::size_of::<i32>() - self.message_read_buffer.len();
971                // ensure we have enough space for an entire int32
972                let mut buffer = [0u8; std::mem::size_of::<i32>()];
973                // but shrink view down to number of bytes remaining
974                let buffer = &mut buffer[0..bytes_needed];
975                match self.stream_read(buffer) {
976                    Err(err) => Err(err),
977                    Ok(0) => Ok(()),
978                    Ok(count) => {
979                        #[cfg(test)]
980                        println!("<<< read {} bytes for message header", count);
981                        self.message_read_buffer
982                            .extend_from_slice(&buffer[0..count]);
983
984                        // all bytes required for i32 message size have been read
985                        if self.message_read_buffer.len() == std::mem::size_of::<i32>() {
986                            let size = &self.message_read_buffer.as_slice();
987                            let size: i32 = (size[0] as i32)
988                                | (size[1] as i32) << 8
989                                | (size[2] as i32) << 16
990                                | (size[3] as i32) << 24;
991                            // size should be at least larger than the bytes required for size header
992                            if size <= std::mem::size_of::<i32>() as i32 {
993                                return Err(Error::BsonDocumentSizeTooSmall(size));
994                            }
995                            // convert to usize type now that we know it's not negative
996                            if size as usize > self.max_message_size {
997                                return Err(Error::BsonDocumentSizeTooLarge(
998                                    size,
999                                    self.max_message_size as i32,
1000                                ));
1001                            }
1002
1003                            // deduct size of i32 header and save
1004                            let size = size as usize - std::mem::size_of::<i32>();
1005
1006                            self.remaining_byte_count = Some(size);
1007                        }
1008                        Ok(())
1009                    }
1010                }
1011            }
1012        }
1013    }
1014
1015    // read the remainder of a bson message
1016    fn read_message(&mut self) -> Result<Option<Message>, Error> {
1017        // update remaining bytes to read for message
1018        self.read_message_size()?;
1019        // read the message bytes
1020        if let Some(remaining) = self.remaining_byte_count {
1021            #[cfg(test)]
1022            println!("--- message requires {} more bytes", remaining);
1023
1024            let mut buffer = vec![0u8; remaining];
1025            match self.stream_read(&mut buffer) {
1026                Err(err) => Err(err),
1027                Ok(0) => Ok(None),
1028                Ok(count) => {
1029                    #[cfg(test)]
1030                    println!("<<< read {} bytes", count);
1031                    // append read bytes
1032                    self.message_read_buffer
1033                        .extend_from_slice(&buffer[0..count]);
1034                    if remaining == count {
1035                        self.remaining_byte_count = None;
1036
1037                        let mut cursor = Cursor::new(std::mem::take(&mut self.message_read_buffer));
1038                        let bson = bson::document::Document::from_reader(&mut cursor)
1039                            .map_err(Error::BsonDocumentParseFailed)?;
1040
1041                        // take back our allocated vec and clear it
1042                        self.message_read_buffer = cursor.into_inner();
1043                        self.message_read_buffer.clear();
1044
1045                        #[cfg(test)]
1046                        println!("<<< read message: {}", bson);
1047
1048                        Ok(Some(
1049                            Message::try_from(bson).map_err(Error::MessageConversionFailed)?,
1050                        ))
1051                    } else {
1052                        // update the remaining byte count
1053                        self.remaining_byte_count = Some(remaining - count);
1054                        Ok(None)
1055                    }
1056                }
1057            }
1058        } else {
1059            Ok(None)
1060        }
1061    }
1062
1063    // read and save of available sections
1064    fn read_sections(&mut self) -> Result<(), Error> {
1065        loop {
1066            match self.read_message() {
1067                Ok(Some(mut message)) => {
1068                    self.pending_sections.extend(message.sections.drain(..));
1069                }
1070                Ok(None) => return Ok(()),
1071                Err(err) => {
1072                    match err {
1073                        // in the event of timeouts and IO errors we finish any remaining work
1074                        Error::MessageReadTimedOut(_) | Error::ReaderReadFailed(_) => {
1075                            // ensure no pending items to handle
1076                            if self.pending_sections.is_empty() && self.inbound_responses.is_empty()
1077                            {
1078                                return Err(err);
1079                            }
1080                            return Ok(());
1081                        }
1082                        // all other errors we terminate
1083                        _ => return Err(err),
1084                    }
1085                }
1086            }
1087        }
1088    }
1089
1090    // route read sections to client and server buffers
1091    fn process_sections(&mut self) -> Result<(), Error> {
1092        while let Some(section) = self.pending_sections.pop_front() {
1093            match section {
1094                Section::Error(error) => {
1095                    if let Some(cookie) = error.cookie {
1096                        // error in response to a request
1097                        self.inbound_responses.push_back(Response::Error {
1098                            cookie,
1099                            error_code: error.code,
1100                        });
1101                    } else {
1102                        return Err(Error::UnknownErrorSectionReceived(error.code));
1103                    }
1104                }
1105                Section::Request(request) => {
1106                    // request to route to our apisets
1107                    self.inbound_requests.push(request);
1108                }
1109                Section::Response(response) => {
1110                    // response to our client
1111
1112                    match (response.cookie, response.state, response.result) {
1113                        (cookie, RequestState::Complete, result) => {
1114                            self.inbound_responses
1115                                .push_back(Response::Success { cookie, result });
1116                        }
1117                        (cookie, RequestState::Pending, _) => {
1118                            self.inbound_responses
1119                                .push_back(Response::Pending { cookie });
1120                        }
1121                    }
1122                }
1123            }
1124        }
1125
1126        Ok(())
1127    }
1128
1129    // queue outbound section for packaging into a Honk-RPC message
1130    fn push_outbound_section(&mut self, section: Section) -> Result<(), Error> {
1131        let max_section_size = self.max_message_size - MIN_MESSAGE_SIZE;
1132
1133        let mut counter: ByteCounter = Default::default();
1134        let section: bson::Document = section.into();
1135        section
1136            .to_writer(&mut counter)
1137            .map_err(Error::BsonWriteFailed)?;
1138        let section_size = counter.bytes();
1139
1140        if section_size <= max_section_size {
1141            self.outbound_sections.push(section);
1142            Ok(())
1143        } else {
1144            Err(Error::SectionTooLarge(section_size, max_section_size))
1145        }
1146    }
1147
1148    // package outbound sections into a message, and serialize message to the message_write_buffer
1149    fn serialize_messages(&mut self) -> Result<(), Error> {
1150        // if no pending sections there is nothing to do
1151        if self.outbound_sections.is_empty() {
1152            return Ok(());
1153        }
1154
1155        // build message and convert to bson to send
1156        let message = Message {
1157            honk_rpc: HONK_RPC_VERSION,
1158            sections: Default::default(),
1159        };
1160        let mut message = bson::document::Document::from(message);
1161        message.insert("sections", std::mem::take(&mut self.outbound_sections));
1162        self.serialize_messages_impl(message)
1163    }
1164
1165    // pack sections into messages and serialise them to buffer
1166    fn serialize_messages_impl(
1167        &mut self,
1168        mut message: bson::document::Document,
1169    ) -> Result<(), Error> {
1170        self.message_serialization_buffer.clear();
1171        message
1172            .to_writer(&mut self.message_serialization_buffer)
1173            .map_err(Error::BsonWriteFailed)?;
1174
1175        if self.message_serialization_buffer.len() > self.max_message_size {
1176            // if we can't split a message anymore then we have a problem
1177            let sections = message.get_array_mut("sections").unwrap();
1178            assert!(sections.len() > 1);
1179
1180            let right = doc! {
1181                "honk_rpc" : HONK_RPC_VERSION,
1182                "sections" : sections.split_off(sections.len() / 2),
1183            };
1184            let left = message;
1185
1186            self.serialize_messages_impl(left)?;
1187            self.serialize_messages_impl(right)?;
1188        } else {
1189            #[cfg(test)]
1190            println!(">>> write message: {:?}", message);
1191            // copy the serialized message into the pending write buffer
1192            self.message_write_buffer
1193                .append(&mut self.message_serialization_buffer);
1194        }
1195
1196        Ok(())
1197    }
1198
1199    // write data to stream and remove from write buffer
1200    fn write_pending_data(&mut self) -> Result<(), Error> {
1201        let bytes_written = self.write_pending_data_impl()?;
1202        self.stream.flush().map_err(Error::WriterWriteFailed)?;
1203        // removes the written bytes
1204        self.message_write_buffer.drain(0..bytes_written);
1205        // and shuffles the data so it is contiguous
1206        self.message_write_buffer.make_contiguous();
1207
1208        Ok(())
1209    }
1210
1211    fn write_pending_data_impl(&mut self) -> Result<usize, Error> {
1212        // write pending data
1213        let (mut pending_data, empty): (&[u8], &[u8]) = self.message_write_buffer.as_slices();
1214        assert!(empty.is_empty());
1215        let pending_bytes: usize = pending_data.len();
1216        let mut bytes_written: usize = 0usize;
1217
1218        while bytes_written != pending_bytes {
1219            match self.stream.write(pending_data) {
1220                Err(err) => {
1221                    let kind = err.kind();
1222                    if kind == ErrorKind::WouldBlock || kind == ErrorKind::TimedOut {
1223                        // no *additional* bytes written so return bytes written so far
1224                        return Ok(bytes_written);
1225                    } else {
1226                        return Err(Error::WriterWriteFailed(err));
1227                    }
1228                }
1229                Ok(count) => {
1230                    bytes_written += count;
1231                    #[cfg(test)]
1232                    println!(">>> sent {} of {} bytes", bytes_written, pending_bytes);
1233                    pending_data = &pending_data[count..];
1234                }
1235            }
1236        }
1237
1238        Ok(bytes_written)
1239    }
1240
1241    /// Read and process Honk-RPC message documents from connected peer, handle any new incoming Honk-RPC requests, update any in-progress async requests and write pending reponses, errors and requests to peer. This function must be called regularly for the `Session` to make forward progress.
1242    pub fn update(&mut self, apisets: Option<&mut [&mut dyn ApiSet]>) -> Result<(), Error> {
1243        // read sections from remote
1244        self.read_sections()?;
1245        // route sections to buffers
1246        self.process_sections()?;
1247
1248        // handle incoming api calls
1249        let apisets = apisets.unwrap_or(&mut []);
1250        self.handle_requests(apisets)?;
1251
1252        // serialize pending responses
1253        self.serialize_messages()?;
1254
1255        // write pendng data to writer
1256        self.write_pending_data()?;
1257
1258        Ok(())
1259    }
1260
1261    // apisets : a slice of mutable ApiSet references sorted by their namespaces
1262    fn handle_requests(&mut self, apisets: &mut [&mut dyn ApiSet]) -> Result<(), Error> {
1263        // first handle all of our inbound requests
1264        let mut inbound_requests = std::mem::take(&mut self.inbound_requests);
1265        for mut request in inbound_requests.drain(..) {
1266            if let Ok(idx) =
1267                apisets.binary_search_by(|probe| probe.namespace().cmp(&request.namespace))
1268            {
1269                let apiset = match apisets.get_mut(idx) {
1270                    Some(apiset) => apiset,
1271                    None => unreachable!(),
1272                };
1273                match apiset.exec_function(
1274                    &request.function,
1275                    request.version,
1276                    std::mem::take(&mut request.arguments),
1277                    request.cookie,
1278                ) {
1279                    // func found, invoked and succeeded
1280                    Some(Ok(result)) => {
1281                        if let Some(cookie) = request.cookie {
1282                            self.push_outbound_section(Section::Response(ResponseSection {
1283                                cookie,
1284                                state: RequestState::Complete,
1285                                result,
1286                            }))?;
1287                        }
1288                    }
1289                    // func found, invoked and failed
1290                    Some(Err(error_code)) => {
1291                        self.push_outbound_section(Section::Error(ErrorSection {
1292                            cookie: request.cookie,
1293                            code: error_code,
1294                            message: None,
1295                            data: None,
1296                        }))?;
1297                    }
1298                    // func found, called, and result is pending
1299                    None => {
1300                        if let Some(cookie) = request.cookie {
1301                            self.push_outbound_section(Section::Response(ResponseSection {
1302                                cookie,
1303                                state: RequestState::Pending,
1304                                result: None,
1305                            }))?;
1306                        }
1307                    }
1308                }
1309            } else {
1310                // invalid namespace
1311                self.push_outbound_section(Section::Error(ErrorSection {
1312                    cookie: request.cookie,
1313                    code: ErrorCode::RequestNamespaceInvalid,
1314                    message: None,
1315                    data: None,
1316                }))?;
1317            }
1318        }
1319
1320        // next send out async responses from apisets
1321        for apiset in apisets.iter_mut() {
1322            // allow apiset to do any required repetitive work
1323            apiset.update();
1324            // put pending results in our message
1325            while let Some((cookie, result)) = apiset.next_result() {
1326                match (cookie, result) {
1327                    // function completed successfully
1328                    (cookie, Ok(result)) => {
1329                        self.push_outbound_section(Section::Response(ResponseSection {
1330                            cookie,
1331                            state: RequestState::Complete,
1332                            result,
1333                        }))?;
1334                    }
1335                    // function completed with failure
1336                    (cookie, Err(error_code)) => {
1337                        self.push_outbound_section(Section::Error(ErrorSection {
1338                            cookie: Some(cookie),
1339                            code: error_code,
1340                            message: None,
1341                            data: None,
1342                        }))?;
1343                    }
1344                }
1345            }
1346        }
1347        Ok(())
1348    }
1349
1350    /// Performs a client call to a remote function. Returns a `RequestCookie` to associate this client call with a future `Response`.
1351    pub fn client_call(
1352        &mut self,
1353        namespace: &str,
1354        function: &str,
1355        version: i32,
1356        arguments: bson::document::Document,
1357    ) -> Result<RequestCookie, Error> {
1358        // always make sure we have a new cookie
1359        let cookie = self.next_cookie;
1360        self.next_cookie += 1;
1361
1362        // add request to outgoing buffer
1363        self.push_outbound_section(Section::Request(RequestSection {
1364            cookie: Some(cookie),
1365            namespace: namespace.to_string(),
1366            function: function.to_string(),
1367            version,
1368            arguments,
1369        }))?;
1370
1371        Ok(cookie)
1372    }
1373
1374    /// Drains all `Response` objects resulting from prevoius invocations of `Session::client_call()`
1375    pub fn client_drain_responses(&mut self) -> std::collections::vec_deque::Drain<Response> {
1376        self.inbound_responses.drain(..)
1377    }
1378
1379    /// Retrieves the next `Response` object from previous invocations of `Session::client_call()`
1380    pub fn client_next_response(&mut self) -> Option<Response> {
1381        self.inbound_responses.pop_front()
1382    }
1383}
1384
1385#[test]
1386fn test_honk_client_read_write() -> anyhow::Result<()> {
1387    let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1388    let listener = TcpListener::bind(socket_addr)?;
1389    let socket_addr = listener.local_addr()?;
1390
1391    let stream1 = TcpStream::connect(socket_addr)?;
1392    stream1.set_nonblocking(true)?;
1393    let (stream2, _socket_addr) = listener.accept()?;
1394    stream2.set_nonblocking(true)?;
1395
1396    let mut alice = Session::new(stream1);
1397    let mut pat = Session::new(stream2);
1398
1399    println!("--- pat reads message, but none has been sent");
1400
1401    // no message sent yet
1402    assert!(pat.read_message()?.is_none());
1403
1404    println!("--- alice sends no message, but no pending sections so no message sent");
1405
1406    // send an empty message
1407    alice.serialize_messages()?;
1408    alice.write_pending_data()?;
1409
1410    println!("--- pat reads message, but none has been sent");
1411
1412    // ensure no mesage as actually sent
1413    match pat.read_message() {
1414        Ok(Some(msg)) => panic!(
1415            "message should not have been sent: {}",
1416            bson::document::Document::from(msg)
1417        ),
1418        Ok(None) => {}
1419        Err(err) => panic!("{:?}", err),
1420    }
1421
1422    println!("--- pat sends an error message");
1423
1424    const CUSTOM_ERROR: &str = "Custom Error!";
1425
1426    pat.push_outbound_section(Section::Error(ErrorSection {
1427        cookie: Some(42069),
1428        code: ErrorCode::Runtime(1),
1429        message: Some(CUSTOM_ERROR.to_string()),
1430        data: None,
1431    }))?;
1432
1433    pat.serialize_messages()?;
1434    pat.write_pending_data()?;
1435
1436    println!("--- alice reads and verifies message");
1437
1438    // wait for alice to receive message
1439    let mut alice_read_message: bool = false;
1440    while !alice_read_message {
1441        // println!("reading...");
1442        if let Some(mut msg) = alice.read_message()? {
1443            assert_eq!(msg.sections.len(), 1);
1444            match msg.sections.pop() {
1445                Some(Section::Error(section)) => {
1446                    match (section.cookie, section.code, section.message) {
1447                        (Some(42069), ErrorCode::Runtime(1), Some(message)) => {
1448                            assert_eq!(message, CUSTOM_ERROR);
1449                            alice_read_message = true;
1450                        }
1451                        (cookie, code, message) => panic!(
1452                            "unexpected error section: cookie: {:?}, code: {:?}, message: {:?}",
1453                            cookie, code, message
1454                        ),
1455                    };
1456                }
1457                Some(_) => panic!("was expecting an Error section"),
1458                None => panic!("we should have a message"),
1459            }
1460        }
1461    }
1462
1463    println!("--- alice sends multi-section message");
1464
1465    alice.push_outbound_section(Section::Error(ErrorSection {
1466        cookie: Some(42069),
1467        code: ErrorCode::Runtime(2),
1468        message: Some(CUSTOM_ERROR.to_string()),
1469        data: None,
1470    }))?;
1471    alice.push_outbound_section(Section::Request(RequestSection {
1472        cookie: None,
1473        namespace: "std".to_string(),
1474        function: "print".to_string(),
1475        version: 0,
1476        arguments: doc! {"message": "hello!"},
1477    }))?;
1478    alice.push_outbound_section(Section::Response(ResponseSection {
1479        cookie: 123456,
1480        state: RequestState::Pending,
1481        result: None,
1482    }))?;
1483
1484    // send a multi-section mesage
1485    alice.serialize_messages()?;
1486    alice.write_pending_data()?;
1487
1488    println!("--- pat reads and verifies multi-section message");
1489
1490    // read sections sent to pat
1491    let mut pat_read_message: bool = false;
1492    while !pat_read_message {
1493        if let Some(msg) = pat.read_message()? {
1494            assert_eq!(msg.sections.len(), 3);
1495            for section in msg.sections.iter() {
1496                match section {
1497                    Section::Error(section) => {
1498                        assert_eq!(section.cookie, Some(42069));
1499                        assert_eq!(section.code, ErrorCode::Runtime(2));
1500                        assert_eq!(section.message, Some(CUSTOM_ERROR.to_string()));
1501                        assert_eq!(section.data, None);
1502                    }
1503                    Section::Request(section) => {
1504                        assert_eq!(section.cookie, None);
1505                        assert_eq!(section.namespace, "std");
1506                        assert_eq!(section.function, "print");
1507                        assert_eq!(section.version, 0i32);
1508                    }
1509                    Section::Response(section) => {
1510                        assert_eq!(section.cookie, 123456);
1511                        assert_eq!(section.state, RequestState::Pending);
1512                        assert_eq!(section.result, None);
1513                    }
1514                }
1515            }
1516            pat_read_message = true;
1517        }
1518    }
1519
1520    Ok(())
1521}
1522
1523#[cfg(test)]
1524struct TestApiSet {
1525    call_count: usize,
1526}
1527
1528#[cfg(test)]
1529impl ApiSet for TestApiSet {
1530    fn namespace(&self) -> &str {
1531        "namespace"
1532    }
1533
1534    fn exec_function(
1535        &mut self,
1536        name: &str,
1537        version: i32,
1538        _args: bson::document::Document,
1539        _request_section: Option<RequestCookie>,
1540    ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
1541        match (name, version) {
1542            ("function", 0) => {
1543                println!("--- namespace::function_0() called");
1544                self.call_count += 1;
1545            }
1546            _ => (),
1547        }
1548        Some(Ok(None))
1549    }
1550}
1551
1552#[test]
1553fn test_honk_timeout() -> anyhow::Result<()> {
1554    let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1555    let listener = TcpListener::bind(socket_addr)?;
1556    let socket_addr = listener.local_addr()?;
1557
1558    let alice_stream = TcpStream::connect(socket_addr)?;
1559    alice_stream.set_nonblocking(true)?;
1560    alice_stream.set_nodelay(true)?;
1561    println!("--- alice peer_addr: {}", alice_stream.peer_addr()?);
1562    let (pat_stream, _socket_addr) = listener.accept()?;
1563    pat_stream.set_nonblocking(true)?;
1564    pat_stream.set_nodelay(true)?;
1565
1566    let mut alice = Session::new(alice_stream);
1567    let mut alice_apiset = TestApiSet { call_count: 0usize };
1568    let mut pat = Session::new(pat_stream);
1569
1570    let start = std::time::Instant::now();
1571
1572    println!(
1573        "--- {:?} alice set max_wait_time to 3 seconds",
1574        std::time::Instant::now().duration_since(start)
1575    );
1576    alice.update(None)?;
1577    alice.set_max_wait_time(std::time::Duration::from_secs(3));
1578    alice.update(None)?;
1579
1580    // a read will happen so time should reset
1581    println!(
1582        "--- {:?} sleep 2 seconds",
1583        std::time::Instant::now().duration_since(start)
1584    );
1585    std::thread::sleep(std::time::Duration::from_secs(2));
1586
1587    println!(
1588        "--- {:?} pat calls namespace::function_0()",
1589        std::time::Instant::now().duration_since(start)
1590    );
1591    pat.client_call("namespace", "function", 0, doc! {})?;
1592    while alice_apiset.call_count != 1 {
1593        pat.update(None)?;
1594        alice.update(Some(&mut [&mut alice_apiset]))?;
1595    }
1596
1597    // a read will happen so time should reset
1598    println!(
1599        "--- {:?} sleep 2 seconds",
1600        std::time::Instant::now().duration_since(start)
1601    );
1602    std::thread::sleep(std::time::Duration::from_secs(2));
1603    pat.update(None)?;
1604    alice.update(None)?;
1605
1606    println!(
1607        "--- {:?} pat calls namespace::function_0()",
1608        std::time::Instant::now().duration_since(start)
1609    );
1610    pat.client_call("namespace", "function", 0, doc! {})?;
1611    while alice_apiset.call_count != 2 {
1612        pat.update(None)?;
1613        alice.update(Some(&mut [&mut alice_apiset]))?;
1614    }
1615
1616    // on reads occur so alice should timeout
1617    println!(
1618        "--- {:?} sleep 4 seconds",
1619        std::time::Instant::now().duration_since(start)
1620    );
1621    std::thread::sleep(std::time::Duration::from_secs(4));
1622
1623    println!(
1624        "--- {:?} pat+alice update",
1625        std::time::Instant::now().duration_since(start)
1626    );
1627    pat.update(None)?;
1628    match alice.update(None) {
1629        Ok(()) => panic!("should have timed out"),
1630        Err(Error::MessageReadTimedOut(duration)) => {
1631            println!("--- expected time out after {:?}", duration)
1632        }
1633        Err(err) => panic!("unexpected error: {:?}", err),
1634    }
1635    Ok(())
1636}