Commit 50606a85 authored by Johnny's avatar Johnny

fix(auth): resolve token refresh and persistence issues

- Fix cookie expiration timezone to use GMT (RFC 6265 compliance)
- Use Connect RPC client for token refresh instead of fetch
- Fix error code checking (numeric Code.Unauthenticated instead of string)
- Prevent infinite redirect loop when already on /auth page
- Fix protobuf Timestamp conversion using timestampDate helper
- Store access token in sessionStorage to avoid unnecessary refreshes on page reload
- Add refresh token cookie fallback for attachment authentication
- Improve error handling with proper type checking

🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: 's avatarClaude Sonnet 4.5 <noreply@anthropic.com>
parent 7932f6d0
...@@ -339,7 +339,9 @@ func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken s ...@@ -339,7 +339,9 @@ func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken s
if expireTime.IsZero() { if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT") attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else { } else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123)) // RFC 6265 requires cookie expiration dates to use GMT timezone
// Convert to UTC and format with explicit "GMT" to ensure browser compatibility
attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT"))
} }
// Try to determine if the request is HTTPS by checking the origin header // Try to determine if the request is HTTPS by checking the origin header
......
...@@ -287,10 +287,10 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech ...@@ -287,10 +287,10 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech
} }
// getCurrentUser retrieves the current authenticated user from the Echo context. // getCurrentUser retrieves the current authenticated user from the Echo context.
// It checks Bearer tokens for authentication (Access Token V2 or PAT). // Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
// Uses the shared Authenticator for consistent authentication logic. // Uses the shared Authenticator for consistent authentication logic.
func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) (*store.User, error) { func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) (*store.User, error) {
// Try Bearer token authentication // Try Bearer token authentication first
authHeader := c.Request().Header.Get("Authorization") authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" { if authHeader != "" {
token := auth.ExtractBearerToken(authHeader) token := auth.ExtractBearerToken(authHeader)
...@@ -317,6 +317,20 @@ func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) ...@@ -317,6 +317,20 @@ func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context)
} }
} }
// Fallback: Try refresh token cookie authentication
// This allows protected attachments to load even when access token has expired,
// as long as the user has a valid refresh token cookie.
cookieHeader := c.Request().Header.Get("Cookie")
if cookieHeader != "" {
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken != "" {
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
if err == nil && user != nil {
return user, nil
}
}
}
// No valid authentication found // No valid authentication found
return nil, nil return nil, nil
} }
......
// In-memory storage for access token (not persisted for security) // Access token storage using sessionStorage for persistence across page refreshes
// sessionStorage is cleared when the tab/window is closed, providing reasonable security
// while avoiding unnecessary token refreshes on page reload
let accessToken: string | null = null; let accessToken: string | null = null;
let tokenExpiresAt: Date | null = null; let tokenExpiresAt: Date | null = null;
export const getAccessToken = (): string | null => accessToken; const SESSION_TOKEN_KEY = "memos_access_token";
const SESSION_EXPIRES_KEY = "memos_token_expires_at";
export const getAccessToken = (): string | null => {
// If not in memory, try to restore from sessionStorage
if (!accessToken) {
try {
const storedToken = sessionStorage.getItem(SESSION_TOKEN_KEY);
const storedExpires = sessionStorage.getItem(SESSION_EXPIRES_KEY);
if (storedToken && storedExpires) {
const expiresAt = new Date(storedExpires);
// Only restore if token hasn't expired
if (expiresAt > new Date()) {
accessToken = storedToken;
tokenExpiresAt = expiresAt;
} else {
// Token expired, clean up sessionStorage
sessionStorage.removeItem(SESSION_TOKEN_KEY);
sessionStorage.removeItem(SESSION_EXPIRES_KEY);
}
}
} catch (e) {
// sessionStorage might not be available (e.g., in some privacy modes)
console.warn("Failed to access sessionStorage:", e);
}
}
return accessToken;
};
export const setAccessToken = (token: string | null, expiresAt?: Date): void => { export const setAccessToken = (token: string | null, expiresAt?: Date): void => {
accessToken = token; accessToken = token;
tokenExpiresAt = expiresAt || null; tokenExpiresAt = expiresAt || null;
try {
if (token && expiresAt) {
// Store in sessionStorage for persistence across page refreshes
sessionStorage.setItem(SESSION_TOKEN_KEY, token);
sessionStorage.setItem(SESSION_EXPIRES_KEY, expiresAt.toISOString());
} else {
// Clear sessionStorage if token is being cleared
sessionStorage.removeItem(SESSION_TOKEN_KEY);
sessionStorage.removeItem(SESSION_EXPIRES_KEY);
}
} catch (e) {
// sessionStorage might not be available (e.g., in some privacy modes)
console.warn("Failed to write to sessionStorage:", e);
}
}; };
export const isTokenExpired = (): boolean => { export const isTokenExpired = (): boolean => {
...@@ -18,4 +63,11 @@ export const isTokenExpired = (): boolean => { ...@@ -18,4 +63,11 @@ export const isTokenExpired = (): boolean => {
export const clearAccessToken = (): void => { export const clearAccessToken = (): void => {
accessToken = null; accessToken = null;
tokenExpiresAt = null; tokenExpiresAt = null;
try {
sessionStorage.removeItem(SESSION_TOKEN_KEY);
sessionStorage.removeItem(SESSION_EXPIRES_KEY);
} catch (e) {
console.warn("Failed to clear sessionStorage:", e);
}
}; };
import { timestampDate } from "@bufbuild/protobuf/wkt"; import { timestampDate } from "@bufbuild/protobuf/wkt";
import { ConnectError } from "@connectrpc/connect";
import { LoaderIcon } from "lucide-react"; import { LoaderIcon } from "lucide-react";
import { observer } from "mobx-react-lite"; import { observer } from "mobx-react-lite";
import { useState } from "react"; import { useState } from "react";
...@@ -59,9 +58,10 @@ const PasswordSignInForm = observer(() => { ...@@ -59,9 +58,10 @@ const PasswordSignInForm = observer(() => {
} }
await initialUserStore(); await initialUserStore();
navigateTo("/"); navigateTo("/");
} catch (error: any) { } catch (error: unknown) {
console.error(error); console.error(error);
toast.error((error as ConnectError).message || "Failed to sign in."); const message = error instanceof Error ? error.message : "Failed to sign in.";
toast.error(message);
} }
actionBtnLoadingState.setFinish(); actionBtnLoadingState.setFinish();
}; };
...@@ -92,7 +92,7 @@ const PasswordSignInForm = observer(() => { ...@@ -92,7 +92,7 @@ const PasswordSignInForm = observer(() => {
readOnly={actionBtnLoadingState.isLoading} readOnly={actionBtnLoadingState.isLoading}
placeholder={t("common.password")} placeholder={t("common.password")}
value={password} value={password}
autoComplete="password" autoComplete="current-password"
autoCapitalize="off" autoCapitalize="off"
spellCheck={false} spellCheck={false}
onChange={handlePasswordInputChanged} onChange={handlePasswordInputChanged}
......
import { createClient, Interceptor } from "@connectrpc/connect"; import { timestampDate } from "@bufbuild/protobuf/wkt";
import { Code, ConnectError, createClient, type Interceptor } from "@connectrpc/connect";
import { createConnectTransport } from "@connectrpc/connect-web"; import { createConnectTransport } from "@connectrpc/connect-web";
import { getAccessToken, setAccessToken } from "./auth-state"; import { getAccessToken, setAccessToken } from "./auth-state";
import { ActivityService } from "./types/proto/api/v1/activity_service_pb"; import { ActivityService } from "./types/proto/api/v1/activity_service_pb";
...@@ -13,7 +14,13 @@ import { UserService } from "./types/proto/api/v1/user_service_pb"; ...@@ -13,7 +14,13 @@ import { UserService } from "./types/proto/api/v1/user_service_pb";
let isRefreshing = false; let isRefreshing = false;
let refreshPromise: Promise<void> | null = null; let refreshPromise: Promise<void> | null = null;
// Auth interceptor that attaches access token and handles 401 errors by refreshing /**
* Authentication interceptor that:
* 1. Attaches access token to outgoing requests
* 2. Handles 401 Unauthenticated errors by refreshing the token
* 3. Retries the original request with the new token
* 4. Redirects to login if refresh fails
*/
const authInterceptor: Interceptor = (next) => async (req) => { const authInterceptor: Interceptor = (next) => async (req) => {
// Add access token to request if available // Add access token to request if available
const token = getAccessToken(); const token = getAccessToken();
...@@ -23,9 +30,9 @@ const authInterceptor: Interceptor = (next) => async (req) => { ...@@ -23,9 +30,9 @@ const authInterceptor: Interceptor = (next) => async (req) => {
try { try {
return await next(req); return await next(req);
} catch (error: any) { } catch (error) {
// Handle unauthenticated error - try to refresh token // Only handle ConnectError with Unauthenticated code
if (error.code === "unauthenticated" && !req.header.get("X-Retry")) { if (error instanceof ConnectError && error.code === Code.Unauthenticated && !req.header.get("X-Retry")) {
// Prevent concurrent refresh attempts // Prevent concurrent refresh attempts
if (!isRefreshing) { if (!isRefreshing) {
isRefreshing = true; isRefreshing = true;
...@@ -47,8 +54,10 @@ const authInterceptor: Interceptor = (next) => async (req) => { ...@@ -47,8 +54,10 @@ const authInterceptor: Interceptor = (next) => async (req) => {
} catch (refreshError) { } catch (refreshError) {
isRefreshing = false; isRefreshing = false;
refreshPromise = null; refreshPromise = null;
// Refresh failed - redirect to login // Refresh failed - redirect to login (only if not already there)
if (!window.location.pathname.startsWith("/auth")) {
window.location.href = "/auth"; window.location.href = "/auth";
}
throw refreshError; throw refreshError;
} }
} }
...@@ -56,24 +65,48 @@ const authInterceptor: Interceptor = (next) => async (req) => { ...@@ -56,24 +65,48 @@ const authInterceptor: Interceptor = (next) => async (req) => {
} }
}; };
async function refreshAccessToken(): Promise<void> { /**
const response = await fetch("/api/v1/auth/refresh", { * Custom fetch that includes credentials for cookie handling.
method: "POST", * Required for HttpOnly refresh token cookie to be sent/received.
credentials: "include", // Include HttpOnly cookies with refresh token */
const fetchWithCredentials: typeof globalThis.fetch = (input, init) => {
return globalThis.fetch(input, {
...init,
credentials: "include",
}); });
};
if (!response.ok) { /**
throw new Error("Failed to refresh token"); * Separate transport for refresh token operations.
} * Uses no auth interceptor to avoid circular dependency when the main
* interceptor triggers a refresh.
*/
const refreshTransport = createConnectTransport({
baseUrl: window.location.origin,
useBinaryFormat: true,
fetch: fetchWithCredentials,
interceptors: [], // No interceptors to avoid recursion
});
const data = await response.json(); // Dedicated auth client for refresh operations only
setAccessToken(data.accessToken, new Date(data.expiresAt)); const refreshAuthClient = createClient(AuthService, refreshTransport);
/**
* Refreshes the access token using the HttpOnly refresh token cookie.
* Called automatically by the auth interceptor when requests fail with 401.
*/
async function refreshAccessToken(): Promise<void> {
const response = await refreshAuthClient.refreshToken({});
setAccessToken(response.accessToken, response.expiresAt ? timestampDate(response.expiresAt) : undefined);
} }
/**
* Main transport for all API requests.
*/
const transport = createConnectTransport({ const transport = createConnectTransport({
baseUrl: window.location.origin, baseUrl: window.location.origin,
// Use binary protobuf format for better performance (smaller payloads, faster serialization)
useBinaryFormat: true, useBinaryFormat: true,
fetch: fetchWithCredentials,
interceptors: [authInterceptor], interceptors: [authInterceptor],
}); });
......
import { timestampDate } from "@bufbuild/protobuf/wkt"; import { timestampDate } from "@bufbuild/protobuf/wkt";
import { ConnectError } from "@connectrpc/connect";
import { LoaderIcon } from "lucide-react"; import { LoaderIcon } from "lucide-react";
import { observer } from "mobx-react-lite"; import { observer } from "mobx-react-lite";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
...@@ -95,15 +94,16 @@ const AuthCallback = observer(() => { ...@@ -95,15 +94,16 @@ const AuthCallback = observer(() => {
await initialUserStore(); await initialUserStore();
// Redirect to return URL if specified, otherwise home // Redirect to return URL if specified, otherwise home
navigateTo(returnUrl || "/"); navigateTo(returnUrl || "/");
} catch (error: any) { } catch (error: unknown) {
console.error(error); console.error(error);
const message = error instanceof Error ? error.message : "Failed to authenticate.";
setState({ setState({
loading: false, loading: false,
errorMessage: (error as ConnectError).message, errorMessage: message,
}); });
} }
})(); })();
}, [searchParams]); }, [searchParams, navigateTo]);
return ( return (
<div className="p-4 py-24 w-full h-full flex justify-center items-center"> <div className="p-4 py-24 w-full h-full flex justify-center items-center">
......
import { create } from "@bufbuild/protobuf"; import { create } from "@bufbuild/protobuf";
import { timestampDate } from "@bufbuild/protobuf/wkt"; import { timestampDate } from "@bufbuild/protobuf/wkt";
import { ConnectError } from "@connectrpc/connect";
import { LoaderIcon } from "lucide-react"; import { LoaderIcon } from "lucide-react";
import { observer } from "mobx-react-lite"; import { observer } from "mobx-react-lite";
import { useState } from "react"; import { useState } from "react";
...@@ -15,7 +14,7 @@ import useLoading from "@/hooks/useLoading"; ...@@ -15,7 +14,7 @@ import useLoading from "@/hooks/useLoading";
import useNavigateTo from "@/hooks/useNavigateTo"; import useNavigateTo from "@/hooks/useNavigateTo";
import { instanceStore } from "@/store"; import { instanceStore } from "@/store";
import { initialUserStore } from "@/store/user"; import { initialUserStore } from "@/store/user";
import { User, User_Role, UserSchema } from "@/types/proto/api/v1/user_service_pb"; import { User_Role, UserSchema } from "@/types/proto/api/v1/user_service_pb";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
const SignUp = observer(() => { const SignUp = observer(() => {
...@@ -70,9 +69,10 @@ const SignUp = observer(() => { ...@@ -70,9 +69,10 @@ const SignUp = observer(() => {
} }
await initialUserStore(); await initialUserStore();
navigateTo("/"); navigateTo("/");
} catch (error: any) { } catch (error: unknown) {
console.error(error); console.error(error);
toast.error((error as ConnectError).message || "Sign up failed"); const message = error instanceof Error ? error.message : "Sign up failed";
toast.error(message);
} }
actionBtnLoadingState.setFinish(); actionBtnLoadingState.setFinish();
}; };
...@@ -112,7 +112,7 @@ const SignUp = observer(() => { ...@@ -112,7 +112,7 @@ const SignUp = observer(() => {
readOnly={actionBtnLoadingState.isLoading} readOnly={actionBtnLoadingState.isLoading}
placeholder={t("common.password")} placeholder={t("common.password")}
value={password} value={password}
autoComplete="password" autoComplete="new-password"
autoCapitalize="off" autoCapitalize="off"
spellCheck={false} spellCheck={false}
onChange={handlePasswordInputChanged} onChange={handlePasswordInputChanged}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment